diff options
Diffstat (limited to 'third_party/highway/hwy')
121 files changed, 74016 insertions, 0 deletions
diff --git a/third_party/highway/hwy/aligned_allocator.cc b/third_party/highway/hwy/aligned_allocator.cc new file mode 100644 index 0000000000..7b9947970e --- /dev/null +++ b/third_party/highway/hwy/aligned_allocator.cc @@ -0,0 +1,152 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include <stdarg.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> // malloc + +#include <atomic> +#include <limits> + +#include "hwy/base.h" + +namespace hwy { +namespace { + +#if HWY_ARCH_RVV && defined(__riscv_vector) +// Not actually an upper bound on the size, but this value prevents crossing a +// 4K boundary (relevant on Andes). +constexpr size_t kAlignment = HWY_MAX(HWY_ALIGNMENT, 4096); +#else +constexpr size_t kAlignment = HWY_ALIGNMENT; +#endif + +#if HWY_ARCH_X86 +// On x86, aliasing can only occur at multiples of 2K, but that's too wasteful +// if this is used for single-vector allocations. 256 is more reasonable. +constexpr size_t kAlias = kAlignment * 4; +#else +constexpr size_t kAlias = kAlignment; +#endif + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t payload_size; +}; +#pragma pack(pop) + +// Returns a 'random' (cyclical) offset for AllocateAlignedBytes. +size_t NextAlignedOffset() { + static std::atomic<uint32_t> next{0}; + constexpr uint32_t kGroups = kAlias / kAlignment; + const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + const size_t offset = kAlignment * group; + HWY_DASSERT((offset % kAlignment == 0) && offset <= kAlias); + return offset; +} + +} // namespace + +HWY_DLLEXPORT void* AllocateAlignedBytes(const size_t payload_size, + AllocPtr alloc_ptr, void* opaque_ptr) { + HWY_ASSERT(payload_size != 0); // likely a bug in caller + if (payload_size >= std::numeric_limits<size_t>::max() / 2) { + HWY_DASSERT(false && "payload_size too large"); + return nullptr; + } + + size_t offset = NextAlignedOffset(); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + offset = kAlignment; // = RoundUpTo(sizeof(AllocationHeader), kAlignment) + static_assert(sizeof(AllocationHeader) <= kAlignment, "Else: round up"); + } + + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated; + if (alloc_ptr == nullptr) { + allocated = malloc(allocated_size); + } else { + allocated = (*alloc_ptr)(opaque_ptr, allocated_size); + } + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast<uintptr_t>(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); + + const uintptr_t payload = aligned + offset; // still aligned + + // Stash `allocated` and payload_size inside header for FreeAlignedBytes(). + // The allocated_size can be reconstructed from the payload_size. + AllocationHeader* header = reinterpret_cast<AllocationHeader*>(payload) - 1; + header->allocated = allocated; + header->payload_size = payload_size; + + return HWY_ASSUME_ALIGNED(reinterpret_cast<void*>(payload), kAlignment); +} + +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast<uintptr_t>(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast<const AllocationHeader*>(payload) - 1; + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +// static +HWY_DLLEXPORT void AlignedDeleter::DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast<uintptr_t>(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast<const AllocationHeader*>(payload) - 1; + + if (deleter) { + (*deleter)(aligned_pointer, header->payload_size); + } + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +} // namespace hwy diff --git a/third_party/highway/hwy/aligned_allocator.h b/third_party/highway/hwy/aligned_allocator.h new file mode 100644 index 0000000000..f6bfca11ee --- /dev/null +++ b/third_party/highway/hwy/aligned_allocator.h @@ -0,0 +1,212 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ +#define HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ + +// Memory allocator with support for alignment and offsets. + +#include <stddef.h> + +#include <memory> + +#include "hwy/highway_export.h" + +namespace hwy { + +// Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which +// requires a literal. This matches typical L1 cache line sizes, which prevents +// false sharing. +#define HWY_ALIGNMENT 64 + +// Pointers to functions equivalent to malloc/free with an opaque void* passed +// to them. +using AllocPtr = void* (*)(void* opaque, size_t bytes); +using FreePtr = void (*)(void* opaque, void* memory); + +// Returns null or a pointer to at least `payload_size` (which can be zero) +// bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and +// the vector size. Calls `alloc` with the passed `opaque` pointer to obtain +// memory or malloc() if it is null. +HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size, + AllocPtr alloc_ptr, void* opaque_ptr); + +// Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it +// must have been returned from a previous call to `AllocateAlignedBytes`. +// Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if +// `free_ptr` function is null, uses the default free(). +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr); + +// Class that deletes the aligned pointer passed to operator() calling the +// destructor before freeing the pointer. This is equivalent to the +// std::default_delete but for aligned objects. For a similar deleter equivalent +// to free() for aligned memory see AlignedFreer(). +class AlignedDeleter { + public: + AlignedDeleter() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedDeleter(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template <typename T> + void operator()(T* aligned_pointer) const { + return DeleteAlignedArray(aligned_pointer, free_, opaque_ptr_, + TypedArrayDeleter<T>); + } + + private: + template <typename T> + static void TypedArrayDeleter(void* ptr, size_t size_in_bytes) { + size_t elems = size_in_bytes / sizeof(T); + for (size_t i = 0; i < elems; i++) { + // Explicitly call the destructor on each element. + (static_cast<T*>(ptr) + i)->~T(); + } + } + + // Function prototype that calls the destructor for each element in a typed + // array. TypeArrayDeleter<T> would match this prototype. + using ArrayDeleter = void (*)(void* t_ptr, size_t t_size); + + HWY_DLLEXPORT static void DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter); + + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to T with custom aligned deleter. This can be a single +// element U or an array of element if T is a U[]. The custom aligned deleter +// will call the destructor on U or each element of a U[] in the array case. +template <typename T> +using AlignedUniquePtr = std::unique_ptr<T, AlignedDeleter>; + +// Aligned memory equivalent of make_unique<T> using the custom allocators +// alloc/free with the passed `opaque` pointer. This function calls the +// constructor with the passed Args... and calls the destructor of the object +// when the AlignedUniquePtr is destroyed. +template <typename T, typename... Args> +AlignedUniquePtr<T> MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free, + void* opaque, Args&&... args) { + T* ptr = static_cast<T*>(AllocateAlignedBytes(sizeof(T), alloc, opaque)); + return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...), + AlignedDeleter(free, opaque)); +} + +// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free +// functions. +template <typename T, typename... Args> +AlignedUniquePtr<T> MakeUniqueAligned(Args&&... args) { + T* ptr = static_cast<T*>(AllocateAlignedBytes( + sizeof(T), /*alloc_ptr=*/nullptr, /*opaque_ptr=*/nullptr)); + return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...), + AlignedDeleter()); +} + +// Helpers for array allocators (avoids overflow) +namespace detail { + +// Returns x such that 1u << x == n (if n is a power of two). +static inline constexpr size_t ShiftCount(size_t n) { + return (n <= 1) ? 0 : 1 + ShiftCount(n / 2); +} + +template <typename T> +T* AllocateAlignedItems(size_t items, AllocPtr alloc_ptr, void* opaque_ptr) { + constexpr size_t size = sizeof(T); + + constexpr bool is_pow2 = (size & (size - 1)) == 0; + constexpr size_t bits = ShiftCount(size); + static_assert(!is_pow2 || (1ull << bits) == size, "ShiftCount is incorrect"); + + const size_t bytes = is_pow2 ? items << bits : items * size; + const size_t check = is_pow2 ? bytes >> bits : bytes / size; + if (check != items) { + return nullptr; // overflowed + } + return static_cast<T*>(AllocateAlignedBytes(bytes, alloc_ptr, opaque_ptr)); +} + +} // namespace detail + +// Aligned memory equivalent of make_unique<T[]> for array types using the +// custom allocators alloc/free. This function calls the constructor with the +// passed Args... on every created item. The destructor of each element will be +// called when the AlignedUniquePtr is destroyed. +template <typename T, typename... Args> +AlignedUniquePtr<T[]> MakeUniqueAlignedArrayWithAlloc( + size_t items, AllocPtr alloc, FreePtr free, void* opaque, Args&&... args) { + T* ptr = detail::AllocateAlignedItems<T>(items, alloc, opaque); + if (ptr != nullptr) { + for (size_t i = 0; i < items; i++) { + new (ptr + i) T(std::forward<Args>(args)...); + } + } + return AlignedUniquePtr<T[]>(ptr, AlignedDeleter(free, opaque)); +} + +template <typename T, typename... Args> +AlignedUniquePtr<T[]> MakeUniqueAlignedArray(size_t items, Args&&... args) { + return MakeUniqueAlignedArrayWithAlloc<T, Args...>( + items, nullptr, nullptr, nullptr, std::forward<Args>(args)...); +} + +// Custom deleter for std::unique_ptr equivalent to using free() as a deleter +// but for aligned memory. +class AlignedFreer { + public: + // Pass address of this to ctor to skip deleting externally-owned memory. + static void DoNothing(void* /*opaque*/, void* /*aligned_pointer*/) {} + + AlignedFreer() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedFreer(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template <typename T> + void operator()(T* aligned_pointer) const { + // TODO(deymo): assert that we are using a POD type T. + FreeAlignedBytes(aligned_pointer, free_, opaque_ptr_); + } + + private: + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to single POD, or (if T is U[]) an array of POD. For non POD +// data use AlignedUniquePtr. +template <typename T> +using AlignedFreeUniquePtr = std::unique_ptr<T, AlignedFreer>; + +// Allocate an aligned and uninitialized array of POD values as a unique_ptr. +// Upon destruction of the unique_ptr the aligned array will be freed. +template <typename T> +AlignedFreeUniquePtr<T[]> AllocateAligned(const size_t items, AllocPtr alloc, + FreePtr free, void* opaque) { + return AlignedFreeUniquePtr<T[]>( + detail::AllocateAlignedItems<T>(items, alloc, opaque), + AlignedFreer(free, opaque)); +} + +// Same as previous AllocateAligned(), using default allocate/free functions. +template <typename T> +AlignedFreeUniquePtr<T[]> AllocateAligned(const size_t items) { + return AllocateAligned<T>(items, nullptr, nullptr, nullptr); +} + +} // namespace hwy +#endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ diff --git a/third_party/highway/hwy/aligned_allocator_test.cc b/third_party/highway/hwy/aligned_allocator_test.cc new file mode 100644 index 0000000000..e8948b4e9b --- /dev/null +++ b/third_party/highway/hwy/aligned_allocator_test.cc @@ -0,0 +1,278 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include <stddef.h> + +#include <array> +#include <new> +#include <random> +#include <vector> + +#include "gtest/gtest.h" + +namespace { + +// Sample object that keeps track on an external counter of how many times was +// the explicit constructor and destructor called. +template <size_t N> +class SampleObject { + public: + SampleObject() { data_[0] = 'a'; } + explicit SampleObject(int* counter) : counter_(counter) { + if (counter) (*counter)++; + data_[0] = 'b'; + } + + ~SampleObject() { + if (counter_) (*counter_)--; + } + + static_assert(N > sizeof(int*), "SampleObject size too small."); + int* counter_ = nullptr; + char data_[N - sizeof(int*)]; +}; + +class FakeAllocator { + public: + // static AllocPtr and FreePtr member to be used with the aligned + // allocator. These functions calls the private non-static members. + static void* StaticAlloc(void* opaque, size_t bytes) { + return reinterpret_cast<FakeAllocator*>(opaque)->Alloc(bytes); + } + static void StaticFree(void* opaque, void* memory) { + return reinterpret_cast<FakeAllocator*>(opaque)->Free(memory); + } + + // Returns the number of pending allocations to be freed. + size_t PendingAllocs() { return allocs_.size(); } + + private: + void* Alloc(size_t bytes) { + void* ret = malloc(bytes); + allocs_.insert(ret); + return ret; + } + void Free(void* memory) { + if (!memory) return; + EXPECT_NE(allocs_.end(), allocs_.find(memory)); + allocs_.erase(memory); + free(memory); + } + + std::set<void*> allocs_; +}; + +} // namespace + +namespace hwy { + +class AlignedAllocatorTest : public testing::Test {}; + +TEST(AlignedAllocatorTest, FreeNullptr) { + // Calling free with a nullptr is always ok. + FreeAlignedBytes(/*aligned_pointer=*/nullptr, /*free_ptr=*/nullptr, + /*opaque_ptr=*/nullptr); +} + +TEST(AlignedAllocatorTest, Log2) { + EXPECT_EQ(0u, detail::ShiftCount(1)); + EXPECT_EQ(1u, detail::ShiftCount(2)); + EXPECT_EQ(3u, detail::ShiftCount(8)); +} + +// Allocator returns null when it detects overflow of items * sizeof(T). +TEST(AlignedAllocatorTest, Overflow) { + constexpr size_t max = ~size_t(0); + constexpr size_t msb = (max >> 1) + 1; + using Size5 = std::array<uint8_t, 5>; + using Size10 = std::array<uint8_t, 10>; + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems<uint32_t>(max / 2, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems<uint32_t>(max / 3, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems<Size5>(max / 4, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems<uint16_t>(msb, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems<double>(msb + 1, nullptr, nullptr)); + EXPECT_EQ(nullptr, + detail::AllocateAlignedItems<Size10>(msb / 4, nullptr, nullptr)); +} + +TEST(AlignedAllocatorTest, AllocDefaultPointers) { + const size_t kSize = 7777; + void* ptr = AllocateAlignedBytes(kSize, /*alloc_ptr=*/nullptr, + /*opaque_ptr=*/nullptr); + ASSERT_NE(nullptr, ptr); + // Make sure the pointer is actually aligned. + EXPECT_EQ(0U, reinterpret_cast<uintptr_t>(ptr) % HWY_ALIGNMENT); + char* p = static_cast<char*>(ptr); + size_t ret = 0; + for (size_t i = 0; i < kSize; i++) { + // Performs a computation using p[] to prevent it being optimized away. + p[i] = static_cast<char>(i & 0x7F); + if (i) ret += static_cast<size_t>(p[i] * p[i - 1]); + } + EXPECT_NE(0U, ret); + FreeAlignedBytes(ptr, /*free_ptr=*/nullptr, /*opaque_ptr=*/nullptr); +} + +TEST(AlignedAllocatorTest, EmptyAlignedUniquePtr) { + AlignedUniquePtr<SampleObject<32>> ptr(nullptr, AlignedDeleter()); + AlignedUniquePtr<SampleObject<32>[]> arr(nullptr, AlignedDeleter()); +} + +TEST(AlignedAllocatorTest, EmptyAlignedFreeUniquePtr) { + AlignedFreeUniquePtr<SampleObject<32>> ptr(nullptr, AlignedFreer()); + AlignedFreeUniquePtr<SampleObject<32>[]> arr(nullptr, AlignedFreer()); +} + +TEST(AlignedAllocatorTest, CustomAlloc) { + FakeAllocator fake_alloc; + + const size_t kSize = 7777; + void* ptr = + AllocateAlignedBytes(kSize, &FakeAllocator::StaticAlloc, &fake_alloc); + ASSERT_NE(nullptr, ptr); + // We should have only requested one alloc from the allocator. + EXPECT_EQ(1U, fake_alloc.PendingAllocs()); + // Make sure the pointer is actually aligned. + EXPECT_EQ(0U, reinterpret_cast<uintptr_t>(ptr) % HWY_ALIGNMENT); + FreeAlignedBytes(ptr, &FakeAllocator::StaticFree, &fake_alloc); + EXPECT_EQ(0U, fake_alloc.PendingAllocs()); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedDefaultConstructor) { + { + auto ptr = MakeUniqueAligned<SampleObject<24>>(); + // Default constructor sets the data_[0] to 'a'. + EXPECT_EQ('a', ptr->data_[0]); + EXPECT_EQ(nullptr, ptr->counter_); + } +} + +TEST(AlignedAllocatorTest, MakeUniqueAligned) { + int counter = 0; + { + // Creates the object, initializes it with the explicit constructor and + // returns an unique_ptr to it. + auto ptr = MakeUniqueAligned<SampleObject<24>>(&counter); + EXPECT_EQ(1, counter); + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', ptr->data_[0]); + } + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedArray) { + int counter = 0; + { + // Creates the array of objects and initializes them with the explicit + // constructor. + auto arr = MakeUniqueAlignedArray<SampleObject<24>>(7, &counter); + EXPECT_EQ(7, counter); + for (size_t i = 0; i < 7; i++) { + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', arr[i].data_[0]) << "Where i = " << i; + } + } + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, AllocSingleInt) { + auto ptr = AllocateAligned<uint32_t>(1); + ASSERT_NE(nullptr, ptr.get()); + EXPECT_EQ(0U, reinterpret_cast<uintptr_t>(ptr.get()) % HWY_ALIGNMENT); + // Force delete of the unique_ptr now to check that it doesn't crash. + ptr.reset(nullptr); + EXPECT_EQ(nullptr, ptr.get()); +} + +TEST(AlignedAllocatorTest, AllocMultipleInt) { + const size_t kSize = 7777; + auto ptr = AllocateAligned<uint32_t>(kSize); + ASSERT_NE(nullptr, ptr.get()); + EXPECT_EQ(0U, reinterpret_cast<uintptr_t>(ptr.get()) % HWY_ALIGNMENT); + // ptr[i] is actually (*ptr.get())[i] which will use the operator[] of the + // underlying type chosen by AllocateAligned() for the std::unique_ptr. + EXPECT_EQ(&(ptr[0]) + 1, &(ptr[1])); + + size_t ret = 0; + for (size_t i = 0; i < kSize; i++) { + // Performs a computation using ptr[] to prevent it being optimized away. + ptr[i] = static_cast<uint32_t>(i); + if (i) ret += ptr[i] * ptr[i - 1]; + } + EXPECT_NE(0U, ret); +} + +TEST(AlignedAllocatorTest, AllocateAlignedObjectWithoutDestructor) { + int counter = 0; + { + // This doesn't call the constructor. + auto obj = AllocateAligned<SampleObject<24>>(1); + obj[0].counter_ = &counter; + } + // Destroying the unique_ptr shouldn't have called the destructor of the + // SampleObject<24>. + EXPECT_EQ(0, counter); +} + +TEST(AlignedAllocatorTest, MakeUniqueAlignedArrayWithCustomAlloc) { + FakeAllocator fake_alloc; + int counter = 0; + { + // Creates the array of objects and initializes them with the explicit + // constructor. + auto arr = MakeUniqueAlignedArrayWithAlloc<SampleObject<24>>( + 7, FakeAllocator::StaticAlloc, FakeAllocator::StaticFree, &fake_alloc, + &counter); + ASSERT_NE(nullptr, arr.get()); + // An array should still only call a single allocation. + EXPECT_EQ(1u, fake_alloc.PendingAllocs()); + EXPECT_EQ(7, counter); + for (size_t i = 0; i < 7; i++) { + // Custom constructor sets the data_[0] to 'b'. + EXPECT_EQ('b', arr[i].data_[0]) << "Where i = " << i; + } + } + EXPECT_EQ(0, counter); + EXPECT_EQ(0u, fake_alloc.PendingAllocs()); +} + +TEST(AlignedAllocatorTest, DefaultInit) { + // The test is whether this compiles. Default-init is useful for output params + // and per-thread storage. + std::vector<AlignedUniquePtr<int[]>> ptrs; + std::vector<AlignedFreeUniquePtr<double[]>> free_ptrs; + ptrs.resize(128); + free_ptrs.resize(128); + // The following is to prevent elision of the pointers. + std::mt19937 rng(129); // Emscripten lacks random_device. + std::uniform_int_distribution<size_t> dist(0, 127); + ptrs[dist(rng)] = MakeUniqueAlignedArray<int>(123); + free_ptrs[dist(rng)] = AllocateAligned<double>(456); + // "Use" pointer without resorting to printf. 0 == 0. Can't shift by 64. + const auto addr1 = reinterpret_cast<uintptr_t>(ptrs[dist(rng)].get()); + const auto addr2 = reinterpret_cast<uintptr_t>(free_ptrs[dist(rng)].get()); + constexpr size_t kBits = sizeof(uintptr_t) * 8; + EXPECT_EQ((addr1 >> (kBits - 1)) >> (kBits - 1), + (addr2 >> (kBits - 1)) >> (kBits - 1)); +} + +} // namespace hwy diff --git a/third_party/highway/hwy/base.h b/third_party/highway/hwy/base.h new file mode 100644 index 0000000000..3075856cb7 --- /dev/null +++ b/third_party/highway/hwy/base.h @@ -0,0 +1,996 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BASE_H_ +#define HIGHWAY_HWY_BASE_H_ + +// For SIMD module implementations and their callers, target-independent. + +#include <stddef.h> +#include <stdint.h> + +#include "hwy/detect_compiler_arch.h" +#include "hwy/highway_export.h" + +#if HWY_COMPILER_MSVC +#include <string.h> // memcpy +#endif +#if HWY_ARCH_X86 +#include <atomic> +#endif + +//------------------------------------------------------------------------------ +// Compiler-specific definitions + +#define HWY_STR_IMPL(macro) #macro +#define HWY_STR(macro) HWY_STR_IMPL(macro) + +#if HWY_COMPILER_MSVC + +#include <intrin.h> + +#define HWY_RESTRICT __restrict +#define HWY_INLINE __forceinline +#define HWY_NOINLINE __declspec(noinline) +#define HWY_FLATTEN +#define HWY_NORETURN __declspec(noreturn) +#define HWY_LIKELY(expr) (expr) +#define HWY_UNLIKELY(expr) (expr) +#define HWY_PRAGMA(tokens) __pragma(tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(warning(tokens)) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(msc) +#define HWY_MAYBE_UNUSED +#define HWY_HAS_ASSUME_ALIGNED 0 +#if (_MSC_VER >= 1700) +#define HWY_MUST_USE_RESULT _Check_return_ +#else +#define HWY_MUST_USE_RESULT +#endif + +#else + +#define HWY_RESTRICT __restrict__ +// force inlining without optimization enabled creates very inefficient code +// that can cause compiler timeout +#ifdef __OPTIMIZE__ +#define HWY_INLINE inline __attribute__((always_inline)) +#else +#define HWY_INLINE inline +#endif +#define HWY_NOINLINE __attribute__((noinline)) +#define HWY_FLATTEN __attribute__((flatten)) +#define HWY_NORETURN __attribute__((noreturn)) +#define HWY_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define HWY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#define HWY_PRAGMA(tokens) _Pragma(#tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(GCC diagnostic tokens) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(gcc) +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define HWY_MAYBE_UNUSED __attribute__((unused)) +#define HWY_MUST_USE_RESULT __attribute__((warn_unused_result)) + +#endif // !HWY_COMPILER_MSVC + +//------------------------------------------------------------------------------ +// Builtin/attributes + +// Enables error-checking of format strings. +#if HWY_HAS_ATTRIBUTE(__format__) +#define HWY_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define HWY_FORMAT(idx_fmt, idx_arg) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* HWY_RESTRICT aligned = (float*)HWY_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if HWY_HAS_BUILTIN(__builtin_assume_aligned) +#define HWY_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +// Clang and GCC require attributes on each function into which SIMD intrinsics +// are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and +// automatic annotation via pragmas. +#if HWY_COMPILER_CLANG +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(clang attribute push(__attribute__((target(targets_str))), \ + apply_to = function)) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(clang attribute pop) +#elif HWY_COMPILER_GCC +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(GCC push_options) HWY_PRAGMA(GCC target targets_str) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(GCC pop_options) +#else +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#endif + +//------------------------------------------------------------------------------ +// Macros + +#define HWY_API static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED + +#define HWY_CONCAT_IMPL(a, b) a##b +#define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b) + +#define HWY_MIN(a, b) ((a) < (b) ? (a) : (b)) +#define HWY_MAX(a, b) ((a) > (b) ? (a) : (b)) + +#if HWY_COMPILER_GCC_ACTUAL +// nielskm: GCC does not support '#pragma GCC unroll' without the factor. +#define HWY_UNROLL(factor) HWY_PRAGMA(GCC unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL(4) +#elif HWY_COMPILER_CLANG || HWY_COMPILER_ICC || HWY_COMPILER_ICX +#define HWY_UNROLL(factor) HWY_PRAGMA(unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL() +#else +#define HWY_UNROLL(factor) +#define HWY_DEFAULT_UNROLL +#endif + +// Tell a compiler that the expression always evaluates to true. +// The expression should be free from any side effects. +// Some older compilers may have trouble with complex expressions, therefore +// it is advisable to split multiple conditions into separate assume statements, +// and manually check the generated code. +// OK but could fail: +// HWY_ASSUME(x == 2 && y == 3); +// Better: +// HWY_ASSUME(x == 2); +// HWY_ASSUME(y == 3); +#if defined(__has_cpp_attribute) && __has_cpp_attribute(assume) +#define HWY_ASSUME(expr) [[assume(expr)]] +#elif HWY_COMPILER_MSVC || HWY_COMPILER_ICC +#define HWY_ASSUME(expr) __assume(expr) +// __builtin_assume() was added in clang 3.6. +#elif HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_assume) +#define HWY_ASSUME(expr) __builtin_assume(expr) +// __builtin_unreachable() was added in GCC 4.5, but __has_builtin() was added +// later, so check for the compiler version directly. +#elif HWY_COMPILER_GCC_ACTUAL >= 405 +#define HWY_ASSUME(expr) \ + ((expr) ? static_cast<void>(0) : __builtin_unreachable()) +#else +#define HWY_ASSUME(expr) static_cast<void>(0) +#endif + +// Compile-time fence to prevent undesirable code reordering. On Clang x86, the +// typical asm volatile("" : : : "memory") has no effect, whereas atomic fence +// does, without generating code. +#if HWY_ARCH_X86 +#define HWY_FENCE std::atomic_thread_fence(std::memory_order_acq_rel) +#else +// TODO(janwas): investigate alternatives. On ARM, the above generates barriers. +#define HWY_FENCE +#endif + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define HWY_REP4(literal) literal, literal, literal, literal + +#define HWY_ABORT(format, ...) \ + ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) + +// Always enabled. +#define HWY_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + HWY_ABORT("Assert %s", #condition); \ + } \ + } while (0) + +#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) +#define HWY_IS_MSAN 1 +#else +#define HWY_IS_MSAN 0 +#endif + +#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) +#define HWY_IS_ASAN 1 +#else +#define HWY_IS_ASAN 0 +#endif + +#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) +#define HWY_IS_TSAN 1 +#else +#define HWY_IS_TSAN 0 +#endif + +// MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo. +// You can disable MSAN by adding this attribute to the function that fails. +#if HWY_IS_MSAN +#define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory)) +#else +#define HWY_ATTR_NO_MSAN +#endif + +// For enabling HWY_DASSERT and shortening tests in slower debug builds +#if !defined(HWY_IS_DEBUG_BUILD) +// Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent +// MSVC defines NDEBUG (if not, could instead check _DEBUG). +#if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || HWY_IS_ASAN || \ + HWY_IS_MSAN || HWY_IS_TSAN || defined(__clang_analyzer__) +#define HWY_IS_DEBUG_BUILD 1 +#else +#define HWY_IS_DEBUG_BUILD 0 +#endif +#endif // HWY_IS_DEBUG_BUILD + +#if HWY_IS_DEBUG_BUILD +#define HWY_DASSERT(condition) HWY_ASSERT(condition) +#else +#define HWY_DASSERT(condition) \ + do { \ + } while (0) +#endif + +namespace hwy { + +//------------------------------------------------------------------------------ +// kMaxVectorSize (undocumented, pending removal) + +#if HWY_ARCH_X86 +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 +#elif HWY_ARCH_RVV && defined(__riscv_vector) +// Not actually an upper bound on the size. +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; +#else +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; +#endif + +//------------------------------------------------------------------------------ +// Alignment + +// Potentially useful for LoadDup128 and capped vectors. In other cases, arrays +// should be allocated dynamically via aligned_allocator.h because Lanes() may +// exceed the stack size. +#if HWY_ARCH_X86 +#define HWY_ALIGN_MAX alignas(64) +#elif HWY_ARCH_RVV && defined(__riscv_vector) +#define HWY_ALIGN_MAX alignas(8) // only elements need be aligned +#else +#define HWY_ALIGN_MAX alignas(16) +#endif + +//------------------------------------------------------------------------------ +// Lane types + +// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name +// by concatenating base type and bits. + +#pragma pack(push, 1) + +// ACLE (https://gcc.gnu.org/onlinedocs/gcc/Half-Precision.html): +// always supported on aarch64, for v7 only if -mfp16-format is given. +#if ((HWY_ARCH_ARM_A64 || (__ARM_FP & 2)) && HWY_COMPILER_GCC) +using float16_t = __fp16; +// C11 extension ISO/IEC TS 18661-3:2015 but not supported on all targets. +// Required for Clang RVV if the float16 extension is used. +#elif HWY_ARCH_RVV && HWY_COMPILER_CLANG && defined(__riscv_zvfh) +using float16_t = _Float16; +// Otherwise emulate +#else +struct float16_t { + uint16_t bits; +}; +#endif + +struct bfloat16_t { + uint16_t bits; +}; + +#pragma pack(pop) + +using float32_t = float; +using float64_t = double; + +#pragma pack(push, 1) + +// Aligned 128-bit type. Cannot use __int128 because clang doesn't yet align it: +// https://reviews.llvm.org/D86310 +struct alignas(16) uint128_t { + uint64_t lo; // little-endian layout + uint64_t hi; +}; + +// 64 bit key plus 64 bit value. Faster than using uint128_t when only the key +// field is to be compared (Lt128Upper instead of Lt128). +struct alignas(16) K64V64 { + uint64_t value; // little-endian layout + uint64_t key; +}; + +// 32 bit key plus 32 bit value. Allows vqsort recursions to terminate earlier +// than when considering both to be a 64-bit key. +struct alignas(8) K32V32 { + uint32_t value; // little-endian layout + uint32_t key; +}; + +#pragma pack(pop) + +static inline HWY_MAYBE_UNUSED bool operator<(const uint128_t& a, + const uint128_t& b) { + return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const uint128_t& a, + const uint128_t& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const uint128_t& a, + const uint128_t& b) { + return a.lo == b.lo && a.hi == b.hi; +} + +static inline HWY_MAYBE_UNUSED bool operator<(const K64V64& a, + const K64V64& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K64V64& a, + const K64V64& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const K64V64& a, + const K64V64& b) { + return a.key == b.key; +} + +static inline HWY_MAYBE_UNUSED bool operator<(const K32V32& a, + const K32V32& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K32V32& a, + const K32V32& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const K32V32& a, + const K32V32& b) { + return a.key == b.key; +} + +//------------------------------------------------------------------------------ +// Controlling overload resolution (SFINAE) + +template <bool Condition> +struct EnableIfT {}; +template <> +struct EnableIfT<true> { + using type = void; +}; + +template <bool Condition> +using EnableIf = typename EnableIfT<Condition>::type; + +template <typename T, typename U> +struct IsSameT { + enum { value = 0 }; +}; + +template <typename T> +struct IsSameT<T, T> { + enum { value = 1 }; +}; + +template <typename T, typename U> +HWY_API constexpr bool IsSame() { + return IsSameT<T, U>::value; +} + +// Insert into template/function arguments to enable this overload only for +// vectors of AT MOST this many bits. +// +// Note that enabling for exactly 128 bits is unnecessary because a function can +// simply be overloaded with Vec128<T> and/or Full128<T> tag. Enabling for other +// sizes (e.g. 64 bit) can be achieved via Simd<T, 8 / sizeof(T), 0>. +#define HWY_IF_LE128(T, N) hwy::EnableIf<N * sizeof(T) <= 16>* = nullptr +#define HWY_IF_LE64(T, N) hwy::EnableIf<N * sizeof(T) <= 8>* = nullptr +#define HWY_IF_LE32(T, N) hwy::EnableIf<N * sizeof(T) <= 4>* = nullptr +#define HWY_IF_GE32(T, N) hwy::EnableIf<N * sizeof(T) >= 4>* = nullptr +#define HWY_IF_GE64(T, N) hwy::EnableIf<N * sizeof(T) >= 8>* = nullptr +#define HWY_IF_GE128(T, N) hwy::EnableIf<N * sizeof(T) >= 16>* = nullptr +#define HWY_IF_GT128(T, N) hwy::EnableIf<(N * sizeof(T) > 16)>* = nullptr + +#define HWY_IF_UNSIGNED(T) hwy::EnableIf<!IsSigned<T>()>* = nullptr +#define HWY_IF_SIGNED(T) \ + hwy::EnableIf<IsSigned<T>() && !IsFloat<T>()>* = nullptr +#define HWY_IF_FLOAT(T) hwy::EnableIf<hwy::IsFloat<T>()>* = nullptr +#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf<!hwy::IsFloat<T>()>* = nullptr + +#define HWY_IF_LANE_SIZE(T, bytes) \ + hwy::EnableIf<sizeof(T) == (bytes)>* = nullptr +#define HWY_IF_NOT_LANE_SIZE(T, bytes) \ + hwy::EnableIf<sizeof(T) != (bytes)>* = nullptr +// bit_array = 0x102 means 1 or 8 bytes. There is no NONE_OF because it sounds +// too similar. If you want the opposite of this (2 or 4 bytes), ask for those +// bits explicitly (0x14) instead of attempting to 'negate' 0x102. +#define HWY_IF_LANE_SIZE_ONE_OF(T, bit_array) \ + hwy::EnableIf<((size_t{1} << sizeof(T)) & (bit_array)) != 0>* = nullptr + +#define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ + hwy::EnableIf<HWY_MIN(sizeof(T) * N, 16) / sizeof(T) == (LANES)>* = nullptr + +// Empty struct used as a size tag type. +template <size_t N> +struct SizeTag {}; + +template <class T> +struct RemoveConstT { + using type = T; +}; +template <class T> +struct RemoveConstT<const T> { + using type = T; +}; + +template <class T> +using RemoveConst = typename RemoveConstT<T>::type; + +//------------------------------------------------------------------------------ +// Type relations + +namespace detail { + +template <typename T> +struct Relations; +template <> +struct Relations<uint8_t> { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = uint16_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations<int8_t> { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = int16_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations<uint16_t> { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = uint32_t; + using Narrow = uint8_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations<int16_t> { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = int32_t; + using Narrow = int8_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations<uint32_t> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = uint64_t; + using Narrow = uint16_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations<int32_t> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = int64_t; + using Narrow = int16_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations<uint64_t> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Wide = uint128_t; + using Narrow = uint32_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations<int64_t> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = int32_t; + enum { is_signed = 1, is_float = 0 }; +}; +template <> +struct Relations<uint128_t> { + using Unsigned = uint128_t; + using Narrow = uint64_t; + enum { is_signed = 0, is_float = 0 }; +}; +template <> +struct Relations<float16_t> { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1 }; +}; +template <> +struct Relations<bfloat16_t> { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1 }; +}; +template <> +struct Relations<float> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = double; + using Narrow = float16_t; + enum { is_signed = 1, is_float = 1 }; +}; +template <> +struct Relations<double> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = float; + enum { is_signed = 1, is_float = 1 }; +}; + +template <size_t N> +struct TypeFromSize; +template <> +struct TypeFromSize<1> { + using Unsigned = uint8_t; + using Signed = int8_t; +}; +template <> +struct TypeFromSize<2> { + using Unsigned = uint16_t; + using Signed = int16_t; +}; +template <> +struct TypeFromSize<4> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; +}; +template <> +struct TypeFromSize<8> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; +}; +template <> +struct TypeFromSize<16> { + using Unsigned = uint128_t; +}; + +} // namespace detail + +// Aliases for types of a different category, but the same size. +template <typename T> +using MakeUnsigned = typename detail::Relations<T>::Unsigned; +template <typename T> +using MakeSigned = typename detail::Relations<T>::Signed; +template <typename T> +using MakeFloat = typename detail::Relations<T>::Float; + +// Aliases for types of the same category, but different size. +template <typename T> +using MakeWide = typename detail::Relations<T>::Wide; +template <typename T> +using MakeNarrow = typename detail::Relations<T>::Narrow; + +// Obtain type from its size [bytes]. +template <size_t N> +using UnsignedFromSize = typename detail::TypeFromSize<N>::Unsigned; +template <size_t N> +using SignedFromSize = typename detail::TypeFromSize<N>::Signed; +template <size_t N> +using FloatFromSize = typename detail::TypeFromSize<N>::Float; + +// Avoid confusion with SizeTag where the parameter is a lane size. +using UnsignedTag = SizeTag<0>; +using SignedTag = SizeTag<0x100>; // integer +using FloatTag = SizeTag<0x200>; + +template <typename T, class R = detail::Relations<T>> +constexpr auto TypeTag() -> hwy::SizeTag<((R::is_signed + R::is_float) << 8)> { + return hwy::SizeTag<((R::is_signed + R::is_float) << 8)>(); +} + +// For when we only want to distinguish FloatTag from everything else. +using NonFloatTag = SizeTag<0x400>; + +template <typename T, class R = detail::Relations<T>> +constexpr auto IsFloatTag() -> hwy::SizeTag<(R::is_float ? 0x200 : 0x400)> { + return hwy::SizeTag<(R::is_float ? 0x200 : 0x400)>(); +} + +//------------------------------------------------------------------------------ +// Type traits + +template <typename T> +HWY_API constexpr bool IsFloat() { + // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or + // from a float, not compared. + return IsSame<T, float>() || IsSame<T, double>(); +} + +template <typename T> +HWY_API constexpr bool IsSigned() { + return T(0) > T(-1); +} +template <> +constexpr bool IsSigned<float16_t>() { + return true; +} +template <> +constexpr bool IsSigned<bfloat16_t>() { + return true; +} + +// Largest/smallest representable integer values. +template <typename T> +HWY_API constexpr T LimitsMax() { + static_assert(!IsFloat<T>(), "Only for integer types"); + using TU = MakeUnsigned<T>; + return static_cast<T>(IsSigned<T>() ? (static_cast<TU>(~0ull) >> 1) + : static_cast<TU>(~0ull)); +} +template <typename T> +HWY_API constexpr T LimitsMin() { + static_assert(!IsFloat<T>(), "Only for integer types"); + return IsSigned<T>() ? T(-1) - LimitsMax<T>() : T(0); +} + +// Largest/smallest representable value (integer or float). This naming avoids +// confusion with numeric_limits<float>::min() (the smallest positive value). +template <typename T> +HWY_API constexpr T LowestValue() { + return LimitsMin<T>(); +} +template <> +constexpr float LowestValue<float>() { + return -3.402823466e+38F; +} +template <> +constexpr double LowestValue<double>() { + return -1.7976931348623158e+308; +} + +template <typename T> +HWY_API constexpr T HighestValue() { + return LimitsMax<T>(); +} +template <> +constexpr float HighestValue<float>() { + return 3.402823466e+38F; +} +template <> +constexpr double HighestValue<double>() { + return 1.7976931348623158e+308; +} + +// Difference between 1.0 and the next representable value. +template <typename T> +HWY_API constexpr T Epsilon() { + return 1; +} +template <> +constexpr float Epsilon<float>() { + return 1.192092896e-7f; +} +template <> +constexpr double Epsilon<double>() { + return 2.2204460492503131e-16; +} + +// Returns width in bits of the mantissa field in IEEE binary32/64. +template <typename T> +constexpr int MantissaBits() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr int MantissaBits<float>() { + return 23; +} +template <> +constexpr int MantissaBits<double>() { + return 52; +} + +// Returns the (left-shifted by one bit) IEEE binary32/64 representation with +// the largest possible (biased) exponent field. Used by IsInf. +template <typename T> +constexpr MakeSigned<T> MaxExponentTimes2() { + return -(MakeSigned<T>{1} << (MantissaBits<T>() + 1)); +} + +// Returns bitmask of the sign bit in IEEE binary32/64. +template <typename T> +constexpr MakeUnsigned<T> SignMask() { + return MakeUnsigned<T>{1} << (sizeof(T) * 8 - 1); +} + +// Returns bitmask of the exponent field in IEEE binary32/64. +template <typename T> +constexpr MakeUnsigned<T> ExponentMask() { + return (~(MakeUnsigned<T>{1} << MantissaBits<T>()) + 1) & ~SignMask<T>(); +} + +// Returns bitmask of the mantissa field in IEEE binary32/64. +template <typename T> +constexpr MakeUnsigned<T> MantissaMask() { + return (MakeUnsigned<T>{1} << MantissaBits<T>()) - 1; +} + +// Returns 1 << mantissa_bits as a floating-point number. All integers whose +// absolute value are less than this can be represented exactly. +template <typename T> +constexpr T MantissaEnd() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr float MantissaEnd<float>() { + return 8388608.0f; // 1 << 23 +} +template <> +constexpr double MantissaEnd<double>() { + // floating point literal with p52 requires C++17. + return 4503599627370496.0; // 1 << 52 +} + +// Returns width in bits of the exponent field in IEEE binary32/64. +template <typename T> +constexpr int ExponentBits() { + // Exponent := remaining bits after deducting sign and mantissa. + return 8 * sizeof(T) - 1 - MantissaBits<T>(); +} + +// Returns largest value of the biased exponent field in IEEE binary32/64, +// right-shifted so that the LSB is bit zero. Example: 0xFF for float. +// This is expressed as a signed integer for more efficient comparison. +template <typename T> +constexpr MakeSigned<T> MaxExponentField() { + return (MakeSigned<T>{1} << ExponentBits<T>()) - 1; +} + +//------------------------------------------------------------------------------ +// Helper functions + +template <typename T1, typename T2> +constexpr inline T1 DivCeil(T1 a, T2 b) { + return (a + b - 1) / b; +} + +// Works for any `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanForward(&index, x); + return index; +#else // HWY_COMPILER_MSVC + return static_cast<size_t>(__builtin_ctz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanForward64(&index, x); + return index; +#else // HWY_ARCH_X86_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast<uint32_t>(x & 0xFFFFFFFF); + unsigned long index; // NOLINT + if (lsb == 0) { + uint32_t msb = static_cast<uint32_t>(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast<size_t>(__builtin_ctzll(x)); +#endif // HWY_COMPILER_MSVC +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanReverse(&index, x); + return 31 - index; +#else // HWY_COMPILER_MSVC + return static_cast<size_t>(__builtin_clz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanReverse64(&index, x); + return 63 - index; +#else // HWY_ARCH_X86_64 + // _BitScanReverse64 not available + const uint32_t msb = static_cast<uint32_t>(x >> 32u); + unsigned long index; // NOLINT + if (msb == 0) { + const uint32_t lsb = static_cast<uint32_t>(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast<size_t>(__builtin_clzll(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t PopCount(uint64_t x) { +#if HWY_COMPILER_GCC // includes clang + return static_cast<size_t>(__builtin_popcountll(x)); + // This instruction has a separate feature flag, but is often called from + // non-SIMD code, so we don't want to require dynamic dispatch. It was first + // supported by Intel in Nehalem (SSE4.2), but MSVC only predefines a macro + // for AVX, so check for that. +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 && defined(__AVX__) + return _mm_popcnt_u64(x); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return _mm_popcnt_u32(static_cast<uint32_t>(x & 0xFFFFFFFFu)) + + _mm_popcnt_u32(static_cast<uint32_t>(x >> 32)); +#else + x -= ((x >> 1) & 0x5555555555555555ULL); + x = (((x >> 2) & 0x3333333333333333ULL) + (x & 0x3333333333333333ULL)); + x = (((x >> 4) + x) & 0x0F0F0F0F0F0F0F0FULL); + x += (x >> 8); + x += (x >> 16); + x += (x >> 32); + return static_cast<size_t>(x & 0x7Fu); +#endif +} + +// Skip HWY_API due to GCC "function not considered for inlining". Previously +// such errors were caused by underlying type mismatches, but it's not clear +// what is still mismatched despite all the casts. +template <typename TI> +/*HWY_API*/ constexpr size_t FloorLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast<size_t>(FloorLog2(static_cast<TI>(x >> 1)) + 1); +} + +template <typename TI> +/*HWY_API*/ constexpr size_t CeilLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast<size_t>(FloorLog2(static_cast<TI>(x - 1)) + 1); +} + +template <typename T> +HWY_INLINE constexpr T AddWithWraparound(hwy::FloatTag /*tag*/, T t, size_t n) { + return t + static_cast<T>(n); +} + +template <typename T> +HWY_INLINE constexpr T AddWithWraparound(hwy::NonFloatTag /*tag*/, T t, + size_t n) { + using TU = MakeUnsigned<T>; + return static_cast<T>( + static_cast<TU>(static_cast<TU>(t) + static_cast<TU>(n)) & + hwy::LimitsMax<TU>()); +} + +#if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 +#pragma intrinsic(_umul128) +#endif + +// 64 x 64 = 128 bit multiplication +HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __uint128_t product = (__uint128_t)a * (__uint128_t)b; + *upper = (uint64_t)(product >> 64); + return (uint64_t)(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _umul128(a, b, upper); +#else + constexpr uint64_t kLo32 = 0xFFFFFFFFU; + const uint64_t lo_lo = (a & kLo32) * (b & kLo32); + const uint64_t hi_lo = (a >> 32) * (b & kLo32); + const uint64_t lo_hi = (a & kLo32) * (b >> 32); + const uint64_t hi_hi = (a >> 32) * (b >> 32); + const uint64_t t = (lo_lo >> 32) + (hi_lo & kLo32) + lo_hi; + *upper = (hi_lo >> 32) + (t >> 32) + hi_hi; + return (t << 32) | (lo_lo & kLo32); +#endif +} + +#if HWY_COMPILER_MSVC +#pragma intrinsic(memcpy) +#pragma intrinsic(memset) +#endif + +// The source/destination must not overlap/alias. +template <size_t kBytes, typename From, typename To> +HWY_API void CopyBytes(const From* from, To* to) { +#if HWY_COMPILER_MSVC + memcpy(to, from, kBytes); +#else + __builtin_memcpy( + static_cast<void*>(to), static_cast<const void*>(from), kBytes); +#endif +} + +// Same as CopyBytes, but for same-sized objects; avoids a size argument. +template <typename From, typename To> +HWY_API void CopySameSize(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { + static_assert(sizeof(From) == sizeof(To), ""); + CopyBytes<sizeof(From)>(from, to); +} + +template <size_t kBytes, typename To> +HWY_API void ZeroBytes(To* to) { +#if HWY_COMPILER_MSVC + memset(to, 0, kBytes); +#else + __builtin_memset(to, 0, kBytes); +#endif +} + +HWY_API float F32FromBF16(bfloat16_t bf) { + uint32_t bits = bf.bits; + bits <<= 16; + float f; + CopySameSize(&bits, &f); + return f; +} + +HWY_API bfloat16_t BF16FromF32(float f) { + uint32_t bits; + CopySameSize(&f, &bits); + bfloat16_t bf; + bf.bits = static_cast<uint16_t>(bits >> 16); + return bf; +} + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...); + +} // namespace hwy + +#endif // HIGHWAY_HWY_BASE_H_ diff --git a/third_party/highway/hwy/base_test.cc b/third_party/highway/hwy/base_test.cc new file mode 100644 index 0000000000..baca70b6f1 --- /dev/null +++ b/third_party/highway/hwy/base_test.cc @@ -0,0 +1,178 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#include <limits> + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "base_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +HWY_NOINLINE void TestAllLimits() { + HWY_ASSERT_EQ(uint8_t{0}, LimitsMin<uint8_t>()); + HWY_ASSERT_EQ(uint16_t{0}, LimitsMin<uint16_t>()); + HWY_ASSERT_EQ(uint32_t{0}, LimitsMin<uint32_t>()); + HWY_ASSERT_EQ(uint64_t{0}, LimitsMin<uint64_t>()); + + HWY_ASSERT_EQ(int8_t{-128}, LimitsMin<int8_t>()); + HWY_ASSERT_EQ(int16_t{-32768}, LimitsMin<int16_t>()); + HWY_ASSERT_EQ(static_cast<int32_t>(0x80000000u), LimitsMin<int32_t>()); + HWY_ASSERT_EQ(static_cast<int64_t>(0x8000000000000000ull), + LimitsMin<int64_t>()); + + HWY_ASSERT_EQ(uint8_t{0xFF}, LimitsMax<uint8_t>()); + HWY_ASSERT_EQ(uint16_t{0xFFFF}, LimitsMax<uint16_t>()); + HWY_ASSERT_EQ(uint32_t{0xFFFFFFFFu}, LimitsMax<uint32_t>()); + HWY_ASSERT_EQ(uint64_t{0xFFFFFFFFFFFFFFFFull}, LimitsMax<uint64_t>()); + + HWY_ASSERT_EQ(int8_t{0x7F}, LimitsMax<int8_t>()); + HWY_ASSERT_EQ(int16_t{0x7FFF}, LimitsMax<int16_t>()); + HWY_ASSERT_EQ(int32_t{0x7FFFFFFFu}, LimitsMax<int32_t>()); + HWY_ASSERT_EQ(int64_t{0x7FFFFFFFFFFFFFFFull}, LimitsMax<int64_t>()); +} + +struct TestLowestHighest { + template <class T> + HWY_NOINLINE void operator()(T /*unused*/) const { + HWY_ASSERT_EQ(std::numeric_limits<T>::lowest(), LowestValue<T>()); + HWY_ASSERT_EQ(std::numeric_limits<T>::max(), HighestValue<T>()); + } +}; + +HWY_NOINLINE void TestAllLowestHighest() { ForAllTypes(TestLowestHighest()); } +struct TestIsUnsigned { + template <class T> + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(!IsFloat<T>(), "Expected !IsFloat"); + static_assert(!IsSigned<T>(), "Expected !IsSigned"); + } +}; + +struct TestIsSigned { + template <class T> + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(!IsFloat<T>(), "Expected !IsFloat"); + static_assert(IsSigned<T>(), "Expected IsSigned"); + } +}; + +struct TestIsFloat { + template <class T> + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(IsFloat<T>(), "Expected IsFloat"); + static_assert(IsSigned<T>(), "Floats are also considered signed"); + } +}; + +HWY_NOINLINE void TestAllType() { + ForUnsignedTypes(TestIsUnsigned()); + ForSignedTypes(TestIsSigned()); + ForFloatTypes(TestIsFloat()); + + static_assert(sizeof(MakeUnsigned<hwy::uint128_t>) == 16, ""); + static_assert(sizeof(MakeWide<uint64_t>) == 16, "Expected uint128_t"); + static_assert(sizeof(MakeNarrow<hwy::uint128_t>) == 8, "Expected uint64_t"); +} + +struct TestIsSame { + template <class T> + HWY_NOINLINE void operator()(T /*unused*/) const { + static_assert(IsSame<T, T>(), "T == T"); + static_assert(!IsSame<MakeSigned<T>, MakeUnsigned<T>>(), "S != U"); + static_assert(!IsSame<MakeUnsigned<T>, MakeSigned<T>>(), "U != S"); + } +}; + +HWY_NOINLINE void TestAllIsSame() { ForAllTypes(TestIsSame()); } + +HWY_NOINLINE void TestAllBitScan() { + HWY_ASSERT_EQ(size_t{0}, Num0BitsAboveMS1Bit_Nonzero32(0x80000000u)); + HWY_ASSERT_EQ(size_t{0}, Num0BitsAboveMS1Bit_Nonzero32(0xFFFFFFFFu)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsAboveMS1Bit_Nonzero32(0x40000000u)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsAboveMS1Bit_Nonzero32(0x40108210u)); + HWY_ASSERT_EQ(size_t{30}, Num0BitsAboveMS1Bit_Nonzero32(2u)); + HWY_ASSERT_EQ(size_t{30}, Num0BitsAboveMS1Bit_Nonzero32(3u)); + HWY_ASSERT_EQ(size_t{31}, Num0BitsAboveMS1Bit_Nonzero32(1u)); + + HWY_ASSERT_EQ(size_t{0}, + Num0BitsAboveMS1Bit_Nonzero64(0x8000000000000000ull)); + HWY_ASSERT_EQ(size_t{0}, + Num0BitsAboveMS1Bit_Nonzero64(0xFFFFFFFFFFFFFFFFull)); + HWY_ASSERT_EQ(size_t{1}, + Num0BitsAboveMS1Bit_Nonzero64(0x4000000000000000ull)); + HWY_ASSERT_EQ(size_t{1}, + Num0BitsAboveMS1Bit_Nonzero64(0x4010821004200011ull)); + HWY_ASSERT_EQ(size_t{62}, Num0BitsAboveMS1Bit_Nonzero64(2ull)); + HWY_ASSERT_EQ(size_t{62}, Num0BitsAboveMS1Bit_Nonzero64(3ull)); + HWY_ASSERT_EQ(size_t{63}, Num0BitsAboveMS1Bit_Nonzero64(1ull)); + + HWY_ASSERT_EQ(size_t{0}, Num0BitsBelowLS1Bit_Nonzero32(1u)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsBelowLS1Bit_Nonzero32(2u)); + HWY_ASSERT_EQ(size_t{30}, Num0BitsBelowLS1Bit_Nonzero32(0xC0000000u)); + HWY_ASSERT_EQ(size_t{31}, Num0BitsBelowLS1Bit_Nonzero32(0x80000000u)); + + HWY_ASSERT_EQ(size_t{0}, Num0BitsBelowLS1Bit_Nonzero64(1ull)); + HWY_ASSERT_EQ(size_t{1}, Num0BitsBelowLS1Bit_Nonzero64(2ull)); + HWY_ASSERT_EQ(size_t{62}, + Num0BitsBelowLS1Bit_Nonzero64(0xC000000000000000ull)); + HWY_ASSERT_EQ(size_t{63}, + Num0BitsBelowLS1Bit_Nonzero64(0x8000000000000000ull)); +} + +HWY_NOINLINE void TestAllPopCount() { + HWY_ASSERT_EQ(size_t{0}, PopCount(0u)); + HWY_ASSERT_EQ(size_t{1}, PopCount(1u)); + HWY_ASSERT_EQ(size_t{1}, PopCount(2u)); + HWY_ASSERT_EQ(size_t{2}, PopCount(3u)); + HWY_ASSERT_EQ(size_t{1}, PopCount(0x80000000u)); + HWY_ASSERT_EQ(size_t{31}, PopCount(0x7FFFFFFFu)); + HWY_ASSERT_EQ(size_t{32}, PopCount(0xFFFFFFFFu)); + + HWY_ASSERT_EQ(size_t{1}, PopCount(0x80000000ull)); + HWY_ASSERT_EQ(size_t{31}, PopCount(0x7FFFFFFFull)); + HWY_ASSERT_EQ(size_t{32}, PopCount(0xFFFFFFFFull)); + HWY_ASSERT_EQ(size_t{33}, PopCount(0x10FFFFFFFFull)); + HWY_ASSERT_EQ(size_t{63}, PopCount(0xFFFEFFFFFFFFFFFFull)); + HWY_ASSERT_EQ(size_t{64}, PopCount(0xFFFFFFFFFFFFFFFFull)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(BaseTest); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllLimits); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllLowestHighest); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllType); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllIsSame); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllBitScan); +HWY_EXPORT_AND_TEST_P(BaseTest, TestAllPopCount); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/cache_control.h b/third_party/highway/hwy/cache_control.h new file mode 100644 index 0000000000..b124e5707e --- /dev/null +++ b/third_party/highway/hwy/cache_control.h @@ -0,0 +1,110 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_ +#define HIGHWAY_HWY_CACHE_CONTROL_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "hwy/base.h" + +// Requires SSE2; fails to compile on 32-bit Clang 7 (see +// https://github.com/gperftools/gperftools/issues/946). +#if !defined(__SSE2__) || (HWY_COMPILER_CLANG && HWY_ARCH_X86_32) +#undef HWY_DISABLE_CACHE_CONTROL +#define HWY_DISABLE_CACHE_CONTROL +#endif + +// intrin.h is sufficient on MSVC and already included by base.h. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#include <emmintrin.h> // SSE2 +#endif + +// Windows.h #defines these, which causes infinite recursion. Temporarily +// undefine them in this header; these functions are anyway deprecated. +// TODO(janwas): remove when these functions are removed. +#pragma push_macro("LoadFence") +#undef LoadFence + +namespace hwy { + +// Even if N*sizeof(T) is smaller, Stream may write a multiple of this size. +#define HWY_STREAM_MULTIPLE 16 + +// The following functions may also require an attribute. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#define HWY_ATTR_CACHE __attribute__((target("sse2"))) +#else +#define HWY_ATTR_CACHE +#endif + +// Delays subsequent loads until prior loads are visible. Beware of potentially +// differing behavior across architectures and vendors: on Intel but not +// AMD CPUs, also serves as a full fence (waits for all prior instructions to +// complete). +HWY_INLINE HWY_ATTR_CACHE void LoadFence() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_lfence(); +#endif +} + +// Ensures values written by previous `Stream` calls are visible on the current +// core. This is NOT sufficient for synchronizing across cores; when `Stream` +// outputs are to be consumed by other core(s), the producer must publish +// availability (e.g. via mutex or atomic_flag) after `FlushStream`. +HWY_INLINE HWY_ATTR_CACHE void FlushStream() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_sfence(); +#endif +} + +// Optionally begins loading the cache line containing "p" to reduce latency of +// subsequent actual loads. +template <typename T> +HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_prefetch(reinterpret_cast<const char*>(p), _MM_HINT_T0); +#elif HWY_COMPILER_GCC // includes clang + // Hint=0 (NTA) behavior differs, but skipping outer caches is probably not + // desirable, so use the default 3 (keep in caches). + __builtin_prefetch(p, /*write=*/0, /*hint=*/3); +#else + (void)p; +#endif +} + +// Invalidates and flushes the cache line containing "p", if possible. +HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_clflush(p); +#else + (void)p; +#endif +} + +// When called inside a spin-loop, may reduce power consumption. +HWY_INLINE HWY_ATTR_CACHE void Pause() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_pause(); +#endif +} + +} // namespace hwy + +// TODO(janwas): remove when these functions are removed. (See above.) +#pragma pop_macro("LoadFence") + +#endif // HIGHWAY_HWY_CACHE_CONTROL_H_ diff --git a/third_party/highway/hwy/contrib/algo/copy-inl.h b/third_party/highway/hwy/contrib/algo/copy-inl.h new file mode 100644 index 0000000000..033cf8a626 --- /dev/null +++ b/third_party/highway/hwy/contrib/algo/copy-inl.h @@ -0,0 +1,136 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a CopyAlignedPadded because it +// would be more verbose than such a loop. + +// Fills `to`[0, `count`) with `value`. +template <class D, typename T = TFromD<D>> +void Fill(D d, T value, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + const Vec<D> v = Set(d, value); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + StoreU(v, d, to + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeFillN(remaining, value, d, to + idx); +} + +// Copies `from`[0, `count`) to `to`, which must not overlap `from`. +template <class D, typename T = TFromD<D>> +void Copy(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec<D> v = LoadU(d, from + idx); + StoreU(v, d, to + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeCopyN(remaining, d, from + idx, to + idx); +} + +// For idx in [0, count) in ascending order, appends `from[idx]` to `to` if the +// corresponding mask element of `func(d, v)` is true. Returns the STL-style end +// of the newly written elements in `to`. +// +// `func` is either a functor with a templated operator()(d, v) returning a +// mask, or a generic lambda if using C++14. Due to apparent limitations of +// Clang on Windows, it is currently necessary to add HWY_ATTR before the +// opening { of the lambda to avoid errors about "function .. requires target". +// +// NOTE: this is only supported for 16-, 32- or 64-bit types. +// NOTE: Func may be called a second time for elements it has already seen, but +// these elements will not be written to `to` again. +template <class D, class Func, typename T = TFromD<D>> +T* CopyIf(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec<D> v = LoadU(d, from + idx); + to += CompressBlendedStore(v, func(d, v), d, to); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return to; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + for (; idx < count; ++idx) { + using V1 = Vec<decltype(d1)>; + // Workaround for -Waggressive-loop-optimizations on GCC 8 + // (iteration 2305843009213693951 invokes undefined behavior for T=i64) + const uintptr_t addr = reinterpret_cast<uintptr_t>(from); + const T* HWY_RESTRICT from_idx = + reinterpret_cast<const T * HWY_RESTRICT>(addr + (idx * sizeof(T))); + const V1 v = LoadU(d1, from_idx); + // Avoid storing to `to` unless we know it should be kept - otherwise, we + // might overrun the end if it was allocated for the exact count. + if (CountTrue(d1, func(d1, v)) == 0) continue; + StoreU(v, d1, to); + to += 1; + } +#else + // Start index of the last unaligned whole vector, ending at the array end. + const size_t last = count - N; + // Number of elements before `from` or already written. + const size_t invalid = idx - last; + HWY_DASSERT(0 != invalid && invalid < N); + const Mask<D> mask = Not(FirstN(d, invalid)); + const Vec<D> v = MaskedLoad(mask, d, from + last); + to += CompressBlendedStore(v, And(mask, func(d, v)), d, to); +#endif + return to; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ diff --git a/third_party/highway/hwy/contrib/algo/copy_test.cc b/third_party/highway/hwy/contrib/algo/copy_test.cc new file mode 100644 index 0000000000..e2675a39d7 --- /dev/null +++ b/third_party/highway/hwy/contrib/algo/copy_test.cc @@ -0,0 +1,199 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/algo/copy_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/algo/copy-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// If your project requires C++14 or later, you can ignore this and pass lambdas +// directly to Transform, without requiring an lvalue as we do here for C++11. +#if __cplusplus < 201402L +#define HWY_GENERIC_LAMBDA 0 +#else +#define HWY_GENERIC_LAMBDA 1 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns random integer in [0, 128), which fits in any lane type. +template <typename T> +T Random7Bit(RandomState& rng) { + return static_cast<T>(Random32(&rng) & 127); +} + +// In C++14, we can instead define these as generic lambdas next to where they +// are invoked. +#if !HWY_GENERIC_LAMBDA + +struct IsOdd { + template <class D, class V> + Mask<D> operator()(D d, V v) const { + return TestBit(v, Set(d, TFromD<D>{1})); + } +}; + +#endif // !HWY_GENERIC_LAMBDA + +// Invokes Test (e.g. TestCopyIf) with all arg combinations. T comes from +// ForFloatTypes. +template <class Test> +struct ForeachCountAndMisalign { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + for (size_t count = 0; count < 2 * N; ++count) { + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test()(d, count, ma, mb, rng); + } + } + } + } +}; + +struct TestFill { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + // HWY_MAX prevents error when misalign == count == 0. + AlignedFreeUniquePtr<T[]> pa = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + T* expected = pa.get() + misalign_a; + const T value = Random7Bit<T>(rng); + for (size_t i = 0; i < count; ++i) { + expected[i] = value; + } + AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + count + 1); + T* actual = pb.get() + misalign_b; + + actual[count] = T{0}; // sentinel + Fill(d, value, count, actual); + HWY_ASSERT_EQ(T{0}, actual[count]); // did not write past end + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected, actual, count, target_name, + __FILE__, __LINE__); + } +}; + +void TestAllFill() { + ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFill>>()); +} + +struct TestCopy { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr<T[]> pa = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random7Bit<T>(rng); + } + AlignedFreeUniquePtr<T[]> pb = + AllocateAligned<T>(HWY_MAX(1, misalign_b + count)); + T* b = pb.get() + misalign_b; + + Copy(d, a, count, b); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, a, b, count, target_name, __FILE__, + __LINE__); + } +}; + +void TestAllCopy() { + ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestCopy>>()); +} + +struct TestCopyIf { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr<T[]> pa = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random7Bit<T>(rng); + } + const size_t padding = Lanes(ScalableTag<T>()); + AlignedFreeUniquePtr<T[]> pb = + AllocateAligned<T>(HWY_MAX(1, misalign_b + count + padding)); + T* b = pb.get() + misalign_b; + + AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count)); + size_t num_odd = 0; + for (size_t i = 0; i < count; ++i) { + if (a[i] & 1) { + expected[num_odd++] = a[i]; + } + } + +#if HWY_GENERIC_LAMBDA + const auto is_odd = [](const auto d, const auto v) HWY_ATTR { + return TestBit(v, Set(d, TFromD<decltype(d)>{1})); + }; +#else + const IsOdd is_odd; +#endif + T* end = CopyIf(d, a, count, b, is_odd); + const size_t num_written = static_cast<size_t>(end - b); + HWY_ASSERT_EQ(num_odd, num_written); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), b, num_odd, target_name, + __FILE__, __LINE__); + } +}; + +void TestAllCopyIf() { + ForUI163264(ForPartialVectors<ForeachCountAndMisalign<TestCopyIf>>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(CopyTest); +HWY_EXPORT_AND_TEST_P(CopyTest, TestAllFill); +HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopy); +HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopyIf); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/algo/find-inl.h b/third_party/highway/hwy/contrib/algo/find-inl.h new file mode 100644 index 0000000000..388842e988 --- /dev/null +++ b/third_party/highway/hwy/contrib/algo/find-inl.h @@ -0,0 +1,109 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns index of the first element equal to `value` in `in[0, count)`, or +// `count` if not found. +template <class D, typename T = TFromD<D>> +size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + const Vec<D> broadcasted = Set(d, value); + + size_t i = 0; + for (; i + N <= count; i += N) { + const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast<size_t>(pos); + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag<T, 1> d1; + using V1 = Vec<decltype(d1)>; + const V1 broadcasted1 = Set(d1, GetLane(broadcasted)); + for (; i < count; ++i) { + if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask)); + if (pos >= 0) return i + static_cast<size_t>(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// Returns index of the first element in `in[0, count)` for which `func(d, vec)` +// returns true, otherwise `count`. +template <class D, class Func, typename T = TFromD<D>> +size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t i = 0; + for (; i + N <= count; i += N) { + const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast<size_t>(pos); + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag<T, 1> d1; + for (; i < count; ++i) { + if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask)); + if (pos >= 0) return i + static_cast<size_t>(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ diff --git a/third_party/highway/hwy/contrib/algo/find_test.cc b/third_party/highway/hwy/contrib/algo/find_test.cc new file mode 100644 index 0000000000..f438a18ba0 --- /dev/null +++ b/third_party/highway/hwy/contrib/algo/find_test.cc @@ -0,0 +1,219 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> // std::find_if +#include <vector> + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/print.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/algo/find_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/algo/find-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// If your project requires C++14 or later, you can ignore this and pass lambdas +// directly to FindIf, without requiring an lvalue as we do here for C++11. +#if __cplusplus < 201402L +#define HWY_GENERIC_LAMBDA 0 +#else +#define HWY_GENERIC_LAMBDA 1 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns random number in [-8, 8) - we use knowledge of the range to Find() +// values we know are not present. +template <typename T> +T Random(RandomState& rng) { + const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; + const double val = (bits - 512) / 64.0; + // Clamp negative to zero for unsigned types. + return static_cast<T>(HWY_MAX(hwy::LowestValue<T>(), val)); +} + +// In C++14, we can instead define these as generic lambdas next to where they +// are invoked. +#if !HWY_GENERIC_LAMBDA + +class GreaterThan { + public: + GreaterThan(int val) : val_(val) {} + template <class D, class V> + Mask<D> operator()(D d, V v) const { + return Gt(v, Set(d, static_cast<TFromD<D>>(val_))); + } + + private: + int val_; +}; + +#endif // !HWY_GENERIC_LAMBDA + +// Invokes Test (e.g. TestFind) with all arg combinations. +template <class Test> +struct ForeachCountAndMisalign { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + // Find() checks 8 vectors at a time, so we want to cover a fairly large + // range without oversampling (checking every possible count). + std::vector<size_t> counts(AdjustedReps(512)); + for (size_t& count : counts) { + count = static_cast<size_t>(rng()) % (16 * N + 1); + } + counts[0] = 0; // ensure we test count=0. + + for (size_t count : counts) { + for (size_t m : misalignments) { + Test()(d, count, m, rng); + } + } + } +}; + +struct TestFind { + template <class D> + void operator()(D d, size_t count, size_t misalign, RandomState& rng) { + using T = TFromD<D>; + // Must allocate at least one even if count is zero. + AlignedFreeUniquePtr<T[]> storage = + AllocateAligned<T>(HWY_MAX(1, misalign + count)); + T* in = storage.get() + misalign; + for (size_t i = 0; i < count; ++i) { + in[i] = Random<T>(rng); + } + + // For each position, search for that element (which we know is there) + for (size_t pos = 0; pos < count; ++pos) { + const size_t actual = Find(d, in[pos], in, count); + + // We may have found an earlier occurrence of the same value; ensure the + // value is the same, and that it is the first. + if (!IsEqual(in[pos], in[actual])) { + fprintf(stderr, "%s count %d, found %.15f at %d but wanted %.15f\n", + hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count), + static_cast<double>(in[actual]), static_cast<int>(actual), + static_cast<double>(in[pos])); + HWY_ASSERT(false); + } + for (size_t i = 0; i < actual; ++i) { + if (IsEqual(in[i], in[pos])) { + fprintf(stderr, "%s count %d, found %f at %d but Find returned %d\n", + hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count), + static_cast<double>(in[i]), static_cast<int>(i), + static_cast<int>(actual)); + HWY_ASSERT(false); + } + } + } + + // Also search for values we know not to be present (out of range) + HWY_ASSERT_EQ(count, Find(d, T{9}, in, count)); + HWY_ASSERT_EQ(count, Find(d, static_cast<T>(-9), in, count)); + } +}; + +void TestAllFind() { + ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFind>>()); +} + +struct TestFindIf { + template <class D> + void operator()(D d, size_t count, size_t misalign, RandomState& rng) { + using T = TFromD<D>; + using TI = MakeSigned<T>; + // Must allocate at least one even if count is zero. + AlignedFreeUniquePtr<T[]> storage = + AllocateAligned<T>(HWY_MAX(1, misalign + count)); + T* in = storage.get() + misalign; + for (size_t i = 0; i < count; ++i) { + in[i] = Random<T>(rng); + HWY_ASSERT(in[i] < 8); + HWY_ASSERT(!hwy::IsSigned<T>() || static_cast<TI>(in[i]) >= -8); + } + + bool found_any = false; + bool not_found_any = false; + + // unsigned T would be promoted to signed and compare greater than any + // negative val, whereas Set() would just cast to an unsigned value and the + // comparison remains unsigned, so avoid negative numbers there. + const int min_val = IsSigned<T>() ? -9 : 0; + // Includes out-of-range value 9 to test the not-found path. + for (int val = min_val; val <= 9; ++val) { +#if HWY_GENERIC_LAMBDA + const auto greater = [val](const auto d, const auto v) HWY_ATTR { + return Gt(v, Set(d, static_cast<T>(val))); + }; +#else + const GreaterThan greater(val); +#endif + const size_t actual = FindIf(d, in, count, greater); + found_any |= actual < count; + not_found_any |= actual == count; + + const auto pos = std::find_if( + in, in + count, [val](T x) { return x > static_cast<T>(val); }); + // Convert returned iterator to index. + const size_t expected = static_cast<size_t>(pos - in); + if (expected != actual) { + fprintf(stderr, "%s count %d val %d, expected %d actual %d\n", + hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count), + val, static_cast<int>(expected), static_cast<int>(actual)); + hwy::detail::PrintArray(hwy::detail::MakeTypeInfo<T>(), "in", in, count, + 0, count); + HWY_ASSERT(false); + } + } + + // We will always not-find something due to val=9. + HWY_ASSERT(not_found_any); + // We'll find something unless the input is empty or {0} - because 0 > i + // is false for all i=[0,9]. + if (count != 0 && in[0] != 0) { + HWY_ASSERT(found_any); + } + } +}; + +void TestAllFindIf() { + ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFindIf>>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(FindTest); +HWY_EXPORT_AND_TEST_P(FindTest, TestAllFind); +HWY_EXPORT_AND_TEST_P(FindTest, TestAllFindIf); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/algo/transform-inl.h b/third_party/highway/hwy/contrib/algo/transform-inl.h new file mode 100644 index 0000000000..3e830acb47 --- /dev/null +++ b/third_party/highway/hwy/contrib/algo/transform-inl.h @@ -0,0 +1,262 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a TransformAlignedPadded because it +// would be more verbose than such a loop. +// +// Func is either a functor with a templated operator()(d, v[, v1[, v2]]), or a +// generic lambda if using C++14. Due to apparent limitations of Clang on +// Windows, it is currently necessary to add HWY_ATTR before the opening { of +// the lambda to avoid errors about "always_inline function .. requires target". +// +// If HWY_MEM_OPS_MIGHT_FAULT, we use scalar code instead of masking. Otherwise, +// we used `MaskedLoad` and `BlendedStore` to read/write the final partial +// vector. + +// Fills `out[0, count)` with the vectors returned by `func(d, index_vec)`, +// where `index_vec` is `Vec<RebindToUnsigned<D>>`. On the first call to `func`, +// the value of its lane i is i, and increases by `Lanes(d)` after every call. +// Note that some of these indices may be `>= count`, but the elements that +// `func` returns in those lanes will not be written to `out`. +template <class D, class Func, typename T = TFromD<D>> +void Generate(D d, T* HWY_RESTRICT out, size_t count, const Func& func) { + const RebindToUnsigned<D> du; + using TU = TFromD<decltype(du)>; + const size_t N = Lanes(d); + + size_t idx = 0; + Vec<decltype(du)> vidx = Iota(du, 0); + for (; idx + N <= count; idx += N) { + StoreU(func(d, vidx), d, out + idx); + vidx = Add(vidx, Set(du, static_cast<TU>(N))); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + const RebindToUnsigned<decltype(d1)> du1; + for (; idx < count; ++idx) { + StoreU(func(d1, Set(du1, static_cast<TU>(idx))), d1, out + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + BlendedStore(func(d, vidx), mask, d, out + idx); +#endif +} + +// Replaces `inout[idx]` with `func(d, inout[idx])`. Example usage: multiplying +// array elements by a constant. +template <class D, class Func, typename T = TFromD<D>> +void Transform(D d, T* HWY_RESTRICT inout, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec<D> v = LoadU(d, inout + idx); + StoreU(func(d, v), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + for (; idx < count; ++idx) { + using V1 = Vec<decltype(d1)>; + const V1 v = LoadU(d1, inout + idx); + StoreU(func(d1, v), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, inout + idx); + BlendedStore(func(d, v), mask, d, inout + idx); +#endif +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx])`. Example usage: +// multiplying array elements by those of another array. +template <class D, class Func, typename T = TFromD<D>> +void Transform1(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec<D> v = LoadU(d, inout + idx); + const Vec<D> v1 = LoadU(d, in1 + idx); + StoreU(func(d, v, v1), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + for (; idx < count; ++idx) { + using V1 = Vec<decltype(d1)>; + const V1 v = LoadU(d1, inout + idx); + const V1 v1 = LoadU(d1, in1 + idx); + StoreU(func(d1, v, v1), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, inout + idx); + const Vec<D> v1 = MaskedLoad(mask, d, in1 + idx); + BlendedStore(func(d, v, v1), mask, d, inout + idx); +#endif +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx], in2[idx])`. Example +// usage: FMA of elements from three arrays, stored into the first array. +template <class D, class Func, typename T = TFromD<D>> +void Transform2(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const T* HWY_RESTRICT in2, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + const Vec<D> v = LoadU(d, inout + idx); + const Vec<D> v1 = LoadU(d, in1 + idx); + const Vec<D> v2 = LoadU(d, in2 + idx); + StoreU(func(d, v, v1, v2), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + for (; idx < count; ++idx) { + using V1 = Vec<decltype(d1)>; + const V1 v = LoadU(d1, inout + idx); + const V1 v1 = LoadU(d1, in1 + idx); + const V1 v2 = LoadU(d1, in2 + idx); + StoreU(func(d1, v, v1, v2), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, inout + idx); + const Vec<D> v1 = MaskedLoad(mask, d, in1 + idx); + const Vec<D> v2 = MaskedLoad(mask, d, in2 + idx); + BlendedStore(func(d, v, v1, v2), mask, d, inout + idx); +#endif +} + +template <class D, typename T = TFromD<D>> +void Replace(D d, T* HWY_RESTRICT inout, size_t count, T new_t, T old_t) { + const size_t N = Lanes(d); + const Vec<D> old_v = Set(d, old_t); + const Vec<D> new_v = Set(d, new_t); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + Vec<D> v = LoadU(d, inout + idx); + StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + const Vec<decltype(d1)> old_v1 = Set(d1, old_t); + const Vec<decltype(d1)> new_v1 = Set(d1, new_t); + for (; idx < count; ++idx) { + using V1 = Vec<decltype(d1)>; + const V1 v1 = LoadU(d1, inout + idx); + StoreU(IfThenElse(Eq(v1, old_v1), new_v1, v1), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, inout + idx); + BlendedStore(IfThenElse(Eq(v, old_v), new_v, v), mask, d, inout + idx); +#endif +} + +template <class D, class Func, typename T = TFromD<D>> +void ReplaceIf(D d, T* HWY_RESTRICT inout, size_t count, T new_t, + const Func& func) { + const size_t N = Lanes(d); + const Vec<D> new_v = Set(d, new_t); + + size_t idx = 0; + for (; idx + N <= count; idx += N) { + Vec<D> v = LoadU(d, inout + idx); + StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx); + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag<T, 1> d1; + const Vec<decltype(d1)> new_v1 = Set(d1, new_t); + for (; idx < count; ++idx) { + using V1 = Vec<decltype(d1)>; + const V1 v = LoadU(d1, inout + idx); + StoreU(IfThenElse(func(d1, v), new_v1, v), d1, inout + idx); + } +#else + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask<D> mask = FirstN(d, remaining); + const Vec<D> v = MaskedLoad(mask, d, inout + idx); + BlendedStore(IfThenElse(func(d, v), new_v, v), mask, d, inout + idx); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ diff --git a/third_party/highway/hwy/contrib/algo/transform_test.cc b/third_party/highway/hwy/contrib/algo/transform_test.cc new file mode 100644 index 0000000000..335607ccfb --- /dev/null +++ b/third_party/highway/hwy/contrib/algo/transform_test.cc @@ -0,0 +1,372 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <string.h> // memcpy + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/algo/transform_test.cc" //NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/algo/transform-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// If your project requires C++14 or later, you can ignore this and pass lambdas +// directly to Transform, without requiring an lvalue as we do here for C++11. +#if __cplusplus < 201402L +#define HWY_GENERIC_LAMBDA 0 +#else +#define HWY_GENERIC_LAMBDA 1 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename T> +T Alpha() { + return static_cast<T>(1.5); // arbitrary scalar +} + +// Returns random floating-point number in [-8, 8) to ensure computations do +// not exceed float32 precision. +template <typename T> +T Random(RandomState& rng) { + const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; + const double val = (bits - 512) / 64.0; + // Clamp negative to zero for unsigned types. + return static_cast<T>(HWY_MAX(hwy::LowestValue<T>(), val)); +} + +// SCAL, AXPY names are from BLAS. +template <typename T> +HWY_NOINLINE void SimpleSCAL(const T* x, T* out, size_t count) { + for (size_t i = 0; i < count; ++i) { + out[i] = Alpha<T>() * x[i]; + } +} + +template <typename T> +HWY_NOINLINE void SimpleAXPY(const T* x, const T* y, T* out, size_t count) { + for (size_t i = 0; i < count; ++i) { + out[i] = Alpha<T>() * x[i] + y[i]; + } +} + +template <typename T> +HWY_NOINLINE void SimpleFMA4(const T* x, const T* y, const T* z, T* out, + size_t count) { + for (size_t i = 0; i < count; ++i) { + out[i] = x[i] * y[i] + z[i]; + } +} + +// In C++14, we can instead define these as generic lambdas next to where they +// are invoked. +#if !HWY_GENERIC_LAMBDA + +// Generator that returns even numbers by doubling the output indices. +struct Gen2 { + template <class D, class VU> + Vec<D> operator()(D d, VU vidx) const { + return BitCast(d, Add(vidx, vidx)); + } +}; + +struct SCAL { + template <class D, class V> + Vec<D> operator()(D d, V v) const { + using T = TFromD<D>; + return Mul(Set(d, Alpha<T>()), v); + } +}; + +struct AXPY { + template <class D, class V> + Vec<D> operator()(D d, V v, V v1) const { + using T = TFromD<D>; + return MulAdd(Set(d, Alpha<T>()), v, v1); + } +}; + +struct FMA4 { + template <class D, class V> + Vec<D> operator()(D /*d*/, V v, V v1, V v2) const { + return MulAdd(v, v1, v2); + } +}; + +#endif // !HWY_GENERIC_LAMBDA + +// Invokes Test (e.g. TestTransform1) with all arg combinations. T comes from +// ForFloatTypes. +template <class Test> +struct ForeachCountAndMisalign { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + for (size_t count = 0; count < 2 * N; ++count) { + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test()(d, count, ma, mb, rng); + } + } + } + } +}; + +// Output-only, no loads +struct TestGenerate { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t /*misalign_b*/, + RandomState& /*rng*/) { + using T = TFromD<D>; + AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + count + 1); + T* actual = pa.get() + misalign_a; + + AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count)); + for (size_t i = 0; i < count; ++i) { + expected[i] = static_cast<T>(2 * i); + } + + // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that + // the attribute also applies to lambdas? If so, remove HWY_ATTR. +#if HWY_GENERIC_LAMBDA + const auto gen2 = [](const auto d, const auto vidx) + HWY_ATTR { return BitCast(d, Add(vidx, vidx)); }; +#else + const Gen2 gen2; +#endif + actual[count] = T{0}; // sentinel + Generate(d, actual, count, gen2); + HWY_ASSERT_EQ(T{0}, actual[count]); // did not write past end + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), actual, count, + target_name, __FILE__, __LINE__); + } +}; + +// Zero extra input arrays +struct TestTransform { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + if (misalign_b != 0) return; + using T = TFromD<D>; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr<T[]> pa = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random<T>(rng); + } + + AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count)); + SimpleSCAL(a, expected.get(), count); + + // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that + // the attribute also applies to lambdas? If so, remove HWY_ATTR. +#if HWY_GENERIC_LAMBDA + const auto scal = [](const auto d, const auto v) + HWY_ATTR { return Mul(Set(d, Alpha<T>()), v); }; +#else + const SCAL scal; +#endif + Transform(d, a, count, scal); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, + __FILE__, __LINE__); + } +}; + +// One extra input array +struct TestTransform1 { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr<T[]> pa = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + AlignedFreeUniquePtr<T[]> pb = + AllocateAligned<T>(HWY_MAX(1, misalign_b + count)); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + for (size_t i = 0; i < count; ++i) { + a[i] = Random<T>(rng); + b[i] = Random<T>(rng); + } + + AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count)); + SimpleAXPY(a, b, expected.get(), count); + +#if HWY_GENERIC_LAMBDA + const auto axpy = [](const auto d, const auto v, const auto v1) HWY_ATTR { + return MulAdd(Set(d, Alpha<T>()), v, v1); + }; +#else + const AXPY axpy; +#endif + Transform1(d, a, count, b, axpy); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, + __FILE__, __LINE__); + } +}; + +// Two extra input arrays +struct TestTransform2 { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + // Prevents error if size to allocate is zero. + AlignedFreeUniquePtr<T[]> pa = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + AlignedFreeUniquePtr<T[]> pb = + AllocateAligned<T>(HWY_MAX(1, misalign_b + count)); + AlignedFreeUniquePtr<T[]> pc = + AllocateAligned<T>(HWY_MAX(1, misalign_a + count)); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + T* c = pc.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random<T>(rng); + b[i] = Random<T>(rng); + c[i] = Random<T>(rng); + } + + AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count)); + SimpleFMA4(a, b, c, expected.get(), count); + +#if HWY_GENERIC_LAMBDA + const auto fma4 = [](auto /*d*/, auto v, auto v1, auto v2) + HWY_ATTR { return MulAdd(v, v1, v2); }; +#else + const FMA4 fma4; +#endif + Transform2(d, a, count, b, c, fma4); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name, + __FILE__, __LINE__); + } +}; + +template <typename T> +class IfEq { + public: + IfEq(T val) : val_(val) {} + + template <class D, class V> + Mask<D> operator()(D d, V v) const { + return Eq(v, Set(d, val_)); + } + + private: + T val_; +}; + +struct TestReplace { + template <class D> + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + if (misalign_b != 0) return; + if (count == 0) return; + using T = TFromD<D>; + AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + count); + T* a = pa.get() + misalign_a; + for (size_t i = 0; i < count; ++i) { + a[i] = Random<T>(rng); + } + AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(count); + + AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(count); + + std::vector<size_t> positions(AdjustedReps(count)); + for (size_t& pos : positions) { + pos = static_cast<size_t>(rng()) % count; + } + + for (size_t pos = 0; pos < count; ++pos) { + const T old_t = a[pos]; + const T new_t = Random<T>(rng); + for (size_t i = 0; i < count; ++i) { + expected[i] = IsEqual(a[i], old_t) ? new_t : a[i]; + } + + // Copy so ReplaceIf gets the same input (and thus also outputs expected) + memcpy(pb.get(), a, count * sizeof(T)); + + Replace(d, a, count, new_t, old_t); + HWY_ASSERT_ARRAY_EQ(expected.get(), a, count); + + ReplaceIf(d, pb.get(), count, new_t, IfEq<T>(old_t)); + HWY_ASSERT_ARRAY_EQ(expected.get(), pb.get(), count); + } + } +}; + +void TestAllGenerate() { + // The test BitCast-s the indices, which does not work for floats. + ForIntegerTypes(ForPartialVectors<ForeachCountAndMisalign<TestGenerate>>()); +} + +void TestAllTransform() { + ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestTransform>>()); +} + +void TestAllTransform1() { + ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestTransform1>>()); +} + +void TestAllTransform2() { + ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestTransform2>>()); +} + +void TestAllReplace() { + ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestReplace>>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(TransformTest); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllGenerate); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform1); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform2); +HWY_EXPORT_AND_TEST_P(TransformTest, TestAllReplace); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h b/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h new file mode 100644 index 0000000000..04d015453b --- /dev/null +++ b/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h @@ -0,0 +1,2599 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The entry points are class templates specialized below for each number of +// bits. Each provides Pack and Unpack member functions which load (Pack) or +// store (Unpack) B raw vectors, and store (Pack) or load (Unpack) a number of +// packed vectors equal to kBits. B denotes the bits per lane: 8 for Pack8, 16 +// for Pack16, which is also the upper bound for kBits. +template <size_t kBits> // <= 8 +struct Pack8 {}; +template <size_t kBits> // <= 16 +struct Pack16 {}; + +template <> +struct Pack8<1> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed = + Xor3(Or(ShiftLeft<7>(raw7), ShiftLeft<6>(raw6)), + Xor3(ShiftLeft<5>(raw5), ShiftLeft<4>(raw4), ShiftLeft<3>(raw3)), + Xor3(ShiftLeft<2>(raw2), ShiftLeft<1>(raw1), raw0)); + StoreU(BitCast(d8, packed), d8, packed_out); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0101u); // LSB in each byte + + const VU16 packed = BitCast(d16, LoadU(d8, packed_in)); + + const VU16 raw0 = And(packed, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(ShiftRight<1>(packed), mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<2>(packed), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<3>(packed), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(ShiftRight<4>(packed), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<5>(packed), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<6>(packed), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<7>(packed), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<1> + +template <> +struct Pack8<2> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed0 = Xor3(ShiftLeft<6>(raw6), ShiftLeft<4>(raw4), + Or(ShiftLeft<2>(raw2), raw0)); + const VU16 packed1 = Xor3(ShiftLeft<6>(raw7), ShiftLeft<4>(raw5), + Or(ShiftLeft<2>(raw3), raw1)); + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0303u); // Lowest 2 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<2>(packed0), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<2>(packed1), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(ShiftRight<4>(packed0), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<4>(packed1), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<6>(packed0), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<6>(packed1), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<2> + +template <> +struct Pack8<3> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + // The upper two bits of these three will be filled with packed3 (6 bits). + VU16 packed0 = Or(ShiftLeft<3>(raw4), raw0); + VU16 packed1 = Or(ShiftLeft<3>(raw5), raw1); + VU16 packed2 = Or(ShiftLeft<3>(raw6), raw2); + const VU16 packed3 = Or(ShiftLeft<3>(raw7), raw3); + + const VU16 hi2 = Set(d16, 0xC0C0u); + packed0 = OrAnd(packed0, ShiftLeft<2>(packed3), hi2); + packed1 = OrAnd(packed1, ShiftLeft<4>(packed3), hi2); + packed2 = OrAnd(packed2, ShiftLeft<6>(packed3), hi2); + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0707u); // Lowest 3 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw4 = And(ShiftRight<3>(packed0), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<3>(packed1), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<3>(packed2), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + // raw73 is the concatenation of the upper two bits in packed0..2. + const VU16 hi2 = Set(d16, 0xC0C0u); + const VU16 raw73 = Xor3(ShiftRight<6>(And(packed2, hi2)), // + ShiftRight<4>(And(packed1, hi2)), + ShiftRight<2>(And(packed0, hi2))); + + const VU16 raw3 = And(mask, raw73); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw7 = And(mask, ShiftRight<3>(raw73)); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<3> + +template <> +struct Pack8<4> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed0 = Or(ShiftLeft<4>(raw2), raw0); + const VU16 packed1 = Or(ShiftLeft<4>(raw3), raw1); + const VU16 packed2 = Or(ShiftLeft<4>(raw6), raw4); + const VU16 packed3 = Or(ShiftLeft<4>(raw7), raw5); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0F0Fu); // Lowest 4 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<4>(packed0), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<4>(packed1), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(packed2, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed3, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<4>(packed2), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<4>(packed3), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<4> + +template <> +struct Pack8<5> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + // Fill upper three bits with upper bits from raw4..7. + const VU16 hi3 = Set(d16, 0xE0E0u); + const VU16 packed0 = OrAnd(raw0, ShiftLeft<3>(raw4), hi3); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<3>(raw5), hi3); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<3>(raw6), hi3); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<3>(raw7), hi3); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + + // Combine lower two bits of raw4..7 into packed4. + const VU16 lo2 = Set(d16, 0x0303u); + const VU16 packed4 = Or(And(raw4, lo2), Xor3(ShiftLeft<2>(And(raw5, lo2)), + ShiftLeft<4>(And(raw6, lo2)), + ShiftLeft<6>(And(raw7, lo2)))); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + + const VU16 mask = Set(d16, 0x1F1Fu); // Lowest 5 bits per byte + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(packed3, mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + // The upper bits are the top 3 bits shifted right by three. + const VU16 top4 = ShiftRight<3>(AndNot(mask, packed0)); + const VU16 top5 = ShiftRight<3>(AndNot(mask, packed1)); + const VU16 top6 = ShiftRight<3>(AndNot(mask, packed2)); + const VU16 top7 = ShiftRight<3>(AndNot(mask, packed3)); + + // Insert the lower 2 bits, which were concatenated into a byte. + const VU16 lo2 = Set(d16, 0x0303u); + const VU16 raw4 = OrAnd(top4, lo2, packed4); + const VU16 raw5 = OrAnd(top5, lo2, ShiftRight<2>(packed4)); + const VU16 raw6 = OrAnd(top6, lo2, ShiftRight<4>(packed4)); + const VU16 raw7 = OrAnd(top7, lo2, ShiftRight<6>(packed4)); + + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<5> + +template <> +struct Pack8<6> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 hi2 = Set(d16, 0xC0C0u); + // Each triplet of these stores raw3/raw7 (6 bits) in the upper 2 bits. + const VU16 packed0 = OrAnd(raw0, ShiftLeft<2>(raw3), hi2); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<4>(raw3), hi2); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<6>(raw3), hi2); + const VU16 packed3 = OrAnd(raw4, ShiftLeft<2>(raw7), hi2); + const VU16 packed4 = OrAnd(raw5, ShiftLeft<4>(raw7), hi2); + const VU16 packed5 = OrAnd(raw6, ShiftLeft<6>(raw7), hi2); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + StoreU(BitCast(d8, packed5), d8, packed_out + 5 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x3F3Fu); // Lowest 6 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + const VU16 packed5 = BitCast(d16, LoadU(d8, packed_in + 5 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw4 = And(packed3, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed4, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(packed5, mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + // raw3/7 are the concatenation of the upper two bits in packed0..2. + const VU16 raw3 = Xor3(ShiftRight<6>(AndNot(mask, packed2)), + ShiftRight<4>(AndNot(mask, packed1)), + ShiftRight<2>(AndNot(mask, packed0))); + const VU16 raw7 = Xor3(ShiftRight<6>(AndNot(mask, packed5)), + ShiftRight<4>(AndNot(mask, packed4)), + ShiftRight<2>(AndNot(mask, packed3))); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<6> + +template <> +struct Pack8<7> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + // Inserted into top bit of packed0..6. + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 hi1 = Set(d16, 0x8080u); + const VU16 packed0 = OrAnd(raw0, Add(raw7, raw7), hi1); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<2>(raw7), hi1); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<3>(raw7), hi1); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<4>(raw7), hi1); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<5>(raw7), hi1); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<6>(raw7), hi1); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<7>(raw7), hi1); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + StoreU(BitCast(d8, packed5), d8, packed_out + 5 * N8); + StoreU(BitCast(d8, packed6), d8, packed_out + 6 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide<decltype(d8)> d16; + using VU16 = Vec<decltype(d16)>; + const size_t N8 = Lanes(d8); + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + const VU16 packed5 = BitCast(d16, LoadU(d8, packed_in + 5 * N8)); + const VU16 packed6 = BitCast(d16, LoadU(d8, packed_in + 6 * N8)); + + const VU16 mask = Set(d16, 0x7F7Fu); // Lowest 7 bits per byte + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(packed3, mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(packed4, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed5, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(packed6, mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 p0 = Xor3(ShiftRight<7>(AndNot(mask, packed6)), + ShiftRight<6>(AndNot(mask, packed5)), + ShiftRight<5>(AndNot(mask, packed4))); + const VU16 p1 = Xor3(ShiftRight<4>(AndNot(mask, packed3)), + ShiftRight<3>(AndNot(mask, packed2)), + ShiftRight<2>(AndNot(mask, packed1))); + const VU16 raw7 = Xor3(ShiftRight<1>(AndNot(mask, packed0)), p0, p1); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<7> + +template <> +struct Pack8<8> { + template <class D8> + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + using VU8 = Vec<decltype(d8)>; + const size_t N8 = Lanes(d8); + const VU8 raw0 = LoadU(d8, raw + 0 * N8); + const VU8 raw1 = LoadU(d8, raw + 1 * N8); + const VU8 raw2 = LoadU(d8, raw + 2 * N8); + const VU8 raw3 = LoadU(d8, raw + 3 * N8); + const VU8 raw4 = LoadU(d8, raw + 4 * N8); + const VU8 raw5 = LoadU(d8, raw + 5 * N8); + const VU8 raw6 = LoadU(d8, raw + 6 * N8); + const VU8 raw7 = LoadU(d8, raw + 7 * N8); + + StoreU(raw0, d8, packed_out + 0 * N8); + StoreU(raw1, d8, packed_out + 1 * N8); + StoreU(raw2, d8, packed_out + 2 * N8); + StoreU(raw3, d8, packed_out + 3 * N8); + StoreU(raw4, d8, packed_out + 4 * N8); + StoreU(raw5, d8, packed_out + 5 * N8); + StoreU(raw6, d8, packed_out + 6 * N8); + StoreU(raw7, d8, packed_out + 7 * N8); + } + + template <class D8> + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + using VU8 = Vec<decltype(d8)>; + const size_t N8 = Lanes(d8); + const VU8 raw0 = LoadU(d8, packed_in + 0 * N8); + const VU8 raw1 = LoadU(d8, packed_in + 1 * N8); + const VU8 raw2 = LoadU(d8, packed_in + 2 * N8); + const VU8 raw3 = LoadU(d8, packed_in + 3 * N8); + const VU8 raw4 = LoadU(d8, packed_in + 4 * N8); + const VU8 raw5 = LoadU(d8, packed_in + 5 * N8); + const VU8 raw6 = LoadU(d8, packed_in + 6 * N8); + const VU8 raw7 = LoadU(d8, packed_in + 7 * N8); + + StoreU(raw0, d8, raw + 0 * N8); + StoreU(raw1, d8, raw + 1 * N8); + StoreU(raw2, d8, raw + 2 * N8); + StoreU(raw3, d8, raw + 3 * N8); + StoreU(raw4, d8, raw + 4 * N8); + StoreU(raw5, d8, raw + 5 * N8); + StoreU(raw6, d8, raw + 6 * N8); + StoreU(raw7, d8, raw + 7 * N8); + } +}; // Pack8<8> + +template <> +struct Pack16<1> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 p0 = Xor3(ShiftLeft<2>(raw2), Add(raw1, raw1), raw0); + const VU16 p1 = + Xor3(ShiftLeft<5>(raw5), ShiftLeft<4>(raw4), ShiftLeft<3>(raw3)); + const VU16 p2 = + Xor3(ShiftLeft<8>(raw8), ShiftLeft<7>(raw7), ShiftLeft<6>(raw6)); + const VU16 p3 = + Xor3(ShiftLeft<0xB>(rawB), ShiftLeft<0xA>(rawA), ShiftLeft<9>(raw9)); + const VU16 p4 = + Xor3(ShiftLeft<0xE>(rawE), ShiftLeft<0xD>(rawD), ShiftLeft<0xC>(rawC)); + const VU16 packed = + Or(Xor3(ShiftLeft<0xF>(rawF), p0, p1), Xor3(p2, p3, p4)); + StoreU(packed, d, packed_out); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 1u); // Lowest bit + + const VU16 packed = LoadU(d, packed_in); + + const VU16 raw0 = And(packed, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(ShiftRight<1>(packed), mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<2>(packed), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<3>(packed), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<4>(packed), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<5>(packed), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<7>(packed), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<8>(packed), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<9>(packed), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<0xA>(packed), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<0xB>(packed), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<0xC>(packed), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<0xD>(packed), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<0xE>(packed), mask); + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<0xF>(packed); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<1> + +template <> +struct Pack16<2> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + VU16 packed0 = Xor3(ShiftLeft<4>(raw4), ShiftLeft<2>(raw2), raw0); + VU16 packed1 = Xor3(ShiftLeft<4>(raw5), ShiftLeft<2>(raw3), raw1); + packed0 = Xor3(packed0, ShiftLeft<8>(raw8), ShiftLeft<6>(raw6)); + packed1 = Xor3(packed1, ShiftLeft<8>(raw9), ShiftLeft<6>(raw7)); + + packed0 = Xor3(packed0, ShiftLeft<12>(rawC), ShiftLeft<10>(rawA)); + packed1 = Xor3(packed1, ShiftLeft<12>(rawD), ShiftLeft<10>(rawB)); + + packed0 = Or(packed0, ShiftLeft<14>(rawE)); + packed1 = Or(packed1, ShiftLeft<14>(rawF)); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x3u); // Lowest 2 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<2>(packed0), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<2>(packed1), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<4>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<4>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed0), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<6>(packed1), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<8>(packed0), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<8>(packed1), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<0xA>(packed0), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<0xA>(packed1), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<0xC>(packed0), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<0xC>(packed1), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<0xE>(packed0); + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<0xE>(packed1); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<2> + +template <> +struct Pack16<3> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // We can fit 15 raw vectors in three packed vectors (five each). + VU16 packed0 = Xor3(ShiftLeft<6>(raw6), ShiftLeft<3>(raw3), raw0); + VU16 packed1 = Xor3(ShiftLeft<6>(raw7), ShiftLeft<3>(raw4), raw1); + VU16 packed2 = Xor3(ShiftLeft<6>(raw8), ShiftLeft<3>(raw5), raw2); + + // rawF will be scattered into the upper bit of these three. + packed0 = Xor3(packed0, ShiftLeft<12>(rawC), ShiftLeft<9>(raw9)); + packed1 = Xor3(packed1, ShiftLeft<12>(rawD), ShiftLeft<9>(rawA)); + packed2 = Xor3(packed2, ShiftLeft<12>(rawE), ShiftLeft<9>(rawB)); + + const VU16 hi1 = Set(d, 0x8000u); + packed0 = Or(packed0, ShiftLeft<15>(rawF)); // MSB only, no mask + packed1 = OrAnd(packed1, ShiftLeft<14>(rawF), hi1); + packed2 = OrAnd(packed2, ShiftLeft<13>(rawF), hi1); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x7u); // Lowest 3 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + + const VU16 raw0 = And(mask, packed0); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(mask, packed1); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(mask, packed2); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(mask, ShiftRight<3>(packed0)); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(mask, ShiftRight<3>(packed1)); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(mask, ShiftRight<3>(packed2)); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(mask, ShiftRight<6>(packed0)); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(mask, ShiftRight<6>(packed1)); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(mask, ShiftRight<6>(packed2)); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(mask, ShiftRight<9>(packed0)); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(mask, ShiftRight<9>(packed1)); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(mask, ShiftRight<9>(packed2)); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(mask, ShiftRight<12>(packed0)); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(mask, ShiftRight<12>(packed1)); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(mask, ShiftRight<12>(packed2)); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the upper bit of packed0..2. + const VU16 down0 = ShiftRight<15>(packed0); + const VU16 down1 = ShiftRight<15>(packed1); + const VU16 down2 = ShiftRight<15>(packed2); + const VU16 rawF = Xor3(ShiftLeft<2>(down2), Add(down1, down1), down0); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<3> + +template <> +struct Pack16<4> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + VU16 packed0 = Xor3(ShiftLeft<8>(raw4), ShiftLeft<4>(raw2), raw0); + VU16 packed1 = Xor3(ShiftLeft<8>(raw5), ShiftLeft<4>(raw3), raw1); + packed0 = Or(packed0, ShiftLeft<12>(raw6)); + packed1 = Or(packed1, ShiftLeft<12>(raw7)); + VU16 packed2 = Xor3(ShiftLeft<8>(rawC), ShiftLeft<4>(rawA), raw8); + VU16 packed3 = Xor3(ShiftLeft<8>(rawD), ShiftLeft<4>(rawB), raw9); + packed2 = Or(packed2, ShiftLeft<12>(rawE)); + packed3 = Or(packed3, ShiftLeft<12>(rawF)); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0xFu); // Lowest 4 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed3 = LoadU(d, packed_in + 3 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<4>(packed0), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<4>(packed1), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<8>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<8>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = ShiftRight<12>(packed0); // no mask required + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = ShiftRight<12>(packed1); // no mask required + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed2, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed3, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<4>(packed2), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<4>(packed3), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<8>(packed2), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<8>(packed3), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<12>(packed2); // no mask required + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<12>(packed3); // no mask required + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<4> + +template <> +struct Pack16<5> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // We can fit 15 raw vectors in five packed vectors (three each). + VU16 packed0 = Xor3(ShiftLeft<10>(rawA), ShiftLeft<5>(raw5), raw0); + VU16 packed1 = Xor3(ShiftLeft<10>(rawB), ShiftLeft<5>(raw6), raw1); + VU16 packed2 = Xor3(ShiftLeft<10>(rawC), ShiftLeft<5>(raw7), raw2); + VU16 packed3 = Xor3(ShiftLeft<10>(rawD), ShiftLeft<5>(raw8), raw3); + VU16 packed4 = Xor3(ShiftLeft<10>(rawE), ShiftLeft<5>(raw9), raw4); + + // rawF will be scattered into the upper bits of these five. + const VU16 hi1 = Set(d, 0x8000u); + packed0 = Or(packed0, ShiftLeft<15>(rawF)); // MSB only, no mask + packed1 = OrAnd(packed1, ShiftLeft<14>(rawF), hi1); + packed2 = OrAnd(packed2, ShiftLeft<13>(rawF), hi1); + packed3 = OrAnd(packed3, ShiftLeft<12>(rawF), hi1); + packed4 = OrAnd(packed4, ShiftLeft<11>(rawF), hi1); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed3 = LoadU(d, packed_in + 3 * N); + const VU16 packed4 = LoadU(d, packed_in + 4 * N); + + const VU16 mask = Set(d, 0x1Fu); // Lowest 5 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<5>(packed0), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<5>(packed1), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<5>(packed2), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<5>(packed3), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<5>(packed4), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<10>(packed0), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<10>(packed1), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<10>(packed2), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<10>(packed3), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<10>(packed4), mask); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the lower bit of packed0..4. + const VU16 down0 = ShiftRight<15>(packed0); + const VU16 down1 = ShiftRight<15>(packed1); + const VU16 hi1 = Set(d, 0x8000u); + const VU16 p0 = + Xor3(ShiftRight<13>(And(packed2, hi1)), Add(down1, down1), down0); + const VU16 rawF = Xor3(ShiftRight<11>(And(packed4, hi1)), + ShiftRight<12>(And(packed3, hi1)), p0); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<5> + +template <> +struct Pack16<6> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 packed3 = Or(ShiftLeft<6>(raw7), raw3); + const VU16 packed7 = Or(ShiftLeft<6>(rawF), rawB); + // Three vectors, two 6-bit raw each; packed3 (12 bits) is spread over the + // four remainder bits at the top of each vector. + const VU16 packed0 = Xor3(ShiftLeft<12>(packed3), ShiftLeft<6>(raw4), raw0); + VU16 packed1 = Or(ShiftLeft<6>(raw5), raw1); + VU16 packed2 = Or(ShiftLeft<6>(raw6), raw2); + const VU16 packed4 = Xor3(ShiftLeft<12>(packed7), ShiftLeft<6>(rawC), raw8); + VU16 packed5 = Or(ShiftLeft<6>(rawD), raw9); + VU16 packed6 = Or(ShiftLeft<6>(rawE), rawA); + + const VU16 hi4 = Set(d, 0xF000u); + packed1 = OrAnd(packed1, ShiftLeft<8>(packed3), hi4); + packed2 = OrAnd(packed2, ShiftLeft<4>(packed3), hi4); + packed5 = OrAnd(packed5, ShiftLeft<8>(packed7), hi4); + packed6 = OrAnd(packed6, ShiftLeft<4>(packed7), hi4); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed4, d, packed_out + 3 * N); + StoreU(packed5, d, packed_out + 4 * N); + StoreU(packed6, d, packed_out + 5 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x3Fu); // Lowest 6 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed4 = LoadU(d, packed_in + 3 * N); + const VU16 packed5 = LoadU(d, packed_in + 4 * N); + const VU16 packed6 = LoadU(d, packed_in + 5 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw4 = And(ShiftRight<6>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<6>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed2), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw8 = And(packed4, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed5, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packed6, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawC = And(ShiftRight<6>(packed4), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<6>(packed5), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<6>(packed6), mask); + StoreU(rawE, d, raw + 0xE * N); + + // packed3 is the concatenation of the four upper bits in packed0..2. + const VU16 down0 = ShiftRight<12>(packed0); + const VU16 down4 = ShiftRight<12>(packed4); + const VU16 hi4 = Set(d, 0xF000u); + const VU16 packed3 = Xor3(ShiftRight<4>(And(packed2, hi4)), + ShiftRight<8>(And(packed1, hi4)), down0); + const VU16 packed7 = Xor3(ShiftRight<4>(And(packed6, hi4)), + ShiftRight<8>(And(packed5, hi4)), down4); + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 rawB = And(packed7, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 raw7 = ShiftRight<6>(packed3); // upper bits already zero + StoreU(raw7, d, raw + 7 * N); + + const VU16 rawF = ShiftRight<6>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<6> + +template <> +struct Pack16<7> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 packed7 = Or(ShiftLeft<7>(rawF), raw7); + // Seven vectors, two 7-bit raw each; packed7 (14 bits) is spread over the + // two remainder bits at the top of each vector. + const VU16 packed0 = Xor3(ShiftLeft<14>(packed7), ShiftLeft<7>(raw8), raw0); + VU16 packed1 = Or(ShiftLeft<7>(raw9), raw1); + VU16 packed2 = Or(ShiftLeft<7>(rawA), raw2); + VU16 packed3 = Or(ShiftLeft<7>(rawB), raw3); + VU16 packed4 = Or(ShiftLeft<7>(rawC), raw4); + VU16 packed5 = Or(ShiftLeft<7>(rawD), raw5); + VU16 packed6 = Or(ShiftLeft<7>(rawE), raw6); + + const VU16 hi2 = Set(d, 0xC000u); + packed1 = OrAnd(packed1, ShiftLeft<12>(packed7), hi2); + packed2 = OrAnd(packed2, ShiftLeft<10>(packed7), hi2); + packed3 = OrAnd(packed3, ShiftLeft<8>(packed7), hi2); + packed4 = OrAnd(packed4, ShiftLeft<6>(packed7), hi2); + packed5 = OrAnd(packed5, ShiftLeft<4>(packed7), hi2); + packed6 = OrAnd(packed6, ShiftLeft<2>(packed7), hi2); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + + const VU16 mask = Set(d, 0x7Fu); // Lowest 7 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw8 = And(ShiftRight<7>(packed0), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<7>(packed1), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<7>(packed2), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<7>(packed3), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<7>(packed4), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<7>(packed5), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<7>(packed6), mask); + StoreU(rawE, d, raw + 0xE * N); + + // packed7 is the concatenation of the two upper bits in packed0..6. + const VU16 down0 = ShiftRight<14>(packed0); + const VU16 hi2 = Set(d, 0xC000u); + const VU16 p0 = Xor3(ShiftRight<12>(And(packed1, hi2)), + ShiftRight<10>(And(packed2, hi2)), down0); + const VU16 p1 = Xor3(ShiftRight<8>(And(packed3, hi2)), // + ShiftRight<6>(And(packed4, hi2)), + ShiftRight<4>(And(packed5, hi2))); + const VU16 packed7 = Xor3(ShiftRight<2>(And(packed6, hi2)), p1, p0); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 rawF = ShiftRight<7>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<7> + +template <> +struct Pack16<8> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // This is equivalent to ConcatEven with 8-bit lanes, but much more + // efficient on RVV and slightly less efficient on SVE2. + const VU16 packed0 = Or(ShiftLeft<8>(raw2), raw0); + const VU16 packed1 = Or(ShiftLeft<8>(raw3), raw1); + const VU16 packed2 = Or(ShiftLeft<8>(raw6), raw4); + const VU16 packed3 = Or(ShiftLeft<8>(raw7), raw5); + const VU16 packed4 = Or(ShiftLeft<8>(rawA), raw8); + const VU16 packed5 = Or(ShiftLeft<8>(rawB), raw9); + const VU16 packed6 = Or(ShiftLeft<8>(rawE), rawC); + const VU16 packed7 = Or(ShiftLeft<8>(rawF), rawD); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = ShiftRight<8>(packed0); // upper bits already zero + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = ShiftRight<8>(packed1); // upper bits already zero + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed2, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed3, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = ShiftRight<8>(packed2); // upper bits already zero + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = ShiftRight<8>(packed3); // upper bits already zero + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed4, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed5, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = ShiftRight<8>(packed4); // upper bits already zero + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = ShiftRight<8>(packed5); // upper bits already zero + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packed6, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packed7, mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<8>(packed6); // upper bits already zero + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<8>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<8> + +template <> +struct Pack16<9> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + // 8 vectors, each with 9+7 bits; top 2 bits are concatenated into packed8. + const VU16 packed0 = Or(ShiftLeft<9>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<9>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<9>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<9>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<9>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<9>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<9>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<9>(rawF), raw7); + + // We could shift down, OR and shift up, but two shifts are typically more + // expensive than AND, shift into position, and OR (which can be further + // reduced via Xor3). + const VU16 mid2 = Set(d, 0x180u); // top 2 in lower 9 + const VU16 part8 = ShiftRight<7>(And(raw8, mid2)); + const VU16 part9 = ShiftRight<5>(And(raw9, mid2)); + const VU16 partA = ShiftRight<3>(And(rawA, mid2)); + const VU16 partB = ShiftRight<1>(And(rawB, mid2)); + const VU16 partC = ShiftLeft<1>(And(rawC, mid2)); + const VU16 partD = ShiftLeft<3>(And(rawD, mid2)); + const VU16 partE = ShiftLeft<5>(And(rawE, mid2)); + const VU16 partF = ShiftLeft<7>(And(rawF, mid2)); + const VU16 packed8 = Xor3(Xor3(part8, part9, partA), + Xor3(partB, partC, partD), Or(partE, partF)); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + + const VU16 mask = Set(d, 0x1FFu); // Lowest 9 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid2 = Set(d, 0x180u); // top 2 in lower 9 + const VU16 raw8 = + OrAnd(ShiftRight<9>(packed0), ShiftLeft<7>(packed8), mid2); + const VU16 raw9 = + OrAnd(ShiftRight<9>(packed1), ShiftLeft<5>(packed8), mid2); + const VU16 rawA = + OrAnd(ShiftRight<9>(packed2), ShiftLeft<3>(packed8), mid2); + const VU16 rawB = + OrAnd(ShiftRight<9>(packed3), ShiftLeft<1>(packed8), mid2); + const VU16 rawC = + OrAnd(ShiftRight<9>(packed4), ShiftRight<1>(packed8), mid2); + const VU16 rawD = + OrAnd(ShiftRight<9>(packed5), ShiftRight<3>(packed8), mid2); + const VU16 rawE = + OrAnd(ShiftRight<9>(packed6), ShiftRight<5>(packed8), mid2); + const VU16 rawF = + OrAnd(ShiftRight<9>(packed7), ShiftRight<7>(packed8), mid2); + + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<9> + +template <> +struct Pack16<10> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 8 vectors, each with 10+6 bits; top 4 bits are concatenated into + // packed8 and packed9. + const VU16 packed0 = Or(ShiftLeft<10>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<10>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<10>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<10>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<10>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<10>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<10>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<10>(rawF), raw7); + + // We could shift down, OR and shift up, but two shifts are typically more + // expensive than AND, shift into position, and OR (which can be further + // reduced via Xor3). + const VU16 mid4 = Set(d, 0x3C0u); // top 4 in lower 10 + const VU16 part8 = ShiftRight<6>(And(raw8, mid4)); + const VU16 part9 = ShiftRight<2>(And(raw9, mid4)); + const VU16 partA = ShiftLeft<2>(And(rawA, mid4)); + const VU16 partB = ShiftLeft<6>(And(rawB, mid4)); + const VU16 partC = ShiftRight<6>(And(rawC, mid4)); + const VU16 partD = ShiftRight<2>(And(rawD, mid4)); + const VU16 partE = ShiftLeft<2>(And(rawE, mid4)); + const VU16 partF = ShiftLeft<6>(And(rawF, mid4)); + const VU16 packed8 = Or(Xor3(part8, part9, partA), partB); + const VU16 packed9 = Or(Xor3(partC, partD, partE), partF); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + + const VU16 mask = Set(d, 0x3FFu); // Lowest 10 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid4 = Set(d, 0x3C0u); // top 4 in lower 10 + const VU16 raw8 = + OrAnd(ShiftRight<10>(packed0), ShiftLeft<6>(packed8), mid4); + const VU16 raw9 = + OrAnd(ShiftRight<10>(packed1), ShiftLeft<2>(packed8), mid4); + const VU16 rawA = + OrAnd(ShiftRight<10>(packed2), ShiftRight<2>(packed8), mid4); + const VU16 rawB = + OrAnd(ShiftRight<10>(packed3), ShiftRight<6>(packed8), mid4); + const VU16 rawC = + OrAnd(ShiftRight<10>(packed4), ShiftLeft<6>(packed9), mid4); + const VU16 rawD = + OrAnd(ShiftRight<10>(packed5), ShiftLeft<2>(packed9), mid4); + const VU16 rawE = + OrAnd(ShiftRight<10>(packed6), ShiftRight<2>(packed9), mid4); + const VU16 rawF = + OrAnd(ShiftRight<10>(packed7), ShiftRight<6>(packed9), mid4); + + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<10> + +template <> +struct Pack16<11> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // It is not obvious what the optimal partitioning looks like. To reduce the + // number of constants, we want to minimize the number of distinct bit + // lengths. 11+5 also requires 6-bit remnants with 4-bit leftovers. + // 8+3 seems better: it is easier to scatter 3 bits into the MSBs. + const VU16 lo8 = Set(d, 0xFFu); + + // Lower 8 bits of all raw + const VU16 packed0 = OrAnd(ShiftLeft<8>(raw1), raw0, lo8); + const VU16 packed1 = OrAnd(ShiftLeft<8>(raw3), raw2, lo8); + const VU16 packed2 = OrAnd(ShiftLeft<8>(raw5), raw4, lo8); + const VU16 packed3 = OrAnd(ShiftLeft<8>(raw7), raw6, lo8); + const VU16 packed4 = OrAnd(ShiftLeft<8>(raw9), raw8, lo8); + const VU16 packed5 = OrAnd(ShiftLeft<8>(rawB), rawA, lo8); + const VU16 packed6 = OrAnd(ShiftLeft<8>(rawD), rawC, lo8); + const VU16 packed7 = OrAnd(ShiftLeft<8>(rawF), rawE, lo8); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + + // Three vectors, five 3bit remnants each, plus one 3bit in their MSB. + const VU16 top0 = ShiftRight<8>(raw0); + const VU16 top1 = ShiftRight<8>(raw1); + const VU16 top2 = ShiftRight<8>(raw2); + // Insert top raw bits into 3-bit groups within packed8..A. Moving the + // mask along avoids masking each of raw0..E and enables OrAnd. + VU16 next = Set(d, 0x38u); // 0x7 << 3 + VU16 packed8 = OrAnd(top0, ShiftRight<5>(raw3), next); + VU16 packed9 = OrAnd(top1, ShiftRight<5>(raw4), next); + VU16 packedA = OrAnd(top2, ShiftRight<5>(raw5), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, ShiftRight<2>(raw6), next); + packed9 = OrAnd(packed9, ShiftRight<2>(raw7), next); + packedA = OrAnd(packedA, ShiftRight<2>(raw8), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, Add(raw9, raw9), next); + packed9 = OrAnd(packed9, Add(rawA, rawA), next); + packedA = OrAnd(packedA, Add(rawB, rawB), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, ShiftLeft<4>(rawC), next); + packed9 = OrAnd(packed9, ShiftLeft<4>(rawD), next); + packedA = OrAnd(packedA, ShiftLeft<4>(rawE), next); + + // Scatter upper 3 bits of rawF into the upper bits. + next = ShiftLeft<3>(next); // = 0x8000u + packed8 = OrAnd(packed8, ShiftLeft<7>(rawF), next); + packed9 = OrAnd(packed9, ShiftLeft<6>(rawF), next); + packedA = OrAnd(packedA, ShiftLeft<5>(rawF), next); + + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 down0 = And(packed0, mask); + const VU16 down1 = ShiftRight<8>(packed0); + const VU16 down2 = And(packed1, mask); + const VU16 down3 = ShiftRight<8>(packed1); + const VU16 down4 = And(packed2, mask); + const VU16 down5 = ShiftRight<8>(packed2); + const VU16 down6 = And(packed3, mask); + const VU16 down7 = ShiftRight<8>(packed3); + const VU16 down8 = And(packed4, mask); + const VU16 down9 = ShiftRight<8>(packed4); + const VU16 downA = And(packed5, mask); + const VU16 downB = ShiftRight<8>(packed5); + const VU16 downC = And(packed6, mask); + const VU16 downD = ShiftRight<8>(packed6); + const VU16 downE = And(packed7, mask); + const VU16 downF = ShiftRight<8>(packed7); + + // Three bits from packed8..A, eight bits from down0..F. + const VU16 hi3 = Set(d, 0x700u); + const VU16 raw0 = OrAnd(down0, ShiftLeft<8>(packed8), hi3); + const VU16 raw1 = OrAnd(down1, ShiftLeft<8>(packed9), hi3); + const VU16 raw2 = OrAnd(down2, ShiftLeft<8>(packedA), hi3); + + const VU16 raw3 = OrAnd(down3, ShiftLeft<5>(packed8), hi3); + const VU16 raw4 = OrAnd(down4, ShiftLeft<5>(packed9), hi3); + const VU16 raw5 = OrAnd(down5, ShiftLeft<5>(packedA), hi3); + + const VU16 raw6 = OrAnd(down6, ShiftLeft<2>(packed8), hi3); + const VU16 raw7 = OrAnd(down7, ShiftLeft<2>(packed9), hi3); + const VU16 raw8 = OrAnd(down8, ShiftLeft<2>(packedA), hi3); + + const VU16 raw9 = OrAnd(down9, ShiftRight<1>(packed8), hi3); + const VU16 rawA = OrAnd(downA, ShiftRight<1>(packed9), hi3); + const VU16 rawB = OrAnd(downB, ShiftRight<1>(packedA), hi3); + + const VU16 rawC = OrAnd(downC, ShiftRight<4>(packed8), hi3); + const VU16 rawD = OrAnd(downD, ShiftRight<4>(packed9), hi3); + const VU16 rawE = OrAnd(downE, ShiftRight<4>(packedA), hi3); + + // Shift MSB into the top 3-of-11 and mask. + const VU16 rawF = Or(downF, Xor3(And(ShiftRight<7>(packed8), hi3), + And(ShiftRight<6>(packed9), hi3), + And(ShiftRight<5>(packedA), hi3))); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<11> + +template <> +struct Pack16<12> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 8 vectors, each with 12+4 bits; top 8 bits are concatenated into + // packed8 to packedB. + const VU16 packed0 = Or(ShiftLeft<12>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<12>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<12>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<12>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<12>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<12>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<12>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<12>(rawF), raw7); + + // Masking after shifting left enables OrAnd. + const VU16 hi8 = Set(d, 0xFF00u); + const VU16 packed8 = OrAnd(ShiftRight<4>(raw8), ShiftLeft<4>(raw9), hi8); + const VU16 packed9 = OrAnd(ShiftRight<4>(rawA), ShiftLeft<4>(rawB), hi8); + const VU16 packedA = OrAnd(ShiftRight<4>(rawC), ShiftLeft<4>(rawD), hi8); + const VU16 packedB = OrAnd(ShiftRight<4>(rawE), ShiftLeft<4>(rawF), hi8); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + + const VU16 mask = Set(d, 0xFFFu); // Lowest 12 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid8 = Set(d, 0xFF0u); // upper 8 in lower 12 + const VU16 raw8 = + OrAnd(ShiftRight<12>(packed0), ShiftLeft<4>(packed8), mid8); + const VU16 raw9 = + OrAnd(ShiftRight<12>(packed1), ShiftRight<4>(packed8), mid8); + const VU16 rawA = + OrAnd(ShiftRight<12>(packed2), ShiftLeft<4>(packed9), mid8); + const VU16 rawB = + OrAnd(ShiftRight<12>(packed3), ShiftRight<4>(packed9), mid8); + const VU16 rawC = + OrAnd(ShiftRight<12>(packed4), ShiftLeft<4>(packedA), mid8); + const VU16 rawD = + OrAnd(ShiftRight<12>(packed5), ShiftRight<4>(packedA), mid8); + const VU16 rawE = + OrAnd(ShiftRight<12>(packed6), ShiftLeft<4>(packedB), mid8); + const VU16 rawF = + OrAnd(ShiftRight<12>(packed7), ShiftRight<4>(packedB), mid8); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<12> + +template <> +struct Pack16<13> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // As with 11 bits, it is not obvious what the optimal partitioning looks + // like. We similarly go with an 8+5 split. + const VU16 lo8 = Set(d, 0xFFu); + + // Lower 8 bits of all raw + const VU16 packed0 = OrAnd(ShiftLeft<8>(raw1), raw0, lo8); + const VU16 packed1 = OrAnd(ShiftLeft<8>(raw3), raw2, lo8); + const VU16 packed2 = OrAnd(ShiftLeft<8>(raw5), raw4, lo8); + const VU16 packed3 = OrAnd(ShiftLeft<8>(raw7), raw6, lo8); + const VU16 packed4 = OrAnd(ShiftLeft<8>(raw9), raw8, lo8); + const VU16 packed5 = OrAnd(ShiftLeft<8>(rawB), rawA, lo8); + const VU16 packed6 = OrAnd(ShiftLeft<8>(rawD), rawC, lo8); + const VU16 packed7 = OrAnd(ShiftLeft<8>(rawF), rawE, lo8); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + + // Five vectors, three 5bit remnants each, plus one 5bit in their MSB. + const VU16 top0 = ShiftRight<8>(raw0); + const VU16 top1 = ShiftRight<8>(raw1); + const VU16 top2 = ShiftRight<8>(raw2); + const VU16 top3 = ShiftRight<8>(raw3); + const VU16 top4 = ShiftRight<8>(raw4); + + // Insert top raw bits into 5-bit groups within packed8..C. Moving the + // mask along avoids masking each of raw0..E and enables OrAnd. + VU16 next = Set(d, 0x3E0u); // 0x1F << 5 + VU16 packed8 = OrAnd(top0, ShiftRight<3>(raw5), next); + VU16 packed9 = OrAnd(top1, ShiftRight<3>(raw6), next); + VU16 packedA = OrAnd(top2, ShiftRight<3>(raw7), next); + VU16 packedB = OrAnd(top3, ShiftRight<3>(raw8), next); + VU16 packedC = OrAnd(top4, ShiftRight<3>(raw9), next); + next = ShiftLeft<5>(next); + packed8 = OrAnd(packed8, ShiftLeft<2>(rawA), next); + packed9 = OrAnd(packed9, ShiftLeft<2>(rawB), next); + packedA = OrAnd(packedA, ShiftLeft<2>(rawC), next); + packedB = OrAnd(packedB, ShiftLeft<2>(rawD), next); + packedC = OrAnd(packedC, ShiftLeft<2>(rawE), next); + + // Scatter upper 5 bits of rawF into the upper bits. + next = ShiftLeft<3>(next); // = 0x8000u + packed8 = OrAnd(packed8, ShiftLeft<7>(rawF), next); + packed9 = OrAnd(packed9, ShiftLeft<6>(rawF), next); + packedA = OrAnd(packedA, ShiftLeft<5>(rawF), next); + packedB = OrAnd(packedB, ShiftLeft<4>(rawF), next); + packedC = OrAnd(packedC, ShiftLeft<3>(rawF), next); + + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 down0 = And(packed0, mask); + const VU16 down1 = ShiftRight<8>(packed0); + const VU16 down2 = And(packed1, mask); + const VU16 down3 = ShiftRight<8>(packed1); + const VU16 down4 = And(packed2, mask); + const VU16 down5 = ShiftRight<8>(packed2); + const VU16 down6 = And(packed3, mask); + const VU16 down7 = ShiftRight<8>(packed3); + const VU16 down8 = And(packed4, mask); + const VU16 down9 = ShiftRight<8>(packed4); + const VU16 downA = And(packed5, mask); + const VU16 downB = ShiftRight<8>(packed5); + const VU16 downC = And(packed6, mask); + const VU16 downD = ShiftRight<8>(packed6); + const VU16 downE = And(packed7, mask); + const VU16 downF = ShiftRight<8>(packed7); + + // Upper five bits from packed8..C, eight bits from down0..F. + const VU16 hi5 = Set(d, 0x1F00u); + const VU16 raw0 = OrAnd(down0, ShiftLeft<8>(packed8), hi5); + const VU16 raw1 = OrAnd(down1, ShiftLeft<8>(packed9), hi5); + const VU16 raw2 = OrAnd(down2, ShiftLeft<8>(packedA), hi5); + const VU16 raw3 = OrAnd(down3, ShiftLeft<8>(packedB), hi5); + const VU16 raw4 = OrAnd(down4, ShiftLeft<8>(packedC), hi5); + + const VU16 raw5 = OrAnd(down5, ShiftLeft<3>(packed8), hi5); + const VU16 raw6 = OrAnd(down6, ShiftLeft<3>(packed9), hi5); + const VU16 raw7 = OrAnd(down7, ShiftLeft<3>(packedA), hi5); + const VU16 raw8 = OrAnd(down8, ShiftLeft<3>(packed9), hi5); + const VU16 raw9 = OrAnd(down9, ShiftLeft<3>(packedA), hi5); + + const VU16 rawA = OrAnd(downA, ShiftRight<2>(packed8), hi5); + const VU16 rawB = OrAnd(downB, ShiftRight<2>(packed9), hi5); + const VU16 rawC = OrAnd(downC, ShiftRight<2>(packedA), hi5); + const VU16 rawD = OrAnd(downD, ShiftRight<2>(packed9), hi5); + const VU16 rawE = OrAnd(downE, ShiftRight<2>(packedA), hi5); + + // Shift MSB into the top 5-of-11 and mask. + const VU16 p0 = Xor3(And(ShiftRight<7>(packed8), hi5), // + And(ShiftRight<6>(packed9), hi5), + And(ShiftRight<5>(packedA), hi5)); + const VU16 p1 = Xor3(And(ShiftRight<4>(packedB), hi5), + And(ShiftRight<3>(packedC), hi5), downF); + const VU16 rawF = Or(p0, p1); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<13> + +template <> +struct Pack16<14> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 14 vectors, each with 14+2 bits; two raw vectors are scattered + // across the upper 2 bits. + const VU16 hi2 = Set(d, 0xC000u); + const VU16 packed0 = Or(raw0, ShiftLeft<14>(rawE)); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<12>(rawE), hi2); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<10>(rawE), hi2); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<8>(rawE), hi2); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<6>(rawE), hi2); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<4>(rawE), hi2); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<2>(rawE), hi2); + const VU16 packed7 = Or(raw7, ShiftLeft<14>(rawF)); + const VU16 packed8 = OrAnd(raw8, ShiftLeft<12>(rawF), hi2); + const VU16 packed9 = OrAnd(raw9, ShiftLeft<10>(rawF), hi2); + const VU16 packedA = OrAnd(rawA, ShiftLeft<8>(rawF), hi2); + const VU16 packedB = OrAnd(rawB, ShiftLeft<6>(rawF), hi2); + const VU16 packedC = OrAnd(rawC, ShiftLeft<4>(rawF), hi2); + const VU16 packedD = OrAnd(rawD, ShiftLeft<2>(rawF), hi2); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + StoreU(packedD, d, packed_out + 0xD * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 packedD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + + const VU16 mask = Set(d, 0x3FFFu); // Lowest 14 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed8, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed9, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packedA, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(packedB, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packedC, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packedD, mask); + StoreU(rawD, d, raw + 0xD * N); + + // rawE is the concatenation of the top two bits in packed0..6. + const VU16 E0 = Xor3(ShiftRight<14>(packed0), // + ShiftRight<12>(AndNot(mask, packed1)), + ShiftRight<10>(AndNot(mask, packed2))); + const VU16 E1 = Xor3(ShiftRight<8>(AndNot(mask, packed3)), + ShiftRight<6>(AndNot(mask, packed4)), + ShiftRight<4>(AndNot(mask, packed5))); + const VU16 rawE = Xor3(ShiftRight<2>(AndNot(mask, packed6)), E0, E1); + const VU16 F0 = Xor3(ShiftRight<14>(AndNot(mask, packed7)), + ShiftRight<12>(AndNot(mask, packed8)), + ShiftRight<10>(AndNot(mask, packed9))); + const VU16 F1 = Xor3(ShiftRight<8>(AndNot(mask, packedA)), + ShiftRight<6>(AndNot(mask, packedB)), + ShiftRight<4>(AndNot(mask, packedC))); + const VU16 rawF = Xor3(ShiftRight<2>(AndNot(mask, packedD)), F0, F1); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<14> + +template <> +struct Pack16<15> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 15 vectors, each with 15+1 bits; one packed vector is scattered + // across the upper bit. + const VU16 hi1 = Set(d, 0x8000u); + const VU16 packed0 = Or(raw0, ShiftLeft<15>(rawF)); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<14>(rawF), hi1); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<13>(rawF), hi1); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<12>(rawF), hi1); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<11>(rawF), hi1); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<10>(rawF), hi1); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<9>(rawF), hi1); + const VU16 packed7 = OrAnd(raw7, ShiftLeft<8>(rawF), hi1); + const VU16 packed8 = OrAnd(raw8, ShiftLeft<7>(rawF), hi1); + const VU16 packed9 = OrAnd(raw9, ShiftLeft<6>(rawF), hi1); + const VU16 packedA = OrAnd(rawA, ShiftLeft<5>(rawF), hi1); + const VU16 packedB = OrAnd(rawB, ShiftLeft<4>(rawF), hi1); + const VU16 packedC = OrAnd(rawC, ShiftLeft<3>(rawF), hi1); + const VU16 packedD = OrAnd(rawD, ShiftLeft<2>(rawF), hi1); + const VU16 packedE = OrAnd(rawE, ShiftLeft<1>(rawF), hi1); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + StoreU(packedD, d, packed_out + 0xD * N); + StoreU(packedE, d, packed_out + 0xE * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 packedD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + const VU16 packedE = BitCast(d, LoadU(d, packed_in + 0xE * N)); + + const VU16 mask = Set(d, 0x7FFFu); // Lowest 15 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed8, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed9, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packedA, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(packedB, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packedC, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packedD, mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(packedE, mask); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the top bit in packed0..E. + const VU16 F0 = Xor3(ShiftRight<15>(packed0), // + ShiftRight<14>(AndNot(mask, packed1)), + ShiftRight<13>(AndNot(mask, packed2))); + const VU16 F1 = Xor3(ShiftRight<12>(AndNot(mask, packed3)), + ShiftRight<11>(AndNot(mask, packed4)), + ShiftRight<10>(AndNot(mask, packed5))); + const VU16 F2 = Xor3(ShiftRight<9>(AndNot(mask, packed6)), + ShiftRight<8>(AndNot(mask, packed7)), + ShiftRight<7>(AndNot(mask, packed8))); + const VU16 F3 = Xor3(ShiftRight<6>(AndNot(mask, packed9)), + ShiftRight<5>(AndNot(mask, packedA)), + ShiftRight<4>(AndNot(mask, packedB))); + const VU16 F4 = Xor3(ShiftRight<3>(AndNot(mask, packedC)), + ShiftRight<2>(AndNot(mask, packedD)), + ShiftRight<1>(AndNot(mask, packedE))); + const VU16 rawF = Xor3(F0, F1, Xor3(F2, F3, F4)); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<15> + +template <> +struct Pack16<16> { + template <class D> + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + StoreU(raw0, d, packed_out + 0 * N); + StoreU(raw1, d, packed_out + 1 * N); + StoreU(raw2, d, packed_out + 2 * N); + StoreU(raw3, d, packed_out + 3 * N); + StoreU(raw4, d, packed_out + 4 * N); + StoreU(raw5, d, packed_out + 5 * N); + StoreU(raw6, d, packed_out + 6 * N); + StoreU(raw7, d, packed_out + 7 * N); + StoreU(raw8, d, packed_out + 8 * N); + StoreU(raw9, d, packed_out + 9 * N); + StoreU(rawA, d, packed_out + 0xA * N); + StoreU(rawB, d, packed_out + 0xB * N); + StoreU(rawC, d, packed_out + 0xC * N); + StoreU(rawD, d, packed_out + 0xD * N); + StoreU(rawE, d, packed_out + 0xE * N); + StoreU(rawF, d, packed_out + 0xF * N); + } + + template <class D> + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec<decltype(d)>; + const size_t N = Lanes(d); + + const VU16 raw0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 raw1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 raw2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 raw3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 raw4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 raw5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 raw6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 raw7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 raw8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 raw9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 rawA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 rawB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 rawC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 rawD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + const VU16 rawE = BitCast(d, LoadU(d, packed_in + 0xE * N)); + const VU16 rawF = BitCast(d, LoadU(d, packed_in + 0xF * N)); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<16> + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ diff --git a/third_party/highway/hwy/contrib/bit_pack/bit_pack_test.cc b/third_party/highway/hwy/contrib/bit_pack/bit_pack_test.cc new file mode 100644 index 0000000000..a239da9cf6 --- /dev/null +++ b/third_party/highway/hwy/contrib/bit_pack/bit_pack_test.cc @@ -0,0 +1,205 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdio.h> + +#include <vector> + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/nanobenchmark.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/bit_pack/bit_pack_test.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/bit_pack/bit_pack-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +#ifndef HWY_BIT_PACK_BENCHMARK +#define HWY_BIT_PACK_BENCHMARK 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +// Used to prevent running benchmark (slow) for partial vectors and targets +// except the best available. Global, not per-target, hence must be outside +// HWY_NAMESPACE. Declare first because HWY_ONCE is only true after some code +// has been re-included. +extern size_t last_bits; +extern uint64_t best_target; +#if HWY_ONCE +size_t last_bits = 0; +uint64_t best_target = ~0ull; +#endif +namespace HWY_NAMESPACE { + +template <size_t kBits, typename T> +T Random(RandomState& rng) { + return static_cast<T>(Random32(&rng) & kBits); +} + +template <typename T> +class Checker { + public: + explicit Checker(size_t num) { raw_.reserve(num); } + void NotifyRaw(T raw) { raw_.push_back(raw); } + + void NotifyRawOutput(size_t bits, T raw) { + if (raw_[num_verified_] != raw) { + HWY_ABORT("%zu bits: pos %zu of %zu, expected %.0f actual %.0f\n", bits, + num_verified_, raw_.size(), + static_cast<double>(raw_[num_verified_]), + static_cast<double>(raw)); + } + ++num_verified_; + } + + private: + std::vector<T> raw_; + size_t num_verified_ = 0; +}; + +template <template <size_t> class PackT, size_t kVectors, size_t kBits> +struct TestPack { + template <typename T, class D> + void operator()(T /* t */, D d) { + constexpr size_t kLoops = 16; // working set slightly larger than L1 + const size_t N = Lanes(d); + RandomState rng(N * 129); + static_assert(kBits <= kVectors, ""); + const size_t num_per_loop = N * kVectors; + const size_t num = num_per_loop * kLoops; + const size_t num_packed_per_loop = N * kBits; + const size_t num_packed = num_packed_per_loop * kLoops; + Checker<T> checker(num); + AlignedFreeUniquePtr<T[]> raw = hwy::AllocateAligned<T>(num); + AlignedFreeUniquePtr<T[]> raw2 = hwy::AllocateAligned<T>(num); + AlignedFreeUniquePtr<T[]> packed = hwy::AllocateAligned<T>(num_packed); + + for (size_t i = 0; i < num; ++i) { + raw[i] = Random<kBits, T>(rng); + checker.NotifyRaw(raw[i]); + } + + best_target = HWY_MIN(best_target, HWY_TARGET); + const bool run_bench = HWY_BIT_PACK_BENCHMARK && (kBits != last_bits) && + (HWY_TARGET == best_target); + last_bits = kBits; + + const PackT<kBits> func; + + if (run_bench) { + const size_t kNumInputs = 1; + const size_t num_items = num * size_t(Unpredictable1()); + const FuncInput inputs[kNumInputs] = {num_items}; + Result results[kNumInputs]; + + Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&](FuncInput) HWY_ATTR { + for (size_t i = 0, pi = 0; i < num; + i += num_per_loop, pi += num_packed_per_loop) { + func.Pack(d, raw.get() + i, packed.get() + pi); + } + packed.get()[Random32(&rng) % num_packed] += Unpredictable1() - 1; + for (size_t i = 0, pi = 0; i < num; + i += num_per_loop, pi += num_packed_per_loop) { + func.Unpack(d, packed.get() + pi, raw2.get() + i); + } + return raw2[Random32(&rng) % num]; + }, + inputs, kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + return; + } + // Print throughput for pack+unpack round trip + for (size_t i = 0; i < num_results; ++i) { + const size_t bytes_per_element = (kBits + 7) / 8; + const double bytes = results[i].input * bytes_per_element; + const double seconds = + results[i].ticks / platform::InvariantTicksPerSecond(); + printf("Bits:%2d elements:%3d GB/s:%4.1f (+/-%3.1f%%)\n", + static_cast<int>(kBits), static_cast<int>(results[i].input), + 1E-9 * bytes / seconds, results[i].variability * 100.0); + } + } else { + for (size_t i = 0, pi = 0; i < num; + i += num_per_loop, pi += num_packed_per_loop) { + func.Pack(d, raw.get() + i, packed.get() + pi); + } + packed.get()[Random32(&rng) % num_packed] += Unpredictable1() - 1; + for (size_t i = 0, pi = 0; i < num; + i += num_per_loop, pi += num_packed_per_loop) { + func.Unpack(d, packed.get() + pi, raw2.get() + i); + } + } + + for (size_t i = 0; i < num; ++i) { + checker.NotifyRawOutput(kBits, raw2[i]); + } + } +}; + +void TestAllPack8() { + ForShrinkableVectors<TestPack<Pack8, 8, 1>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 2>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 3>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 4>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 5>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 6>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 7>>()(uint8_t()); + ForShrinkableVectors<TestPack<Pack8, 8, 8>>()(uint8_t()); +} + +void TestAllPack16() { + ForShrinkableVectors<TestPack<Pack16, 16, 1>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 2>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 3>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 4>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 5>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 6>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 7>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 8>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 9>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 10>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 11>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 12>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 13>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 14>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 15>>()(uint16_t()); + ForShrinkableVectors<TestPack<Pack16, 16, 16>>()(uint16_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(BitPackTest); +HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack8); +HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack16); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/dot/dot-inl.h b/third_party/highway/hwy/contrib/dot/dot-inl.h new file mode 100644 index 0000000000..e04636f1b8 --- /dev/null +++ b/third_party/highway/hwy/contrib/dot/dot-inl.h @@ -0,0 +1,252 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#include <cmath> + +#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct Dot { + // Specify zero or more of these, ORed together, as the kAssumptions template + // argument to Compute. Each one may improve performance or reduce code size, + // at the cost of additional requirements on the arguments. + enum Assumptions { + // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T). + kAtLeastOneVector = 1, + // num_elements is divisible by N (a power of two, so this can be used if + // the problem size is known to be a power of two >= HWY_MAX_BYTES / + // sizeof(T)). + kMultipleOfVector = 2, + // RoundUpTo(num_elements, N) elements are accessible; their value does not + // matter (will be treated as if they were zero). + kPaddedToVector = 4, + }; + + // Returns sum{pa[i] * pb[i]} for float or double inputs. Aligning the + // pointers to a multiple of N elements is helpful but not required. + template <int kAssumptions, class D, typename T = TFromD<D>, + HWY_IF_NOT_LANE_SIZE_D(D, 2)> + static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa, + const T* const HWY_RESTRICT pb, + const size_t num_elements) { + static_assert(IsFloat<T>(), "MulAdd requires float type"); + using V = decltype(Zero(d)); + + const size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + // Only 2x unroll to avoid excessive code size. + T sum0 = T(0); + T sum1 = T(0); + for (; i + 2 <= num_elements; i += 2) { + sum0 += pa[i + 0] * pb[i + 0]; + sum1 += pa[i + 1] * pb[i + 1]; + } + if (i < num_elements) { + sum1 += pa[i] * pb[i]; + } + return sum0 + sum1; + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + + // Main loop: unrolled + for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = LoadU(d, pa + i); + const auto b2 = LoadU(d, pb + i); + i += N; + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = LoadU(d, pa + i); + const auto b3 = LoadU(d, pb + i); + i += N; + sum3 = MulAdd(a3, b3, sum3); + } + + // Up to 3 iterations of whole vectors + for (; i + N <= num_elements; i += N) { + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(d, remaining); + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(d, N - remaining); + const auto a = LoadU(d, pa + i); // always unaligned + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return GetLane(SumOfLanes(d, sum0)); + } + + // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a + // multiple of N elements is helpful but not required. + template <int kAssumptions, class D> + static HWY_INLINE float Compute(const D d, + const bfloat16_t* const HWY_RESTRICT pa, + const bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned<D> du16; + const Repartition<float, D> df32; + + using V = decltype(Zero(df32)); + const size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. + float sum1 = 0.0f; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]); + sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]); + } + if (i < num_elements) { + sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return GetLane(SumOfLanes(df32, sum0)); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ diff --git a/third_party/highway/hwy/contrib/dot/dot_test.cc b/third_party/highway/hwy/contrib/dot/dot_test.cc new file mode 100644 index 0000000000..12d7ab270d --- /dev/null +++ b/third_party/highway/hwy/contrib/dot/dot_test.cc @@ -0,0 +1,167 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/dot/dot-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename T> +HWY_NOINLINE T SimpleDot(const T* pa, const T* pb, size_t num) { + double sum = 0.0; + for (size_t i = 0; i < num; ++i) { + sum += pa[i] * pb[i]; + } + return static_cast<T>(sum); +} + +HWY_NOINLINE float SimpleDot(const bfloat16_t* pa, const bfloat16_t* pb, + size_t num) { + float sum = 0.0f; + for (size_t i = 0; i < num; ++i) { + sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum; +} + +template <typename T> +void SetValue(const float value, T* HWY_RESTRICT ptr) { + *ptr = static_cast<T>(value); +} +void SetValue(const float value, bfloat16_t* HWY_RESTRICT ptr) { + *ptr = BF16FromF32(value); +} + +class TestDot { + // Computes/verifies one dot product. + template <int kAssumptions, class D> + void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + const size_t N = Lanes(d); + const auto random_t = [&rng]() { + const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; + return static_cast<float>(bits - 512) * (1.0f / 64); + }; + + const size_t padded = + (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; + AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded); + AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + padded); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + size_t i = 0; + for (; i < num; ++i) { + SetValue(random_t(), a + i); + SetValue(random_t(), b + i); + } + // Fill padding with NaN - the values are not used, but avoids MSAN errors. + for (; i < padded; ++i) { + ScalableTag<float> df1; + SetValue(GetLane(NaN(df1)), a + i); + SetValue(GetLane(NaN(df1)), b + i); + } + + const auto expected = SimpleDot(a, b, num); + const auto actual = Dot::Compute<kAssumptions>(d, a, b, num); + const auto max = static_cast<decltype(actual)>(8 * 8 * num); + HWY_ASSERT(-max <= actual && actual <= max); + HWY_ASSERT(expected - 1E-4 <= actual && actual <= expected + 1E-4); + } + + // Runs tests with various alignments. + template <int kAssumptions, class D> + void ForeachMisalign(D d, size_t num, RandomState& rng) { + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test<kAssumptions>(d, num, ma, mb, rng); + } + } + } + + // Runs tests with various lengths compatible with the given assumptions. + template <int kAssumptions, class D> + void ForeachCount(D d, RandomState& rng) { + const size_t N = Lanes(d); + const size_t counts[] = {1, + 3, + 7, + 16, + HWY_MAX(N / 2, 1), + HWY_MAX(2 * N / 3, 1), + N, + N + 1, + 4 * N / 3, + 3 * N, + 8 * N, + 8 * N + 2}; + for (size_t num : counts) { + if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; + if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; + ForeachMisalign<kAssumptions>(d, num, rng); + } + } + + public: + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + // All 8 combinations of the three length-related flags: + ForeachCount<0>(d, rng); + ForeachCount<Dot::kAtLeastOneVector>(d, rng); + ForeachCount<Dot::kMultipleOfVector>(d, rng); + ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng); + ForeachCount<Dot::kPaddedToVector>(d, rng); + ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng); + ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng); + ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector | + Dot::kAtLeastOneVector>(d, rng); + } +}; + +void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); } +void TestAllDotBF16() { ForShrinkableVectors<TestDot>()(bfloat16_t()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(DotTest); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/image/image.cc b/third_party/highway/hwy/contrib/image/image.cc new file mode 100644 index 0000000000..67b37d2711 --- /dev/null +++ b/third_party/highway/hwy/contrib/image/image.cc @@ -0,0 +1,145 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/image/image.h" + +#include <algorithm> // std::swap +#include <cstddef> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/image/image.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +size_t GetVectorSize() { return Lanes(ScalableTag<uint8_t>()); } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(GetVectorSize); // Local function. +} // namespace + +size_t ImageBase::VectorSize() { + // Do not cache result - must return the current value, which may be greater + // than the first call if it was subject to DisableTargets! + return HWY_DYNAMIC_DISPATCH(GetVectorSize)(); +} + +size_t ImageBase::BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 1) { + HWY_DASSERT(vec_size >= sizeof_t); + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = HWY_MAX(vec_size, HWY_ALIGNMENT); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % HWY_ALIGNMENT == 0) { + bytes_per_row += align; + } + + HWY_DASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast<uint32_t>(xsize)), + ysize_(static_cast<uint32_t>(ysize)), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + HWY_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + bytes_ = AllocateAligned<uint8_t>(bytes_per_row_ * ysize); + HWY_ASSERT(bytes_.get() != nullptr); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t bytes_per_row, void* const aligned) + : xsize_(static_cast<uint32_t>(xsize)), + ysize_(static_cast<uint32_t>(ysize)), + bytes_per_row_(bytes_per_row), + bytes_(static_cast<uint8_t*>(aligned), + AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + const size_t vec_size = VectorSize(); + HWY_ASSERT(bytes_per_row % vec_size == 0); + HWY_ASSERT(reinterpret_cast<uintptr_t>(aligned) % vec_size == 0); +} + +void ImageBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if HWY_IS_MSAN || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); // Bytes, independent of sizeof_t! + if (vec_size == 1) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* HWY_RESTRICT row = static_cast<uint8_t*>(VoidRow(y)); +#if defined(__clang__) && (__clang_major__ <= 6) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + memset(row, 0, initialize_size); +#else + memset(row + valid_size, 0, initialize_size - valid_size); +#endif // clang6 + } +#else + (void)sizeof_t; + (void)padding; +#endif // HWY_IS_MSAN +} + +void ImageBase::Swap(ImageBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/image/image.h b/third_party/highway/hwy/contrib/image/image.h new file mode 100644 index 0000000000..c99863b06c --- /dev/null +++ b/third_party/highway/hwy/contrib/image/image.h @@ -0,0 +1,470 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ +#define HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <utility> // std::move + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/highway_export.h" + +namespace hwy { + +// Type-independent parts of Image<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct HWY_CONTRIB_DLLEXPORT ImageBase { + // Returns required alignment in bytes for externally allocated memory. + static size_t VectorSize(); + + // Returns distance [bytes] between the start of two consecutive rows, a + // multiple of VectorSize but NOT kAlias (see implementation). + static size_t BytesPerRow(const size_t xsize, const size_t sizeof_t); + + // No allocation (for output params or unused images) + ImageBase() + : xsize_(0), + ysize_(0), + bytes_per_row_(0), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {} + + // Allocates memory (this is the common case) + ImageBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // References but does not take ownership of external memory. Useful for + // interoperability with other libraries. `aligned` must be aligned to a + // multiple of VectorSize() and `bytes_per_row` must also be a multiple of + // VectorSize() or preferably equal to BytesPerRow(). + ImageBase(size_t xsize, size_t ysize, size_t bytes_per_row, void* aligned); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + ImageBase(const ImageBase& other) = delete; + ImageBase& operator=(const ImageBase& other) = delete; + + // Move constructor (required for returning Image from function) + ImageBase(ImageBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + ImageBase& operator=(ImageBase&& other) noexcept = default; + + void Swap(ImageBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. Caller is responsible + // for ensuring xsize/ysize are <= the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + xsize_ = static_cast<uint32_t>(xsize); + ysize_ = static_cast<uint32_t>(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + HWY_INLINE size_t xsize() const { return xsize_; } + HWY_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + HWY_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast<uint8_t * HWY_RESTRICT>(HWY_ASSUME_ALIGNED(p, 64)); + } + HWY_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast<const uint8_t * HWY_RESTRICT>(HWY_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidRow(const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (y >= ysize_) { + HWY_ABORT("Row(%d) >= %u\n", static_cast<int>(y), ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return HWY_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x <= xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + size_t bytes_per_row_; // Includes padding. + AlignedFreeUniquePtr<uint8_t[]> bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template <typename ComponentType> +class Image : public ImageBase { + public: + using T = ComponentType; + + Image() = default; + Image(const size_t xsize, const size_t ysize) + : ImageBase(xsize, ysize, sizeof(T)) {} + Image(const size_t xsize, const size_t ysize, size_t bytes_per_row, + void* aligned) + : ImageBase(xsize, ysize, bytes_per_row, aligned) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + HWY_INLINE const T* ConstRow(const size_t y) const { + return static_cast<const T*>(VoidRow(y)); + } + HWY_INLINE const T* ConstRow(const size_t y) { + return static_cast<const T*>(VoidRow(y)); + } + + // Returns pointer to non-const. This allows passing const Image* parameters + // when the callee is only supposed to fill the pixels, as opposed to + // allocating or resizing the image. + HWY_INLINE T* MutableRow(const size_t y) const { + return static_cast<T*>(VoidRow(y)); + } + HWY_INLINE T* MutableRow(const size_t y) { + return static_cast<T*>(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { + return static_cast<intptr_t>(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageF = Image<float>; + +// A bundle of 3 same-sized images. To fill an existing Image3 using +// single-channel producers, we also need access to each const Image*. Const +// prevents breaking the same-size invariant, while still allowing pixels to be +// changed via MutableRow. +template <typename ComponentType> +class Image3 { + public: + using T = ComponentType; + using ImageT = Image<T>; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{ImageT(), ImageT(), ImageT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{ImageT(xsize, ysize), ImageT(xsize, ysize), + ImageT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(ImageT&& plane0, ImageT&& plane1, ImageT&& plane2) { + if (!SameSize(plane0, plane1) || !SameSize(plane0, plane2)) { + HWY_ABORT( + "Not same size: %d x %d, %d x %d, %d x %d\n", + static_cast<int>(plane0.xsize()), static_cast<int>(plane0.ysize()), + static_cast<int>(plane1.xsize()), static_cast<int>(plane1.ysize()), + static_cast<int>(plane2.xsize()), static_cast<int>(plane2.ysize())); + } + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + return static_cast<const T*>(VoidPlaneRow(c, y)); + } + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) { + return static_cast<const T*>(VoidPlaneRow(c, y)); + } + + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) const { + return static_cast<T*>(VoidPlaneRow(c, y)); + } + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) { + return static_cast<T*>(VoidPlaneRow(c, y)); + } + + HWY_INLINE const ImageT& Plane(size_t idx) const { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (ImageT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + HWY_INLINE size_t xsize() const { return planes_[0].xsize(); } + HWY_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (c >= kNumPlanes || y >= ysize()) { + HWY_ABORT("PlaneRow(%d, %d) >= %d\n", static_cast<int>(c), + static_cast<int>(y), static_cast<int>(ysize())); + } +#endif + // Use the first plane's stride because the compiler might not realize they + // are all equal. Thus we only need a single multiplication for all planes. + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast<const T * HWY_RESTRICT>( + HWY_ASSUME_ALIGNED(row, HWY_ALIGNMENT)); + } + + private: + ImageT planes_[kNumPlanes]; +}; + +using Image3F = Image3<float>; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions. Can compare size via SameSize(rect1, rect2). +class Rect { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max, size_t xend, size_t yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image. + template <typename Image> + explicit Rect(const Image& image) + : Rect(0, 0, image.xsize(), image.ysize()) {} + + Rect() : Rect(0, 0, 0, 0) {} + + Rect(const Rect&) = default; + Rect& operator=(const Rect&) = default; + + Rect Subrect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max) { + return Rect(x0_ + xbegin, y0_ + ybegin, xsize_max, ysize_max, x0_ + xsize_, + y0_ + ysize_); + } + + template <typename T> + const T* ConstRow(const Image<T>* image, size_t y) const { + return image->ConstRow(y + y0_) + x0_; + } + + template <typename T> + T* MutableRow(const Image<T>* image, size_t y) const { + return image->MutableRow(y + y0_) + x0_; + } + + template <typename T> + const T* ConstPlaneRow(const Image3<T>& image, size_t c, size_t y) const { + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + template <typename T> + T* MutablePlaneRow(Image3<T>* image, const size_t c, size_t y) const { + return image->MutablePlaneRow(c, y + y0_) + x0_; + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Image<T> or Image3<T>; however if ImageT is Rect, results are nonsensical. + template <class ImageT> + bool IsInside(const ImageT& image) const { + return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize()); + } + + size_t x0() const { return x0_; } + size_t y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(size_t begin, size_t size_max, + size_t end) { + return (begin + size_max <= end) ? size_max + : (end > begin ? end - begin : 0); + } + + size_t x0_; + size_t y0_; + + size_t xsize_; + size_t ysize_; +}; + +// Works for any image-like input type(s). +template <class Image1, class Image2> +HWY_MAYBE_UNUSED bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static HWY_INLINE HWY_MAYBE_UNUSED size_t Mirror(int64_t x, + const int64_t xsize) { + HWY_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return static_cast<size_t>(x); +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + HWY_INLINE size_t operator()(const int64_t coord, const size_t size) const { + return Mirror(coord, static_cast<int64_t>(size)); + } +}; + +// Returns the same coordinate, for when we know "coord" is already valid (e.g. +// interior of an image). +struct WrapUnchanged { + HWY_INLINE size_t operator()(const int64_t coord, size_t /*size*/) const { + return static_cast<size_t>(coord); + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template <class View> + WrapRowMirror(const View& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const HWY_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const HWY_RESTRICT first_row_; + const float* const HWY_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + HWY_INLINE const float* operator()(const float* const HWY_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ diff --git a/third_party/highway/hwy/contrib/image/image_test.cc b/third_party/highway/hwy/contrib/image/image_test.cc new file mode 100644 index 0000000000..6886577a46 --- /dev/null +++ b/third_party/highway/hwy/contrib/image/image_test.cc @@ -0,0 +1,152 @@ +// Copyright (c) the JPEG XL Project +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/image/image.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> + +#include <random> +#include <utility> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/image/image_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target: +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Ensure we can always write full aligned vectors. +struct TestAlignedT { + template <typename T> + void operator()(T /*unused*/) const { + std::mt19937 rng(129); + std::uniform_int_distribution<int> dist(0, 16); + const ScalableTag<T> d; + + for (size_t ysize = 1; ysize < 4; ++ysize) { + for (size_t xsize = 1; xsize < 64; ++xsize) { + Image<T> img(xsize, ysize); + + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto values = Iota(d, static_cast<T>(dist(rng))); + Store(values, d, row + x); + } + } + + // Sanity check to prevent optimizing out the writes + const auto x = std::uniform_int_distribution<size_t>(0, xsize - 1)(rng); + const auto y = std::uniform_int_distribution<size_t>(0, ysize - 1)(rng); + HWY_ASSERT(img.ConstRow(y)[x] < 16 + Lanes(d)); + } + } + } +}; + +void TestAligned() { ForUnsignedTypes(TestAlignedT()); } + +// Ensure we can write an unaligned vector starting at the last valid value. +struct TestUnalignedT { + template <typename T> + void operator()(T /*unused*/) const { + std::mt19937 rng(129); + std::uniform_int_distribution<int> dist(0, 3); + const ScalableTag<T> d; + + for (size_t ysize = 1; ysize < 4; ++ysize) { + for (size_t xsize = 1; xsize < 128; ++xsize) { + Image<T> img(xsize, ysize); + img.InitializePaddingForUnalignedAccesses(); + +// This test reads padding, which only works if it was initialized, +// which only happens in MSAN builds. +#if HWY_IS_MSAN || HWY_IDE + // Initialize only the valid samples + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + row[x] = static_cast<T>(1u << dist(rng)); + } + } + + // Read padding bits + auto accum = Zero(d); + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + accum = Or(accum, LoadU(d, row + x)); + } + } + + // Ensure padding was zero + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + Store(accum, d, lanes.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(lanes[i] < 16); + } +#else // Check that writing padding does not overwrite valid samples + // Initialize only the valid samples + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize; ++x) { + row[x] = static_cast<T>(x); + } + } + + // Zero padding and rightmost sample + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + StoreU(Zero(d), d, row + xsize - 1); + } + + // Ensure no samples except the rightmost were overwritten + for (size_t y = 0; y < ysize; ++y) { + T* HWY_RESTRICT row = img.MutableRow(y); + for (size_t x = 0; x < xsize - 1; ++x) { + HWY_ASSERT_EQ(static_cast<T>(x), row[x]); + } + } +#endif + } + } + } +}; + +void TestUnaligned() { ForUnsignedTypes(TestUnalignedT()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(ImageTest); +HWY_EXPORT_AND_TEST_P(ImageTest, TestAligned); +HWY_EXPORT_AND_TEST_P(ImageTest, TestUnaligned); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/math/math-inl.h b/third_party/highway/hwy/contrib/math/math-inl.h new file mode 100644 index 0000000000..b4cbb5d119 --- /dev/null +++ b/third_party/highway/hwy/contrib/math/math-inl.h @@ -0,0 +1,1242 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +/** + * Highway SIMD version of std::acos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc cosine of 'x' + */ +template <class D, class V> +HWY_INLINE V Acos(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallAcos(const D d, VecArg<V> x) { + return Acos(d, x); +} + +/** + * Highway SIMD version of std::acosh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[1, +FLT_MAX], float64[1, +DBL_MAX] + * @return hyperbolic arc cosine of 'x' + */ +template <class D, class V> +HWY_INLINE V Acosh(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallAcosh(const D d, VecArg<V> x) { + return Acosh(d, x); +} + +/** + * Highway SIMD version of std::asin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc sine of 'x' + */ +template <class D, class V> +HWY_INLINE V Asin(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallAsin(const D d, VecArg<V> x) { + return Asin(d, x); +} + +/** + * Highway SIMD version of std::asinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic arc sine of 'x' + */ +template <class D, class V> +HWY_INLINE V Asinh(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallAsinh(const D d, VecArg<V> x) { + return Asinh(d, x); +} + +/** + * Highway SIMD version of std::atan(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return arc tangent of 'x' + */ +template <class D, class V> +HWY_INLINE V Atan(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallAtan(const D d, VecArg<V> x) { + return Atan(d, x); +} + +/** + * Highway SIMD version of std::atanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: (-1, +1) + * @return hyperbolic arc tangent of 'x' + */ +template <class D, class V> +HWY_INLINE V Atanh(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallAtanh(const D d, VecArg<V> x) { + return Atanh(d, x); +} + +/** + * Highway SIMD version of std::cos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return cosine of 'x' + */ +template <class D, class V> +HWY_INLINE V Cos(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallCos(const D d, VecArg<V> x) { + return Cos(d, x); +} + +/** + * Highway SIMD version of std::exp(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x + */ +template <class D, class V> +HWY_INLINE V Exp(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallExp(const D d, VecArg<V> x) { + return Exp(d, x); +} + +/** + * Highway SIMD version of std::expm1(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x - 1 + */ +template <class D, class V> +HWY_INLINE V Expm1(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallExpm1(const D d, VecArg<V> x) { + return Expm1(d, x); +} + +/** + * Highway SIMD version of std::log(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return natural logarithm of 'x' + */ +template <class D, class V> +HWY_INLINE V Log(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallLog(const D d, VecArg<V> x) { + return Log(d, x); +} + +/** + * Highway SIMD version of std::log10(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 10 logarithm of 'x' + */ +template <class D, class V> +HWY_INLINE V Log10(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallLog10(const D d, VecArg<V> x) { + return Log10(d, x); +} + +/** + * Highway SIMD version of std::log1p(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[0, +FLT_MAX], float64[0, +DBL_MAX] + * @return log(1 + x) + */ +template <class D, class V> +HWY_INLINE V Log1p(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallLog1p(const D d, VecArg<V> x) { + return Log1p(d, x); +} + +/** + * Highway SIMD version of std::log2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 2 logarithm of 'x' + */ +template <class D, class V> +HWY_INLINE V Log2(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallLog2(const D d, VecArg<V> x) { + return Log2(d, x); +} + +/** + * Highway SIMD version of std::sin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return sine of 'x' + */ +template <class D, class V> +HWY_INLINE V Sin(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallSin(const D d, VecArg<V> x) { + return Sin(d, x); +} + +/** + * Highway SIMD version of std::sinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-88.7228, +88.7228], float64[-709, +709] + * @return hyperbolic sine of 'x' + */ +template <class D, class V> +HWY_INLINE V Sinh(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallSinh(const D d, VecArg<V> x) { + return Sinh(d, x); +} + +/** + * Highway SIMD version of std::tanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic tangent of 'x' + */ +template <class D, class V> +HWY_INLINE V Tanh(const D d, V x); +template <class D, class V> +HWY_NOINLINE V CallTanh(const D d, VecArg<V> x) { + return Tanh(d, x); +} + +//////////////////////////////////////////////////////////////////////////////// +// Implementation +//////////////////////////////////////////////////////////////////////////////// +namespace impl { + +// Estrin's Scheme is a faster method for evaluating large polynomials on +// super scalar architectures. It works by factoring the Horner's Method +// polynomial into power of two sub-trees that can be evaluated in parallel. +// Wikipedia Link: https://en.wikipedia.org/wiki/Estrin%27s_scheme +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1) { + return MulAdd(c1, x, c0); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2) { + T x2 = Mul(x, x); + return MulAdd(x2, c2, MulAdd(c1, x, c0)); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3) { + T x2 = Mul(x, x); + return MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, c4, MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(c5, x, c4), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, c6, MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, c8, + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(c9, x, c8), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, c10, MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd( + x8, MulAdd(x4, c12, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(c13, x, c12), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, c14, MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, c16, + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(c17, x, c16), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template <class T> +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17, + T c18) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(x2, c18, MulAdd(c17, x, c16)), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} + +template <class FloatOrDouble> +struct AsinImpl {}; +template <class FloatOrDouble> +struct AtanImpl {}; +template <class FloatOrDouble> +struct CosSinImpl {}; +template <class FloatOrDouble> +struct ExpImpl {}; +template <class FloatOrDouble> +struct LogImpl {}; + +template <> +struct AsinImpl<float> { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template <class D, class V> + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666677296f); + const auto k1 = Set(d, +0.07495029271f); + const auto k2 = Set(d, +0.04547423869f); + const auto k3 = Set(d, +0.02424046025f); + const auto k4 = Set(d, +0.04197454825f); + + return Estrin(x2, k0, k1, k2, k3, k4); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AsinImpl<double> { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template <class D, class V> + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666666666666497543); + const auto k1 = Set(d, +0.07500000000378581611); + const auto k2 = Set(d, +0.04464285681377102438); + const auto k3 = Set(d, +0.03038195928038132237); + const auto k4 = Set(d, +0.02237176181932048341); + const auto k5 = Set(d, +0.01735956991223614604); + const auto k6 = Set(d, +0.01388715184501609218); + const auto k7 = Set(d, +0.01215360525577377331); + const auto k8 = Set(d, +0.006606077476277170610); + const auto k9 = Set(d, +0.01929045477267910674); + const auto k10 = Set(d, -0.01581918243329996643); + const auto k11 = Set(d, +0.03161587650653934628); + + return Estrin(x2, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11); + } +}; + +#endif + +template <> +struct AtanImpl<float> { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template <class D, class V> + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333331018686294555664062f); + const auto k1 = Set(d, +0.199926957488059997558594f); + const auto k2 = Set(d, -0.142027363181114196777344f); + const auto k3 = Set(d, +0.106347933411598205566406f); + const auto k4 = Set(d, -0.0748900920152664184570312f); + const auto k5 = Set(d, +0.0425049886107444763183594f); + const auto k6 = Set(d, -0.0159569028764963150024414f); + const auto k7 = Set(d, +0.00282363896258175373077393f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7), Mul(y, x), x); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AtanImpl<double> { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template <class D, class V> + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333333333333311110369124); + const auto k1 = Set(d, +0.199999999996591265594148); + const auto k2 = Set(d, -0.14285714266771329383765); + const auto k3 = Set(d, +0.111111105648261418443745); + const auto k4 = Set(d, -0.090908995008245008229153); + const auto k5 = Set(d, +0.0769219538311769618355029); + const auto k6 = Set(d, -0.0666573579361080525984562); + const auto k7 = Set(d, +0.0587666392926673580854313); + const auto k8 = Set(d, -0.0523674852303482457616113); + const auto k9 = Set(d, +0.0466667150077840625632675); + const auto k10 = Set(d, -0.0407629191276836500001934); + const auto k11 = Set(d, +0.0337852580001353069993897); + const auto k12 = Set(d, -0.0254517624932312641616861); + const auto k13 = Set(d, +0.016599329773529201970117); + const auto k14 = Set(d, -0.00889896195887655491740809); + const auto k15 = Set(d, +0.00370026744188713119232403); + const auto k16 = Set(d, -0.00110611831486672482563471); + const auto k17 = Set(d, +0.000209850076645816976906797); + const auto k18 = Set(d, -1.88796008463073496563746e-5); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, + k12, k13, k14, k15, k16, k17, k18), + Mul(y, x), x); + } +}; + +#endif + +template <> +struct CosSinImpl<float> { + // Rounds float toward zero and returns as int32_t. + template <class D, class V> + HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind<int32_t, D>(), x); + } + + template <class D, class V> + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -1.66666597127914428710938e-1f); + const auto k1 = Set(d, +8.33307858556509017944336e-3f); + const auto k2 = Set(d, -1.981069071916863322258e-4f); + const auto k3 = Set(d, +2.6083159809786593541503e-6f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3), Mul(y, x), x); + } + + template <class D, class V, class VI32> + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0f + kHalfPiPart1f + kHalfPiPart2f + kHalfPiPart3f ~= -pi/2 + const V kHalfPiPart0f = Set(d, -0.5f * 3.140625f); + const V kHalfPiPart1f = Set(d, -0.5f * 0.0009670257568359375f); + const V kHalfPiPart2f = Set(d, -0.5f * 6.2771141529083251953e-7f); + const V kHalfPiPart3f = Set(d, -0.5f * 1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kHalfPiPart0f, x); + x = MulAdd(qf, kHalfPiPart1f, x); + x = MulAdd(qf, kHalfPiPart2f, x); + x = MulAdd(qf, kHalfPiPart3f, x); + return x; + } + + template <class D, class V, class VI32> + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0f + kPiPart1f + kPiPart2f + kPiPart3f ~= -pi + const V kPiPart0f = Set(d, -3.140625f); + const V kPiPart1f = Set(d, -0.0009670257568359375f); + const V kPiPart2f = Set(d, -6.2771141529083251953e-7f); + const V kPiPart3f = Set(d, -1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kPiPart0f, x); + x = MulAdd(qf, kPiPart1f, x); + x = MulAdd(qf, kPiPart2f, x); + x = MulAdd(qf, kPiPart3f, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template <class D, class VI32> + HWY_INLINE Vec<Rebind<float, D>> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind<int32_t, D>(), 2); + return BitCast(d, ShiftLeft<30>(AndNot(q, kTwo))); + } + + // ((q & 1) ? -0.0 : +0.0) + template <class D, class VI32> + HWY_INLINE Vec<Rebind<float, D>> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind<int32_t, D>(), 1); + return BitCast(d, ShiftLeft<31>(And(q, kOne))); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct CosSinImpl<double> { + // Rounds double toward zero and returns as int32_t. + template <class D, class V> + HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind<int32_t, D>(), x); + } + + template <class D, class V> + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -0.166666666666666657414808); + const auto k1 = Set(d, +0.00833333333333332974823815); + const auto k2 = Set(d, -0.000198412698412696162806809); + const auto k3 = Set(d, +2.75573192239198747630416e-6); + const auto k4 = Set(d, -2.50521083763502045810755e-8); + const auto k5 = Set(d, +1.60590430605664501629054e-10); + const auto k6 = Set(d, -7.64712219118158833288484e-13); + const auto k7 = Set(d, +2.81009972710863200091251e-15); + const auto k8 = Set(d, -7.97255955009037868891952e-18); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8), Mul(y, x), x); + } + + template <class D, class V, class VI32> + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0d + kHalfPiPart1d + kHalfPiPart2d + kHalfPiPart3d ~= -pi/2 + const V kHalfPiPart0d = Set(d, -0.5 * 3.1415926218032836914); + const V kHalfPiPart1d = Set(d, -0.5 * 3.1786509424591713469e-8); + const V kHalfPiPart2d = Set(d, -0.5 * 1.2246467864107188502e-16); + const V kHalfPiPart3d = Set(d, -0.5 * 1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kHalfPiPart0d, x); + x = MulAdd(qf, kHalfPiPart1d, x); + x = MulAdd(qf, kHalfPiPart2d, x); + x = MulAdd(qf, kHalfPiPart3d, x); + return x; + } + + template <class D, class V, class VI32> + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0d + kPiPart1d + kPiPart2d + kPiPart3d ~= -pi + const V kPiPart0d = Set(d, -3.1415926218032836914); + const V kPiPart1d = Set(d, -3.1786509424591713469e-8); + const V kPiPart2d = Set(d, -1.2246467864107188502e-16); + const V kPiPart3d = Set(d, -1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kPiPart0d, x); + x = MulAdd(qf, kPiPart1d, x); + x = MulAdd(qf, kPiPart2d, x); + x = MulAdd(qf, kPiPart3d, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template <class D, class VI32> + HWY_INLINE Vec<Rebind<double, D>> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind<int32_t, D>(), 2); + return BitCast( + d, ShiftLeft<62>(PromoteTo(Rebind<int64_t, D>(), AndNot(q, kTwo)))); + } + + // ((q & 1) ? -0.0 : +0.0) + template <class D, class VI32> + HWY_INLINE Vec<Rebind<double, D>> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind<int32_t, D>(), 1); + return BitCast( + d, ShiftLeft<63>(PromoteTo(Rebind<int64_t, D>(), And(q, kOne)))); + } +}; + +#endif + +template <> +struct ExpImpl<float> { + // Rounds float toward zero and returns as int32_t. + template <class D, class V> + HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind<int32_t, D>(), x); + } + + template <class D, class V> + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5f); + const auto k1 = Set(d, +0.166666671633720397949219f); + const auto k2 = Set(d, +0.0416664853692054748535156f); + const auto k3 = Set(d, +0.00833336077630519866943359f); + const auto k4 = Set(d, +0.00139304355252534151077271f); + const auto k5 = Set(d, +0.000198527617612853646278381f); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5), Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template <class D, class VI32> + HWY_INLINE Vec<D> Pow2I(D d, VI32 x) { + const Rebind<int32_t, D> di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(Add(x, kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template <class D, class V, class VI32> + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template <class D, class V, class VI32> + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0f + kLn2Part1f ~= -ln(2) + const V kLn2Part0f = Set(d, -0.693145751953125f); + const V kLn2Part1f = Set(d, -1.428606765330187045e-6f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kLn2Part0f, x); + x = MulAdd(qf, kLn2Part1f, x); + return x; + } +}; + +template <> +struct LogImpl<float> { + template <class D, class V> + HWY_INLINE Vec<Rebind<int32_t, D>> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind<int32_t, D> di32; + const Rebind<uint32_t, D> du32; + const auto kBias = Set(di32, 0x7F); + return Sub(BitCast(di32, ShiftRight<23>(BitCast(du32, x))), kBias); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template <class D, class V> + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.66666662693f); + const V k1 = Set(d, 0.40000972152f); + const V k2 = Set(d, 0.28498786688f); + const V k3 = Set(d, 0.24279078841f); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(k2, x4, k0), x2, Mul(MulAdd(k3, x4, k1), x4)); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct ExpImpl<double> { + // Rounds double toward zero and returns as int32_t. + template <class D, class V> + HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind<int32_t, D>(), x); + } + + template <class D, class V> + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5); + const auto k1 = Set(d, +0.166666666666666851703837); + const auto k2 = Set(d, +0.0416666666666665047591422); + const auto k3 = Set(d, +0.00833333333331652721664984); + const auto k4 = Set(d, +0.00138888888889774492207962); + const auto k5 = Set(d, +0.000198412698960509205564975); + const auto k6 = Set(d, +2.4801587159235472998791e-5); + const auto k7 = Set(d, +2.75572362911928827629423e-6); + const auto k8 = Set(d, +2.75573911234900471893338e-7); + const auto k9 = Set(d, +2.51112930892876518610661e-8); + const auto k10 = Set(d, +2.08860621107283687536341e-9); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10), + Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template <class D, class VI32> + HWY_INLINE Vec<D> Pow2I(D d, VI32 x) { + const Rebind<int32_t, D> di32; + const Rebind<int64_t, D> di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset)))); + } + + // Sets the exponent of 'x' to 2^e. + template <class D, class V, class VI32> + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template <class D, class V, class VI32> + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0d + kLn2Part1d ~= -ln(2) + const V kLn2Part0d = Set(d, -0.6931471805596629565116018); + const V kLn2Part1d = Set(d, -0.28235290563031577122588448175e-12); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kLn2Part0d, x); + x = MulAdd(qf, kLn2Part1d, x); + return x; + } +}; + +template <> +struct LogImpl<double> { + template <class D, class V> + HWY_INLINE Vec<Rebind<int64_t, D>> Log2p1NoSubnormal(D /*d*/, V x) { + const Rebind<int64_t, D> di64; + const Rebind<uint64_t, D> du64; + return Sub(BitCast(di64, ShiftRight<52>(BitCast(du64, x))), + Set(di64, 0x3FF)); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template <class D, class V> + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.6666666666666735130); + const V k1 = Set(d, 0.3999999999940941908); + const V k2 = Set(d, 0.2857142874366239149); + const V k3 = Set(d, 0.2222219843214978396); + const V k4 = Set(d, 0.1818357216161805012); + const V k5 = Set(d, 0.1531383769920937332); + const V k6 = Set(d, 0.1479819860511658591); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(MulAdd(MulAdd(k6, x4, k4), x4, k2), x4, k0), x2, + (Mul(MulAdd(MulAdd(k5, x4, k3), x4, k1), x4))); + } +}; + +#endif + +template <class D, class V, bool kAllowSubnormals = true> +HWY_INLINE V Log(const D d, V x) { + // http://git.musl-libc.org/cgit/musl/tree/src/math/log.c for more info. + using T = TFromD<D>; + impl::LogImpl<T> impl; + + constexpr bool kIsF32 = (sizeof(T) == 4); + + // Float Constants + const V kLn2Hi = Set(d, kIsF32 ? static_cast<T>(0.69313812256f) + : static_cast<T>(0.693147180369123816490)); + const V kLn2Lo = Set(d, kIsF32 ? static_cast<T>(9.0580006145e-6f) + : static_cast<T>(1.90821492927058770002e-10)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kMinNormal = Set(d, kIsF32 ? static_cast<T>(1.175494351e-38f) + : static_cast<T>(2.2250738585072014e-308)); + const V kScale = Set(d, kIsF32 ? static_cast<T>(3.355443200e+7f) + : static_cast<T>(1.8014398509481984e+16)); + + // Integer Constants + using TI = MakeSigned<T>; + const Rebind<TI, D> di; + using VI = decltype(Zero(di)); + const VI kLowerBits = Set(di, kIsF32 ? static_cast<TI>(0x00000000L) + : static_cast<TI>(0xFFFFFFFFLL)); + const VI kMagic = Set(di, kIsF32 ? static_cast<TI>(0x3F3504F3L) + : static_cast<TI>(0x3FE6A09E00000000LL)); + const VI kExpMask = Set(di, kIsF32 ? static_cast<TI>(0x3F800000L) + : static_cast<TI>(0x3FF0000000000000LL)); + const VI kExpScale = + Set(di, kIsF32 ? static_cast<TI>(-25) : static_cast<TI>(-54)); + const VI kManMask = Set(di, kIsF32 ? static_cast<TI>(0x7FFFFFL) + : static_cast<TI>(0xFFFFF00000000LL)); + + // Scale up 'x' so that it is no longer denormalized. + VI exp_bits; + V exp; + if (kAllowSubnormals == true) { + const auto is_denormal = Lt(x, kMinNormal); + x = IfThenElse(is_denormal, Mul(x, kScale), x); + + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic)); + const VI exp_scale = + BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale))); + exp = ConvertTo( + d, Add(exp_scale, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits)))); + } else { + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic)); + exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits))); + } + + // Renormalize. + const V y = Or(And(x, BitCast(d, kLowerBits)), + BitCast(d, Add(And(exp_bits, kManMask), kMagic))); + + // Approximate and reconstruct. + const V ym1 = Sub(y, kOne); + const V z = Div(ym1, Add(y, kOne)); + + return MulSub( + exp, kLn2Hi, + Sub(MulSub(z, Sub(ym1, impl.LogPoly(d, z)), Mul(exp, kLn2Lo)), ym1)); +} + +} // namespace impl + +template <class D, class V> +HWY_INLINE V Acos(const D d, V x) { + using T = TFromD<D>; + + const V kZero = Zero(d); + const V kHalf = Set(d, static_cast<T>(+0.5)); + const V kPi = Set(d, static_cast<T>(+3.14159265358979323846264)); + const V kPiOverTwo = Set(d, static_cast<T>(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl<T> impl; + const V t = Mul(impl.AsinPoly(d, yy, y), Mul(y, yy)); + + const V t_plus_y = Add(t, y); + const V z = + IfThenElse(mask, Sub(kPiOverTwo, Add(Xor(y, sign_x), Xor(t, sign_x))), + Add(t_plus_y, t_plus_y)); + return IfThenElse(Or(mask, Ge(x, kZero)), z, Sub(kPi, z)); +} + +template <class D, class V> +HWY_INLINE V Acosh(const D d, V x) { + using T = TFromD<D>; + + const V kLarge = Set(d, static_cast<T>(268435456.0)); + const V kLog2 = Set(d, static_cast<T>(0.693147180559945286227)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kTwo = Set(d, static_cast<T>(+2.0)); + + const auto is_x_large = Gt(x, kLarge); + const auto is_x_gt_2 = Gt(x, kTwo); + + const V x_minus_1 = Sub(x, kOne); + const V y0 = MulSub(kTwo, x, Div(kOne, Add(Sqrt(MulSub(x, x, kOne)), x))); + const V y1 = + Add(Sqrt(MulAdd(x_minus_1, kTwo, Mul(x_minus_1, x_minus_1))), x_minus_1); + const V y2 = + IfThenElse(is_x_gt_2, IfThenElse(is_x_large, x, y0), Add(y1, kOne)); + const V z = impl::Log<D, V, /*kAllowSubnormals=*/false>(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + return Add(IfThenElse(is_x_gt_2, z, + IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor))), + IfThenElseZero(is_x_large, kLog2)); +} + +template <class D, class V> +HWY_INLINE V Asin(const D d, V x) { + using T = TFromD<D>; + + const V kHalf = Set(d, static_cast<T>(+0.5)); + const V kTwo = Set(d, static_cast<T>(+2.0)); + const V kPiOverTwo = Set(d, static_cast<T>(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl<T> impl; + const V z0 = MulAdd(impl.AsinPoly(d, yy, y), Mul(yy, y), y); + const V z1 = NegMulAdd(z0, kTwo, kPiOverTwo); + return Or(IfThenElse(mask, z0, z1), sign_x); +} + +template <class D, class V> +HWY_INLINE V Asinh(const D d, V x) { + using T = TFromD<D>; + + const V kSmall = Set(d, static_cast<T>(1.0 / 268435456.0)); + const V kLarge = Set(d, static_cast<T>(268435456.0)); + const V kLog2 = Set(d, static_cast<T>(0.693147180559945286227)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kTwo = Set(d, static_cast<T>(+2.0)); + + const V sign_x = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign_x); + + const auto is_x_large = Gt(abs_x, kLarge); + const auto is_x_lt_2 = Lt(abs_x, kTwo); + + const V x2 = Mul(x, x); + const V sqrt_x2_plus_1 = Sqrt(Add(x2, kOne)); + + const V y0 = MulAdd(abs_x, kTwo, Div(kOne, Add(sqrt_x2_plus_1, abs_x))); + const V y1 = Add(Div(x2, Add(sqrt_x2_plus_1, kOne)), abs_x); + const V y2 = + IfThenElse(is_x_lt_2, Add(y1, kOne), IfThenElse(is_x_large, abs_x, y0)); + const V z = impl::Log<D, V, /*kAllowSubnormals=*/false>(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + const auto large = IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor)); + const V y = IfThenElse(Lt(abs_x, kSmall), x, large); + return Or(Add(IfThenElse(is_x_lt_2, y, z), IfThenElseZero(is_x_large, kLog2)), + sign_x); +} + +template <class D, class V> +HWY_INLINE V Atan(const D d, V x) { + using T = TFromD<D>; + + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kPiOverTwo = Set(d, static_cast<T>(+1.57079632679489661923132169)); + + const V sign = And(SignBit(d), x); + const V abs_x = Xor(x, sign); + const auto mask = Gt(abs_x, kOne); + + impl::AtanImpl<T> impl; + const auto divisor = IfThenElse(mask, abs_x, kOne); + const V y = impl.AtanPoly(d, IfThenElse(mask, Div(kOne, divisor), abs_x)); + return Or(IfThenElse(mask, Sub(kPiOverTwo, y), y), sign); +} + +template <class D, class V> +HWY_INLINE V Atanh(const D d, V x) { + using T = TFromD<D>; + + const V kHalf = Set(d, static_cast<T>(+0.5)); + const V kOne = Set(d, static_cast<T>(+1.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + return Mul(Log1p(d, Div(Add(abs_x, abs_x), Sub(kOne, abs_x))), + Xor(kHalf, sign)); +} + +template <class D, class V> +HWY_INLINE V Cos(const D d, V x) { + using T = TFromD<D>; + impl::CosSinImpl<T> impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast<T>(0.31830988618379067153)); + + // Integer Constants + const Rebind<int32_t, D> di32; + using VI32 = decltype(Zero(di32)); + const VI32 kOne = Set(di32, 1); + + const V y = Abs(x); // cos(x) == cos(|x|) + + // Compute the quadrant, q = int(|x| / pi) * 2 + 1 + const VI32 q = Add(ShiftLeft<1>(impl.ToInt32(d, Mul(y, kOneOverPi))), kOne); + + // Reduce range, apply sign, and approximate. + return impl.Poly( + d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q))); +} + +template <class D, class V> +HWY_INLINE V Exp(const D d, V x) { + using T = TFromD<D>; + + const V kHalf = Set(d, static_cast<T>(+0.5)); + const V kLowerBound = + Set(d, static_cast<T>((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kNegZero = Set(d, static_cast<T>(-0.0)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kOneOverLog2 = Set(d, static_cast<T>(+1.442695040888963407359924681)); + + impl::ExpImpl<T> impl; + + // q = static_cast<int32>((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.ExpReduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + +template <class D, class V> +HWY_INLINE V Expm1(const D d, V x) { + using T = TFromD<D>; + + const V kHalf = Set(d, static_cast<T>(+0.5)); + const V kLowerBound = + Set(d, static_cast<T>((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kLn2Over2 = Set(d, static_cast<T>(+0.346573590279972654708616)); + const V kNegOne = Set(d, static_cast<T>(-1.0)); + const V kNegZero = Set(d, static_cast<T>(-0.0)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kOneOverLog2 = Set(d, static_cast<T>(+1.442695040888963407359924681)); + + impl::ExpImpl<T> impl; + + // q = static_cast<int32>((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.ExpPoly(d, impl.ExpReduce(d, x, q)); + const V z = IfThenElse(Lt(Abs(x), kLn2Over2), y, + Sub(impl.LoadExpShortRange(d, Add(y, kOne), q), kOne)); + return IfThenElse(Lt(x, kLowerBound), kNegOne, z); +} + +template <class D, class V> +HWY_INLINE V Log(const D d, V x) { + return impl::Log<D, V, /*kAllowSubnormals=*/true>(d, x); +} + +template <class D, class V> +HWY_INLINE V Log10(const D d, V x) { + using T = TFromD<D>; + return Mul(Log(d, x), Set(d, static_cast<T>(0.4342944819032518276511))); +} + +template <class D, class V> +HWY_INLINE V Log1p(const D d, V x) { + using T = TFromD<D>; + const V kOne = Set(d, static_cast<T>(+1.0)); + + const V y = Add(x, kOne); + const auto is_pole = Eq(y, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y), kOne); + const auto non_pole = + Mul(impl::Log<D, V, /*kAllowSubnormals=*/false>(d, y), Div(x, divisor)); + return IfThenElse(is_pole, x, non_pole); +} + +template <class D, class V> +HWY_INLINE V Log2(const D d, V x) { + using T = TFromD<D>; + return Mul(Log(d, x), Set(d, static_cast<T>(1.44269504088896340735992))); +} + +template <class D, class V> +HWY_INLINE V Sin(const D d, V x) { + using T = TFromD<D>; + impl::CosSinImpl<T> impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast<T>(0.31830988618379067153)); + const V kHalf = Set(d, static_cast<T>(0.5)); + + // Integer Constants + const Rebind<int32_t, D> di32; + using VI32 = decltype(Zero(di32)); + + const V abs_x = Abs(x); + const V sign_x = Xor(abs_x, x); + + // Compute the quadrant, q = int((|x| / pi) + 0.5) + const VI32 q = impl.ToInt32(d, MulAdd(abs_x, kOneOverPi, kHalf)); + + // Reduce range, apply sign, and approximate. + return impl.Poly(d, Xor(impl.SinReduce(d, abs_x, q), + Xor(impl.SinSignFromQuadrant(d, q), sign_x))); +} + +template <class D, class V> +HWY_INLINE V Sinh(const D d, V x) { + using T = TFromD<D>; + const V kHalf = Set(d, static_cast<T>(+0.5)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kTwo = Set(d, static_cast<T>(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, abs_x); + const V z = Mul(Div(Add(y, kTwo), Add(y, kOne)), Mul(y, kHalf)); + return Xor(z, sign); // Reapply the sign bit +} + +template <class D, class V> +HWY_INLINE V Tanh(const D d, V x) { + using T = TFromD<D>; + const V kLimit = Set(d, static_cast<T>(18.714973875)); + const V kOne = Set(d, static_cast<T>(+1.0)); + const V kTwo = Set(d, static_cast<T>(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, Mul(abs_x, kTwo)); + const V z = IfThenElse(Gt(abs_x, kLimit), kOne, Div(y, Add(y, kTwo))); + return Xor(z, sign); // Reapply the sign bit +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ diff --git a/third_party/highway/hwy/contrib/math/math_test.cc b/third_party/highway/hwy/contrib/math/math_test.cc new file mode 100644 index 0000000000..2cc58c6106 --- /dev/null +++ b/third_party/highway/hwy/contrib/math/math_test.cc @@ -0,0 +1,228 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stdio.h> + +#include <cfloat> // FLT_MAX +#include <cmath> // std::abs +#include <type_traits> + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/math/math-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <class Out, class In> +inline Out BitCast(const In& in) { + static_assert(sizeof(Out) == sizeof(In), ""); + Out out; + CopyBytes<sizeof(out)>(&in, &out); + return out; +} + +template <class T, class D> +HWY_NOINLINE void TestMath(const std::string name, T (*fx1)(T), + Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min, T max, + uint64_t max_error_ulp) { + using UintT = MakeUnsigned<T>; + + const UintT min_bits = BitCast<UintT>(min); + const UintT max_bits = BitCast<UintT>(max); + + // If min is negative and max is positive, the range needs to be broken into + // two pieces, [+0, max] and [-0, min], otherwise [min, max]. + int range_count = 1; + UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}}; + if ((min < 0.0) && (max > 0.0)) { + ranges[0][0] = BitCast<UintT>(static_cast<T>(+0.0)); + ranges[0][1] = max_bits; + ranges[1][0] = BitCast<UintT>(static_cast<T>(-0.0)); + ranges[1][1] = min_bits; + range_count = 2; + } + + uint64_t max_ulp = 0; + // Emulation is slower, so cannot afford as many. + constexpr UintT kSamplesPerRange = static_cast<UintT>(AdjustedReps(4000)); + for (int range_index = 0; range_index < range_count; ++range_index) { + const UintT start = ranges[range_index][0]; + const UintT stop = ranges[range_index][1]; + const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange)); + for (UintT value_bits = start; value_bits <= stop; value_bits += step) { + // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise + // value_bits can be less than start, and thus possibly NaN. + const T value = BitCast<T>(HWY_MIN(HWY_MAX(start, value_bits), stop)); + const T actual = GetLane(fxN(d, Set(d, value))); + const T expected = fx1(value); + + // Skip small inputs and outputs on armv7, it flushes subnormals to zero. +#if HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7 + if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) { + continue; + } +#endif + + const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected); + max_ulp = HWY_MAX(max_ulp, ulp); + if (ulp > max_error_ulp) { + fprintf(stderr, + "%s: %s(%f) expected %f actual %f ulp %" PRIu64 " max ulp %u\n", + hwy::TypeName(T(), Lanes(d)).c_str(), name.c_str(), value, + expected, actual, static_cast<uint64_t>(ulp), + static_cast<uint32_t>(max_error_ulp)); + } + } + } + fprintf(stderr, "%s: %s max_ulp %" PRIu64 "\n", + hwy::TypeName(T(), Lanes(d)).c_str(), name.c_str(), max_ulp); + HWY_ASSERT(max_ulp <= max_error_ulp); +} + +#define DEFINE_MATH_TEST_FUNC(NAME) \ + HWY_NOINLINE void TestAll##NAME() { \ + ForFloatTypes(ForPartialVectors<Test##NAME>()); \ + } + +#undef DEFINE_MATH_TEST +#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \ + F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \ + struct Test##NAME { \ + template <class T, class D> \ + HWY_NOINLINE void operator()(T, D d) { \ + if (sizeof(T) == 4) { \ + TestMath<T, D>(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \ + F32_ERROR); \ + } else { \ + TestMath<T, D>(HWY_STR(NAME), F64x1, F64xN, d, \ + static_cast<T>(F64_MIN), static_cast<T>(F64_MAX), \ + F64_ERROR); \ + } \ + } \ + }; \ + DEFINE_MATH_TEST_FUNC(NAME) + +// Floating point values closest to but less than 1.0 +const float kNearOneF = BitCast<float>(0x3F7FFFFF); +const double kNearOneD = BitCast<double>(0x3FEFFFFFFFFFFFFFULL); + +// The discrepancy is unacceptably large for MSYS2 (less accurate libm?), so +// only increase the error tolerance there. +constexpr uint64_t Cos64ULP() { +#if defined(__MINGW32__) + return 23; +#else + return 3; +#endif +} + +constexpr uint64_t ACosh32ULP() { +#if defined(__MINGW32__) + return 8; +#else + return 3; +#endif +} + +// clang-format off +DEFINE_MATH_TEST(Acos, + std::acos, CallAcos, -1.0f, +1.0f, 3, // NEON is 3 instead of 2 + std::acos, CallAcos, -1.0, +1.0, 2) +DEFINE_MATH_TEST(Acosh, + std::acosh, CallAcosh, +1.0f, +FLT_MAX, ACosh32ULP(), + std::acosh, CallAcosh, +1.0, +DBL_MAX, 3) +DEFINE_MATH_TEST(Asin, + std::asin, CallAsin, -1.0f, +1.0f, 4, // ARMv7 is 4 instead of 2 + std::asin, CallAsin, -1.0, +1.0, 2) +DEFINE_MATH_TEST(Asinh, + std::asinh, CallAsinh, -FLT_MAX, +FLT_MAX, 3, + std::asinh, CallAsinh, -DBL_MAX, +DBL_MAX, 3) +DEFINE_MATH_TEST(Atan, + std::atan, CallAtan, -FLT_MAX, +FLT_MAX, 3, + std::atan, CallAtan, -DBL_MAX, +DBL_MAX, 3) +DEFINE_MATH_TEST(Atanh, + std::atanh, CallAtanh, -kNearOneF, +kNearOneF, 4, // NEON is 4 instead of 3 + std::atanh, CallAtanh, -kNearOneD, +kNearOneD, 3) +DEFINE_MATH_TEST(Cos, + std::cos, CallCos, -39000.0f, +39000.0f, 3, + std::cos, CallCos, -39000.0, +39000.0, Cos64ULP()) +DEFINE_MATH_TEST(Exp, + std::exp, CallExp, -FLT_MAX, +104.0f, 1, + std::exp, CallExp, -DBL_MAX, +104.0, 1) +DEFINE_MATH_TEST(Expm1, + std::expm1, CallExpm1, -FLT_MAX, +104.0f, 4, + std::expm1, CallExpm1, -DBL_MAX, +104.0, 4) +DEFINE_MATH_TEST(Log, + std::log, CallLog, +FLT_MIN, +FLT_MAX, 1, + std::log, CallLog, +DBL_MIN, +DBL_MAX, 1) +DEFINE_MATH_TEST(Log10, + std::log10, CallLog10, +FLT_MIN, +FLT_MAX, 2, + std::log10, CallLog10, +DBL_MIN, +DBL_MAX, 2) +DEFINE_MATH_TEST(Log1p, + std::log1p, CallLog1p, +0.0f, +1e37f, 3, // NEON is 3 instead of 2 + std::log1p, CallLog1p, +0.0, +DBL_MAX, 2) +DEFINE_MATH_TEST(Log2, + std::log2, CallLog2, +FLT_MIN, +FLT_MAX, 2, + std::log2, CallLog2, +DBL_MIN, +DBL_MAX, 2) +DEFINE_MATH_TEST(Sin, + std::sin, CallSin, -39000.0f, +39000.0f, 3, + std::sin, CallSin, -39000.0, +39000.0, 4) // MSYS is 4 instead of 3 +DEFINE_MATH_TEST(Sinh, + std::sinh, CallSinh, -80.0f, +80.0f, 4, + std::sinh, CallSinh, -709.0, +709.0, 4) +DEFINE_MATH_TEST(Tanh, + std::tanh, CallTanh, -FLT_MAX, +FLT_MAX, 4, + std::tanh, CallTanh, -DBL_MAX, +DBL_MAX, 4) +// clang-format on + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMathTest); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAcosh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAsin); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAsinh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtan); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllAtanh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllCos); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSin); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllSinh); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllTanh); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/contrib/sort/BUILD b/third_party/highway/hwy/contrib/sort/BUILD new file mode 100644 index 0000000000..af4ed78837 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/BUILD @@ -0,0 +1,193 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +# Unused on Bazel builds, where this is not defined/known; Copybara replaces +# usages with an empty list. +COMPAT = [ + "//buildenv/target:non_prod", # includes mobile/vendor. +] + +# cc_library( +# name = "vxsort", +# srcs = [ +# "vxsort/isa_detection.cpp", +# "vxsort/isa_detection_msvc.cpp", +# "vxsort/isa_detection_sane.cpp", +# "vxsort/machine_traits.avx2.cpp", +# "vxsort/smallsort/avx2_load_mask_tables.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.double.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.float.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.double.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.float.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.cpp", +# "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.cpp", +# "vxsort/vxsort_stats.cpp", +# ], +# hdrs = [ +# "vxsort/alignment.h", +# "vxsort/defs.h", +# "vxsort/isa_detection.h", +# "vxsort/machine_traits.avx2.h", +# "vxsort/machine_traits.avx512.h", +# "vxsort/machine_traits.h", +# "vxsort/packer.h", +# "vxsort/smallsort/bitonic_sort.AVX2.double.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.float.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.double.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.float.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.h", +# "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.h", +# "vxsort/smallsort/bitonic_sort.h", +# "vxsort/vxsort.h", +# "vxsort/vxsort_stats.h", +# ], +# compatible_with = [], +# textual_hdrs = [ +# "vxsort/vxsort_targets_disable.h", +# "vxsort/vxsort_targets_enable_avx2.h", +# "vxsort/vxsort_targets_enable_avx512.h", +# ], +# ) + +cc_library( + name = "vqsort", + srcs = [ + # Split into separate files to reduce MSVC build time. + "vqsort.cc", + "vqsort_128a.cc", + "vqsort_128d.cc", + "vqsort_f32a.cc", + "vqsort_f32d.cc", + "vqsort_f64a.cc", + "vqsort_f64d.cc", + "vqsort_i16a.cc", + "vqsort_i16d.cc", + "vqsort_i32a.cc", + "vqsort_i32d.cc", + "vqsort_i64a.cc", + "vqsort_i64d.cc", + "vqsort_kv64a.cc", + "vqsort_kv64d.cc", + "vqsort_kv128a.cc", + "vqsort_kv128d.cc", + "vqsort_u16a.cc", + "vqsort_u16d.cc", + "vqsort_u32a.cc", + "vqsort_u32d.cc", + "vqsort_u64a.cc", + "vqsort_u64d.cc", + ], + hdrs = [ + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = ["hwy_contrib_EXPORTS"], + textual_hdrs = [ + "shared-inl.h", + "sorting_networks-inl.h", + "traits-inl.h", + "traits128-inl.h", + "vqsort-inl.h", + # Placeholder for internal instrumentation. Do not remove. + ], + deps = [ + # Only if VQSORT_SECURE_RNG is set. + # "//third_party/absl/random", + "//:hwy", + # ":vxsort", # required if HAVE_VXSORT + ], +) + +# ----------------------------------------------------------------------------- +# Internal-only targets + +cc_library( + name = "helpers", + testonly = 1, + textual_hdrs = [ + "algo-inl.h", + "result-inl.h", + ], + deps = [ + ":vqsort", + "//:nanobenchmark", + # Required for HAVE_PDQSORT, but that is unused and this is + # unavailable to Bazel builds, hence commented out. + # "//third_party/boost/allowed", + # Avoid ips4o and thus TBB to work around hwloc build failure. + ], +) + +cc_binary( + name = "print_network", + testonly = 1, + srcs = ["print_network.cc"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + ], +) + +cc_test( + name = "sort_test", + size = "medium", + srcs = ["sort_test.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) + +cc_binary( + name = "bench_sort", + testonly = 1, + srcs = ["bench_sort.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) + +cc_binary( + name = "bench_parallel", + testonly = 1, + srcs = ["bench_parallel.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "@com_google_googletest//:gtest_main", + "//:hwy", + "//:hwy_test_util", + ], +) diff --git a/third_party/highway/hwy/contrib/sort/README.md b/third_party/highway/hwy/contrib/sort/README.md new file mode 100644 index 0000000000..a0051414d3 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/README.md @@ -0,0 +1,87 @@ +# Vectorized and performance-portable Quicksort + +## Introduction + +As of 2022-06-07 this sorts large arrays of built-in types about ten times as +fast as `std::sort`. See also our +[blog post](https://opensource.googleblog.com/2022/06/Vectorized%20and%20performance%20portable%20Quicksort.html) +and [paper](https://arxiv.org/abs/2205.05982). + +## Instructions + +Here are instructions for reproducing our results on Linux and AWS (SVE, NEON). + +### Linux + +Please first ensure golang, and Clang (tested with 13.0.1) are installed via +your system's package manager. + +``` +go install github.com/bazelbuild/bazelisk@latest +git clone https://github.com/google/highway +cd highway +CC=clang CXX=clang++ ~/go/bin/bazelisk build -c opt hwy/contrib/sort:all +bazel-bin/hwy/contrib/sort/sort_test +bazel-bin/hwy/contrib/sort/bench_sort +``` + +### AWS Graviton3 + +Instance config: amazon linux 5.10 arm64, c7g.8xlarge (largest allowed config is +32 vCPU). Initial launch will fail. Wait a few minutes for an email saying the +config is verified, then re-launch. See IPv4 hostname in list of instances. + +`ssh -i /path/key.pem ec2-user@hostname` + +Note that the AWS CMake package is too old for llvm, so we build it first: +``` +wget https://cmake.org/files/v3.23/cmake-3.23.2.tar.gz +tar -xvzf cmake-3.23.2.tar.gz && cd cmake-3.23.2/ +./bootstrap -- -DCMAKE_USE_OPENSSL=OFF +make -j8 && sudo make install +cd .. +``` + +AWS clang is at version 11.1, which generates unnecessary `AND` instructions +which slow down the sort by 1.15x. We tested with clang trunk as of June 13 +(which reports Git hash 8f6512fea000c3a0d394864bb94e524bee375069). To build: + +``` +git clone --depth 1 https://github.com/llvm/llvm-project.git +cd llvm-project +mkdir -p build && cd build +/usr/local/bin/cmake ../llvm -DLLVM_ENABLE_PROJECTS="clang" -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi" -DCMAKE_BUILD_TYPE=Release +make -j32 && sudo make install +``` + +``` +sudo yum install go +go install github.com/bazelbuild/bazelisk@latest +git clone https://github.com/google/highway +cd highway +CC=/usr/local/bin/clang CXX=/usr/local/bin/clang++ ~/go/bin/bazelisk build -c opt --copt=-march=armv8.2-a+sve hwy/contrib/sort:all +bazel-bin/hwy/contrib/sort/sort_test +bazel-bin/hwy/contrib/sort/bench_sort +``` + +The above command line enables SVE, which is currently only available on +Graviton 3. You can also test NEON on the same processor, or other Arm CPUs, by +changing the `-march=` option to `--copt=-march=armv8.2-a+crypto`. Note that +such flags will be unnecessary once Clang supports `#pragma target` for NEON and +SVE intrinsics, as it does for x86. + +## Results + +`bench_sort` outputs the instruction set (AVX3 refers to AVX-512), the sort +algorithm (std for `std::sort`, vq for our vqsort), the type of keys being +sorted (f32 is float), the distribution of keys (uniform32 for uniform random +with range 0-2^32), the number of keys, then the throughput of sorted keys (i.e. +number of key bytes output per second). + +Example excerpt from Xeon 6154 (Skylake-X) CPU clocked at 3 GHz: + +``` +[ RUN ] BenchSortGroup/BenchSort.BenchAllSort/AVX3 + AVX3: std: f32: uniform32: 1.00E+06 54 MB/s ( 1 threads) + AVX3: vq: f32: uniform32: 1.00E+06 1143 MB/s ( 1 threads) +``` diff --git a/third_party/highway/hwy/contrib/sort/algo-inl.h b/third_party/highway/hwy/contrib/sort/algo-inl.h new file mode 100644 index 0000000000..1ebbbd5745 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/algo-inl.h @@ -0,0 +1,513 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +#include <stdint.h> +#include <string.h> // memcpy + +#include <algorithm> // std::sort, std::min, std::max +#include <functional> // std::less, std::greater +#include <thread> // NOLINT +#include <vector> + +#include "hwy/base.h" +#include "hwy/contrib/sort/vqsort.h" + +// Third-party algorithms +#define HAVE_AVX2SORT 0 +#define HAVE_IPS4O 0 +// When enabling, consider changing max_threads (required for Table 1a) +#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) +#define HAVE_PDQSORT 0 +#define HAVE_SORT512 0 +#define HAVE_VXSORT 0 + +#if HAVE_AVX2SORT +HWY_PUSH_ATTRIBUTES("avx2,avx") +#include "avx2sort.h" //NOLINT +HWY_POP_ATTRIBUTES +#endif +#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O +#include "third_party/ips4o/include/ips4o.hpp" +#include "third_party/ips4o/include/ips4o/thread_pool.hpp" +#endif +#if HAVE_PDQSORT +#include "third_party/boost/allowed/sort/sort.hpp" +#endif +#if HAVE_SORT512 +#include "sort512.h" //NOLINT +#endif + +// vxsort is difficult to compile for multiple targets because it also uses +// .cpp files, and we'd also have to #undef its include guards. Instead, compile +// only for AVX2 or AVX3 depending on this macro. +#define VXSORT_AVX3 1 +#if HAVE_VXSORT +// inlined from vxsort_targets_enable_avx512 (must close before end of header) +#ifdef __GNUC__ +#ifdef __clang__ +#if VXSORT_AVX3 +#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \ + apply_to = any(function)) +#else +#pragma clang attribute push(__attribute__((target("avx2"))), \ + apply_to = any(function)) +#endif // VXSORT_AVX3 + +#else +#pragma GCC push_options +#if VXSORT_AVX3 +#pragma GCC target("avx512f,avx512dq") +#else +#pragma GCC target("avx2") +#endif // VXSORT_AVX3 +#endif +#endif + +#if VXSORT_AVX3 +#include "vxsort/machine_traits.avx512.h" +#else +#include "vxsort/machine_traits.avx2.h" +#endif // VXSORT_AVX3 +#include "vxsort/vxsort.h" +#ifdef __GNUC__ +#ifdef __clang__ +#pragma clang attribute pop +#else +#pragma GCC pop_options +#endif +#endif +#endif // HAVE_VXSORT + +namespace hwy { + +enum class Dist { kUniform8, kUniform16, kUniform32 }; + +static inline std::vector<Dist> AllDist() { + return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32}; +} + +static inline const char* DistName(Dist dist) { + switch (dist) { + case Dist::kUniform8: + return "uniform8"; + case Dist::kUniform16: + return "uniform16"; + case Dist::kUniform32: + return "uniform32"; + } + return "unreachable"; +} + +template <typename T> +class InputStats { + public: + void Notify(T value) { + min_ = std::min(min_, value); + max_ = std::max(max_, value); + // Converting to integer would truncate floats, multiplying to save digits + // risks overflow especially when casting, so instead take the sum of the + // bit representations as the checksum. + uint64_t bits = 0; + static_assert(sizeof(T) <= 8, "Expected a built-in type"); + CopyBytes<sizeof(T)>(&value, &bits); // not same size + sum_ += bits; + count_ += 1; + } + + bool operator==(const InputStats& other) const { + if (count_ != other.count_) { + HWY_ABORT("count %d vs %d\n", static_cast<int>(count_), + static_cast<int>(other.count_)); + } + + if (min_ != other.min_ || max_ != other.max_) { + HWY_ABORT("minmax %f/%f vs %f/%f\n", static_cast<double>(min_), + static_cast<double>(max_), static_cast<double>(other.min_), + static_cast<double>(other.max_)); + } + + // Sum helps detect duplicated/lost values + if (sum_ != other.sum_) { + HWY_ABORT("Sum mismatch %g %g; min %g max %g\n", + static_cast<double>(sum_), static_cast<double>(other.sum_), + static_cast<double>(min_), static_cast<double>(max_)); + } + + return true; + } + + private: + T min_ = hwy::HighestValue<T>(); + T max_ = hwy::LowestValue<T>(); + uint64_t sum_ = 0; + size_t count_ = 0; +}; + +enum class Algo { +#if HAVE_AVX2SORT + kSEA, +#endif +#if HAVE_IPS4O + kIPS4O, +#endif +#if HAVE_PARALLEL_IPS4O + kParallelIPS4O, +#endif +#if HAVE_PDQSORT + kPDQ, +#endif +#if HAVE_SORT512 + kSort512, +#endif +#if HAVE_VXSORT + kVXSort, +#endif + kStd, + kVQSort, + kHeap, +}; + +static inline const char* AlgoName(Algo algo) { + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return "sea"; +#endif +#if HAVE_IPS4O + case Algo::kIPS4O: + return "ips4o"; +#endif +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + return "par_ips4o"; +#endif +#if HAVE_PDQSORT + case Algo::kPDQ: + return "pdq"; +#endif +#if HAVE_SORT512 + case Algo::kSort512: + return "sort512"; +#endif +#if HAVE_VXSORT + case Algo::kVXSort: + return "vxsort"; +#endif + case Algo::kStd: + return "std"; + case Algo::kVQSort: + return "vq"; + case Algo::kHeap: + return "heap"; + } + return "unreachable"; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#endif + +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +class Xorshift128Plus { + static HWY_INLINE uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + public: + // Generates two vectors of 64-bit seeds via SplitMix64 and stores into + // `seeds`. Generating these afresh in each ChoosePivot is too expensive. + template <class DU64> + static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) { + seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < 2 * Lanes(du64); ++i) { + seeds[i] = SplitMix64(seeds[i - 1]); + } + } + + // Need to pass in the state because vector cannot be class members. + template <class VU64> + static VU64 RandomBits(VU64& state0, VU64& state1) { + VU64 s1 = state0; + VU64 s0 = state1; + const VU64 bits = Add(s1, s0); + state0 = s0; + s1 = Xor(s1, ShiftLeft<23>(s1)); + state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + return bits; + } +}; + +template <class D, class VU64, HWY_IF_NOT_FLOAT_D(D)> +Vec<D> RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) { + const VU64 bits = Xorshift128Plus::RandomBits(s0, s1); + return BitCast(d, And(bits, mask)); +} + +// It is important to avoid denormals, which are flushed to zero by SIMD but not +// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. +template <class DF, class VU64, HWY_IF_FLOAT_D(DF)> +Vec<DF> RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) { + using TF = TFromD<DF>; + const RebindToUnsigned<decltype(df)> du; + using VU = Vec<decltype(du)>; + + const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask); + +#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types + using TU = MakeUnsigned<TF>; + const VU bits = Set(du, static_cast<TU>(GetLane(bits64) & LimitsMax<TU>())); +#else + const VU bits = BitCast(du, bits64); +#endif + // Avoid NaN/denormal by only generating values in [1, 2), i.e. random + // mantissas with the exponent taken from the representation of 1.0. + const VU k1 = BitCast(du, Set(df, TF{1.0})); + const VU mantissa_mask = Set(du, MantissaMask<TF>()); + const VU representation = OrAnd(k1, bits, mantissa_mask); + return BitCast(df, representation); +} + +template <class DU64> +Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { + switch (sizeof_t) { + case 2: + return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull + : 0xFFFFFFFFFFFFFFFFull); + case 4: + return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull + : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull + : 0xFFFFFFFFFFFFFFFFull); + case 8: + return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull + : (dist == Dist::kUniform16) ? 0x000000000000FFFFull + : 0x00000000FFFFFFFFull); + default: + HWY_ABORT("Logic error"); + return Zero(du64); + } +} + +template <typename T> +InputStats<T> GenerateInput(const Dist dist, T* v, size_t num) { + SortTag<uint64_t> du64; + using VU64 = Vec<decltype(du64)>; + const size_t N64 = Lanes(du64); + auto seeds = hwy::AllocateAligned<uint64_t>(2 * N64); + Xorshift128Plus::GenerateSeeds(du64, seeds.get()); + VU64 s0 = Load(du64, seeds.get()); + VU64 s1 = Load(du64, seeds.get() + N64); + +#if HWY_TARGET == HWY_SCALAR + const Sisd<T> d; +#else + const Repartition<T, decltype(du64)> d; +#endif + using V = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU64 mask = MaskForDist(du64, dist, sizeof(T)); + auto buf = hwy::AllocateAligned<T>(N); + + size_t i = 0; + for (; i + N <= num; i += N) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, v + i); + } + if (i < num) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, buf.get()); + memcpy(v + i, buf.get(), (num - i) * sizeof(T)); + } + + InputStats<T> input_stats; + for (size_t i = 0; i < num; ++i) { + input_stats.Notify(v[i]); + } + return input_stats; +} + +struct ThreadLocal { + Sorter sorter; +}; + +struct SharedState { +#if HAVE_PARALLEL_IPS4O + const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a + ips4o::StdThreadPool pool{static_cast<int>( + HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))}; +#endif + std::vector<ThreadLocal> tls{1}; +}; + +// Bridge from keys (passed to Run) to lanes as expected by HeapSort. For +// non-128-bit keys they are the same: +template <class Order, typename KeyType, HWY_IF_NOT_LANE_SIZE(KeyType, 16)> +void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) { + using detail::TraitsLane; + using detail::SharedTraits; + if (Order().IsAscending()) { + const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st; + return detail::HeapSort(st, keys, num_keys); + } else { + const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st; + return detail::HeapSort(st, keys, num_keys); + } +} + +#if VQSORT_ENABLED +template <class Order> +void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) { + using detail::SharedTraits; + using detail::Traits128; + uint64_t* lanes = reinterpret_cast<uint64_t*>(keys); + const size_t num_lanes = num_keys * 2; + if (Order().IsAscending()) { + const SharedTraits<Traits128<detail::OrderAscending128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } else { + const SharedTraits<Traits128<detail::OrderDescending128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } +} + +template <class Order> +void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) { + using detail::SharedTraits; + using detail::Traits128; + uint64_t* lanes = reinterpret_cast<uint64_t*>(keys); + const size_t num_lanes = num_keys * 2; + if (Order().IsAscending()) { + const SharedTraits<Traits128<detail::OrderAscendingKV128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } else { + const SharedTraits<Traits128<detail::OrderDescendingKV128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } +} +#endif // VQSORT_ENABLED + +template <class Order, typename KeyType> +void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num, + SharedState& shared, size_t thread) { + const std::less<KeyType> less; + const std::greater<KeyType> greater; + + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return avx2::quicksort(inout, static_cast<int>(num)); +#endif + +#if HAVE_IPS4O + case Algo::kIPS4O: + if (Order().IsAscending()) { + return ips4o::sort(inout, inout + num, less); + } else { + return ips4o::sort(inout, inout + num, greater); + } +#endif + +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + if (Order().IsAscending()) { + return ips4o::parallel::sort(inout, inout + num, less, shared.pool); + } else { + return ips4o::parallel::sort(inout, inout + num, greater, shared.pool); + } +#endif + +#if HAVE_SORT512 + case Algo::kSort512: + HWY_ABORT("not supported"); + // return Sort512::Sort(inout, num); +#endif + +#if HAVE_PDQSORT + case Algo::kPDQ: + if (Order().IsAscending()) { + return boost::sort::pdqsort_branchless(inout, inout + num, less); + } else { + return boost::sort::pdqsort_branchless(inout, inout + num, greater); + } +#endif + +#if HAVE_VXSORT + case Algo::kVXSort: { +#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2) + fprintf(stderr, "Do not call for target %s\n", + hwy::TargetName(HWY_TARGET)); + return; +#else +#if VXSORT_AVX3 + vxsort::vxsort<KeyType, vxsort::AVX512> vx; +#else + vxsort::vxsort<KeyType, vxsort::AVX2> vx; +#endif + if (Order().IsAscending()) { + return vx.sort(inout, inout + num - 1); + } else { + fprintf(stderr, "Skipping VX - does not support descending order\n"); + return; + } +#endif // enabled for this target + } +#endif // HAVE_VXSORT + + case Algo::kStd: + if (Order().IsAscending()) { + return std::sort(inout, inout + num, less); + } else { + return std::sort(inout, inout + num, greater); + } + + case Algo::kVQSort: + return shared.tls[thread].sorter(inout, num, Order()); + + case Algo::kHeap: + return CallHeapSort<Order>(inout, num); + + default: + HWY_ABORT("Not implemented"); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/bench_parallel.cc b/third_party/highway/hwy/contrib/sort/bench_parallel.cc new file mode 100644 index 0000000000..1c8c928e21 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/bench_parallel.cc @@ -0,0 +1,238 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Concurrent, independent sorts for generating more memory traffic and testing +// scalability. + +#include <stdint.h> +#include <stdio.h> + +#include <condition_variable> //NOLINT +#include <functional> +#include <memory> +#include <mutex> //NOLINT +#include <thread> //NOLINT +#include <utility> +#include <vector> + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_parallel.cc" //NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/aligned_allocator.h" +// Last +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +class ThreadPool { + public: + // Starts the given number of worker threads and blocks until they are ready. + explicit ThreadPool( + const size_t num_threads = std::thread::hardware_concurrency()) + : num_threads_(num_threads) { + HWY_ASSERT(num_threads_ > 0); + threads_.reserve(num_threads_); + for (size_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back(ThreadFunc, this, i); + } + + WorkersReadyBarrier(); + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Waits for all threads to exit. + ~ThreadPool() { + StartWorkers(kWorkerExit); + + for (std::thread& thread : threads_) { + thread.join(); + } + } + + size_t NumThreads() const { return threads_.size(); } + + template <class Func> + void RunOnThreads(size_t max_threads, const Func& func) { + task_ = &CallClosure<Func>; + data_ = &func; + StartWorkers(max_threads); + WorkersReadyBarrier(); + } + + private: + // After construction and between calls to Run, workers are "ready", i.e. + // waiting on worker_start_cv_. They are "started" by sending a "command" + // and notifying all worker_start_cv_ waiters. (That is why all workers + // must be ready/waiting - otherwise, the notification will not reach all of + // them and the main thread waits in vain for them to report readiness.) + using WorkerCommand = uint64_t; + + static constexpr WorkerCommand kWorkerWait = ~1ULL; + static constexpr WorkerCommand kWorkerExit = ~2ULL; + + // Calls a closure (lambda with captures). + template <class Closure> + static void CallClosure(const void* f, size_t thread) { + (*reinterpret_cast<const Closure*>(f))(thread); + } + + void WorkersReadyBarrier() { + std::unique_lock<std::mutex> lock(mutex_); + // Typically only a single iteration. + while (workers_ready_ != threads_.size()) { + workers_ready_cv_.wait(lock); + } + workers_ready_ = 0; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + } + + // Precondition: all workers are ready. + void StartWorkers(const WorkerCommand worker_command) { + std::unique_lock<std::mutex> lock(mutex_); + worker_start_command_ = worker_command; + // Workers will need this lock, so release it before they wake up. + lock.unlock(); + worker_start_cv_.notify_all(); + } + + static void ThreadFunc(ThreadPool* self, size_t thread) { + // Until kWorkerExit command received: + for (;;) { + std::unique_lock<std::mutex> lock(self->mutex_); + // Notify main thread that this thread is ready. + if (++self->workers_ready_ == self->num_threads_) { + self->workers_ready_cv_.notify_one(); + } + RESUME_WAIT: + // Wait for a command. + self->worker_start_cv_.wait(lock); + const WorkerCommand command = self->worker_start_command_; + switch (command) { + case kWorkerWait: // spurious wakeup: + goto RESUME_WAIT; // lock still held, avoid incrementing ready. + case kWorkerExit: + return; // exits thread + default: + break; + } + + lock.unlock(); + // Command is the maximum number of threads that should run the task. + HWY_ASSERT(command < self->NumThreads()); + if (thread < command) { + self->task_(self->data_, thread); + } + } + } + + const size_t num_threads_; + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector<std::thread> threads_; + + std::mutex mutex_; // guards both cv and their variables. + std::condition_variable workers_ready_cv_; + size_t workers_ready_ = 0; + std::condition_variable worker_start_cv_; + WorkerCommand worker_start_command_; + + // Written by main thread, read by workers (after mutex lock/unlock). + std::function<void(const void*, size_t)> task_; // points to CallClosure + const void* data_; // points to caller's Func +}; + +template <class Traits> +void RunWithoutVerify(Traits st, const Dist dist, const size_t num_keys, + const Algo algo, SharedState& shared, size_t thread) { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + using Order = typename Traits::Order; + const size_t num_lanes = num_keys * st.LanesPerKey(); + auto aligned = hwy::AllocateAligned<LaneType>(num_lanes); + + (void)GenerateInput(dist, aligned.get(), num_lanes); + + const Timestamp t0; + Run<Order>(algo, reinterpret_cast<KeyType*>(aligned.get()), num_keys, shared, + thread); + HWY_ASSERT(aligned[0] < aligned[num_lanes - 1]); +} + +void BenchParallel() { + // Not interested in benchmark results for other targets on x86 + if (HWY_ARCH_X86 && (HWY_TARGET != HWY_AVX2 && HWY_TARGET != HWY_AVX3)) { + return; + } + + ThreadPool pool; + const size_t NT = pool.NumThreads(); + + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<int64_t>>> st; + using KeyType = typename decltype(st)::KeyType; + const size_t num_keys = size_t{100} * 1000 * 1000; + +#if HAVE_IPS4O + const Algo algo = Algo::kIPS4O; +#else + const Algo algo = Algo::kVQSort; +#endif + const Dist dist = Dist::kUniform32; + + SharedState shared; + shared.tls.resize(NT); + + std::vector<Result> results; + for (size_t nt = 1; nt < NT; nt += HWY_MAX(1, NT / 16)) { + Timestamp t0; + // Default capture because MSVC wants algo/dist but clang does not. + pool.RunOnThreads(nt, [=, &shared](size_t thread) { + RunWithoutVerify(st, dist, num_keys, algo, shared, thread); + }); + const double sec = SecondsSince(t0); + results.emplace_back(algo, dist, num_keys, nt, sec, sizeof(KeyType), + st.KeyString()); + results.back().Print(); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(BenchParallel); +HWY_EXPORT_AND_TEST_P(BenchParallel, BenchParallel); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/bench_sort.cc b/third_party/highway/hwy/contrib/sort/bench_sort.cc new file mode 100644 index 0000000000..a668fde907 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/bench_sort.cc @@ -0,0 +1,310 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdint.h> +#include <stdio.h> + +#include <vector> + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_sort.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" // SharedTraits +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +// Mode for larger sorts because M1 is able to access more than the per-core +// share of L2, so 1M elements might still be in cache. +#define SORT_100M 0 + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +// Defined within HWY_ONCE, used by BenchAllSort. +extern int64_t first_sort_target; + +namespace HWY_NAMESPACE { +namespace { +using detail::TraitsLane; +using detail::OrderAscending; +using detail::OrderDescending; +using detail::SharedTraits; + +#if VQSORT_ENABLED || HWY_IDE +using detail::OrderAscending128; +using detail::OrderAscendingKV128; +using detail::Traits128; + +template <class Traits> +HWY_NOINLINE void BenchPartition() { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const SortTag<LaneType> d; + detail::SharedTraits<Traits> st; + const Dist dist = Dist::kUniform8; + double sum = 0.0; + + detail::Generator rng(&sum, 123); // for ChoosePivot + + const size_t max_log2 = AdjustedLog2Reps(20); + for (size_t log2 = max_log2; log2 < max_log2 + 1; ++log2) { + const size_t num_lanes = 1ull << log2; + const size_t num_keys = num_lanes / st.LanesPerKey(); + auto aligned = hwy::AllocateAligned<LaneType>(num_lanes); + auto buf = hwy::AllocateAligned<LaneType>( + HWY_MAX(hwy::SortConstants::PartitionBufNum(Lanes(d)), + hwy::SortConstants::PivotBufNum(sizeof(LaneType), Lanes(d)))); + + std::vector<double> seconds; + const size_t num_reps = (1ull << (14 - log2 / 2)) * 30; + for (size_t rep = 0; rep < num_reps; ++rep) { + (void)GenerateInput(dist, aligned.get(), num_lanes); + + // The pivot value can influence performance. Do exactly what vqsort will + // do so that the performance (influenced by prefetching and branch + // prediction) is likely to predict the actual performance inside vqsort. + detail::DrawSamples(d, st, aligned.get(), num_lanes, buf.get(), rng); + detail::SortSamples(d, st, buf.get()); + auto pivot = detail::ChoosePivotByRank(d, st, buf.get()); + + const Timestamp t0; + detail::Partition(d, st, aligned.get(), num_lanes - 1, pivot, buf.get()); + seconds.push_back(SecondsSince(t0)); + // 'Use' the result to prevent optimizing out the partition. + sum += static_cast<double>(aligned.get()[num_lanes / 2]); + } + + Result(Algo::kVQSort, dist, num_keys, 1, SummarizeMeasurements(seconds), + sizeof(KeyType), st.KeyString()) + .Print(); + } + HWY_ASSERT(sum != 999999); // Prevent optimizing out +} + +HWY_NOINLINE void BenchAllPartition() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + BenchPartition<TraitsLane<OrderDescending<float>>>(); + BenchPartition<TraitsLane<OrderDescending<int32_t>>>(); + BenchPartition<TraitsLane<OrderDescending<int64_t>>>(); + BenchPartition<Traits128<OrderAscending128>>(); + // BenchPartition<Traits128<OrderDescending128>>(); + BenchPartition<Traits128<OrderAscendingKV128>>(); +} + +template <class Traits> +HWY_NOINLINE void BenchBase(std::vector<Result>& results) { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const SortTag<LaneType> d; + detail::SharedTraits<Traits> st; + const Dist dist = Dist::kUniform32; + + const size_t N = Lanes(d); + const size_t num_lanes = SortConstants::BaseCaseNum(N); + const size_t num_keys = num_lanes / st.LanesPerKey(); + auto keys = hwy::AllocateAligned<LaneType>(num_lanes); + auto buf = hwy::AllocateAligned<LaneType>(num_lanes + N); + + std::vector<double> seconds; + double sum = 0; // prevents elision + constexpr size_t kMul = AdjustedReps(600); // ensures long enough to measure + + for (size_t rep = 0; rep < 30; ++rep) { + InputStats<LaneType> input_stats = + GenerateInput(dist, keys.get(), num_lanes); + + const Timestamp t0; + for (size_t i = 0; i < kMul; ++i) { + detail::BaseCase(d, st, keys.get(), keys.get() + num_lanes, num_lanes, + buf.get()); + sum += static_cast<double>(keys[0]); + } + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + HWY_ASSERT(VerifySort(st, input_stats, keys.get(), num_lanes, "BenchBase")); + } + HWY_ASSERT(sum < 1E99); + results.emplace_back(Algo::kVQSort, dist, num_keys * kMul, 1, + SummarizeMeasurements(seconds), sizeof(KeyType), + st.KeyString()); +} + +HWY_NOINLINE void BenchAllBase() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + std::vector<Result> results; + BenchBase<TraitsLane<OrderAscending<float>>>(results); + BenchBase<TraitsLane<OrderDescending<int64_t>>>(results); + BenchBase<Traits128<OrderAscending128>>(results); + for (const Result& r : results) { + r.Print(); + } +} + +#else +void BenchAllPartition() {} +void BenchAllBase() {} +#endif // VQSORT_ENABLED + +std::vector<Algo> AlgoForBench() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_PARALLEL_IPS4O + Algo::kParallelIPS4O, +#elif HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif +// Only include if we're compiling for the target it supports. +#if HAVE_VXSORT && ((VXSORT_AVX3 && HWY_TARGET == HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET == HWY_AVX2)) + Algo::kVXSort, +#endif + +#if !HAVE_PARALLEL_IPS4O +#if !SORT_100M + // These are 10-20x slower, but that's OK for the default size when we + // are not testing the parallel nor 100M modes. + Algo::kStd, Algo::kHeap, +#endif + + Algo::kVQSort, // only ~4x slower, but not required for Table 1a +#endif + }; +} + +template <class Traits> +HWY_NOINLINE void BenchSort(size_t num_keys) { + if (first_sort_target == 0) first_sort_target = HWY_TARGET; + + SharedState shared; + detail::SharedTraits<Traits> st; + using Order = typename Traits::Order; + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const size_t num_lanes = num_keys * st.LanesPerKey(); + auto aligned = hwy::AllocateAligned<LaneType>(num_lanes); + + const size_t reps = num_keys > 1000 * 1000 ? 10 : 30; + + for (Algo algo : AlgoForBench()) { + // Other algorithms don't depend on the vector instructions, so only run + // them for the first target. +#if !HAVE_VXSORT + if (algo != Algo::kVQSort && HWY_TARGET != first_sort_target) { + continue; + } +#endif + + for (Dist dist : AllDist()) { + std::vector<double> seconds; + for (size_t rep = 0; rep < reps; ++rep) { + InputStats<LaneType> input_stats = + GenerateInput(dist, aligned.get(), num_lanes); + + const Timestamp t0; + Run<Order>(algo, reinterpret_cast<KeyType*>(aligned.get()), num_keys, + shared, /*thread=*/0); + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + HWY_ASSERT( + VerifySort(st, input_stats, aligned.get(), num_lanes, "BenchSort")); + } + Result(algo, dist, num_keys, 1, SummarizeMeasurements(seconds), + sizeof(KeyType), st.KeyString()) + .Print(); + } // dist + } // algo +} + +HWY_NOINLINE void BenchAllSort() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + constexpr size_t K = 1000; + constexpr size_t M = K * K; + (void)K; + (void)M; + for (size_t num_keys : { +#if HAVE_PARALLEL_IPS4O || SORT_100M + 100 * M, +#else + 1 * M, +#endif + }) { + BenchSort<TraitsLane<OrderAscending<float>>>(num_keys); + // BenchSort<TraitsLane<OrderDescending<double>>>(num_keys); + // BenchSort<TraitsLane<OrderAscending<int16_t>>>(num_keys); + BenchSort<TraitsLane<OrderDescending<int32_t>>>(num_keys); + BenchSort<TraitsLane<OrderAscending<int64_t>>>(num_keys); + // BenchSort<TraitsLane<OrderDescending<uint16_t>>>(num_keys); + // BenchSort<TraitsLane<OrderDescending<uint32_t>>>(num_keys); + // BenchSort<TraitsLane<OrderAscending<uint64_t>>>(num_keys); + +#if !HAVE_VXSORT && VQSORT_ENABLED + BenchSort<Traits128<OrderAscending128>>(num_keys); + BenchSort<Traits128<OrderAscendingKV128>>(num_keys); +#endif + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +int64_t first_sort_target = 0; // none run yet +namespace { +HWY_BEFORE_TEST(BenchSort); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllPartition); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllBase); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllSort); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/print_network.cc b/third_party/highway/hwy/contrib/sort/print_network.cc new file mode 100644 index 0000000000..59cfebcfbd --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/print_network.cc @@ -0,0 +1,191 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdio.h> + +#include <algorithm> + +#include "hwy/base.h" + +// Based on A.7 in "Entwurf und Implementierung vektorisierter +// Sortieralgorithmen" and code by Mark Blacher. +void PrintMergeNetwork16x2() { + for (int i = 8; i < 16; ++i) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.SwapAdjacent(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1<kOrder>(d, v%x);\n", i, i); + } + printf("\n"); +} + +void PrintMergeNetwork16x4() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.Reverse4(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.Reverse4(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.Reverse4(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.Reverse4(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse4(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1<kOrder>(d, v%x);\n", i, i); + } +} + +void PrintMergeNetwork16x8() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.ReverseKeys8(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse8(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance2<kOrder>(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1<kOrder>(d, v%x);\n", i, i); + } +} + +void PrintMergeNetwork16x16() { + printf("\n"); + + for (int i = 8; i < 16; ++i) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i, i); + } + for (int i = 0; i < 8; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 15 - i); + } + for (int i = 0; i < 4; ++i) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 4, i + 4); + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 12, i + 12); + } + for (int i = 0; i < 4; ++i) { + printf("st.Sort2(d, v%x, v%x);\n", i, 7 - i); + printf("st.Sort2(d, v%x, v%x);\n", i + 8, 15 - i); + } + for (int i = 0; i < 16; i += 4) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 2, i + 2); + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 3, i + 3); + } + for (int i = 0; i < 16; i += 4) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 3); + printf("st.Sort2(d, v%x, v%x);\n", i + 1, i + 2); + } + for (int i = 0; i < 16; i += 2) { + printf("v%x = st.ReverseKeys16(d, v%x);\n", i + 1, i + 1); + } + for (int i = 0; i < 16; i += 2) { + printf("st.Sort2(d, v%x, v%x);\n", i, i + 1); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsReverse16<kOrder>(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance4<kOrder>(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance2<kOrder>(d, v%x);\n", i, i); + } + for (int i = 0; i < 16; ++i) { + printf("v%x = st.SortPairsDistance1<kOrder>(d, v%x);\n", i, i); + } +} + +int main(int argc, char** argv) { + PrintMergeNetwork16x2(); + PrintMergeNetwork16x4(); + PrintMergeNetwork16x8(); + PrintMergeNetwork16x16(); + return 0; +} diff --git a/third_party/highway/hwy/contrib/sort/result-inl.h b/third_party/highway/hwy/contrib/sort/result-inl.h new file mode 100644 index 0000000000..f3d842dfbd --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/result-inl.h @@ -0,0 +1,139 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/algo-inl.h" + +// Normal include guard for non-SIMD parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +#include <time.h> + +#include <algorithm> // std::sort +#include <string> + +#include "hwy/base.h" +#include "hwy/nanobenchmark.h" + +namespace hwy { + +struct Timestamp { + Timestamp() { t = platform::Now(); } + double t; +}; + +static inline double SecondsSince(const Timestamp& t0) { + const Timestamp t1; + return t1.t - t0.t; +} + +// Returns trimmed mean (we don't want to run an out-of-L3-cache sort often +// enough for the mode to be reliable). +static inline double SummarizeMeasurements(std::vector<double>& seconds) { + std::sort(seconds.begin(), seconds.end()); + double sum = 0; + int count = 0; + const size_t num = seconds.size(); + for (size_t i = num / 4; i < num / 2; ++i) { + sum += seconds[i]; + count += 1; + } + return sum / count; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct Result { + Result() {} + Result(const Algo algo, Dist dist, size_t num_keys, size_t num_threads, + double sec, size_t sizeof_key, const std::string& key_name) + : target(HWY_TARGET), + algo(algo), + dist(dist), + num_keys(num_keys), + num_threads(num_threads), + sec(sec), + sizeof_key(sizeof_key), + key_name(key_name) {} + + void Print() const { + const double bytes = static_cast<double>(num_keys) * + static_cast<double>(num_threads) * + static_cast<double>(sizeof_key); + printf("%10s: %12s: %7s: %9s: %.2E %4.0f MB/s (%2zu threads)\n", + hwy::TargetName(target), AlgoName(algo), key_name.c_str(), + DistName(dist), static_cast<double>(num_keys), bytes * 1E-6 / sec, + num_threads); + } + + int64_t target; + Algo algo; + Dist dist; + size_t num_keys = 0; + size_t num_threads = 0; + double sec = 0.0; + size_t sizeof_key = 0; + std::string key_name; +}; + +template <class Traits, typename LaneType> +bool VerifySort(Traits st, const InputStats<LaneType>& input_stats, + const LaneType* out, size_t num_lanes, const char* caller) { + constexpr size_t N1 = st.LanesPerKey(); + HWY_ASSERT(num_lanes >= N1); + + InputStats<LaneType> output_stats; + // Ensure it matches the sort order + for (size_t i = 0; i < num_lanes - N1; i += N1) { + output_stats.Notify(out[i]); + if (N1 == 2) output_stats.Notify(out[i + 1]); + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (st.Compare1(out + i + N1, out + i)) { + printf("%s: i=%d of %d lanes: N1=%d %5.0f %5.0f vs. %5.0f %5.0f\n\n", + caller, static_cast<int>(i), static_cast<int>(num_lanes), + static_cast<int>(N1), static_cast<double>(out[i + 1]), + static_cast<double>(out[i + 0]), + static_cast<double>(out[i + N1 + 1]), + static_cast<double>(out[i + N1])); + HWY_ABORT("%d-bit sort is incorrect\n", + static_cast<int>(sizeof(LaneType) * 8 * N1)); + } + } + output_stats.Notify(out[num_lanes - N1]); + if (N1 == 2) output_stats.Notify(out[num_lanes - N1 + 1]); + + return input_stats == output_stats; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/shared-inl.h b/third_party/highway/hwy/contrib/sort/shared-inl.h new file mode 100644 index 0000000000..735f95ee22 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/shared-inl.h @@ -0,0 +1,134 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions shared between vqsort-inl and sorting_networks-inl. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Internal constants - these are to avoid magic numbers/literals and cannot be +// changed without also changing the associated code. +struct SortConstants { +// SortingNetwork reshapes its input into a matrix. This is the maximum number +// of *keys* per vector. +#if HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + static constexpr size_t kMaxCols = 8; // avoid build timeout/stack overflow +#else + static constexpr size_t kMaxCols = 16; // enough for u32 in 512-bit vector +#endif + + // 16 rows is a compromise between using the 32 AVX-512/SVE/RVV registers, + // fitting within 16 AVX2 registers with only a few spills, keeping BaseCase + // code size reasonable (7 KiB for AVX-512 and 16 cols), and minimizing the + // extra logN factor for larger networks (for which only loose upper bounds + // on size are known). + static constexpr size_t kMaxRowsLog2 = 4; + static constexpr size_t kMaxRows = size_t{1} << kMaxRowsLog2; + + static constexpr HWY_INLINE size_t BaseCaseNum(size_t N) { + return kMaxRows * HWY_MIN(N, kMaxCols); + } + + // Unrolling is important (pipelining and amortizing branch mispredictions); + // 2x is sufficient to reach full memory bandwidth on SKX in Partition, but + // somewhat slower for sorting than 4x. + // + // To change, must also update left + 3 * N etc. in the loop. + static constexpr size_t kPartitionUnroll = 4; + + static constexpr HWY_INLINE size_t PartitionBufNum(size_t N) { + // The main loop reads kPartitionUnroll vectors, and first loads from + // both left and right beforehand, so it requires min = 2 * + // kPartitionUnroll vectors. To handle smaller amounts (only guaranteed + // >= BaseCaseNum), we partition the right side into a buffer. We need + // another vector at the end so CompressStore does not overwrite anything. + return (2 * kPartitionUnroll + 1) * N; + } + + // Chunk := group of keys loaded for sampling a pivot. Matches the typical + // cache line size of 64 bytes to get maximum benefit per L2 miss. Sort() + // ensures vectors are no larger than that, so this can be independent of the + // vector size and thus constexpr. + static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t) { + return 64 / sizeof_t; + } + + static constexpr HWY_INLINE size_t PivotBufNum(size_t sizeof_t, size_t N) { + // 3 chunks of medians, 1 chunk of median medians plus two padding vectors. + return (3 + 1) * LanesPerChunk(sizeof_t) + 2 * N; + } + + template <typename T> + static constexpr HWY_INLINE size_t BufNum(size_t N) { + // One extra for padding plus another for full-vector loads. + return HWY_MAX(BaseCaseNum(N) + 2 * N, + HWY_MAX(PartitionBufNum(N), PivotBufNum(sizeof(T), N))); + } + + template <typename T> + static constexpr HWY_INLINE size_t BufBytes(size_t vector_size) { + return sizeof(T) * BufNum<T>(vector_size / sizeof(T)); + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#endif + +#include "hwy/highway.h" + +// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and +// Arm v7 debug. +#undef VQSORT_ENABLED +#if (HWY_TARGET == HWY_SCALAR) || \ + (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \ + (HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) +#define VQSORT_ENABLED 0 +#else +#define VQSORT_ENABLED 1 +#endif + +namespace hwy { +namespace HWY_NAMESPACE { + +// Default tag / vector width selector. +#if HWY_TARGET == HWY_RVV +// Use LMUL = 1/2; for SEW=64 this ends up emulated via vsetvl. +template <typename T> +using SortTag = ScalableTag<T, -1>; +#else +template <typename T> +using SortTag = ScalableTag<T>; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/sort_test.cc b/third_party/highway/hwy/contrib/sort/sort_test.cc new file mode 100644 index 0000000000..2d1f1d5169 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/sort_test.cc @@ -0,0 +1,626 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stdint.h> +#include <stdio.h> +#include <string.h> // memcpy + +#include <unordered_map> +#include <vector> + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/sort/vqsort.h" +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // BaseCase +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +using detail::OrderAscending; +using detail::OrderDescending; +using detail::SharedTraits; +using detail::TraitsLane; +#if VQSORT_ENABLED || HWY_IDE +using detail::OrderAscending128; +using detail::OrderAscendingKV128; +using detail::OrderAscendingKV64; +using detail::OrderDescending128; +using detail::OrderDescendingKV128; +using detail::OrderDescendingKV64; +using detail::Traits128; + +template <class Traits> +static HWY_NOINLINE void TestMedian3() { + using LaneType = typename Traits::LaneType; + using D = CappedTag<LaneType, 1>; + SharedTraits<Traits> st; + const D d; + using V = Vec<D>; + for (uint32_t bits = 0; bits < 8; ++bits) { + const V v0 = Set(d, LaneType{(bits & (1u << 0)) ? 1u : 0u}); + const V v1 = Set(d, LaneType{(bits & (1u << 1)) ? 1u : 0u}); + const V v2 = Set(d, LaneType{(bits & (1u << 2)) ? 1u : 0u}); + const LaneType m = GetLane(detail::MedianOf3(st, v0, v1, v2)); + // If at least half(rounded up) of bits are 1, so is the median. + const size_t count = PopCount(bits); + HWY_ASSERT_EQ((count >= 2) ? static_cast<LaneType>(1) : 0, m); + } +} + +HWY_NOINLINE void TestAllMedian() { + TestMedian3<TraitsLane<OrderAscending<uint64_t> > >(); +} + +template <class Traits> +static HWY_NOINLINE void TestBaseCaseAscDesc() { + using LaneType = typename Traits::LaneType; + SharedTraits<Traits> st; + const SortTag<LaneType> d; + const size_t N = Lanes(d); + const size_t base_case_num = SortConstants::BaseCaseNum(N); + const size_t N1 = st.LanesPerKey(); + + constexpr int kDebug = 0; + auto aligned_lanes = hwy::AllocateAligned<LaneType>(N + base_case_num + N); + auto buf = hwy::AllocateAligned<LaneType>(base_case_num + 2 * N); + + std::vector<size_t> lengths; + lengths.push_back(HWY_MAX(1, N1)); + lengths.push_back(3 * N1); + lengths.push_back(base_case_num / 2); + lengths.push_back(base_case_num / 2 + N1); + lengths.push_back(base_case_num - N1); + lengths.push_back(base_case_num); + + std::vector<size_t> misalignments; + misalignments.push_back(0); + misalignments.push_back(1); + if (N >= 6) misalignments.push_back(N / 2 - 1); + misalignments.push_back(N / 2); + misalignments.push_back(N / 2 + 1); + misalignments.push_back(HWY_MIN(2 * N / 3 + 3, size_t{N - 1})); + + for (bool asc : {false, true}) { + for (size_t len : lengths) { + for (size_t misalign : misalignments) { + LaneType* HWY_RESTRICT lanes = aligned_lanes.get() + misalign; + if (kDebug) { + printf("============%s asc %d N1 %d len %d misalign %d\n", + st.KeyString().c_str(), asc, static_cast<int>(N1), + static_cast<int>(len), static_cast<int>(misalign)); + } + + for (size_t i = 0; i < misalign; ++i) { + aligned_lanes[i] = hwy::LowestValue<LaneType>(); + } + InputStats<LaneType> input_stats; + for (size_t i = 0; i < len; ++i) { + lanes[i] = asc ? static_cast<LaneType>(LaneType(i) + 1) + : static_cast<LaneType>(LaneType(len) - LaneType(i)); + input_stats.Notify(lanes[i]); + if (kDebug >= 2) { + printf("%3zu: %f\n", i, static_cast<double>(lanes[i])); + } + } + for (size_t i = len; i < base_case_num + N; ++i) { + lanes[i] = hwy::LowestValue<LaneType>(); + } + + detail::BaseCase(d, st, lanes, lanes + len, len, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = 0; i < len; ++i) { + printf("%3zu: %f\n", i, static_cast<double>(lanes[i])); + } + } + + HWY_ASSERT(VerifySort(st, input_stats, lanes, len, "BaseAscDesc")); + for (size_t i = 0; i < misalign; ++i) { + if (aligned_lanes[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun misalign at %d\n", static_cast<int>(i)); + } + for (size_t i = len; i < base_case_num + N; ++i) { + if (lanes[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun right at %d\n", static_cast<int>(i)); + } + } // misalign + } // len + } // asc +} + +template <class Traits> +static HWY_NOINLINE void TestBaseCase01() { + using LaneType = typename Traits::LaneType; + SharedTraits<Traits> st; + const SortTag<LaneType> d; + const size_t N = Lanes(d); + const size_t base_case_num = SortConstants::BaseCaseNum(N); + const size_t N1 = st.LanesPerKey(); + + constexpr int kDebug = 0; + auto lanes = hwy::AllocateAligned<LaneType>(base_case_num + N); + auto buf = hwy::AllocateAligned<LaneType>(base_case_num + 2 * N); + + std::vector<size_t> lengths; + lengths.push_back(HWY_MAX(1, N1)); + lengths.push_back(3 * N1); + lengths.push_back(base_case_num / 2); + lengths.push_back(base_case_num / 2 + N1); + lengths.push_back(base_case_num - N1); + lengths.push_back(base_case_num); + + for (size_t len : lengths) { + if (kDebug) { + printf("============%s 01 N1 %d len %d\n", st.KeyString().c_str(), + static_cast<int>(N1), static_cast<int>(len)); + } + const uint64_t kMaxBits = AdjustedLog2Reps(HWY_MIN(len, size_t{14})); + for (uint64_t bits = 0; bits < ((1ull << kMaxBits) - 1); ++bits) { + InputStats<LaneType> input_stats; + for (size_t i = 0; i < len; ++i) { + lanes[i] = (i < 64 && (bits & (1ull << i))) ? 1 : 0; + input_stats.Notify(lanes[i]); + if (kDebug >= 2) { + printf("%3zu: %f\n", i, static_cast<double>(lanes[i])); + } + } + for (size_t i = len; i < base_case_num + N; ++i) { + lanes[i] = hwy::LowestValue<LaneType>(); + } + + detail::BaseCase(d, st, lanes.get(), lanes.get() + len, len, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = 0; i < len; ++i) { + printf("%3zu: %f\n", i, static_cast<double>(lanes[i])); + } + } + + HWY_ASSERT(VerifySort(st, input_stats, lanes.get(), len, "Base01")); + for (size_t i = len; i < base_case_num + N; ++i) { + if (lanes[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun right at %d\n", static_cast<int>(i)); + } + } // bits + } // len +} + +template <class Traits> +static HWY_NOINLINE void TestBaseCase() { + TestBaseCaseAscDesc<Traits>(); + TestBaseCase01<Traits>(); +} + +HWY_NOINLINE void TestAllBaseCase() { + // Workaround for stack overflow on MSVC debug. +#if defined(_MSC_VER) + return; +#endif + TestBaseCase<TraitsLane<OrderAscending<int32_t> > >(); + TestBaseCase<TraitsLane<OrderDescending<int64_t> > >(); + TestBaseCase<Traits128<OrderAscending128> >(); + TestBaseCase<Traits128<OrderDescending128> >(); +} + +template <class Traits> +static HWY_NOINLINE void VerifyPartition( + Traits st, typename Traits::LaneType* HWY_RESTRICT lanes, size_t left, + size_t border, size_t right, const size_t N1, + const typename Traits::LaneType* pivot) { + /* for (size_t i = left; i < right; ++i) { + if (i == border) printf("--\n"); + printf("%4zu: %3d\n", i, lanes[i]); + }*/ + + HWY_ASSERT(left % N1 == 0); + HWY_ASSERT(border % N1 == 0); + HWY_ASSERT(right % N1 == 0); + const bool asc = typename Traits::Order().IsAscending(); + for (size_t i = left; i < border; i += N1) { + if (st.Compare1(pivot, lanes + i)) { + HWY_ABORT( + "%s: asc %d left[%d] piv %.0f %.0f compares before %.0f %.0f " + "border %d", + st.KeyString().c_str(), asc, static_cast<int>(i), + static_cast<double>(pivot[1]), static_cast<double>(pivot[0]), + static_cast<double>(lanes[i + 1]), static_cast<double>(lanes[i + 0]), + static_cast<int>(border)); + } + } + for (size_t i = border; i < right; i += N1) { + if (!st.Compare1(pivot, lanes + i)) { + HWY_ABORT( + "%s: asc %d right[%d] piv %.0f %.0f compares after %.0f %.0f " + "border %d", + st.KeyString().c_str(), asc, static_cast<int>(i), + static_cast<double>(pivot[1]), static_cast<double>(pivot[0]), + static_cast<double>(lanes[i + 1]), static_cast<double>(lanes[i]), + static_cast<int>(border)); + } + } +} + +template <class Traits> +static HWY_NOINLINE void TestPartition() { + using LaneType = typename Traits::LaneType; + const SortTag<LaneType> d; + SharedTraits<Traits> st; + const bool asc = typename Traits::Order().IsAscending(); + const size_t N = Lanes(d); + constexpr int kDebug = 0; + const size_t base_case_num = SortConstants::BaseCaseNum(N); + // left + len + align + const size_t total = 32 + (base_case_num + 4 * HWY_MAX(N, 4)) + 2 * N; + auto aligned_lanes = hwy::AllocateAligned<LaneType>(total); + auto buf = hwy::AllocateAligned<LaneType>(SortConstants::PartitionBufNum(N)); + + const size_t N1 = st.LanesPerKey(); + for (bool in_asc : {false, true}) { + for (int left_i : {0, 1, 4, 6, 7, 8, 12, 15, 22, 28, 30, 31}) { + const size_t left = static_cast<size_t>(left_i) & ~(N1 - 1); + for (size_t ofs : {N, N + 1, N + 3, 2 * N, 2 * N + 2, 2 * N + 3, + 3 * N - 1, 4 * N - 3, 4 * N - 2}) { + const size_t len = (base_case_num + ofs) & ~(N1 - 1); + for (LaneType pivot1 : + {LaneType(0), LaneType(len / 3), LaneType(len / 2), + LaneType(2 * len / 3), LaneType(len)}) { + const LaneType pivot2[2] = {pivot1, 0}; + const auto pivot = st.SetKey(d, pivot2); + for (size_t misalign = 0; misalign < N; + misalign += st.LanesPerKey()) { + LaneType* HWY_RESTRICT lanes = aligned_lanes.get() + misalign; + const size_t right = left + len; + if (kDebug) { + printf( + "=========%s asc %d left %d len %d right %d piv %.0f %.0f\n", + st.KeyString().c_str(), asc, static_cast<int>(left), + static_cast<int>(len), static_cast<int>(right), + static_cast<double>(pivot2[1]), + static_cast<double>(pivot2[0])); + } + + for (size_t i = 0; i < misalign; ++i) { + aligned_lanes[i] = hwy::LowestValue<LaneType>(); + } + for (size_t i = 0; i < left; ++i) { + lanes[i] = hwy::LowestValue<LaneType>(); + } + std::unordered_map<LaneType, int> counts; + for (size_t i = left; i < right; ++i) { + lanes[i] = static_cast<LaneType>( + in_asc ? LaneType(i + 1) - static_cast<LaneType>(left) + : static_cast<LaneType>(right) - LaneType(i)); + ++counts[lanes[i]]; + if (kDebug >= 2) { + printf("%3zu: %f\n", i, static_cast<double>(lanes[i])); + } + } + for (size_t i = right; i < total - misalign; ++i) { + lanes[i] = hwy::LowestValue<LaneType>(); + } + + size_t border = + left + detail::Partition(d, st, lanes + left, right - left, + pivot, buf.get()); + + if (kDebug >= 2) { + printf("out>>>>>>\n"); + for (size_t i = left; i < right; ++i) { + printf("%3zu: %f\n", i, static_cast<double>(lanes[i])); + } + for (size_t i = right; i < total - misalign; ++i) { + printf("%3zu: sentinel %f\n", i, static_cast<double>(lanes[i])); + } + } + for (size_t i = left; i < right; ++i) { + --counts[lanes[i]]; + } + for (auto kv : counts) { + if (kv.second != 0) { + PrintValue(kv.first); + HWY_ABORT("Incorrect count %d\n", kv.second); + } + } + VerifyPartition(st, lanes, left, border, right, N1, pivot2); + for (size_t i = 0; i < misalign; ++i) { + if (aligned_lanes[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun misalign at %d\n", static_cast<int>(i)); + } + for (size_t i = 0; i < left; ++i) { + if (lanes[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun left at %d\n", static_cast<int>(i)); + } + for (size_t i = right; i < total - misalign; ++i) { + if (lanes[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun right at %d\n", static_cast<int>(i)); + } + } // misalign + } // pivot + } // len + } // left + } // asc +} + +HWY_NOINLINE void TestAllPartition() { + TestPartition<TraitsLane<OrderDescending<int32_t> > >(); + TestPartition<Traits128<OrderAscending128> >(); + +#if !HWY_IS_DEBUG_BUILD + TestPartition<TraitsLane<OrderAscending<int16_t> > >(); + TestPartition<TraitsLane<OrderAscending<int64_t> > >(); + TestPartition<TraitsLane<OrderDescending<float> > >(); +#if HWY_HAVE_FLOAT64 + TestPartition<TraitsLane<OrderDescending<double> > >(); +#endif + TestPartition<Traits128<OrderDescending128> >(); +#endif +} + +// (used for sample selection for choosing a pivot) +template <typename TU> +static HWY_NOINLINE void TestRandomGenerator() { + static_assert(!hwy::IsSigned<TU>(), ""); + SortTag<TU> du; + const size_t N = Lanes(du); + + detail::Generator rng(&N, N); + + const size_t lanes_per_block = HWY_MAX(64 / sizeof(TU), N); // power of two + + for (uint32_t num_blocks = 2; num_blocks < 100000; + num_blocks = 3 * num_blocks / 2) { + // Generate some numbers and ensure all are in range + uint64_t sum = 0; + constexpr size_t kReps = 10000; + for (size_t rep = 0; rep < kReps; ++rep) { + const uint32_t bits = rng() & 0xFFFFFFFF; + const size_t index = detail::RandomChunkIndex(num_blocks, bits); + HWY_ASSERT(((index + 1) * lanes_per_block) <= + num_blocks * lanes_per_block); + + sum += index; + } + + // Also ensure the mean is near the middle of the range + const double expected = (num_blocks - 1) / 2.0; + const double actual = static_cast<double>(sum) / kReps; + HWY_ASSERT(0.9 * expected <= actual && actual <= 1.1 * expected); + } +} + +HWY_NOINLINE void TestAllGenerator() { + TestRandomGenerator<uint32_t>(); + TestRandomGenerator<uint64_t>(); +} + +#else +static void TestAllMedian() {} +static void TestAllBaseCase() {} +static void TestAllPartition() {} +static void TestAllGenerator() {} +#endif // VQSORT_ENABLED + +// Remembers input, and compares results to that of a reference algorithm. +template <class Traits> +class CompareResults { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + + public: + CompareResults(const LaneType* in, size_t num_lanes) { + copy_.resize(num_lanes); + memcpy(copy_.data(), in, num_lanes * sizeof(LaneType)); + } + + bool Verify(const LaneType* output) { +#if HAVE_PDQSORT + const Algo reference = Algo::kPDQ; +#else + const Algo reference = Algo::kStd; +#endif + SharedState shared; + using Order = typename Traits::Order; + const Traits st; + const size_t num_keys = copy_.size() / st.LanesPerKey(); + Run<Order>(reference, reinterpret_cast<KeyType*>(copy_.data()), num_keys, + shared, /*thread=*/0); +#if VQSORT_PRINT >= 3 + fprintf(stderr, "\nExpected:\n"); + for (size_t i = 0; i < copy_.size(); ++i) { + PrintValue(copy_[i]); + } + fprintf(stderr, "\n"); +#endif + for (size_t i = 0; i < copy_.size(); ++i) { + if (copy_[i] != output[i]) { + if (sizeof(KeyType) == 16) { + fprintf(stderr, + "%s Asc %d mismatch at %d of %d: %" PRIu64 " %" PRIu64 "\n", + st.KeyString().c_str(), Order().IsAscending(), + static_cast<int>(i), static_cast<int>(copy_.size()), + static_cast<uint64_t>(copy_[i]), + static_cast<uint64_t>(output[i])); + } else { + fprintf(stderr, "Type %s Asc %d mismatch at %d of %d: ", + st.KeyString().c_str(), Order().IsAscending(), + static_cast<int>(i), static_cast<int>(copy_.size())); + PrintValue(copy_[i]); + PrintValue(output[i]); + fprintf(stderr, "\n"); + } + return false; + } + } + return true; + } + + private: + std::vector<LaneType> copy_; +}; + +std::vector<Algo> AlgoForTest() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif + Algo::kHeap, Algo::kVQSort, + }; +} + +template <class Traits> +void TestSort(size_t num_lanes) { +// Workaround for stack overflow on clang-cl (/F 8388608 does not help). +#if defined(_MSC_VER) + return; +#endif + using Order = typename Traits::Order; + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + SharedState shared; + SharedTraits<Traits> st; + + // Round up to a whole number of keys. + num_lanes += (st.Is128() && (num_lanes & 1)); + const size_t num_keys = num_lanes / st.LanesPerKey(); + + constexpr size_t kMaxMisalign = 16; + auto aligned = + hwy::AllocateAligned<LaneType>(kMaxMisalign + num_lanes + kMaxMisalign); + for (Algo algo : AlgoForTest()) { + for (Dist dist : AllDist()) { + for (size_t misalign : {size_t{0}, size_t{st.LanesPerKey()}, + size_t{3 * st.LanesPerKey()}, kMaxMisalign / 2}) { + LaneType* lanes = aligned.get() + misalign; + + // Set up red zones before/after the keys to sort + for (size_t i = 0; i < misalign; ++i) { + aligned[i] = hwy::LowestValue<LaneType>(); + } + for (size_t i = 0; i < kMaxMisalign; ++i) { + lanes[num_lanes + i] = hwy::HighestValue<LaneType>(); + } +#if HWY_IS_MSAN + __msan_poison(aligned.get(), misalign * sizeof(LaneType)); + __msan_poison(lanes + num_lanes, kMaxMisalign * sizeof(LaneType)); +#endif + InputStats<LaneType> input_stats = + GenerateInput(dist, lanes, num_lanes); + + CompareResults<Traits> compare(lanes, num_lanes); + Run<Order>(algo, reinterpret_cast<KeyType*>(lanes), num_keys, shared, + /*thread=*/0); + HWY_ASSERT(compare.Verify(lanes)); + HWY_ASSERT(VerifySort(st, input_stats, lanes, num_lanes, "TestSort")); + + // Check red zones +#if HWY_IS_MSAN + __msan_unpoison(aligned.get(), misalign * sizeof(LaneType)); + __msan_unpoison(lanes + num_lanes, kMaxMisalign * sizeof(LaneType)); +#endif + for (size_t i = 0; i < misalign; ++i) { + if (aligned[i] != hwy::LowestValue<LaneType>()) + HWY_ABORT("Overrun left at %d\n", static_cast<int>(i)); + } + for (size_t i = num_lanes; i < num_lanes + kMaxMisalign; ++i) { + if (lanes[i] != hwy::HighestValue<LaneType>()) + HWY_ABORT("Overrun right at %d\n", static_cast<int>(i)); + } + } // misalign + } // dist + } // algo +} + +void TestAllSort() { + for (int num : {129, 504, 3 * 1000, 34567}) { + const size_t num_lanes = AdjustedReps(static_cast<size_t>(num)); + TestSort<TraitsLane<OrderAscending<int16_t> > >(num_lanes); + TestSort<TraitsLane<OrderDescending<uint16_t> > >(num_lanes); + + TestSort<TraitsLane<OrderDescending<int32_t> > >(num_lanes); + TestSort<TraitsLane<OrderDescending<uint32_t> > >(num_lanes); + + TestSort<TraitsLane<OrderAscending<int64_t> > >(num_lanes); + TestSort<TraitsLane<OrderAscending<uint64_t> > >(num_lanes); + + // WARNING: for float types, SIMD comparisons will flush denormals to + // zero, causing mismatches with scalar sorts. In this test, we avoid + // generating denormal inputs. + TestSort<TraitsLane<OrderAscending<float> > >(num_lanes); +#if HWY_HAVE_FLOAT64 // protects algo-inl's GenerateRandom + if (Sorter::HaveFloat64()) { + TestSort<TraitsLane<OrderDescending<double> > >(num_lanes); + } +#endif + +// Our HeapSort does not support 128-bit keys. +#if VQSORT_ENABLED + TestSort<Traits128<OrderAscending128> >(num_lanes); + TestSort<Traits128<OrderDescending128> >(num_lanes); + + TestSort<TraitsLane<OrderAscendingKV64> >(num_lanes); + TestSort<TraitsLane<OrderDescendingKV64> >(num_lanes); + + TestSort<Traits128<OrderAscendingKV128> >(num_lanes); + TestSort<Traits128<OrderDescendingKV128> >(num_lanes); +#endif + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(SortTest); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllMedian); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllBaseCase); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllPartition); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllGenerator); +HWY_EXPORT_AND_TEST_P(SortTest, TestAllSort); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h b/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h new file mode 100644 index 0000000000..2615a04b68 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h @@ -0,0 +1,707 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#endif + +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED + +using Constants = hwy::SortConstants; + +// ------------------------------ SharedTraits + +// Code shared between all traits. It's unclear whether these can profitably be +// specialized for Lane vs Block, or optimized like SortPairsDistance1 using +// Compare/DupOdd. +template <class Base> +struct SharedTraits : public Base { + // Conditionally swaps lane 0 with 2, 1 with 3 etc. + template <class D> + HWY_INLINE Vec<D> SortPairsDistance2(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->SwapAdjacentPairs(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template <class D> + HWY_INLINE Vec<D> SortPairsReverse8(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->ReverseKeys8(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template <class D> + HWY_INLINE Vec<D> SortPairsReverse16(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + static_assert(Constants::kMaxCols <= 16, "Need actual Reverse16"); + Vec<D> swapped = base->ReverseKeys(d, v); + base->Sort2(d, v, swapped); + return ConcatUpperLower(d, swapped, v); // 8 = half of the vector + } +}; + +// ------------------------------ Sorting network + +// (Green's irregular) sorting network for independent columns in 16 vectors. +template <class D, class Traits, class V = Vec<D>> +HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + st.Sort2(d, v8, va); + st.Sort2(d, v9, vb); + st.Sort2(d, vc, ve); + st.Sort2(d, vd, vf); + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + st.Sort2(d, v8, vc); + st.Sort2(d, v9, vd); + st.Sort2(d, va, ve); + st.Sort2(d, vb, vf); + st.Sort2(d, v0, v8); + st.Sort2(d, v1, v9); + st.Sort2(d, v2, va); + st.Sort2(d, v3, vb); + st.Sort2(d, v4, vc); + st.Sort2(d, v5, vd); + st.Sort2(d, v6, ve); + st.Sort2(d, v7, vf); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v3, vc); + st.Sort2(d, v7, vb); + st.Sort2(d, vd, ve); + st.Sort2(d, v4, v8); + st.Sort2(d, v1, v2); + st.Sort2(d, v1, v4); + st.Sort2(d, v7, vd); + st.Sort2(d, v2, v8); + st.Sort2(d, vb, ve); + st.Sort2(d, v2, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vd); + st.Sort2(d, v3, v8); + st.Sort2(d, v7, vc); + st.Sort2(d, v3, v5); + st.Sort2(d, v6, v8); + st.Sort2(d, v7, v9); + st.Sort2(d, va, vc); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v7, v8); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vc); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); +} + +// ------------------------------ Merging networks + +// Blacher's hybrid bitonic/odd-even networks, generated by print_network.cc. + +template <class D, class Traits, class V = Vec<D>> +HWY_INLINE void Merge2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys2(d, v8); + v9 = st.ReverseKeys2(d, v9); + va = st.ReverseKeys2(d, va); + vb = st.ReverseKeys2(d, vb); + vc = st.ReverseKeys2(d, vc); + vd = st.ReverseKeys2(d, vd); + ve = st.ReverseKeys2(d, ve); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys2(d, v4); + vc = st.ReverseKeys2(d, vc); + v5 = st.ReverseKeys2(d, v5); + vd = st.ReverseKeys2(d, vd); + v6 = st.ReverseKeys2(d, v6); + ve = st.ReverseKeys2(d, ve); + v7 = st.ReverseKeys2(d, v7); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys2(d, v2); + v3 = st.ReverseKeys2(d, v3); + v6 = st.ReverseKeys2(d, v6); + v7 = st.ReverseKeys2(d, v7); + va = st.ReverseKeys2(d, va); + vb = st.ReverseKeys2(d, vb); + ve = st.ReverseKeys2(d, ve); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + v9 = st.ReverseKeys2(d, v9); + vb = st.ReverseKeys2(d, vb); + vd = st.ReverseKeys2(d, vd); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template <class D, class Traits, class V = Vec<D>> +HWY_INLINE void Merge4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys4(d, v8); + v9 = st.ReverseKeys4(d, v9); + va = st.ReverseKeys4(d, va); + vb = st.ReverseKeys4(d, vb); + vc = st.ReverseKeys4(d, vc); + vd = st.ReverseKeys4(d, vd); + ve = st.ReverseKeys4(d, ve); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys4(d, v4); + vc = st.ReverseKeys4(d, vc); + v5 = st.ReverseKeys4(d, v5); + vd = st.ReverseKeys4(d, vd); + v6 = st.ReverseKeys4(d, v6); + ve = st.ReverseKeys4(d, ve); + v7 = st.ReverseKeys4(d, v7); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys4(d, v2); + v3 = st.ReverseKeys4(d, v3); + v6 = st.ReverseKeys4(d, v6); + v7 = st.ReverseKeys4(d, v7); + va = st.ReverseKeys4(d, va); + vb = st.ReverseKeys4(d, vb); + ve = st.ReverseKeys4(d, ve); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + v9 = st.ReverseKeys4(d, v9); + vb = st.ReverseKeys4(d, vb); + vd = st.ReverseKeys4(d, vd); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + v8 = st.SortPairsReverse4(d, v8); + v9 = st.SortPairsReverse4(d, v9); + va = st.SortPairsReverse4(d, va); + vb = st.SortPairsReverse4(d, vb); + vc = st.SortPairsReverse4(d, vc); + vd = st.SortPairsReverse4(d, vd); + ve = st.SortPairsReverse4(d, ve); + vf = st.SortPairsReverse4(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template <class D, class Traits, class V = Vec<D>> +HWY_INLINE void Merge8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + v8 = st.ReverseKeys8(d, v8); + v9 = st.ReverseKeys8(d, v9); + va = st.ReverseKeys8(d, va); + vb = st.ReverseKeys8(d, vb); + vc = st.ReverseKeys8(d, vc); + vd = st.ReverseKeys8(d, vd); + ve = st.ReverseKeys8(d, ve); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys8(d, v4); + vc = st.ReverseKeys8(d, vc); + v5 = st.ReverseKeys8(d, v5); + vd = st.ReverseKeys8(d, vd); + v6 = st.ReverseKeys8(d, v6); + ve = st.ReverseKeys8(d, ve); + v7 = st.ReverseKeys8(d, v7); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys8(d, v2); + v3 = st.ReverseKeys8(d, v3); + v6 = st.ReverseKeys8(d, v6); + v7 = st.ReverseKeys8(d, v7); + va = st.ReverseKeys8(d, va); + vb = st.ReverseKeys8(d, vb); + ve = st.ReverseKeys8(d, ve); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys8(d, v1); + v3 = st.ReverseKeys8(d, v3); + v5 = st.ReverseKeys8(d, v5); + v7 = st.ReverseKeys8(d, v7); + v9 = st.ReverseKeys8(d, v9); + vb = st.ReverseKeys8(d, vb); + vd = st.ReverseKeys8(d, vd); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse8(d, v0); + v1 = st.SortPairsReverse8(d, v1); + v2 = st.SortPairsReverse8(d, v2); + v3 = st.SortPairsReverse8(d, v3); + v4 = st.SortPairsReverse8(d, v4); + v5 = st.SortPairsReverse8(d, v5); + v6 = st.SortPairsReverse8(d, v6); + v7 = st.SortPairsReverse8(d, v7); + v8 = st.SortPairsReverse8(d, v8); + v9 = st.SortPairsReverse8(d, v9); + va = st.SortPairsReverse8(d, va); + vb = st.SortPairsReverse8(d, vb); + vc = st.SortPairsReverse8(d, vc); + vd = st.SortPairsReverse8(d, vd); + ve = st.SortPairsReverse8(d, ve); + vf = st.SortPairsReverse8(d, vf); + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +// Unused on MSVC, see below +#if !HWY_COMPILER_MSVC + +template <class D, class Traits, class V = Vec<D>> +HWY_INLINE void Merge16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, + V& vd, V& ve, V& vf) { + v8 = st.ReverseKeys16(d, v8); + v9 = st.ReverseKeys16(d, v9); + va = st.ReverseKeys16(d, va); + vb = st.ReverseKeys16(d, vb); + vc = st.ReverseKeys16(d, vc); + vd = st.ReverseKeys16(d, vd); + ve = st.ReverseKeys16(d, ve); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + v4 = st.ReverseKeys16(d, v4); + vc = st.ReverseKeys16(d, vc); + v5 = st.ReverseKeys16(d, v5); + vd = st.ReverseKeys16(d, vd); + v6 = st.ReverseKeys16(d, v6); + ve = st.ReverseKeys16(d, ve); + v7 = st.ReverseKeys16(d, v7); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v7); + st.Sort2(d, v8, vf); + st.Sort2(d, v1, v6); + st.Sort2(d, v9, ve); + st.Sort2(d, v2, v5); + st.Sort2(d, va, vd); + st.Sort2(d, v3, v4); + st.Sort2(d, vb, vc); + v2 = st.ReverseKeys16(d, v2); + v3 = st.ReverseKeys16(d, v3); + v6 = st.ReverseKeys16(d, v6); + v7 = st.ReverseKeys16(d, v7); + va = st.ReverseKeys16(d, va); + vb = st.ReverseKeys16(d, vb); + ve = st.ReverseKeys16(d, ve); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + v1 = st.ReverseKeys16(d, v1); + v3 = st.ReverseKeys16(d, v3); + v5 = st.ReverseKeys16(d, v5); + v7 = st.ReverseKeys16(d, v7); + v9 = st.ReverseKeys16(d, v9); + vb = st.ReverseKeys16(d, vb); + vd = st.ReverseKeys16(d, vd); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + v0 = st.SortPairsReverse16(d, v0); + v1 = st.SortPairsReverse16(d, v1); + v2 = st.SortPairsReverse16(d, v2); + v3 = st.SortPairsReverse16(d, v3); + v4 = st.SortPairsReverse16(d, v4); + v5 = st.SortPairsReverse16(d, v5); + v6 = st.SortPairsReverse16(d, v6); + v7 = st.SortPairsReverse16(d, v7); + v8 = st.SortPairsReverse16(d, v8); + v9 = st.SortPairsReverse16(d, v9); + va = st.SortPairsReverse16(d, va); + vb = st.SortPairsReverse16(d, vb); + vc = st.SortPairsReverse16(d, vc); + vd = st.SortPairsReverse16(d, vd); + ve = st.SortPairsReverse16(d, ve); + vf = st.SortPairsReverse16(d, vf); + v0 = st.SortPairsDistance4(d, v0); + v1 = st.SortPairsDistance4(d, v1); + v2 = st.SortPairsDistance4(d, v2); + v3 = st.SortPairsDistance4(d, v3); + v4 = st.SortPairsDistance4(d, v4); + v5 = st.SortPairsDistance4(d, v5); + v6 = st.SortPairsDistance4(d, v6); + v7 = st.SortPairsDistance4(d, v7); + v8 = st.SortPairsDistance4(d, v8); + v9 = st.SortPairsDistance4(d, v9); + va = st.SortPairsDistance4(d, va); + vb = st.SortPairsDistance4(d, vb); + vc = st.SortPairsDistance4(d, vc); + vd = st.SortPairsDistance4(d, vd); + ve = st.SortPairsDistance4(d, ve); + vf = st.SortPairsDistance4(d, vf); + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +#endif // !HWY_COMPILER_MSVC + +// Reshapes `buf` into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. +// +// References: +// https://drops.dagstuhl.de/opus/volltexte/2021/13775/pdf/LIPIcs-SEA-2021-3.pdf +// https://github.com/simd-sorting/fast-and-robust/blob/master/avx2_sort_demo/avx2sort.h +// "Entwurf und Implementierung vektorisierter Sortieralgorithmen" (M. Blacher) +template <class Traits, class V> +HWY_INLINE void SortingNetwork(Traits st, size_t cols, V& v0, V& v1, V& v2, + V& v3, V& v4, V& v5, V& v6, V& v7, V& v8, V& v9, + V& va, V& vb, V& vc, V& vd, V& ve, V& vf) { + const CappedTag<typename Traits::LaneType, Constants::kMaxCols> d; + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // The network width depends on the number of keys, not lanes. + constexpr size_t kLanesPerKey = st.LanesPerKey(); + const size_t keys = cols / kLanesPerKey; + constexpr size_t kMaxKeys = MaxLanes(d) / kLanesPerKey; + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Checking MaxLanes avoids generating HWY_ASSERT code for the unreachable + // code paths: if MaxLanes < 2, then keys <= cols < 2. + if (HWY_LIKELY(keys >= 2 && kMaxKeys >= 2)) { + Merge2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, + vf); + + if (HWY_LIKELY(keys >= 4 && kMaxKeys >= 4)) { + Merge4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, + vf); + + if (HWY_LIKELY(keys >= 8 && kMaxKeys >= 8)) { + Merge8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, + ve, vf); + + // Avoids build timeout. Must match #if condition in kMaxCols. +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + if (HWY_LIKELY(keys >= 16 && kMaxKeys >= 16)) { + Merge16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, + ve, vf); + + static_assert(Constants::kMaxCols <= 16, "Add more branches"); + } +#endif + } + } + } +} + +// As above, but loads from/stores to `buf`. This ensures full vectors are +// aligned, and enables loads/stores without bounds checks. +// +// NOINLINE because this is large and called twice from vqsort-inl.h. +template <class Traits, typename T> +HWY_NOINLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) { + const CappedTag<T, Constants::kMaxCols> d; + using V = decltype(Zero(d)); + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // These are aligned iff cols == Lanes(d). We prefer unaligned/non-constexpr + // offsets to duplicating this code for every value of cols. + static_assert(Constants::kMaxRows == 16, "Update loads/stores/args"); + V v0 = LoadU(d, buf + 0x0 * cols); + V v1 = LoadU(d, buf + 0x1 * cols); + V v2 = LoadU(d, buf + 0x2 * cols); + V v3 = LoadU(d, buf + 0x3 * cols); + V v4 = LoadU(d, buf + 0x4 * cols); + V v5 = LoadU(d, buf + 0x5 * cols); + V v6 = LoadU(d, buf + 0x6 * cols); + V v7 = LoadU(d, buf + 0x7 * cols); + V v8 = LoadU(d, buf + 0x8 * cols); + V v9 = LoadU(d, buf + 0x9 * cols); + V va = LoadU(d, buf + 0xa * cols); + V vb = LoadU(d, buf + 0xb * cols); + V vc = LoadU(d, buf + 0xc * cols); + V vd = LoadU(d, buf + 0xd * cols); + V ve = LoadU(d, buf + 0xe * cols); + V vf = LoadU(d, buf + 0xf * cols); + + SortingNetwork(st, cols, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, + vd, ve, vf); + + StoreU(v0, d, buf + 0x0 * cols); + StoreU(v1, d, buf + 0x1 * cols); + StoreU(v2, d, buf + 0x2 * cols); + StoreU(v3, d, buf + 0x3 * cols); + StoreU(v4, d, buf + 0x4 * cols); + StoreU(v5, d, buf + 0x5 * cols); + StoreU(v6, d, buf + 0x6 * cols); + StoreU(v7, d, buf + 0x7 * cols); + StoreU(v8, d, buf + 0x8 * cols); + StoreU(v9, d, buf + 0x9 * cols); + StoreU(va, d, buf + 0xa * cols); + StoreU(vb, d, buf + 0xb * cols); + StoreU(vc, d, buf + 0xc * cols); + StoreU(vd, d, buf + 0xd * cols); + StoreU(ve, d, buf + 0xe * cols); + StoreU(vf, d, buf + 0xf * cols); +} + +#else +template <class Base> +struct SharedTraits : public Base {}; +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/traits-inl.h b/third_party/highway/hwy/contrib/sort/traits-inl.h new file mode 100644 index 0000000000..8dfc639bbd --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/traits-inl.h @@ -0,0 +1,568 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#endif + +#include <string> + +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/contrib/sort/vqsort.h" // SortDescending +#include "hwy/highway.h" +#include "hwy/print.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED || HWY_IDE + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +template <typename T> +struct KeyLane { + static constexpr bool Is128() { return false; } + // False indicates the entire key (i.e. lane) should be compared. KV stands + // for key-value. + static constexpr bool IsKV() { return false; } + constexpr size_t LanesPerKey() const { return 1; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = T; + // What type to pass to Sorter::operator(). + using KeyType = T; + + std::string KeyString() const { + char string100[100]; + hwy::detail::TypeName(hwy::detail::MakeTypeInfo<KeyType>(), 1, string100); + return string100; + } + + // For HeapSort + HWY_INLINE void Swap(T* a, T* b) const { + const T temp = *a; + *a = *b; + *b = temp; + } + + template <class V, class M> + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressNot(keys, mask); + } + + // Broadcasts one key into a vector + template <class D> + HWY_INLINE Vec<D> SetKey(D d, const T* key) const { + return Set(d, *key); + } + + template <class D> + HWY_INLINE Mask<D> EqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const { + return Eq(a, b); + } + + template <class D> + HWY_INLINE Mask<D> NotEqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const { + return Ne(a, b); + } + + // For keys=lanes, any difference counts. + template <class D> + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned<D> du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); + } + + HWY_INLINE bool Equal1(const T* a, const T* b) const { return *a == *b; } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const { + return Reverse(d, v); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys2(D d, Vec<D> v) const { + return Reverse2(d, v); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys4(D d, Vec<D> v) const { + return Reverse4(d, v); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys8(D d, Vec<D> v) const { + return Reverse8(d, v); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys16(D d, Vec<D> v) const { + static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit"); + return ReverseKeys(d, v); + } + + template <class V> + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEven(odd, even); + } + + template <class D, HWY_IF_LANE_SIZE_D(D, 2)> + HWY_INLINE Vec<D> SwapAdjacentPairs(D d, const Vec<D> v) const { + const Repartition<uint32_t, D> du32; + return BitCast(d, Shuffle2301(BitCast(du32, v))); + } + template <class D, HWY_IF_LANE_SIZE_D(D, 4)> + HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const { + return Shuffle1032(v); + } + template <class D, HWY_IF_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const { + return SwapAdjacentBlocks(v); + } + + template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> SwapAdjacentQuads(D d, const Vec<D> v) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide<D> dw; +#else + const RepartitionToWide<RebindToUnsigned<D> > dw; +#endif + return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v))); + } + template <class D, HWY_IF_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> SwapAdjacentQuads(D d, const Vec<D> v) const { + // Assumes max vector size = 512 + return ConcatLowerUpper(d, v, v); + } + + template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd, + const Vec<D> even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide<D> dw; +#else + const RepartitionToWide<RebindToUnsigned<D> > dw; +#endif + return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even))); + } + template <class D, HWY_IF_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> OddEvenPairs(D /* tag */, Vec<D> odd, Vec<D> even) const { + return OddEvenBlocks(odd, even); + } + + template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide<D> dw; +#else + const RepartitionToWide<RebindToUnsigned<D> > dw; +#endif + return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even))); + } + template <class D, HWY_IF_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const { + return ConcatUpperLower(d, odd, even); + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +template <typename T> +struct OrderAscending : public KeyLane<T> { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *a < *b; } + + template <class D> + HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const { + return Lt(a, b); + } + + // Two halves of Sort2, used in ScanMinMax. + template <class D> + HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Min(a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Max(a, b); + } + + template <class D> + HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::LowestValue<T>()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::HighestValue<T>()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + return Sub(v, Set(d, hwy::Epsilon<T>())); + } +}; + +template <typename T> +struct OrderDescending : public KeyLane<T> { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *b < *a; } + + template <class D> + HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const { + return Lt(b, a); + } + + template <class D> + HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Max(a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Min(a, b); + } + + template <class D> + HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::HighestValue<T>()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::LowestValue<T>()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + return Add(v, Set(d, hwy::Epsilon<T>())); + } +}; + +struct KeyValue64 : public KeyLane<uint64_t> { + // True indicates only part of the key (i.e. lane) should be compared. KV + // stands for key-value. + static constexpr bool IsKV() { return true; } + + template <class D> + HWY_INLINE Mask<D> EqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const { + return Eq(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + template <class D> + HWY_INLINE Mask<D> NotEqualKeys(D /*tag*/, Vec<D> a, Vec<D> b) const { + return Ne(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + HWY_INLINE bool Equal1(const uint64_t* a, const uint64_t* b) const { + return (*a >> 32) == (*b >> 32); + } + + // Only count differences in the actual key, not the value. + template <class D> + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned<D> du; + const Vec<decltype(du)> zero = Zero(du); + const Vec<decltype(du)> keys = ShiftRight<32>(diff); // clear values + return AllTrue(du, Eq(BitCast(du, keys), zero)); + } +}; + +struct OrderAscendingKV64 : public KeyValue64 { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (*a >> 32) < (*b >> 32); + } + + template <class D> + HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const { + return Lt(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template <class D> + HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Min(a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Max(a, b); + } + + template <class D> + HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + // Same as for regular lanes. + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::LowestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::HighestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + return Sub(v, Set(d, uint64_t{1})); + } +}; + +struct OrderDescendingKV64 : public KeyValue64 { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (*b >> 32) < (*a >> 32); + } + + template <class D> + HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const { + return Lt(ShiftRight<32>(b), ShiftRight<32>(a)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template <class D> + HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Max(a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const { + return Min(a, b); + } + + template <class D> + HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::HighestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::LowestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + return Add(v, Set(d, uint64_t{1})); + } +}; + +// Shared code that depends on Order. +template <class Base> +struct TraitsLane : public Base { + // For each lane i: replaces a[i] with the first and b[i] with the second + // according to Base. + // Corresponds to a conditional swap, which is one "node" of a sorting + // network. Min/Max are cheaper than compare + blend at least for integers. + template <class D> + HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const { + const Base* base = static_cast<const Base*>(this); + + const Vec<D> a_copy = a; + // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 + // instructions. We can reduce it to a compare + 2 IfThenElse. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + if (sizeof(TFromD<D>) == 8) { + const Mask<D> cmp = base->Compare(d, a, b); + a = IfThenElse(cmp, a, b); + b = IfThenElse(cmp, b, a_copy); + return; + } +#endif + a = base->First(d, a, b); + b = base->Last(d, a_copy, b); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template <class D, HWY_IF_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->ReverseKeys2(d, v); + // Further to the above optimization, Sort2+OddEvenKeys compile to four + // instructions; we can save one by combining two blends. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + const Vec<D> cmp = VecFromMask(d, base->Compare(d, v, swapped)); + return IfVecThenElse(DupOdd(cmp), swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // (See above - we use Sort2 for non-64-bit types.) + template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)> + HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->ReverseKeys2(d, v); + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template <class D> + HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->ReverseKeys4(d, v); + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template <class D> + HWY_INLINE Vec<D> SortPairsDistance4(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->SwapAdjacentQuads(d, v); + // Only used in Merge16, so this will not be used on AVX2 (which only has 4 + // u64 lanes), so skip the above optimization for 64-bit AVX2. + Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } +}; + +#else + +// Base class shared between OrderAscending, OrderDescending. +template <typename T> +struct KeyLane { + constexpr bool Is128() const { return false; } + constexpr size_t LanesPerKey() const { return 1; } + + using LaneType = T; + using KeyType = T; + + std::string KeyString() const { + char string100[100]; + hwy::detail::TypeName(hwy::detail::MakeTypeInfo<KeyType>(), 1, string100); + return string100; + } +}; + +template <typename T> +struct OrderAscending : public KeyLane<T> { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *a < *b; } + + template <class D> + HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) { + return Lt(a, b); + } +}; + +template <typename T> +struct OrderDescending : public KeyLane<T> { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const T* a, const T* b) { return *b < *a; } + + template <class D> + HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) { + return Lt(b, a); + } +}; + +template <class Order> +struct TraitsLane : public Order { + // For HeapSort + template <typename T> // MSVC doesn't find typename Order::LaneType. + HWY_INLINE void Swap(T* a, T* b) const { + const T temp = *a; + *a = *b; + *b = temp; + } + + template <class D> + HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const { + return Set(d, *key); + } +}; + +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/traits128-inl.h b/third_party/highway/hwy/contrib/sort/traits128-inl.h new file mode 100644 index 0000000000..d889140868 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/traits128-inl.h @@ -0,0 +1,517 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#endif + +#include <string> + +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/contrib/sort/vqsort.h" // SortDescending +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED || HWY_IDE + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +struct KeyAny128 { + static constexpr bool Is128() { return true; } + constexpr size_t LanesPerKey() const { return 2; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = uint64_t; + // KeyType and KeyString are defined by derived classes. + + HWY_INLINE void Swap(LaneType* a, LaneType* b) const { + const FixedTag<LaneType, 2> d; + const auto temp = LoadU(d, a); + StoreU(LoadU(d, b), d, a); + StoreU(temp, d, b); + } + + template <class V, class M> + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressBlocksNot(keys, mask); + } + + template <class D> + HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const { + return LoadDup128(d, key); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const { + return ReverseBlocks(d, v); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys2(D /* tag */, const Vec<D> v) const { + return SwapAdjacentBlocks(v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template <class D> + HWY_INLINE Vec<D> ReverseKeys4(D d, const Vec<D> v) const { + HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>)); + return ReverseKeys(d, v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template <class D> + HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd, + const Vec<D> even) const { + HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>)); + return ConcatUpperLower(d, odd, even); + } + + template <class V> + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEvenBlocks(odd, even); + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys8(D, Vec<D>) const { + HWY_ASSERT(0); // not supported: would require 1024-bit vectors + } + + template <class D> + HWY_INLINE Vec<D> ReverseKeys16(D, Vec<D>) const { + HWY_ASSERT(0); // not supported: would require 2048-bit vectors + } + + // This is only called for 8/16 col networks (not supported). + template <class D> + HWY_INLINE Vec<D> SwapAdjacentPairs(D, Vec<D>) const { + HWY_ASSERT(0); + } + + // This is only called for 16 col networks (not supported). + template <class D> + HWY_INLINE Vec<D> SwapAdjacentQuads(D, Vec<D>) const { + HWY_ASSERT(0); + } + + // This is only called for 8 col networks (not supported). + template <class D> + HWY_INLINE Vec<D> OddEvenQuads(D, Vec<D>, Vec<D>) const { + HWY_ASSERT(0); + } +}; + +// Base class shared between OrderAscending128, OrderDescending128. +struct Key128 : public KeyAny128 { + // False indicates the entire key should be compared. KV means key-value. + static constexpr bool IsKV() { return false; } + + // What type to pass to Sorter::operator(). + using KeyType = hwy::uint128_t; + + std::string KeyString() const { return "U128"; } + + template <class D> + HWY_INLINE Mask<D> EqualKeys(D d, Vec<D> a, Vec<D> b) const { + return Eq128(d, a, b); + } + + template <class D> + HWY_INLINE Mask<D> NotEqualKeys(D d, Vec<D> a, Vec<D> b) const { + return Ne128(d, a, b); + } + + // For keys=entire 128 bits, any difference counts. + template <class D> + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned<D> du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return a[0] == b[0] && a[1] == b[1]; + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +struct OrderAscending128 : public Key128 { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1]; + } + + template <class D> + HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const { + return Lt128(d, a, b); + } + + // Used by CompareTop + template <class V> + HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const { + return Lt(a, b); + } + + template <class D> + HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const { + return Min128(d, a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const { + return Max128(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::LowestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::HighestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + const Vec<D> k0 = Zero(d); + const Vec<D> k1 = OddEven(k0, Set(d, uint64_t{1})); + const Mask<D> borrow = Eq(v, k0); // don't-care, lo == 0 + // lo == 0? 1 : 0, 0 + const Vec<D> adjust = ShiftLeftLanes<1>(IfThenElseZero(borrow, k1)); + return Sub(Sub(v, k1), adjust); + } +}; + +struct OrderDescending128 : public Key128 { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1]; + } + + template <class D> + HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const { + return Lt128(d, b, a); + } + + // Used by CompareTop + template <class V> + HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const { + return Lt(b, a); + } + + template <class D> + HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const { + return Max128(d, a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const { + return Min128(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::HighestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::LowestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + const Vec<D> k1 = OddEven(Zero(d), Set(d, uint64_t{1})); + const Vec<D> added = Add(v, k1); + const Mask<D> overflowed = Lt(added, v); // false, overflowed + // overflowed? 1 : 0, 0 + const Vec<D> adjust = ShiftLeftLanes<1>(IfThenElseZero(overflowed, k1)); + return Add(added, adjust); + } +}; + +// Base class shared between OrderAscendingKV128, OrderDescendingKV128. +struct KeyValue128 : public KeyAny128 { + // True indicates only part of the key (the more significant lane) should be + // compared. KV stands for key-value. + static constexpr bool IsKV() { return true; } + + // What type to pass to Sorter::operator(). + using KeyType = K64V64; + + std::string KeyString() const { return "KV128"; } + + template <class D> + HWY_INLINE Mask<D> EqualKeys(D d, Vec<D> a, Vec<D> b) const { + return Eq128Upper(d, a, b); + } + + template <class D> + HWY_INLINE Mask<D> NotEqualKeys(D d, Vec<D> a, Vec<D> b) const { + return Ne128Upper(d, a, b); + } + + // Only count differences in the actual key, not the value. + template <class D> + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec<D> diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned<D> du; + const Vec<decltype(du)> zero = Zero(du); + const Vec<decltype(du)> keys = OddEven(diff, zero); // clear values + return AllTrue(du, Eq(BitCast(du, keys), zero)); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return a[1] == b[1]; + } +}; + +struct OrderAscendingKV128 : public KeyValue128 { + using Order = SortAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return a[1] < b[1]; + } + + template <class D> + HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const { + return Lt128Upper(d, a, b); + } + + // Used by CompareTop + template <class V> + HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const { + return Lt(a, b); + } + + template <class D> + HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const { + return Min128Upper(d, a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const { + return Max128Upper(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::LowestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::HighestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + const Vec<D> k1 = OddEven(Set(d, uint64_t{1}), Zero(d)); + return Sub(v, k1); + } +}; + +struct OrderDescendingKV128 : public KeyValue128 { + using Order = SortDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) { + return b[1] < a[1]; + } + + template <class D> + HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const { + return Lt128Upper(d, b, a); + } + + // Used by CompareTop + template <class V> + HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const { + return Lt(b, a); + } + + template <class D> + HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const { + return Max128Upper(d, a, b); + } + + template <class D> + HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const { + return Min128Upper(d, a, b); + } + + // Same as for regular lanes because 128-bit lanes are u64. + template <class D> + HWY_INLINE Vec<D> FirstValue(D d) const { + return Set(d, hwy::HighestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> LastValue(D d) const { + return Set(d, hwy::LowestValue<TFromD<D> >()); + } + + template <class D> + HWY_INLINE Vec<D> PrevValue(D d, Vec<D> v) const { + const Vec<D> k1 = OddEven(Set(d, uint64_t{1}), Zero(d)); + return Add(v, k1); + } +}; + +// Shared code that depends on Order. +template <class Base> +class Traits128 : public Base { + // Special case for >= 256 bit vectors +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256 + // Returns vector with only the top u64 lane valid. Useful when the next step + // is to replicate the mask anyway. + template <class D> + HWY_INLINE HWY_MAYBE_UNUSED Vec<D> CompareTop(D d, Vec<D> a, Vec<D> b) const { + const Base* base = static_cast<const Base*>(this); + const Mask<D> eqHL = Eq(a, b); + const Vec<D> ltHL = VecFromMask(d, base->CompareLanes(a, b)); +#if HWY_TARGET == HWY_SVE_256 + return IfThenElse(eqHL, DupEven(ltHL), ltHL); +#else + const Vec<D> ltLX = ShiftLeftLanes<1>(ltHL); + return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX); +#endif + } + + // We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in + // the most-significant of those lanes (the result of CompareTop), so + // replicate it 4x. Only called for >= 256-bit vectors. + template <class V> + HWY_INLINE V ReplicateTop4x(V v) const { +#if HWY_TARGET == HWY_SVE_256 + return svdup_lane_u64(v, 3); +#elif HWY_TARGET <= HWY_AVX3 + return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +#else // AVX2 + return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +#endif + } +#endif // HWY_TARGET + + public: + template <class D> + HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v, + TFromD<D>* HWY_RESTRICT buf) const { + const Base* base = static_cast<const Base*>(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->First(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template <class D> + HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v, + TFromD<D>* HWY_RESTRICT buf) const { + const Base* base = static_cast<const Base*>(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->Last(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template <class D> + HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const { + const Base* base = static_cast<const Base*>(this); + + const Vec<D> a_copy = a; + const auto lt = base->Compare(d, a, b); + a = IfThenElse(lt, a, b); + b = IfThenElse(lt, b, a_copy); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template <class D> + HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->ReverseKeys2(d, v); + +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256 + const Vec<D> select = ReplicateTop4x(CompareTop(d, v, swapped)); + return IfVecThenElse(select, swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template <class D> + HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const { + const Base* base = static_cast<const Base*>(this); + Vec<D> swapped = base->ReverseKeys4(d, v); + + // Only specialize for AVX3 because this requires 512-bit vectors. +#if HWY_TARGET <= HWY_AVX3 + const Vec512<uint64_t> outHx = CompareTop(d, v, swapped); + // Similar to ReplicateTop4x, we want to gang together 2 comparison results + // (4 lanes). They are not contiguous, so use permute to replicate 4x. + alignas(64) uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7}; + const Vec512<uint64_t> select = + TableLookupLanes(outHx, SetTableIndices(d, kIndices)); + return IfVecThenElse(select, swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); +#endif + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template <class D> + HWY_INLINE Vec<D> SortPairsDistance4(D, Vec<D>) const { + // Only used by Merge16, which would require 2048 bit vectors (unsupported). + HWY_ASSERT(0); + } +}; + +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/vqsort-inl.h b/third_party/highway/hwy/contrib/sort/vqsort-inl.h new file mode 100644 index 0000000000..edebe4af11 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort-inl.h @@ -0,0 +1,1484 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +#ifndef VQSORT_PRINT +#define VQSORT_PRINT 0 +#endif + +// Makes it harder for adversaries to predict our sampling locations, at the +// cost of 1-2% increased runtime. +#ifndef VQSORT_SECURE_RNG +#define VQSORT_SECURE_RNG 0 +#endif + +#if VQSORT_SECURE_RNG +#include "third_party/absl/random/random.h" +#endif + +#include <stdio.h> // unconditional #include so we can use if(VQSORT_PRINT). +#include <string.h> // memcpy + +#include "hwy/cache_control.h" // Prefetch +#include "hwy/contrib/sort/vqsort.h" // Fill24Bytes + +#if HWY_IS_MSAN +#include <sanitizer/msan_interface.h> +#endif + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#endif + +#if VQSORT_PRINT +#include "hwy/print-inl.h" +#endif + +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" +// Placeholder for internal instrumentation. Do not remove. +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +using Constants = hwy::SortConstants; + +// Wrappers to avoid #if in user code (interferes with code folding) + +HWY_INLINE void UnpoisonIfMemorySanitizer(void* p, size_t bytes) { +#if HWY_IS_MSAN + __msan_unpoison(p, bytes); +#else + (void)p; + (void)bytes; +#endif +} + +template <class D> +HWY_INLINE void MaybePrintVector(D d, const char* label, Vec<D> v, + size_t start = 0, size_t max_lanes = 16) { +#if VQSORT_PRINT >= 2 // Print is only defined #if + Print(d, label, v, start, max_lanes); +#else + (void)d; + (void)label; + (void)v; + (void)start; + (void)max_lanes; +#endif +} + +// ------------------------------ HeapSort + +template <class Traits, typename T> +void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + size_t start) { + constexpr size_t N1 = st.LanesPerKey(); + const FixedTag<T, N1> d; + + while (start < num_lanes) { + const size_t left = 2 * start + N1; + const size_t right = 2 * start + 2 * N1; + if (left >= num_lanes) break; + size_t idx_larger = start; + const auto key_j = st.SetKey(d, lanes + start); + if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) { + idx_larger = left; + } + if (right < num_lanes && + AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger), + st.SetKey(d, lanes + right)))) { + idx_larger = right; + } + if (idx_larger == start) break; + st.Swap(lanes + start, lanes + idx_larger); + start = idx_larger; + } +} + +// Heapsort: O(1) space, O(N*logN) worst-case comparisons. +// Based on LLVM sanitizer_common.h, licensed under Apache-2.0. +template <class Traits, typename T> +void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) { + constexpr size_t N1 = st.LanesPerKey(); + + if (num_lanes < 2 * N1) return; + + // Build heap. + for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { + SiftDown(st, lanes, num_lanes, i); + } + + for (size_t i = num_lanes - N1; i != 0; i -= N1) { + // Swap root with last + st.Swap(lanes + 0, lanes + i); + + // Sift down the new root. + SiftDown(st, lanes, i, 0); + } +} + +#if VQSORT_ENABLED || HWY_IDE + +// ------------------------------ BaseCase + +// Sorts `keys` within the range [0, num) via sorting network. +template <class D, class Traits, typename T> +HWY_INLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys, + T* HWY_RESTRICT keys_end, size_t num, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + using V = decltype(Zero(d)); + + // _Nonzero32 requires num - 1 != 0. + if (HWY_UNLIKELY(num <= 1)) return; + + // Reshape into a matrix with kMaxRows rows, and columns limited by the + // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum). + const size_t num_pow2 = size_t{1} + << (32 - Num0BitsAboveMS1Bit_Nonzero32( + static_cast<uint32_t>(num - 1))); + HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N)); + const size_t cols = + HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2); + HWY_DASSERT(cols <= N); + + // We can avoid padding and load/store directly to `keys` after checking the + // original input array has enough space. Except at the right border, it's OK + // to sort more than the current sub-array. Even if we sort across a previous + // partition point, we know that keys will not migrate across it. However, we + // must use the maximum size of the sorting network, because the StoreU of its + // last vector would otherwise write invalid data starting at kMaxRows * cols. + const size_t N_sn = Lanes(CappedTag<T, Constants::kMaxCols>()); + if (HWY_LIKELY(keys + N_sn * Constants::kMaxRows <= keys_end)) { + SortingNetwork(st, keys, N_sn); + return; + } + + // Copy `keys` to `buf`. + size_t i; + for (i = 0; i + N <= num; i += N) { + Store(LoadU(d, keys + i), d, buf + i); + } + SafeCopyN(num - i, d, keys + i, buf + i); + i = num; + + // Fill with padding - last in sort order, not copied to keys. + const V kPadding = st.LastValue(d); + // Initialize an extra vector because SortingNetwork loads full vectors, + // which may exceed cols*kMaxRows. + for (; i < (cols * Constants::kMaxRows + N); i += N) { + StoreU(kPadding, d, buf + i); + } + + SortingNetwork(st, buf, cols); + + for (i = 0; i + N <= num; i += N) { + StoreU(Load(d, buf + i), d, keys + i); + } + SafeCopyN(num - i, d, buf + i, keys + i); +} + +// ------------------------------ Partition + +// Consumes from `keys` until a multiple of kUnroll*N remains. +// Temporarily stores the right side into `buf`, then moves behind `num`. +// Returns the number of keys consumed from the left side. +template <class D, class Traits, class T> +HWY_INLINE size_t PartitionToMultipleOfUnroll(D d, Traits st, + T* HWY_RESTRICT keys, size_t& num, + const Vec<D> pivot, + T* HWY_RESTRICT buf) { + constexpr size_t kUnroll = Constants::kPartitionUnroll; + const size_t N = Lanes(d); + size_t readL = 0; + T* HWY_RESTRICT posL = keys; + size_t bufR = 0; + // Partition requires both a multiple of kUnroll*N and at least + // 2*kUnroll*N for the initial loads. If less, consume all here. + const size_t num_rem = + (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1)); + size_t i = 0; + for (; i + N <= num_rem; i += N) { + const Vec<D> vL = LoadU(d, keys + readL); + readL += N; + + const auto comp = st.Compare(d, pivot, vL); + posL += CompressBlendedStore(vL, Not(comp), d, posL); + bufR += CompressStore(vL, comp, d, buf + bufR); + } + // Last iteration: only use valid lanes. + if (HWY_LIKELY(i != num_rem)) { + const auto mask = FirstN(d, num_rem - i); + const Vec<D> vL = LoadU(d, keys + readL); + + const auto comp = st.Compare(d, pivot, vL); + posL += CompressBlendedStore(vL, AndNot(comp, mask), d, posL); + bufR += CompressStore(vL, And(comp, mask), d, buf + bufR); + } + + // MSAN seems not to understand CompressStore. buf[0, bufR) are valid. + UnpoisonIfMemorySanitizer(buf, bufR * sizeof(T)); + + // Everything we loaded was put into buf, or behind the current `posL`, after + // which there is space for bufR items. First move items from `keys + num` to + // `posL` to free up space, then copy `buf` into the vacated `keys + num`. + // A loop with masked loads from `buf` is insufficient - we would also need to + // mask from `keys + num`. Combining a loop with memcpy for the remainders is + // slower than just memcpy, so we use that for simplicity. + num -= bufR; + memcpy(posL, keys + num, bufR * sizeof(T)); + memcpy(keys + num, buf, bufR * sizeof(T)); + return static_cast<size_t>(posL - keys); // caller will shrink num by this. +} + +template <class V> +V OrXor(const V o, const V x1, const V x2) { + // TODO(janwas): add op so we can benefit from AVX-512 ternlog? + return Or(o, Xor(x1, x2)); +} + +// Note: we could track the OrXor of v and pivot to see if the entire left +// partition is equal, but that happens rarely and thus is a net loss. +template <class D, class Traits, typename T> +HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v, + const Vec<D> pivot, T* HWY_RESTRICT keys, + size_t& writeL, size_t& remaining) { + const size_t N = Lanes(d); + + const auto comp = st.Compare(d, pivot, v); + + remaining -= N; + if (hwy::HWY_NAMESPACE::CompressIsPartition<T>::value || + (HWY_MAX_BYTES == 16 && st.Is128())) { + // Non-native Compress (e.g. AVX2): we are able to partition a vector using + // a single Compress+two StoreU instead of two Compress[Blended]Store. The + // latter are more expensive. Because we store entire vectors, the contents + // between the updated writeL and writeR are ignored and will be overwritten + // by subsequent calls. This works because writeL and writeR are at least + // two vectors apart. + const auto lr = st.CompressKeys(v, comp); + const size_t num_left = N - CountTrue(d, comp); + StoreU(lr, d, keys + writeL); + // Now write the right-side elements (if any), such that the previous writeR + // is one past the end of the newly written right elements, then advance. + StoreU(lr, d, keys + remaining + writeL); + writeL += num_left; + } else { + // Native Compress[Store] (e.g. AVX3), which only keep the left or right + // side, not both, hence we require two calls. + const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL); + writeL += num_left; + + (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL); + } +} + +template <class D, class Traits, typename T> +HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0, + const Vec<D> v1, const Vec<D> v2, + const Vec<D> v3, const Vec<D> pivot, + T* HWY_RESTRICT keys, size_t& writeL, + size_t& remaining) { + StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining); +} + +// Moves "<= pivot" keys to the front, and others to the back. pivot is +// broadcasted. Time-critical! +// +// Aligned loads do not seem to be worthwhile (not bottlenecked by load ports). +template <class D, class Traits, typename T> +HWY_INLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + const Vec<D> pivot, T* HWY_RESTRICT buf) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all + // lanes happen to be in the right-side partition, this will overrun `keys`, + // which triggers asan errors. Avoid by special-casing the last vector. + HWY_DASSERT(num > 2 * N); // ensured by HandleSpecialCases + num -= N; + size_t last = num; + const V vlast = LoadU(d, keys + last); + + const size_t consumedL = + PartitionToMultipleOfUnroll(d, st, keys, num, pivot, buf); + keys += consumedL; + last -= consumedL; + num -= consumedL; + constexpr size_t kUnroll = Constants::kPartitionUnroll; + + // Partition splits the vector into 3 sections, left to right: Elements + // smaller or equal to the pivot, unpartitioned elements and elements larger + // than the pivot. To write elements unconditionally on the loop body without + // overwriting existing data, we maintain two regions of the loop where all + // elements have been copied elsewhere (e.g. vector registers.). I call these + // bufferL and bufferR, for left and right respectively. + // + // These regions are tracked by the indices (writeL, writeR, left, right) as + // presented in the diagram below. + // + // writeL writeR + // \/ \/ + // | <= pivot | bufferL | unpartitioned | bufferR | > pivot | + // \/ \/ + // left right + // + // In the main loop body below we choose a side, load some elements out of the + // vector and move either `left` or `right`. Next we call into StoreLeftRight + // to partition the data, and the partitioned elements will be written either + // to writeR or writeL and the corresponding index will be moved accordingly. + // + // Note that writeR is not explicitly tracked as an optimization for platforms + // with conditional operations. Instead we track writeL and the number of + // elements left to process (`remaining`). From the diagram above we can see + // that: + // writeR - writeL = remaining => writeR = remaining + writeL + // + // Tracking `remaining` is advantageous because each iteration reduces the + // number of unpartitioned elements by a fixed amount, so we can compute + // `remaining` without data dependencies. + // + size_t writeL = 0; + size_t remaining = num; + + const T* HWY_RESTRICT readL = keys; + const T* HWY_RESTRICT readR = keys + num; + // Cannot load if there were fewer than 2 * kUnroll * N. + if (HWY_LIKELY(num != 0)) { + HWY_DASSERT(num >= 2 * kUnroll * N); + HWY_DASSERT((num & (kUnroll * N - 1)) == 0); + + // Make space for writing in-place by reading from readL/readR. + const V vL0 = LoadU(d, readL + 0 * N); + const V vL1 = LoadU(d, readL + 1 * N); + const V vL2 = LoadU(d, readL + 2 * N); + const V vL3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + readR -= kUnroll * N; + const V vR0 = LoadU(d, readR + 0 * N); + const V vR1 = LoadU(d, readR + 1 * N); + const V vR2 = LoadU(d, readR + 2 * N); + const V vR3 = LoadU(d, readR + 3 * N); + + // readL/readR changed above, so check again before the loop. + while (readL != readR) { + V v0, v1, v2, v3; + + // Data-dependent but branching is faster than forcing branch-free. + const size_t capacityL = + static_cast<size_t>((readL - keys) - static_cast<ptrdiff_t>(writeL)); + HWY_DASSERT(capacityL <= num); // >= 0 + // Load data from the end of the vector with less data (front or back). + // The next paragraphs explain how this works. + // + // let block_size = (kUnroll * N) + // On the loop prelude we load block_size elements from the front of the + // vector and an additional block_size elements from the back. On each + // iteration k elements are written to the front of the vector and + // (block_size - k) to the back. + // + // This creates a loop invariant where the capacity on the front + // (capacityL) and on the back (capacityR) always add to 2 * block_size. + // In other words: + // capacityL + capacityR = 2 * block_size + // capacityR = 2 * block_size - capacityL + // + // This means that: + // capacityL < capacityR <=> + // capacityL < 2 * block_size - capacityL <=> + // 2 * capacityL < 2 * block_size <=> + // capacityL < block_size + // + // Thus the check on the next line is equivalent to capacityL > capacityR. + // + if (kUnroll * N < capacityL) { + readR -= kUnroll * N; + v0 = LoadU(d, readR + 0 * N); + v1 = LoadU(d, readR + 1 * N); + v2 = LoadU(d, readR + 2 * N); + v3 = LoadU(d, readR + 3 * N); + hwy::Prefetch(readR - 3 * kUnroll * N); + } else { + v0 = LoadU(d, readL + 0 * N); + v1 = LoadU(d, readL + 1 * N); + v2 = LoadU(d, readL + 2 * N); + v3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + hwy::Prefetch(readL + 3 * kUnroll * N); + } + + StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining); + } + + // Now finish writing the saved vectors to the middle. + StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining); + StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, remaining); + } + + // We have partitioned [left, right) such that writeL is the boundary. + HWY_DASSERT(remaining == 0); + // Make space for inserting vlast: move up to N of the first right-side keys + // into the unused space starting at last. If we have fewer, ensure they are + // the last items in that vector by subtracting from the *load* address, + // which is safe because we have at least two vectors (checked above). + const size_t totalR = last - writeL; + const size_t startR = totalR < N ? writeL + totalR - N : writeL; + StoreU(LoadU(d, keys + startR), d, keys + last); + + // Partition vlast: write L, then R, into the single-vector gap at writeL. + const auto comp = st.Compare(d, pivot, vlast); + writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL); + (void)CompressBlendedStore(vlast, comp, d, keys + writeL); + + return consumedL + writeL; +} + +// Returns true and partitions if [keys, keys + num) contains only {valueL, +// valueR}. Otherwise, sets third to the first differing value; keys may have +// been reordered and a regular Partition is still necessary. +// Called from two locations, hence NOINLINE. +template <class D, class Traits, typename T> +HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec<D> valueL, + const Vec<D> valueR, Vec<D>& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + + size_t i = 0; + size_t writeL = 0; + + // As long as all lanes are equal to L or R, we can overwrite with valueL. + // This is faster than first counting, then backtracking to fill L and R. + for (; i + N <= num; i += N) { + const Vec<D> v = LoadU(d, keys + i); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to AVX3. + const Mask<D> eqL = st.EqualKeys(d, v, valueL); + const Mask<D> eqR = st.EqualKeys(d, v, valueR); + // At least one other value present; will require a regular partition. + // On AVX-512, Or + AllTrue are folded into a single kortest if we are + // careful with the FindKnownFirstTrue argument, see below. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the + // loop, which is a pessimization because this if-true branch is cold. + // We can defeat this via Not(Xor), which is equivalent because eqL and + // eqR cannot be true at the same time. Can we elide the additional Not? + // FindFirstFalse instructions are generally unavailable, but we can + // fuse Not and Xor/Or into one ExclusiveNeither. + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i, writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + for (; writeL + N <= i; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL); + return false; + } + StoreU(valueL, d, keys + writeL); + writeL += CountTrue(d, eqL); + } + + // Final vector, masked comparison (no effect if i == num) + const size_t remaining = num - i; + SafeCopyN(remaining, d, keys + i, buf); + const Vec<D> v = Load(d, buf); + const Mask<D> valid = FirstN(d, remaining); + const Mask<D> eqL = And(st.EqualKeys(d, v, valueL), valid); + const Mask<D> eqR = st.EqualKeys(d, v, valueR); + // Invalid lanes are considered equal. + const Mask<D> eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i, + writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + for (; writeL + N <= i; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + BlendedStore(valueR, FirstN(d, i - writeL), d, keys + writeL); + return false; + } + BlendedStore(valueL, valid, d, keys + writeL); + writeL += CountTrue(d, eqL); + + // Fill right side + i = writeL; + for (; i + N <= num; i += N) { + StoreU(valueR, d, keys + i); + } + BlendedStore(valueR, FirstN(d, num - i), d, keys + i); + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Successful MaybePartitionTwoValue\n"); + } + return true; +} + +// Same as above, except that the pivot equals valueR, so scan right to left. +template <class D, class Traits, typename T> +HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec<D> valueL, + const Vec<D> valueR, Vec<D>& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + + HWY_DASSERT(num >= N); + size_t pos = num - N; // current read/write position + size_t countR = 0; // number of valueR found + + // For whole vectors, in descending address order: as long as all lanes are + // equal to L or R, overwrite with valueR. This is faster than counting, then + // filling both L and R. Loop terminates after unsigned wraparound. + for (; pos < num; pos -= N) { + const Vec<D> v = LoadU(d, keys + pos); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to AVX3. + const Mask<D> eqL = st.EqualKeys(d, v, valueL); + const Mask<D> eqR = st.EqualKeys(d, v, valueR); + // If there is a third value, stop and undo what we've done. On AVX-512, + // Or + AllTrue are folded into a single kortest, but only if we are + // careful with the FindKnownFirstTrue argument - see prior comment on that. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + pos + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + for (; pos + N <= endL; pos += N) { + StoreU(valueL, d, keys + pos); + } + BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos); + return false; + } + StoreU(valueR, d, keys + pos); + countR += CountTrue(d, eqR); + } + + // Final partial (or empty) vector, masked comparison. + const size_t remaining = pos + N; + HWY_DASSERT(remaining <= N); + const Vec<D> v = LoadU(d, keys); // Safe because num >= N. + const Mask<D> valid = FirstN(d, remaining); + const Mask<D> eqL = st.EqualKeys(d, v, valueL); + const Mask<D> eqR = And(st.EqualKeys(d, v, valueR), valid); + // Invalid lanes are considered equal. + const Mask<D> eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + for (; pos + N <= endL; pos += N) { + StoreU(valueL, d, keys + pos); + } + BlendedStore(valueL, FirstN(d, endL - pos), d, keys + pos); + return false; + } + const size_t lastR = CountTrue(d, eqR); + countR += lastR; + + // First finish writing valueR - [0, N) lanes were not yet written. + StoreU(valueR, d, keys); // Safe because num >= N. + + // Fill left side (ascending order for clarity) + const size_t endL = num - countR; + size_t i = 0; + for (; i + N <= endL; i += N) { + StoreU(valueL, d, keys + i); + } + Store(valueL, d, buf); + SafeCopyN(endL - i, d, buf, keys + i); // avoids asan overrun + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, + "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n", + countR, pos, i, endL); + } + + return true; +} + +// `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the +// second key. This is the first path into `MaybePartitionTwoValue`, called +// when all samples are equal. Returns false if there are at least a third +// value and sets `third`. Otherwise, partitions the array and returns true. +template <class D, class Traits, typename T> +HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec<D> pivot, + T* HWY_RESTRICT keys, size_t num, + const size_t idx_second, const Vec<D> second, + Vec<D>& third, T* HWY_RESTRICT buf) { + // True if second comes before pivot. + const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second)); + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second, + is_pivotR); + } + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot))); + + // If pivot is R, we scan backwards over the entire array. Otherwise, + // we already scanned up to idx_second and can leave those in place. + return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot, + third, buf) + : MaybePartitionTwoValue(d, st, keys + idx_second, + num - idx_second, pivot, second, + third, buf); +} + +// Second path into `MaybePartitionTwoValue`, called when not all samples are +// equal. `samples` is sorted. +template <class D, class Traits, typename T> +HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + constexpr size_t N1 = st.LanesPerKey(); + const Vec<D> valueL = st.SetKey(d, samples); + const Vec<D> valueR = st.SetKey(d, samples + kSampleLanes - N1); + HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR))); + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR))); + const Vec<D> prev = st.PrevValue(d, valueR); + // If the sample has more than two values, then the keys have at least that + // many, and thus this special case is inapplicable. + if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) { + return false; + } + + // Must not overwrite samples because if this returns false, caller wants to + // read the original samples again. + T* HWY_RESTRICT buf = samples + kSampleLanes; + Vec<D> third; // unused + return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf); +} + +// ------------------------------ Pivot sampling + +template <class Traits, class V> +HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) { + const DFromV<V> d; + // Slightly faster for 128-bit, apparently because not serially dependent. + if (st.Is128()) { + // Median = XOR-sum 'minus' the first and last. Calling First twice is + // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR. + const auto sum = Xor(Xor(v0, v1), v2); + const auto first = st.First(d, st.First(d, v0, v1), v2); + const auto last = st.Last(d, st.Last(d, v0, v1), v2); + return Xor(Xor(sum, first), last); + } + st.Sort2(d, v0, v2); + v1 = st.Last(d, v0, v1); + v1 = st.First(d, v1, v2); + return v1; +} + +#if VQSORT_SECURE_RNG +using Generator = absl::BitGen; +#else +// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028 +#pragma pack(push, 1) +class Generator { + public: + Generator(const void* heap, size_t num) { + Sorter::Fill24Bytes(heap, num, &a_); + k_ = 1; // stream index: must be odd + } + + explicit Generator(uint64_t seed) { + a_ = b_ = w_ = seed; + k_ = 1; + } + + uint64_t operator()() { + const uint64_t b = b_; + w_ += k_; + const uint64_t next = a_ ^ w_; + a_ = (b + (b << 3)) ^ (b >> 11); + const uint64_t rot = (b << 24) | (b >> 40); + b_ = rot + next; + return next; + } + + private: + uint64_t a_; + uint64_t b_; + uint64_t w_; + uint64_t k_; // increment +}; +#pragma pack(pop) + +#endif // !VQSORT_SECURE_RNG + +// Returns slightly biased random index of a chunk in [0, num_chunks). +// See https://www.pcg-random.org/posts/bounded-rands.html. +HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) { + const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32; + HWY_DASSERT(chunk_index < num_chunks); + return static_cast<size_t>(chunk_index); +} + +// Writes samples from `keys[0, num)` into `buf`. +template <class D, class Traits, typename T> +HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf, Generator& rng) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // Power of two + constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T)); + + // Align start of keys to chunks. We always have at least 2 chunks because the + // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks. + HWY_DASSERT(num >= 2 * kLanesPerChunk); + const size_t misalign = + (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (kLanesPerChunk - 1); + if (misalign != 0) { + const size_t consume = kLanesPerChunk - misalign; + keys += consume; + num -= consume; + } + + // Generate enough random bits for 9 uint32 + uint64_t* bits64 = reinterpret_cast<uint64_t*>(buf); + for (size_t i = 0; i < 5; ++i) { + bits64[i] = rng(); + } + const uint32_t* bits = reinterpret_cast<const uint32_t*>(buf); + + const size_t num_chunks64 = num / kLanesPerChunk; + // Clamp to uint32 for RandomChunkIndex + const uint32_t num_chunks = + static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull)); + + const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk; + const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk; + const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk; + const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk; + const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk; + const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk; + const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) * kLanesPerChunk; + const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) * kLanesPerChunk; + const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) * kLanesPerChunk; + for (size_t i = 0; i < kLanesPerChunk; i += N) { + const V v0 = Load(d, keys + offset0 + i); + const V v1 = Load(d, keys + offset1 + i); + const V v2 = Load(d, keys + offset2 + i); + const V medians0 = MedianOf3(st, v0, v1, v2); + Store(medians0, d, buf + i); + + const V v3 = Load(d, keys + offset3 + i); + const V v4 = Load(d, keys + offset4 + i); + const V v5 = Load(d, keys + offset5 + i); + const V medians1 = MedianOf3(st, v3, v4, v5); + Store(medians1, d, buf + i + kLanesPerChunk); + + const V v6 = Load(d, keys + offset6 + i); + const V v7 = Load(d, keys + offset7 + i); + const V v8 = Load(d, keys + offset8 + i); + const V medians2 = MedianOf3(st, v6, v7, v8); + Store(medians2, d, buf + i + kLanesPerChunk * 2); + } +} + +// For detecting inputs where (almost) all keys are equal. +template <class D, class Traits> +HWY_INLINE bool UnsortedSampleEqual(D d, Traits st, + const TFromD<D>* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = 3 * 64 / sizeof(TFromD<D>); + const size_t N = Lanes(d); + using V = Vec<D>; + + const V first = st.SetKey(d, samples); + // OR of XOR-difference may be faster than comparison. + V diff = Zero(d); + size_t i = 0; + for (; i + N <= kSampleLanes; i += N) { + const V v = Load(d, samples + i); + diff = OrXor(diff, first, v); + } + // Remainder, if any. + const V v = Load(d, samples + i); + const auto valid = FirstN(d, kSampleLanes - i); + diff = IfThenElse(valid, OrXor(diff, first, v), diff); + + return st.NoKeyDifference(d, diff); +} + +template <class D, class Traits, typename T> +HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) { + // buf contains 192 bytes, so 16 128-bit vectors are necessary and sufficient. + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + const CappedTag<T, 16 / sizeof(T)> d128; + const size_t N128 = Lanes(d128); + constexpr size_t kCols = HWY_MIN(16 / sizeof(T), Constants::kMaxCols); + constexpr size_t kBytes = kCols * Constants::kMaxRows * sizeof(T); + static_assert(192 <= kBytes, ""); + // Fill with padding - last in sort order. + const auto kPadding = st.LastValue(d128); + // Initialize an extra vector because SortingNetwork loads full vectors, + // which may exceed cols*kMaxRows. + for (size_t i = kSampleLanes; i <= kBytes / sizeof(T); i += N128) { + StoreU(kPadding, d128, buf + i); + } + + SortingNetwork(st, buf, kCols); + + if (VQSORT_PRINT >= 2) { + const size_t N = Lanes(d); + fprintf(stderr, "Samples:\n"); + for (size_t i = 0; i < kSampleLanes; i += N) { + MaybePrintVector(d, "", Load(d, buf + i), 0, N); + } + } +} + +// ------------------------------ Pivot selection + +enum class PivotResult { + kDone, // stop without partitioning (all equal, or two-value partition) + kNormal, // partition and recurse left and right + kIsFirst, // partition but skip left recursion + kWasLast, // partition but skip right recursion +}; + +HWY_INLINE const char* PivotResultString(PivotResult result) { + switch (result) { + case PivotResult::kDone: + return "done"; + case PivotResult::kNormal: + return "normal"; + case PivotResult::kIsFirst: + return "first"; + case PivotResult::kWasLast: + return "last"; + } + return "unknown"; +} + +template <class Traits, typename T> +HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + constexpr size_t N1 = st.LanesPerKey(); + + constexpr size_t kRankMid = kSampleLanes / 2; + static_assert(kRankMid % N1 == 0, "Mid is not an aligned key"); + + // Find the previous value not equal to the median. + size_t rank_prev = kRankMid - N1; + for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) { + // All previous samples are equal to the median. + if (rank_prev == 0) return 0; + } + + size_t rank_next = rank_prev + N1; + for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) { + // The median is also the largest sample. If it is also the largest key, + // we'd end up with an empty right partition, so choose the previous key. + if (rank_next == kSampleLanes - N1) return rank_prev; + } + + // If we choose the median as pivot, the ratio of keys ending in the left + // partition will likely be rank_next/kSampleLanes (if the sample is + // representative). This is because equal-to-pivot values also land in the + // left - it's infeasible to do an in-place vectorized 3-way partition. + // Check whether prev would lead to a more balanced partition. + const size_t excess_if_median = rank_next - kRankMid; + const size_t excess_if_prev = kRankMid - rank_prev; + return excess_if_median < excess_if_prev ? kRankMid : rank_prev; +} + +// Returns pivot chosen from `samples`. It will never be the largest key +// (thus the right partition will never be empty). +template <class D, class Traits, typename T> +HWY_INLINE Vec<D> ChoosePivotByRank(D d, Traits st, + const T* HWY_RESTRICT samples) { + const size_t pivot_rank = PivotRank(st, samples); + const Vec<D> pivot = st.SetKey(d, samples + pivot_rank); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, " Pivot rank %zu = %f\n", pivot_rank, + static_cast<double>(GetLane(pivot))); + } + // Verify pivot is not equal to the last sample. + constexpr size_t kSampleLanes = 3 * 64 / sizeof(T); + constexpr size_t N1 = st.LanesPerKey(); + const Vec<D> last = st.SetKey(d, samples + kSampleLanes - N1); + const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last)); + (void)all_neq; + HWY_DASSERT(all_neq); + return pivot; +} + +// Returns true if all keys equal `pivot`, otherwise returns false and sets +// `*first_mismatch' to the index of the first differing key. +template <class D, class Traits, typename T> +HWY_INLINE bool AllEqual(D d, Traits st, const Vec<D> pivot, + const T* HWY_RESTRICT keys, size_t num, + size_t* HWY_RESTRICT first_mismatch) { + const size_t N = Lanes(d); + // Ensures we can use overlapping loads for the tail; see HandleSpecialCases. + HWY_DASSERT(num >= N); + const Vec<D> zero = Zero(d); + + // Vector-align keys + i. + const size_t misalign = + (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (N - 1); + HWY_DASSERT(misalign % st.LanesPerKey() == 0); + const size_t consume = N - misalign; + { + const Vec<D> v = LoadU(d, keys); + // Only check masked lanes; consider others to be equal. + const Mask<D> diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot)); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = lane; + return false; + } + } + size_t i = consume; + HWY_DASSERT(((reinterpret_cast<uintptr_t>(keys + i) / sizeof(T)) & (N - 1)) == + 0); + + // Sticky bits registering any difference between `keys` and the first key. + // We use vector XOR because it may be cheaper than comparisons, especially + // for 128-bit. 2x unrolled for more ILP. + Vec<D> diff0 = zero; + Vec<D> diff1 = zero; + + // We want to stop once a difference has been found, but without slowing + // down the loop by comparing during each iteration. The compromise is to + // compare after a 'group', which consists of kLoops times two vectors. + constexpr size_t kLoops = 8; + const size_t lanes_per_group = kLoops * 2 * N; + + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec<D> v0 = Load(d, keys + i + loop * 2 * N); + const Vec<D> v1 = Load(d, keys + i + loop * 2 * N + N); + diff0 = OrXor(diff0, v0, pivot); + diff1 = OrXor(diff1, v1, pivot); + } + + // If there was a difference in the entire group: + if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) { + // .. then loop until the first one, with termination guarantee. + for (;; i += N) { + const Vec<D> v = Load(d, keys + i); + const Mask<D> diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + } + } + + // Whole vectors, no unrolling, compare directly + for (; i + N <= num; i += N) { + const Vec<D> v = Load(d, keys + i); + const Mask<D> diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + // Always re-check the last (unaligned) vector to reduce branching. + i = num - N; + const Vec<D> v = LoadU(d, keys + i); + const Mask<D> diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "All keys equal\n"); + } + return true; // all equal +} + +// Called from 'two locations', but only one is active (IsKV is constexpr). +template <class D, class Traits, typename T> +HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec<D> pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for before\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec<D> first = pivot; + + // Whole group, unrolled + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec<D> curr = LoadU(d, keys + i + loop * N); + first = st.First(d, first, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + // Whole vectors, no unrolling + for (; i + N <= num; i += N) { + const Vec<D> curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec<D> curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the first +} + +// Called from 'two locations', but only one is active (IsKV is constexpr). +template <class D, class Traits, typename T> +HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec<D> pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for after\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec<D> last = pivot; + + // Whole group, unrolled + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec<D> curr = LoadU(d, keys + i + loop * N); + last = st.Last(d, last, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + // Whole vectors, no unrolling + for (; i + N <= num; i += N) { + const Vec<D> curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec<D> curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the last +} + +// Returns pivot chosen from `keys[0, num)`. It will never be the largest key +// (thus the right partition will never be empty). +template <class D, class Traits, typename T> +HWY_INLINE Vec<D> ChoosePivotForEqualSamples(D d, Traits st, + T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT samples, + Vec<D> second, Vec<D> third, + PivotResult& result) { + const Vec<D> pivot = st.SetKey(d, samples); // the single unique sample + + // Early out for mostly-0 arrays, where pivot is often FirstValue. + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) { + result = PivotResult::kIsFirst; + return pivot; + } + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) { + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and + // cannot be used. + if (st.IsKV()) { + // If true, pivot is either middle or last. + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + if (HWY_UNLIKELY(before)) { + // Not last, so middle. + if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would + // be empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Otherwise, pivot is first or middle. Rule out it being first: + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + // It is first: fall through to shared code below. + } else { + // Check if pivot is between two known values. If so, it is not the first + // nor the last and we can avoid scanning. + st.Sort2(d, second, third); + HWY_DASSERT(AllTrue(d, st.Compare(d, second, third))); + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + const bool after = !AllFalse(d, st.Compare(d, pivot, third)); + // Only reached if there are three keys, which means pivot is either first, + // last, or in between. Thus there is another key that comes before or + // after. + HWY_DASSERT(before || after); + if (HWY_UNLIKELY(before)) { + // Neither first nor last. + if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would + // be empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Has after, and we found one before: in the middle. + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + } + + // Pivot is first. We could consider a special partition mode that only + // reads from and writes to the right side, and later fills in the left + // side, which we know is equal to the pivot. However, that leads to more + // cache misses if the array is large, and doesn't save much, hence is a + // net loss. + result = PivotResult::kIsFirst; + return pivot; +} + +// ------------------------------ Quicksort recursion + +template <class D, class Traits, typename T> +HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 2) { + const size_t N = Lanes(d); + if (num < N) return; + + Vec<D> first = st.LastValue(d); + Vec<D> last = st.FirstValue(d); + + size_t i = 0; + for (; i + N <= num; i += N) { + const Vec<D> v = LoadU(d, keys + i); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + if (HWY_LIKELY(i != num)) { + HWY_DASSERT(num >= N); // See HandleSpecialCases + const Vec<D> v = LoadU(d, keys + num - N); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + + first = st.FirstOfLanes(d, first, buf); + last = st.LastOfLanes(d, last, buf); + MaybePrintVector(d, "first", first, 0, st.LanesPerKey()); + MaybePrintVector(d, "last", last, 0, st.LanesPerKey()); + } +} + +// keys_end is the end of the entire user input, not just the current subarray +// [keys, keys + num). +template <class D, class Traits, typename T> +HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys, + T* HWY_RESTRICT keys_end, const size_t num, + T* HWY_RESTRICT buf, Generator& rng, + const size_t remaining_levels) { + HWY_DASSERT(num != 0); + + if (HWY_UNLIKELY(num <= Constants::BaseCaseNum(Lanes(d)))) { + BaseCase(d, st, keys, keys_end, num, buf); + return; + } + + // Move after BaseCase so we skip printing for small subarrays. + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu\n", remaining_levels, + num); + PrintMinMax(d, st, keys, num, buf); + } + + DrawSamples(d, st, keys, num, buf, rng); + + Vec<D> pivot; + PivotResult result = PivotResult::kNormal; + if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) { + pivot = st.SetKey(d, buf); + size_t idx_second = 0; + if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) { + return; + } + HWY_DASSERT(idx_second % st.LanesPerKey() == 0); + // Must capture the value before PartitionIfTwoKeys may overwrite it. + const Vec<D> second = st.SetKey(d, keys + idx_second); + MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey()); + MaybePrintVector(d, "second", second, 0, st.LanesPerKey()); + + Vec<D> third; + // Not supported for key-value types because two 'keys' may be equivalent + // but not interchangeable (their values may differ). + if (HWY_UNLIKELY(!st.IsKV() && + PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second, + second, third, buf))) { + return; // Done, skip recursion because each side has all-equal keys. + } + + // We can no longer start scanning from idx_second because + // PartitionIfTwoKeys may have reordered keys. + pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third, + result); + // If kNormal, `pivot` is very common but not the first/last. It is + // tempting to do a 3-way partition (to avoid moving the =pivot keys a + // second time), but that is a net loss due to the extra comparisons. + } else { + SortSamples(d, st, buf); + + // Not supported for key-value types because two 'keys' may be equivalent + // but not interchangeable (their values may differ). + if (HWY_UNLIKELY(!st.IsKV() && + PartitionIfTwoSamples(d, st, keys, num, buf))) { + return; + } + + pivot = ChoosePivotByRank(d, st, buf); + } + + // Too many recursions. This is unlikely to happen because we select pivots + // from large (though still O(1)) samples. + if (HWY_UNLIKELY(remaining_levels == 0)) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "HeapSort reached, size=%zu\n", num); + } + HeapSort(st, keys, num); // Slow but N*logN. + return; + } + + const size_t bound = Partition(d, st, keys, num, pivot, buf); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "bound %zu num %zu result %s\n", bound, num, + PivotResultString(result)); + } + // The left partition is not empty because the pivot is one of the keys + // (unless kWasLast, in which case the pivot is PrevValue, but we still + // have at least one value <= pivot because AllEqual ruled out the case of + // only one unique value, and there is exactly one value after pivot). + HWY_DASSERT(bound != 0); + // ChoosePivot* ensure pivot != last, so the right partition is never empty. + HWY_DASSERT(bound != num); + + if (HWY_LIKELY(result != PivotResult::kIsFirst)) { + Recurse(d, st, keys, keys_end, bound, buf, rng, remaining_levels - 1); + } + if (HWY_LIKELY(result != PivotResult::kWasLast)) { + Recurse(d, st, keys + bound, keys_end, num - bound, buf, rng, + remaining_levels - 1); + } +} + +// Returns true if sorting is finished. +template <class D, class Traits, typename T> +HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, + size_t num) { + const size_t N = Lanes(d); + const size_t base_case_num = Constants::BaseCaseNum(N); + + // 128-bit keys require vectors with at least two u64 lanes, which is always + // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the + // hardware vector width is less than 128bit / fraction. + const bool partial_128 = !IsFull(d) && N < 2 && st.Is128(); + // Partition assumes its input is at least two vectors. If vectors are huge, + // base_case_num may actually be smaller. If so, which is only possible on + // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of + // HWY_LANES to account for the largest possible LMUL. + constexpr bool kPotentiallyHuge = + HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols; + const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num); + if (partial_128 || huge_vec) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "WARNING: using slow HeapSort: partial %d huge %d\n", + partial_128, huge_vec); + } + HeapSort(st, keys, num); + return true; + } + + // Small arrays are already handled by Recurse. + + // We could also check for already sorted/reverse/equal, but that's probably + // counterproductive if vqsort is used as a base case. + + return false; // not finished sorting +} + +#endif // VQSORT_ENABLED +} // namespace detail + +// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`. +// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons. +// Non-stable (order of equal keys may change), except for the common case where +// the upper bits of T are the key, and the lower bits are a sequential or at +// least unique ID. +// There is no upper limit on `num`, but note that pivots may be chosen by +// sampling only from the first 256 GiB. +// +// `d` is typically SortTag<T> (chooses between full and partial vectors). +// `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. +template <class D, class Traits, typename T> +void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "=============== Sort num %zu\n", num); + } + +#if VQSORT_ENABLED || HWY_IDE +#if !HWY_HAVE_SCALABLE + // On targets with fixed-size vectors, avoid _using_ the allocated memory. + // We avoid (potentially expensive for small input sizes) allocations on + // platforms where no targets are scalable. For 512-bit vectors, this fits on + // the stack (several KiB). + HWY_ALIGN T storage[SortConstants::BufNum<T>(HWY_LANES(T))] = {}; + static_assert(sizeof(storage) <= 8192, "Unexpectedly large, check size"); + buf = storage; +#endif // !HWY_HAVE_SCALABLE + + if (detail::HandleSpecialCases(d, st, keys, num)) return; + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf); + } +#endif // HWY_MAX_BYTES > 64 + + detail::Generator rng(keys, num); + + // Introspection: switch to worst-case N*logN heapsort after this many. + const size_t max_levels = 2 * hwy::CeilLog2(num) + 4; + detail::Recurse(d, st, keys, keys + num, num, buf, rng, max_levels); +#else + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "WARNING: using slow HeapSort because vqsort disabled\n"); + } + return detail::HeapSort(st, keys, num); +#endif // VQSORT_ENABLED +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE diff --git a/third_party/highway/hwy/contrib/sort/vqsort.cc b/third_party/highway/hwy/contrib/sort/vqsort.cc new file mode 100644 index 0000000000..b3bac0720a --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort.cc @@ -0,0 +1,184 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#include <string.h> // memset + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/shared-inl.h" + +// Architectures for which we know HWY_HAVE_SCALABLE == 0. This opts into an +// optimization that replaces dynamic allocation with stack storage. +#ifndef VQSORT_STACK +#if HWY_ARCH_X86 || HWY_ARCH_WASM +#define VQSORT_STACK 1 +#else +#define VQSORT_STACK 0 +#endif +#endif // VQSORT_STACK + +#if !VQSORT_STACK +#include "hwy/aligned_allocator.h" +#endif + +// Check if we have sys/random.h. First skip some systems on which the check +// itself (features.h) might be problematic. +#if defined(ANDROID) || defined(__ANDROID__) || HWY_ARCH_RVV +#define VQSORT_GETRANDOM 0 +#endif + +#if !defined(VQSORT_GETRANDOM) && HWY_OS_LINUX +#include <features.h> + +// ---- which libc +#if defined(__UCLIBC__) +#define VQSORT_GETRANDOM 1 // added Mar 2015, before uclibc-ng 1.0 + +#elif defined(__GLIBC__) && defined(__GLIBC_PREREQ) +#if __GLIBC_PREREQ(2, 25) +#define VQSORT_GETRANDOM 1 +#else +#define VQSORT_GETRANDOM 0 +#endif + +#else +// Assume MUSL, which has getrandom since 2018. There is no macro to test, see +// https://www.openwall.com/lists/musl/2013/03/29/13. +#define VQSORT_GETRANDOM 1 + +#endif // ---- which libc +#endif // linux + +#if !defined(VQSORT_GETRANDOM) +#define VQSORT_GETRANDOM 0 +#endif + +// Seed source for SFC generator: 1=getrandom, 2=CryptGenRandom +// (not all Android support the getrandom wrapper) +#ifndef VQSORT_SECURE_SEED + +#if VQSORT_GETRANDOM +#define VQSORT_SECURE_SEED 1 +#elif defined(_WIN32) || defined(_WIN64) +#define VQSORT_SECURE_SEED 2 +#else +#define VQSORT_SECURE_SEED 0 +#endif + +#endif // VQSORT_SECURE_SEED + +#if !VQSORT_SECURE_RNG + +#include <time.h> +#if VQSORT_SECURE_SEED == 1 +#include <sys/random.h> +#elif VQSORT_SECURE_SEED == 2 +#include <windows.h> +#pragma comment(lib, "advapi32.lib") +// Must come after windows.h. +#include <wincrypt.h> +#endif // VQSORT_SECURE_SEED + +#endif // !VQSORT_SECURE_RNG + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +size_t VectorSize() { return Lanes(ScalableTag<uint8_t, 3>()); } +bool HaveFloat64() { return HWY_HAVE_FLOAT64; } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(VectorSize); +HWY_EXPORT(HaveFloat64); + +} // namespace + +Sorter::Sorter() { +#if VQSORT_STACK + ptr_ = nullptr; // Sort will use stack storage instead +#else + // Determine the largest buffer size required for any type by trying them all. + // (The capping of N in BaseCaseNum means that smaller N but larger sizeof_t + // may require a larger buffer.) + const size_t vector_size = HWY_DYNAMIC_DISPATCH(VectorSize)(); + const size_t max_bytes = + HWY_MAX(HWY_MAX(SortConstants::BufBytes<uint16_t>(vector_size), + SortConstants::BufBytes<uint32_t>(vector_size)), + SortConstants::BufBytes<uint64_t>(vector_size)); + ptr_ = hwy::AllocateAlignedBytes(max_bytes, nullptr, nullptr); + + // Prevent msan errors by initializing. + memset(ptr_, 0, max_bytes); +#endif +} + +void Sorter::Delete() { +#if !VQSORT_STACK + FreeAlignedBytes(ptr_, nullptr, nullptr); + ptr_ = nullptr; +#endif +} + +#if !VQSORT_SECURE_RNG + +void Sorter::Fill24Bytes(const void* seed_heap, size_t seed_num, void* bytes) { +#if VQSORT_SECURE_SEED == 1 + // May block if urandom is not yet initialized. + const ssize_t ret = getrandom(bytes, 24, /*flags=*/0); + if (ret == 24) return; +#elif VQSORT_SECURE_SEED == 2 + HCRYPTPROV hProvider{}; + if (CryptAcquireContextA(&hProvider, nullptr, nullptr, PROV_RSA_FULL, + CRYPT_VERIFYCONTEXT)) { + const BOOL ok = + CryptGenRandom(hProvider, 24, reinterpret_cast<BYTE*>(bytes)); + CryptReleaseContext(hProvider, 0); + if (ok) return; + } +#endif + + // VQSORT_SECURE_SEED == 0, or one of the above failed. Get some entropy from + // stack/heap/code addresses and the clock() timer. + uint64_t* words = reinterpret_cast<uint64_t*>(bytes); + uint64_t** seed_stack = &words; + void (*seed_code)(const void*, size_t, void*) = &Fill24Bytes; + const uintptr_t bits_stack = reinterpret_cast<uintptr_t>(seed_stack); + const uintptr_t bits_heap = reinterpret_cast<uintptr_t>(seed_heap); + const uintptr_t bits_code = reinterpret_cast<uintptr_t>(seed_code); + const uint64_t bits_time = static_cast<uint64_t>(clock()); + words[0] = bits_stack ^ bits_time ^ seed_num; + words[1] = bits_heap ^ bits_time ^ seed_num; + words[2] = bits_code ^ bits_time ^ seed_num; +} + +#endif // !VQSORT_SECURE_RNG + +bool Sorter::HaveFloat64() { return HWY_DYNAMIC_DISPATCH(HaveFloat64)(); } + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort.h b/third_party/highway/hwy/contrib/sort/vqsort.h new file mode 100644 index 0000000000..88d78ac7f9 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort.h @@ -0,0 +1,108 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Interface to vectorized quicksort with dynamic dispatch. +// Blog post: https://tinyurl.com/vqsort-blog +// Paper with measurements: https://arxiv.org/abs/2205.05982 +// +// To ensure the overhead of using wide vectors (e.g. AVX2 or AVX-512) is +// worthwhile, we recommend using this code for sorting arrays whose size is at +// least 512 KiB. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Tag arguments that determine the sort order. +struct SortAscending { + constexpr bool IsAscending() const { return true; } +}; +struct SortDescending { + constexpr bool IsAscending() const { return false; } +}; + +// Allocates O(1) space. Type-erased RAII wrapper over hwy/aligned_allocator.h. +// This allows amortizing the allocation over multiple sorts. +class HWY_CONTRIB_DLLEXPORT Sorter { + public: + Sorter(); + ~Sorter() { Delete(); } + + // Move-only + Sorter(const Sorter&) = delete; + Sorter& operator=(const Sorter&) = delete; + Sorter(Sorter&& other) { + Delete(); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + } + Sorter& operator=(Sorter&& other) { + Delete(); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + return *this; + } + + // Sorts keys[0, n). Dispatches to the best available instruction set, + // and does not allocate memory. + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // For internal use only + static void Fill24Bytes(const void* seed_heap, size_t seed_num, void* bytes); + static bool HaveFloat64(); + + private: + void Delete(); + + template <typename T> + T* Get() const { + return static_cast<T*>(ptr_); + } + + void* ptr_ = nullptr; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ diff --git a/third_party/highway/hwy/contrib/sort/vqsort_128a.cc b/third_party/highway/hwy/contrib/sort/vqsort_128a.cc new file mode 100644 index 0000000000..40daea85c7 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_128a.cc @@ -0,0 +1,62 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void Sort128Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag<uint64_t> d; + detail::SharedTraits<detail::Traits128<detail::OrderAscending128>> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Asc); +} // namespace + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(Sort128Asc) + (reinterpret_cast<uint64_t*>(keys), n * 2, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_128d.cc b/third_party/highway/hwy/contrib/sort/vqsort_128d.cc new file mode 100644 index 0000000000..357da840c1 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_128d.cc @@ -0,0 +1,62 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void Sort128Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag<uint64_t> d; + detail::SharedTraits<detail::Traits128<detail::OrderDescending128>> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Desc); +} // namespace + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(Sort128Desc) + (reinterpret_cast<uint64_t*>(keys), n * 2, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc new file mode 100644 index 0000000000..3856eea5dd --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc @@ -0,0 +1,53 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF32Asc(float* HWY_RESTRICT keys, size_t num, float* HWY_RESTRICT buf) { + SortTag<float> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<float>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Asc); +} // namespace + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortF32Asc)(keys, n, Get<float>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc new file mode 100644 index 0000000000..7f5f97cdf2 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF32Desc(float* HWY_RESTRICT keys, size_t num, + float* HWY_RESTRICT buf) { + SortTag<float> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<float>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Desc); +} // namespace + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortF32Desc)(keys, n, Get<float>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc new file mode 100644 index 0000000000..287d5214e5 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF64Asc(double* HWY_RESTRICT keys, size_t num, + double* HWY_RESTRICT buf) { +#if HWY_HAVE_FLOAT64 + SortTag<double> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<double>>> st; + Sort(d, st, keys, num, buf); +#else + (void)keys; + (void)num; + (void)buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Asc); +} // namespace + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortF64Asc)(keys, n, Get<double>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc new file mode 100644 index 0000000000..74d40c1ed3 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc @@ -0,0 +1,61 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortF64Desc(double* HWY_RESTRICT keys, size_t num, + double* HWY_RESTRICT buf) { +#if HWY_HAVE_FLOAT64 + SortTag<double> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<double>>> st; + Sort(d, st, keys, num, buf); +#else + (void)keys; + (void)num; + (void)buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Desc); +} // namespace + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortF64Desc)(keys, n, Get<double>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc new file mode 100644 index 0000000000..ef4bb75bc4 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI16Asc(int16_t* HWY_RESTRICT keys, size_t num, + int16_t* HWY_RESTRICT buf) { + SortTag<int16_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<int16_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Asc); +} // namespace + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI16Asc)(keys, n, Get<int16_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc new file mode 100644 index 0000000000..6507ed6080 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI16Desc(int16_t* HWY_RESTRICT keys, size_t num, + int16_t* HWY_RESTRICT buf) { + SortTag<int16_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<int16_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Desc); +} // namespace + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI16Desc)(keys, n, Get<int16_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc new file mode 100644 index 0000000000..ae65be997e --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI32Asc(int32_t* HWY_RESTRICT keys, size_t num, + int32_t* HWY_RESTRICT buf) { + SortTag<int32_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<int32_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Asc); +} // namespace + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI32Asc)(keys, n, Get<int32_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc new file mode 100644 index 0000000000..3ce276ee9c --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI32Desc(int32_t* HWY_RESTRICT keys, size_t num, + int32_t* HWY_RESTRICT buf) { + SortTag<int32_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<int32_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Desc); +} // namespace + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI32Desc)(keys, n, Get<int32_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc new file mode 100644 index 0000000000..901b8ead8a --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI64Asc(int64_t* HWY_RESTRICT keys, size_t num, + int64_t* HWY_RESTRICT buf) { + SortTag<int64_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<int64_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Asc); +} // namespace + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortI64Asc)(keys, n, Get<int64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc new file mode 100644 index 0000000000..7713f2eb89 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortI64Desc(int64_t* HWY_RESTRICT keys, size_t num, + int64_t* HWY_RESTRICT buf) { + SortTag<int64_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<int64_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Desc); +} // namespace + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortI64Desc)(keys, n, Get<int64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv128a.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv128a.cc new file mode 100644 index 0000000000..1e02742ef1 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_kv128a.cc @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128a.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV128Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag<uint64_t> d; + detail::SharedTraits<detail::Traits128<detail::OrderAscendingKV128>> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV128Asc); +} // namespace + +void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortKV128Asc) + (reinterpret_cast<uint64_t*>(keys), n * 2, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv128d.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv128d.cc new file mode 100644 index 0000000000..3dd53b5da3 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_kv128d.cc @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128d.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV128Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag<uint64_t> d; + detail::SharedTraits<detail::Traits128<detail::OrderDescendingKV128>> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV128Desc); +} // namespace + +void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortKV128Desc) + (reinterpret_cast<uint64_t*>(keys), n * 2, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv64a.cc new file mode 100644 index 0000000000..c513e3c4ce --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_kv64a.cc @@ -0,0 +1,65 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64a.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV64Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag<uint64_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscendingKV64>> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV64Asc); +} // namespace + +void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortKV64Asc) + (reinterpret_cast<uint64_t*>(keys), n, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv64d.cc new file mode 100644 index 0000000000..c6c5fdcf74 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_kv64d.cc @@ -0,0 +1,65 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64d.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortKV64Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { +#if VQSORT_ENABLED + SortTag<uint64_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescendingKV64>> st; + Sort(d, st, keys, num, buf); +#else + (void) keys; + (void) num; + (void) buf; + HWY_ASSERT(0); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV64Desc); +} // namespace + +void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortKV64Desc) + (reinterpret_cast<uint64_t*>(keys), n, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc new file mode 100644 index 0000000000..0a97ffa923 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU16Asc(uint16_t* HWY_RESTRICT keys, size_t num, + uint16_t* HWY_RESTRICT buf) { + SortTag<uint16_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<uint16_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Asc); +} // namespace + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU16Asc)(keys, n, Get<uint16_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc new file mode 100644 index 0000000000..286ebbba65 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU16Desc(uint16_t* HWY_RESTRICT keys, size_t num, + uint16_t* HWY_RESTRICT buf) { + SortTag<uint16_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<uint16_t>>> + st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Desc); +} // namespace + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU16Desc)(keys, n, Get<uint16_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc new file mode 100644 index 0000000000..b6a69e6e28 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU32Asc(uint32_t* HWY_RESTRICT keys, size_t num, + uint32_t* HWY_RESTRICT buf) { + SortTag<uint32_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<uint32_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Asc); +} // namespace + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU32Asc)(keys, n, Get<uint32_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc new file mode 100644 index 0000000000..38fc1e1bfe --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU32Desc(uint32_t* HWY_RESTRICT keys, size_t num, + uint32_t* HWY_RESTRICT buf) { + SortTag<uint32_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<uint32_t>>> + st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Desc); +} // namespace + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU32Desc)(keys, n, Get<uint32_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc new file mode 100644 index 0000000000..a29824a6f9 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU64Asc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag<uint64_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<uint64_t>>> st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Asc); +} // namespace + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortAscending) const { + HWY_DYNAMIC_DISPATCH(SortU64Asc)(keys, n, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc new file mode 100644 index 0000000000..d692458623 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc @@ -0,0 +1,55 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +void SortU64Desc(uint64_t* HWY_RESTRICT keys, size_t num, + uint64_t* HWY_RESTRICT buf) { + SortTag<uint64_t> d; + detail::SharedTraits<detail::TraitsLane<detail::OrderDescending<uint64_t>>> + st; + Sort(d, st, keys, num, buf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Desc); +} // namespace + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortDescending) const { + HWY_DYNAMIC_DISPATCH(SortU64Desc)(keys, n, Get<uint64_t>()); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/detect_compiler_arch.h b/third_party/highway/hwy/detect_compiler_arch.h new file mode 100644 index 0000000000..466e30b308 --- /dev/null +++ b/third_party/highway/hwy/detect_compiler_arch.h @@ -0,0 +1,235 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ +#define HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ + +// Detects compiler and arch from predefined macros. Zero dependencies for +// inclusion by foreach_target.h. + +// Add to #if conditions to prevent IDE from graying out code. +#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \ + (defined Q_CREATOR_RUN) || (defined __CLANGD__) || \ + (defined GROK_ELLIPSIS_BUILD) +#define HWY_IDE 1 +#else +#define HWY_IDE 0 +#endif + +//------------------------------------------------------------------------------ +// Compiler + +// Actual MSVC, not clang-cl, which defines _MSC_VER but doesn't behave like +// MSVC in other aspects (e.g. HWY_DIAGNOSTICS). +#if defined(_MSC_VER) && !defined(__clang__) +#define HWY_COMPILER_MSVC _MSC_VER +#else +#define HWY_COMPILER_MSVC 0 +#endif + +#if defined(_MSC_VER) && defined(__clang__) +#define HWY_COMPILER_CLANGCL _MSC_VER +#else +#define HWY_COMPILER_CLANGCL 0 +#endif + +#ifdef __INTEL_COMPILER +#define HWY_COMPILER_ICC __INTEL_COMPILER +#else +#define HWY_COMPILER_ICC 0 +#endif + +#ifdef __INTEL_LLVM_COMPILER +#define HWY_COMPILER_ICX __INTEL_LLVM_COMPILER +#else +#define HWY_COMPILER_ICX 0 +#endif + +// HWY_COMPILER_GCC is a generic macro for all compilers implementing the GNU +// compiler extensions (eg. Clang, Intel...) +#ifdef __GNUC__ +#define HWY_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define HWY_COMPILER_GCC 0 +#endif + +// Clang or clang-cl, not GCC. +#ifdef __clang__ +// In case of Apple LLVM (whose version number is unrelated to that of LLVM) or +// an invalid version number, deduce it from the presence of warnings. +// Adapted from https://github.com/simd-everywhere/simde/ simde-detect-clang.h. +#if defined(__apple_build_version__) || __clang_major__ >= 999 +#if __has_warning("-Wbitwise-instead-of-logical") +#define HWY_COMPILER_CLANG 1400 +#elif __has_warning("-Wreserved-identifier") +#define HWY_COMPILER_CLANG 1300 +#elif __has_warning("-Wformat-insufficient-args") +#define HWY_COMPILER_CLANG 1200 +#elif __has_warning("-Wimplicit-const-int-float-conversion") +#define HWY_COMPILER_CLANG 1100 +#elif __has_warning("-Wmisleading-indentation") +#define HWY_COMPILER_CLANG 1000 +#elif defined(__FILE_NAME__) +#define HWY_COMPILER_CLANG 900 +#elif __has_warning("-Wextra-semi-stmt") || \ + __has_builtin(__builtin_rotateleft32) +#define HWY_COMPILER_CLANG 800 +// For reasons unknown, XCode 10.3 (Apple LLVM version 10.0.1) is apparently +// based on Clang 7, but does not support the warning we test. +// See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions and +// https://trac.macports.org/wiki/XcodeVersionInfo. +#elif __has_warning("-Wc++98-compat-extra-semi") || \ + (defined(__apple_build_version__) && __apple_build_version__ >= 10010000) +#define HWY_COMPILER_CLANG 700 +#else // Anything older than 7.0 is not recommended for Highway. +#define HWY_COMPILER_CLANG 600 +#endif // __has_warning chain +#else // use normal version +#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +#endif +#else // Not clang +#define HWY_COMPILER_CLANG 0 +#endif + +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG +#define HWY_COMPILER_GCC_ACTUAL HWY_COMPILER_GCC +#else +#define HWY_COMPILER_GCC_ACTUAL 0 +#endif + +// More than one may be nonzero, but we want at least one. +#if 0 == (HWY_COMPILER_MSVC + HWY_COMPILER_CLANGCL + HWY_COMPILER_ICC + \ + HWY_COMPILER_GCC + HWY_COMPILER_CLANG) +#error "Unsupported compiler" +#endif + +// We should only detect one of these (only clang/clangcl overlap) +#if 1 < \ + (!!HWY_COMPILER_MSVC + !!HWY_COMPILER_ICC + !!HWY_COMPILER_GCC_ACTUAL + \ + !!(HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG)) +#error "Detected multiple compilers" +#endif + +#ifdef __has_builtin +#define HWY_HAS_BUILTIN(name) __has_builtin(name) +#else +#define HWY_HAS_BUILTIN(name) 0 +#endif + +#ifdef __has_attribute +#define HWY_HAS_ATTRIBUTE(name) __has_attribute(name) +#else +#define HWY_HAS_ATTRIBUTE(name) 0 +#endif + +#ifdef __has_feature +#define HWY_HAS_FEATURE(name) __has_feature(name) +#else +#define HWY_HAS_FEATURE(name) 0 +#endif + +//------------------------------------------------------------------------------ +// Architecture + +#if defined(__i386__) || defined(_M_IX86) +#define HWY_ARCH_X86_32 1 +#else +#define HWY_ARCH_X86_32 0 +#endif + +#if defined(__x86_64__) || defined(_M_X64) +#define HWY_ARCH_X86_64 1 +#else +#define HWY_ARCH_X86_64 0 +#endif + +#if HWY_ARCH_X86_32 && HWY_ARCH_X86_64 +#error "Cannot have both x86-32 and x86-64" +#endif + +#if HWY_ARCH_X86_32 || HWY_ARCH_X86_64 +#define HWY_ARCH_X86 1 +#else +#define HWY_ARCH_X86 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) +#define HWY_ARCH_PPC 1 +#else +#define HWY_ARCH_PPC 0 +#endif + +#if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64) +#define HWY_ARCH_ARM_A64 1 +#else +#define HWY_ARCH_ARM_A64 0 +#endif + +#if (defined(__ARM_ARCH) && __ARM_ARCH == 7) || (defined(_M_ARM) && _M_ARM == 7) +#define HWY_ARCH_ARM_V7 1 +#else +#define HWY_ARCH_ARM_V7 0 +#endif + +#if HWY_ARCH_ARM_A64 && HWY_ARCH_ARM_V7 +#error "Cannot have both A64 and V7" +#endif + +// Any *supported* version of Arm, i.e. 7 or later +#if HWY_ARCH_ARM_A64 || HWY_ARCH_ARM_V7 +#define HWY_ARCH_ARM 1 +#else +#define HWY_ARCH_ARM 0 +#endif + +// Older than v7 (e.g. armel aka Arm v5), in which case we do not support SIMD. +#if (defined(__arm__) || defined(_M_ARM)) && !HWY_ARCH_ARM +#define HWY_ARCH_ARM_OLD 1 +#else +#define HWY_ARCH_ARM_OLD 0 +#endif + +#if defined(__EMSCRIPTEN__) || defined(__wasm__) || defined(__WASM__) +#define HWY_ARCH_WASM 1 +#else +#define HWY_ARCH_WASM 0 +#endif + +#ifdef __riscv +#define HWY_ARCH_RVV 1 +#else +#define HWY_ARCH_RVV 0 +#endif + +// It is an error to detect multiple architectures at the same time, but OK to +// detect none of the above. +#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \ + HWY_ARCH_WASM + HWY_ARCH_RVV) > 1 +#error "Must not detect more than one architecture" +#endif + +#if defined(_WIN32) || defined(_WIN64) +#define HWY_OS_WIN 1 +#else +#define HWY_OS_WIN 0 +#endif + +#if defined(linux) || defined(__linux__) +#define HWY_OS_LINUX 1 +#else +#define HWY_OS_LINUX 0 +#endif + +#endif // HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ diff --git a/third_party/highway/hwy/detect_targets.h b/third_party/highway/hwy/detect_targets.h new file mode 100644 index 0000000000..2beca95bf5 --- /dev/null +++ b/third_party/highway/hwy/detect_targets.h @@ -0,0 +1,479 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_TARGETS_H_ +#define HIGHWAY_HWY_DETECT_TARGETS_H_ + +// Defines targets and chooses which to enable. + +#include "hwy/detect_compiler_arch.h" + +//------------------------------------------------------------------------------ +// Optional configuration + +// See g3doc/quick_reference.md for documentation of these macros. + +// Uncomment to override the default baseline determined from predefined macros: +// #define HWY_BASELINE_TARGETS (HWY_SSE4 | HWY_SCALAR) + +// Uncomment to override the default blocklist: +// #define HWY_BROKEN_TARGETS HWY_AVX3 + +// Uncomment to definitely avoid generating those target(s): +// #define HWY_DISABLED_TARGETS HWY_SSE4 + +// Uncomment to avoid emitting BMI/BMI2/FMA instructions (allows generating +// AVX2 target for VMs which support AVX2 but not the other instruction sets) +// #define HWY_DISABLE_BMI2_FMA + +// Uncomment to enable SSSE3/SSE4 on MSVC even if AVX is not enabled +// #define HWY_WANT_SSSE3 +// #define HWY_WANT_SSE4 + +//------------------------------------------------------------------------------ +// Targets + +// Unique bit value for each target. A lower value is "better" (e.g. more lanes) +// than a higher value within the same group/platform - see HWY_STATIC_TARGET. +// +// All values are unconditionally defined so we can test HWY_TARGETS without +// first checking the HWY_ARCH_*. +// +// The C99 preprocessor evaluates #if expressions using intmax_t types. This +// holds at least 64 bits in practice (verified 2022-07-18 via Godbolt on +// 32-bit clang/GCC/MSVC compilers for x86/Arm7/AArch32/RISC-V/WASM). We now +// avoid overflow when computing HWY_TARGETS (subtracting one instead of +// left-shifting 2^62), but still do not use bit 63 because it is the sign bit. + +// --------------------------- x86: 15 targets (+ one fallback) +// Bits 0..6 reserved (7 targets) +// Currently satisfiable by Ice Lake (VNNI, VPCLMULQDQ, VPOPCNTDQ, VBMI, VBMI2, +// VAES, BITALG). Later to be added: BF16 (Cooper Lake). VP2INTERSECT is only in +// Tiger Lake? We do not yet have uses for GFNI. +#define HWY_AVX3_DL (1LL << 7) // see HWY_WANT_AVX3_DL below +#define HWY_AVX3 (1LL << 8) +#define HWY_AVX2 (1LL << 9) +// Bit 10: reserved for AVX +#define HWY_SSE4 (1LL << 11) +#define HWY_SSSE3 (1LL << 12) +// Bits 13..14 reserved for SSE3 or SSE2 (2 targets) +// The highest bit in the HWY_TARGETS mask that a x86 target can have. Used for +// dynamic dispatch. All x86 target bits must be lower or equal to +// (1 << HWY_HIGHEST_TARGET_BIT_X86) and they can only use +// HWY_MAX_DYNAMIC_TARGETS in total. +#define HWY_HIGHEST_TARGET_BIT_X86 14 + +// --------------------------- Arm: 15 targets (+ one fallback) +// Bits 15..23 reserved (9 targets) +#define HWY_SVE2_128 (1LL << 24) // specialized target (e.g. Arm N2) +#define HWY_SVE_256 (1LL << 25) // specialized target (e.g. Arm V1) +#define HWY_SVE2 (1LL << 26) +#define HWY_SVE (1LL << 27) +#define HWY_NEON (1LL << 28) // On A64, includes/requires AES +// Bit 29 reserved (Helium?) +#define HWY_HIGHEST_TARGET_BIT_ARM 29 + +// --------------------------- RISC-V: 9 targets (+ one fallback) +// Bits 30..36 reserved (7 targets) +#define HWY_RVV (1LL << 37) +// Bit 38 reserved +#define HWY_HIGHEST_TARGET_BIT_RVV 38 + +// --------------------------- Future expansion: 4 targets +// Bits 39..42 reserved + + +// --------------------------- IBM Power: 9 targets (+ one fallback) +// Bits 43..48 reserved (6 targets) +#define HWY_PPC8 (1LL << 49) // v2.07 or 3 +// Bits 50..51 reserved for prior VSX/AltiVec (2 targets) +#define HWY_HIGHEST_TARGET_BIT_PPC 51 + +// --------------------------- WebAssembly: 9 targets (+ one fallback) +// Bits 52..57 reserved (6 targets) +#define HWY_WASM_EMU256 (1LL << 58) // Experimental +#define HWY_WASM (1LL << 59) +// Bits 60 reserved +#define HWY_HIGHEST_TARGET_BIT_WASM 60 + +// --------------------------- Emulation: 2 targets + +#define HWY_EMU128 (1LL << 61) +// We do not add/left-shift, so this will not overflow to a negative number. +#define HWY_SCALAR (1LL << 62) +#define HWY_HIGHEST_TARGET_BIT_SCALAR 62 + +// Do not use bit 63 - would be confusing to have negative numbers. + +//------------------------------------------------------------------------------ +// Set default blocklists + +// Disabled means excluded from enabled at user's request. A separate config +// macro allows disabling without deactivating the blocklist below. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS 0 +#endif + +// Broken means excluded from enabled due to known compiler issues. Allow the +// user to override this blocklist without any guarantee of success. +#ifndef HWY_BROKEN_TARGETS + +// x86 clang-6: we saw multiple AVX2/3 compile errors and in one case invalid +// SSE4 codegen (possibly only for msan), so disable all those targets. +#if HWY_ARCH_X86 && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_BROKEN_TARGETS (HWY_SSE4 | HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL) +// This entails a major speed reduction, so warn unless the user explicitly +// opts in to scalar-only. +#if !defined(HWY_COMPILE_ONLY_SCALAR) +#pragma message("x86 Clang <= 6: define HWY_COMPILE_ONLY_SCALAR or upgrade.") +#endif + +// 32-bit may fail to compile AVX2/3. +#elif HWY_ARCH_X86_32 +#define HWY_BROKEN_TARGETS (HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL) + +// MSVC AVX3 support is buggy: https://github.com/Mysticial/Flops/issues/16 +#elif HWY_COMPILER_MSVC != 0 +#define HWY_BROKEN_TARGETS (HWY_AVX3 | HWY_AVX3_DL) + +// armv7be has not been tested and is not yet supported. +#elif HWY_ARCH_ARM_V7 && \ + (defined(__ARM_BIG_ENDIAN) || \ + (defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN)) +#define HWY_BROKEN_TARGETS (HWY_NEON) + +// SVE[2] require recent clang or gcc versions. +#elif (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) +#define HWY_BROKEN_TARGETS (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128) + +#else +#define HWY_BROKEN_TARGETS 0 +#endif + +#endif // HWY_BROKEN_TARGETS + +// Enabled means not disabled nor blocklisted. +#define HWY_ENABLED(targets) \ + ((targets) & ~((HWY_DISABLED_TARGETS) | (HWY_BROKEN_TARGETS))) + +// Opt-out for EMU128 (affected by a GCC bug on multiple arches, fixed in 12.3: +// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=106322). This is separate +// from HWY_BROKEN_TARGETS because it affects the fallback target, which must +// always be enabled. If 1, we instead choose HWY_SCALAR even without +// HWY_COMPILE_ONLY_SCALAR being set. +#if !defined(HWY_BROKEN_EMU128) // allow overriding +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1203) || \ + defined(HWY_NO_LIBCXX) +#define HWY_BROKEN_EMU128 1 +#else +#define HWY_BROKEN_EMU128 0 +#endif +#endif // HWY_BROKEN_EMU128 + +//------------------------------------------------------------------------------ +// Detect baseline targets using predefined macros + +// Baseline means the targets for which the compiler is allowed to generate +// instructions, implying the target CPU would have to support them. This does +// not take the blocklist into account. + +#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128 +#define HWY_BASELINE_SCALAR HWY_SCALAR +#else +#define HWY_BASELINE_SCALAR HWY_EMU128 +#endif + +// Also check HWY_ARCH to ensure that simulating unknown platforms ends up with +// HWY_TARGET == HWY_BASELINE_SCALAR. + +#if HWY_ARCH_WASM && defined(__wasm_simd128__) +#if defined(HWY_WANT_WASM2) +#define HWY_BASELINE_WASM HWY_WASM_EMU256 +#else +#define HWY_BASELINE_WASM HWY_WASM +#endif // HWY_WANT_WASM2 +#else +#define HWY_BASELINE_WASM 0 +#endif + +// Avoid choosing the PPC target until we have an implementation. +#if HWY_ARCH_PPC && defined(__VSX__) && 0 +#define HWY_BASELINE_PPC8 HWY_PPC8 +#else +#define HWY_BASELINE_PPC8 0 +#endif + +#define HWY_BASELINE_SVE2 0 +#define HWY_BASELINE_SVE 0 +#define HWY_BASELINE_NEON 0 + +#if HWY_ARCH_ARM + +#if defined(__ARM_FEATURE_SVE2) +#undef HWY_BASELINE_SVE2 // was 0, will be re-defined +// If user specified -msve-vector-bits=128, they assert the vector length is +// 128 bits and we should use the HWY_SVE2_128 (more efficient for some ops). +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 128 +#define HWY_BASELINE_SVE2 HWY_SVE2_128 +// Otherwise we're not sure what the vector length will be. The baseline must be +// unconditionally valid, so we can only assume HWY_SVE2. However, when running +// on a CPU with 128-bit vectors, user code that supports dynamic dispatch will +// still benefit from HWY_SVE2_128 because we add it to HWY_ATTAINABLE_TARGETS. +#else +#define HWY_BASELINE_SVE2 HWY_SVE2 +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE2 + +#if defined(__ARM_FEATURE_SVE) +#undef HWY_BASELINE_SVE // was 0, will be re-defined +// See above. If user-specified vector length matches our optimization, use it. +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 256 +#define HWY_BASELINE_SVE HWY_SVE_256 +#else +#define HWY_BASELINE_SVE HWY_SVE +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE + +// GCC 4.5.4 only defines __ARM_NEON__; 5.4 defines both. +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#undef HWY_BASELINE_NEON +#define HWY_BASELINE_NEON HWY_NEON +#endif + +#endif // HWY_ARCH_ARM + +// Special handling for MSVC because it has fewer predefined macros: +#if HWY_COMPILER_MSVC + +// 1) We can only be sure SSSE3/SSE4 are enabled if AVX is: +// https://stackoverflow.com/questions/18563978/. +#if defined(__AVX__) +#define HWY_CHECK_SSSE3 1 +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSSE3 0 +#define HWY_CHECK_SSE4 0 +#endif + +// 2) Cannot check for PCLMUL/AES and BMI2/FMA/F16C individually; we assume +// PCLMUL/AES are available if SSE4 is, and BMI2/FMA/F16C if AVX2 is. +#define HWY_CHECK_PCLMUL_AES 1 +#define HWY_CHECK_BMI2_FMA 1 +#define HWY_CHECK_F16C 1 + +#else // non-MSVC + +#if defined(__SSSE3__) +#define HWY_CHECK_SSSE3 1 +#else +#define HWY_CHECK_SSSE3 0 +#endif + +#if defined(__SSE4_1__) && defined(__SSE4_2__) +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSE4 0 +#endif + +// If these are disabled, they should not gate the availability of SSE4/AVX2. +#if defined(HWY_DISABLE_PCLMUL_AES) || (defined(__PCLMUL__) && defined(__AES__)) +#define HWY_CHECK_PCLMUL_AES 1 +#else +#define HWY_CHECK_PCLMUL_AES 0 +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) || (defined(__BMI2__) && defined(__FMA__)) +#define HWY_CHECK_BMI2_FMA 1 +#else +#define HWY_CHECK_BMI2_FMA 0 +#endif + +#if defined(HWY_DISABLE_F16C) || defined(__F16C__) +#define HWY_CHECK_F16C 1 +#else +#define HWY_CHECK_F16C 0 +#endif + +#endif // non-MSVC + +#if HWY_ARCH_X86 && (HWY_WANT_SSSE3 || HWY_CHECK_SSSE3) +#define HWY_BASELINE_SSSE3 HWY_SSSE3 +#else +#define HWY_BASELINE_SSSE3 0 +#endif + +#if HWY_ARCH_X86 && (HWY_WANT_SSE4 || (HWY_CHECK_SSE4 && HWY_CHECK_PCLMUL_AES)) +#define HWY_BASELINE_SSE4 HWY_SSE4 +#else +#define HWY_BASELINE_SSE4 0 +#endif + +#if HWY_BASELINE_SSE4 != 0 && HWY_CHECK_BMI2_FMA && HWY_CHECK_F16C && \ + defined(__AVX2__) +#define HWY_BASELINE_AVX2 HWY_AVX2 +#else +#define HWY_BASELINE_AVX2 0 +#endif + +// Require everything in AVX2 plus AVX-512 flags (also set by MSVC) +#if HWY_BASELINE_AVX2 != 0 && defined(__AVX512F__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__) && defined(__AVX512VL__) +#define HWY_BASELINE_AVX3 HWY_AVX3 +#else +#define HWY_BASELINE_AVX3 0 +#endif + +// TODO(janwas): not yet known whether these will be set by MSVC +#if HWY_BASELINE_AVX3 != 0 && defined(__AVXVNNI__) && defined(__VAES__) && \ + defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \ + defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \ + defined(__AVX512BITALG__) +#define HWY_BASELINE_AVX3_DL HWY_AVX3_DL +#else +#define HWY_BASELINE_AVX3_DL 0 +#endif + +#if HWY_ARCH_RVV && defined(__riscv_vector) +#define HWY_BASELINE_RVV HWY_RVV +#else +#define HWY_BASELINE_RVV 0 +#endif + +// Allow the user to override this without any guarantee of success. +#ifndef HWY_BASELINE_TARGETS +#define HWY_BASELINE_TARGETS \ + (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \ + HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | HWY_BASELINE_NEON | \ + HWY_BASELINE_SSSE3 | HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | \ + HWY_BASELINE_AVX3 | HWY_BASELINE_AVX3_DL | HWY_BASELINE_RVV) +#endif // HWY_BASELINE_TARGETS + +//------------------------------------------------------------------------------ +// Choose target for static dispatch + +#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS) +#if HWY_ENABLED_BASELINE == 0 +#error "At least one baseline target must be defined and enabled" +#endif + +// Best baseline, used for static dispatch. This is the least-significant 1-bit +// within HWY_ENABLED_BASELINE and lower bit values imply "better". +#define HWY_STATIC_TARGET (HWY_ENABLED_BASELINE & -HWY_ENABLED_BASELINE) + +// Start by assuming static dispatch. If we later use dynamic dispatch, this +// will be defined to other targets during the multiple-inclusion, and finally +// return to the initial value. Defining this outside begin/end_target ensures +// inl headers successfully compile by themselves (required by Bazel). +#define HWY_TARGET HWY_STATIC_TARGET + +//------------------------------------------------------------------------------ +// Choose targets for dynamic dispatch according to one of four policies + +#if 1 < (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_EMU128) + \ + defined(HWY_COMPILE_ONLY_STATIC)) +#error "Can only define one of HWY_COMPILE_ONLY_{SCALAR|EMU128|STATIC} - bug?" +#endif +// Defining one of HWY_COMPILE_ONLY_* will trump HWY_COMPILE_ALL_ATTAINABLE. + +// Clang, GCC and MSVC allow runtime dispatch on x86. +#if HWY_ARCH_X86 +#define HWY_HAVE_RUNTIME_DISPATCH 1 +// On Arm, currently only GCC does, and we require Linux to detect CPU +// capabilities. +#elif HWY_ARCH_ARM && HWY_COMPILER_GCC_ACTUAL && HWY_OS_LINUX && !defined(TOOLCHAIN_MISS_SYS_AUXV_H) +#define HWY_HAVE_RUNTIME_DISPATCH 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH 0 +#endif + +// AVX3_DL is not widely available yet. To reduce code size and compile time, +// only include it in the set of attainable targets (for dynamic dispatch) if +// the user opts in, OR it is in the baseline (we check whether enabled below). +#if defined(HWY_WANT_AVX3_DL) || (HWY_BASELINE & HWY_AVX3_DL) +#define HWY_ATTAINABLE_AVX3_DL HWY_AVX3_DL +#else +#define HWY_ATTAINABLE_AVX3_DL 0 +#endif + +#if HWY_ARCH_ARM_A64 && (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256))) +#define HWY_ATTAINABLE_SVE HWY_ENABLED(HWY_SVE | HWY_SVE_256) +#else +#define HWY_ATTAINABLE_SVE 0 +#endif + +#if HWY_ARCH_ARM_A64 && (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128))) +#define HWY_ATTAINABLE_SVE2 HWY_ENABLED(HWY_SVE2 | HWY_SVE2_128) +#else +#define HWY_ATTAINABLE_SVE2 0 +#endif + +// Attainable means enabled and the compiler allows intrinsics (even when not +// allowed to autovectorize). Used in 3 and 4. +#if HWY_ARCH_X86 +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSSE3 | HWY_SSE4 | HWY_AVX2 | \ + HWY_AVX3 | HWY_ATTAINABLE_AVX3_DL) +#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_NEON | HWY_ATTAINABLE_SVE | \ + HWY_ATTAINABLE_SVE2) +#else +#define HWY_ATTAINABLE_TARGETS \ + (HWY_ENABLED_BASELINE | HWY_ATTAINABLE_SVE | HWY_ATTAINABLE_SVE2) +#endif + +// 1) For older compilers: avoid SIMD intrinsics, but still support all ops. +#if defined(HWY_COMPILE_ONLY_EMU128) && !HWY_BROKEN_EMU128 +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_EMU128 // override baseline +#define HWY_TARGETS HWY_EMU128 + +// 1b) HWY_SCALAR is less capable than HWY_EMU128 (which supports all ops), but +// we currently still support it for backwards compatibility. +#elif defined(HWY_COMPILE_ONLY_SCALAR) || \ + (defined(HWY_COMPILE_ONLY_EMU128) && HWY_BROKEN_EMU128) +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_SCALAR // override baseline +#define HWY_TARGETS HWY_SCALAR + +// 2) For forcing static dispatch without code changes (removing HWY_EXPORT) +#elif defined(HWY_COMPILE_ONLY_STATIC) +#define HWY_TARGETS HWY_STATIC_TARGET + +// 3) For tests: include all attainable targets (in particular: scalar) +#elif defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST) +#define HWY_TARGETS HWY_ATTAINABLE_TARGETS + +// 4) Default: attainable WITHOUT non-best baseline. This reduces code size by +// excluding superseded targets, in particular scalar. Note: HWY_STATIC_TARGET +// may be 2^62 (HWY_SCALAR), so we must not left-shift/add it. Subtracting one +// sets all lower bits (better targets), then we also include the static target. +#else +#define HWY_TARGETS \ + (HWY_ATTAINABLE_TARGETS & ((HWY_STATIC_TARGET - 1LL) | HWY_STATIC_TARGET)) + +#endif // target policy + +// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being +// one of the dynamic targets. This also implies HWY_TARGETS != 0 and +// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0. +#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0 +#error "Logic error: best baseline should be included in dynamic targets" +#endif + +#endif // HIGHWAY_HWY_DETECT_TARGETS_H_ diff --git a/third_party/highway/hwy/examples/benchmark.cc b/third_party/highway/hwy/examples/benchmark.cc new file mode 100644 index 0000000000..55afd3bcca --- /dev/null +++ b/third_party/highway/hwy/examples/benchmark.cc @@ -0,0 +1,255 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stddef.h> +#include <stdint.h> +#include <stdio.h> + +#include <cmath> // std::abs +#include <memory> +#include <numeric> // std::iota, std::inner_product + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/examples/benchmark.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// Must come after foreach_target.h to avoid redefinition errors. +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" +#include "hwy/nanobenchmark.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::CombineShiftRightLanes; +#endif + +class TwoArray { + public: + // Must be a multiple of the vector lane count * 8. + static size_t NumItems() { return 3456; } + + TwoArray() + : a_(AllocateAligned<float>(NumItems() * 2)), b_(a_.get() + NumItems()) { + // = 1, but compiler doesn't know + const float init = static_cast<float>(Unpredictable1()); + std::iota(a_.get(), a_.get() + NumItems(), init); + std::iota(b_, b_ + NumItems(), init); + } + + protected: + AlignedFreeUniquePtr<float[]> a_; + float* b_; +}; + +// Measures durations, verifies results, prints timings. +template <class Benchmark> +void RunBenchmark(const char* caption) { + printf("%10s: ", caption); + const size_t kNumInputs = 1; + const size_t num_items = Benchmark::NumItems() * size_t(Unpredictable1()); + const FuncInput inputs[kNumInputs] = {num_items}; + Result results[kNumInputs]; + + Benchmark benchmark; + + Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&benchmark](const FuncInput input) { return benchmark(input); }, inputs, + kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + } + + benchmark.Verify(num_items); + + for (size_t i = 0; i < num_results; ++i) { + const double cycles_per_item = + results[i].ticks / static_cast<double>(results[i].input); + const double mad = results[i].variability * cycles_per_item; + printf("%6" PRIu64 ": %6.3f (+/- %5.3f)\n", + static_cast<uint64_t>(results[i].input), cycles_per_item, mad); + } +} + +void Intro() { + const float in[16] = {1, 2, 3, 4, 5, 6}; + float out[16]; + const ScalableTag<float> d; // largest possible vector + for (size_t i = 0; i < 16; i += Lanes(d)) { + const auto vec = LoadU(d, in + i); // no alignment requirement + auto result = Mul(vec, vec); + result = Add(result, result); // can update if not const + StoreU(result, d, out + i); + } + printf("\nF(x)->2*x^2, F(%.0f) = %.1f\n", in[2], out[2]); +} + +// BEGINNER: dot product +// 0.4 cyc/float = bronze, 0.25 = silver, 0.15 = gold! +class BenchmarkDot : public TwoArray { + public: + BenchmarkDot() : dot_{-1.0f} {} + + FuncOutput operator()(const size_t num_items) { + const ScalableTag<float> d; + const size_t N = Lanes(d); + using V = decltype(Zero(d)); + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // We cannot use an array because V might be a sizeless type. For reasonable + // code, we unroll 4x, but 8x might help (2 FMA ports * 4 cycle latency). + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + const float* const HWY_RESTRICT pa = &a_[0]; + const float* const HWY_RESTRICT pb = b_; + for (size_t i = 0; i < num_items; i += 4 * N) { + const auto a0 = Load(d, pa + i + 0 * N); + const auto b0 = Load(d, pb + i + 0 * N); + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = Load(d, pa + i + 1 * N); + const auto b1 = Load(d, pb + i + 1 * N); + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = Load(d, pa + i + 2 * N); + const auto b2 = Load(d, pb + i + 2 * N); + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = Load(d, pa + i + 3 * N); + const auto b3 = Load(d, pb + i + 3 * N); + sum3 = MulAdd(a3, b3, sum3); + } + // Reduction tree: sum of all accumulators by pairs into sum0. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + dot_ = GetLane(SumOfLanes(d, sum0)); + return static_cast<FuncOutput>(dot_); + } + void Verify(size_t num_items) { + if (dot_ == -1.0f) { + fprintf(stderr, "Dot: must call Verify after benchmark"); + abort(); + } + + const float expected = + std::inner_product(a_.get(), a_.get() + num_items, b_, 0.0f); + const float rel_err = std::abs(expected - dot_) / expected; + if (rel_err > 1.1E-6f) { + fprintf(stderr, "Dot: expected %e actual %e (%e)\n", expected, dot_, + rel_err); + abort(); + } + } + + private: + float dot_; // for Verify +}; + +// INTERMEDIATE: delta coding +// 1.0 cycles/float = bronze, 0.7 = silver, 0.4 = gold! +struct BenchmarkDelta : public TwoArray { + FuncOutput operator()(const size_t num_items) const { +#if HWY_TARGET == HWY_SCALAR + b_[0] = a_[0]; + for (size_t i = 1; i < num_items; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } +#elif HWY_CAP_GE256 + // Larger vectors are split into 128-bit blocks, easiest to use the + // unaligned load support to shift between them. + const ScalableTag<float> df; + const size_t N = Lanes(df); + size_t i; + b_[0] = a_[0]; + for (i = 1; i < N; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } + for (; i < num_items; i += N) { + const auto a = Load(df, &a_[i]); + const auto shifted = LoadU(df, &a_[i - 1]); + Store(a - shifted, df, &b_[i]); + } +#else // 128-bit + // Slightly better than unaligned loads + const HWY_CAPPED(float, 4) df; + const size_t N = Lanes(df); + size_t i; + b_[0] = a_[0]; + for (i = 1; i < N; ++i) { + b_[i] = a_[i] - a_[i - 1]; + } + auto prev = Load(df, &a_[0]); + for (; i < num_items; i += Lanes(df)) { + const auto a = Load(df, &a_[i]); + const auto shifted = CombineShiftRightLanes<3>(df, a, prev); + prev = a; + Store(Sub(a, shifted), df, &b_[i]); + } +#endif + return static_cast<FuncOutput>(b_[num_items - 1]); + } + + void Verify(size_t num_items) { + for (size_t i = 0; i < num_items; ++i) { + const float expected = (i == 0) ? a_[0] : a_[i] - a_[i - 1]; + const float err = std::abs(expected - b_[i]); + if (err > 1E-6f) { + fprintf(stderr, "Delta: expected %e, actual %e\n", expected, b_[i]); + } + } + } +}; + +void RunBenchmarks() { + Intro(); + printf("------------------------ %s\n", TargetName(HWY_TARGET)); + RunBenchmark<BenchmarkDot>("dot"); + RunBenchmark<BenchmarkDelta>("delta"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +HWY_EXPORT(RunBenchmarks); + +void Run() { + for (int64_t target : SupportedAndGeneratedTargets()) { + SetSupportedTargetsForTest(target); + HWY_DYNAMIC_DISPATCH(RunBenchmarks)(); + } + SetSupportedTargetsForTest(0); // Reset the mask afterwards. +} + +} // namespace hwy + +int main(int /*argc*/, char** /*argv*/) { + hwy::Run(); + return 0; +} +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/examples/skeleton-inl.h b/third_party/highway/hwy/examples/skeleton-inl.h new file mode 100644 index 0000000000..8aec33e666 --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton-inl.h @@ -0,0 +1,66 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo of functions that might be called from multiple SIMD modules (either +// other -inl.h files, or a .cc file between begin/end_target-inl). This is +// optional - all SIMD code can reside in .cc files. However, this allows +// splitting code into different files while still inlining instead of requiring +// calling through function pointers. + +// Per-target include guard. This is only required when using dynamic dispatch, +// i.e. including foreach_target.h. For static dispatch, a normal include +// guard would be fine because the header is only compiled once. +#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#else +#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#endif + +// It is fine to #include normal or *-inl headers. +#include <stddef.h> + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +// Highway ops reside here; ADL does not find templates nor builtins. +namespace hn = hwy::HWY_NAMESPACE; + +// Example of a type-agnostic (caller-specified lane type) and width-agnostic +// (uses best available instruction set) function in a header. +// +// Computes x[i] = mul_array[i] * x_array[i] + add_array[i] for i < size. +template <class D, typename T> +HWY_MAYBE_UNUSED void MulAddLoop(const D d, const T* HWY_RESTRICT mul_array, + const T* HWY_RESTRICT add_array, + const size_t size, T* HWY_RESTRICT x_array) { + for (size_t i = 0; i < size; i += hn::Lanes(d)) { + const auto mul = hn::Load(d, mul_array + i); + const auto add = hn::Load(d, add_array + i); + auto x = hn::Load(d, x_array + i); + x = hn::MulAdd(mul, x, add); + hn::Store(x, d, x_array + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#endif // include guard diff --git a/third_party/highway/hwy/examples/skeleton.cc b/third_party/highway/hwy/examples/skeleton.cc new file mode 100644 index 0000000000..778ba4ac0a --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton.cc @@ -0,0 +1,122 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/examples/skeleton.h" + +#include <stdio.h> + +// >>>> for dynamic dispatch only, skip if you want static dispatch + +// First undef to prevent error when re-included. +#undef HWY_TARGET_INCLUDE +// For dynamic dispatch, specify the name of the current file (unfortunately +// __FILE__ is not reliable) so that foreach_target.h can re-include it. +#define HWY_TARGET_INCLUDE "hwy/examples/skeleton.cc" +// Generates code for each enabled target by re-including this source file. +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// <<<< end of dynamic dispatch + +// Must come after foreach_target.h to avoid redefinition errors. +#include "hwy/highway.h" + +// Optional, can instead add HWY_ATTR to all functions. +HWY_BEFORE_NAMESPACE(); + +namespace skeleton { +// This namespace name is unique per target, which allows code for multiple +// targets to co-exist in the same translation unit. Required when using dynamic +// dispatch, otherwise optional. +namespace HWY_NAMESPACE { + +// Highway ops reside here; ADL does not find templates nor builtins. +namespace hn = hwy::HWY_NAMESPACE; + +// Computes log2 by converting to a vector of floats. Compiled once per target. +template <class DF> +HWY_ATTR_NO_MSAN void OneFloorLog2(const DF df, + const uint8_t* HWY_RESTRICT values, + uint8_t* HWY_RESTRICT log2) { + // Type tags for converting to other element types (Rebind = same count). + const hn::RebindToSigned<DF> d32; + const hn::Rebind<uint8_t, DF> d8; + using VI32 = hn::Vec<decltype(d32)>; + + const VI32 vi32 = hn::PromoteTo(d32, hn::Load(d8, values)); + const VI32 bits = hn::BitCast(d32, hn::ConvertTo(df, vi32)); + const VI32 exponent = hn::Sub(hn::ShiftRight<23>(bits), hn::Set(d32, 127)); + hn::Store(hn::DemoteTo(d8, exponent), d8, log2); +} + +void CodepathDemo() { + // Highway defaults to portability, but per-target codepaths may be selected + // via #if HWY_TARGET == HWY_SSE4 or by testing capability macros: +#if HWY_HAVE_INTEGER64 + const char* gather = "Has int64"; +#else + const char* gather = "No int64"; +#endif + printf("Target %s: %s\n", hwy::TargetName(HWY_TARGET), gather); +} + +void FloorLog2(const uint8_t* HWY_RESTRICT values, size_t count, + uint8_t* HWY_RESTRICT log2) { + CodepathDemo(); + + const hn::ScalableTag<float> df; + const size_t N = hn::Lanes(df); + size_t i = 0; + for (; i + N <= count; i += N) { + OneFloorLog2(df, values + i, log2 + i); + } + for (; i < count; ++i) { + hn::CappedTag<float, 1> d1; + OneFloorLog2(d1, values + i, log2 + i); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +// The table of pointers to the various implementations in HWY_NAMESPACE must +// be compiled only once (foreach_target #includes this file multiple times). +// HWY_ONCE is true for only one of these 'compilation passes'. +#if HWY_ONCE + +namespace skeleton { + +// This macro declares a static array used for dynamic dispatch; it resides in +// the same outer namespace that contains FloorLog2. +HWY_EXPORT(FloorLog2); + +// This function is optional and only needed in the case of exposing it in the +// header file. Otherwise using HWY_DYNAMIC_DISPATCH(FloorLog2) in this module +// is equivalent to inlining this function. +HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, + const size_t count, + uint8_t* HWY_RESTRICT out) { + // This must reside outside of HWY_NAMESPACE because it references (calls the + // appropriate one from) the per-target implementations there. + // For static dispatch, use HWY_STATIC_DISPATCH. + return HWY_DYNAMIC_DISPATCH(FloorLog2)(in, count, out); +} + +// Optional: anything to compile only once, e.g. non-SIMD implementations of +// public functions provided by this module, can go inside #if HWY_ONCE. + +} // namespace skeleton +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/examples/skeleton.h b/third_party/highway/hwy/examples/skeleton.h new file mode 100644 index 0000000000..381ac69af6 --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton.h @@ -0,0 +1,36 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Demo interface to target-specific code in skeleton.cc + +// Normal header with include guard and namespace. +#ifndef HIGHWAY_HWY_EXAMPLES_SKELETON_H_ +#define HIGHWAY_HWY_EXAMPLES_SKELETON_H_ + +#include <stddef.h> + +// Platform-specific definitions used for declaring an interface, independent of +// the SIMD instruction set. +#include "hwy/base.h" // HWY_RESTRICT + +namespace skeleton { + +// Computes base-2 logarithm by converting to float. Supports dynamic dispatch. +HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, + const size_t count, uint8_t* HWY_RESTRICT out); + +} // namespace skeleton + +#endif // HIGHWAY_HWY_EXAMPLES_SKELETON_H_ diff --git a/third_party/highway/hwy/examples/skeleton_test.cc b/third_party/highway/hwy/examples/skeleton_test.cc new file mode 100644 index 0000000000..c7c26bf5b4 --- /dev/null +++ b/third_party/highway/hwy/examples/skeleton_test.cc @@ -0,0 +1,110 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Example of unit test for the "skeleton" library. + +#include "hwy/examples/skeleton.h" + +#include <stdio.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "examples/skeleton_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// Must come after foreach_target.h to avoid redefinition errors. +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Optional: factor out parts of the implementation into *-inl.h +// (must also come after foreach_target.h to avoid redefinition errors) +#include "hwy/examples/skeleton-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace skeleton { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +// Calls function defined in skeleton.cc. +struct TestFloorLog2 { + template <class T, class DF> + HWY_NOINLINE void operator()(T /*unused*/, DF df) { + const size_t count = 5 * hn::Lanes(df); + auto in = hwy::AllocateAligned<uint8_t>(count); + auto expected = hwy::AllocateAligned<uint8_t>(count); + + hwy::RandomState rng; + for (size_t i = 0; i < count; ++i) { + expected[i] = Random32(&rng) & 7; + in[i] = static_cast<uint8_t>(1u << expected[i]); + } + auto out = hwy::AllocateAligned<uint8_t>(count); + CallFloorLog2(in.get(), count, out.get()); + int sum = 0; + for (size_t i = 0; i < count; ++i) { + HWY_ASSERT_EQ(expected[i], out[i]); + sum += out[i]; + } + hwy::PreventElision(sum); + } +}; + +HWY_NOINLINE void TestAllFloorLog2() { + hn::ForPartialVectors<TestFloorLog2>()(float()); +} + +// Calls function defined in skeleton-inl.h. +struct TestSumMulAdd { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + hwy::RandomState rng; + const size_t count = 4096; + EXPECT_EQ(0, count % hn::Lanes(d)); + auto mul = hwy::AllocateAligned<T>(count); + auto x = hwy::AllocateAligned<T>(count); + auto add = hwy::AllocateAligned<T>(count); + for (size_t i = 0; i < count; ++i) { + mul[i] = static_cast<T>(Random32(&rng) & 0xF); + x[i] = static_cast<T>(Random32(&rng) & 0xFF); + add[i] = static_cast<T>(Random32(&rng) & 0xFF); + } + double expected_sum = 0.0; + for (size_t i = 0; i < count; ++i) { + expected_sum += mul[i] * x[i] + add[i]; + } + + MulAddLoop(d, mul.get(), add.get(), count, x.get()); + HWY_ASSERT_EQ(4344240.0, expected_sum); + } +}; + +HWY_NOINLINE void TestAllSumMulAdd() { + hn::ForFloatTypes(hn::ForPartialVectors<TestSumMulAdd>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace skeleton +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace skeleton { +HWY_BEFORE_TEST(SkeletonTest); +HWY_EXPORT_AND_TEST_P(SkeletonTest, TestAllFloorLog2); +HWY_EXPORT_AND_TEST_P(SkeletonTest, TestAllSumMulAdd); +} // namespace skeleton + +#endif diff --git a/third_party/highway/hwy/foreach_target.h b/third_party/highway/hwy/foreach_target.h new file mode 100644 index 0000000000..3929905ca2 --- /dev/null +++ b/third_party/highway/hwy/foreach_target.h @@ -0,0 +1,261 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_FOREACH_TARGET_H_ +#define HIGHWAY_HWY_FOREACH_TARGET_H_ + +// Re-includes the translation unit zero or more times to compile for any +// targets except HWY_STATIC_TARGET. Defines unique HWY_TARGET each time so that +// highway.h defines the corresponding macro/namespace. + +#include "hwy/detect_targets.h" + +// *_inl.h may include other headers, which requires include guards to prevent +// repeated inclusion. The guards must be reset after compiling each target, so +// the header is again visible. This is done by flipping HWY_TARGET_TOGGLE, +// defining it if undefined and vice versa. This macro is initially undefined +// so that IDEs don't gray out the contents of each header. +#ifdef HWY_TARGET_TOGGLE +#error "This macro must not be defined outside foreach_target.h" +#endif + +#ifdef HWY_HIGHWAY_INCLUDED // highway.h include guard +// Trigger fixup at the bottom of this header. +#define HWY_ALREADY_INCLUDED + +// The next highway.h must re-include set_macros-inl.h because the first +// highway.h chose the static target instead of what we will set below. +#undef HWY_SET_MACROS_PER_TARGET +#endif + +// Disable HWY_EXPORT in user code until we have generated all targets. Note +// that a subsequent highway.h will not override this definition. +#undef HWY_ONCE +#define HWY_ONCE (0 || HWY_IDE) + +// Avoid warnings on #include HWY_TARGET_INCLUDE by hiding them from the IDE; +// also skip if only 1 target defined (no re-inclusion will be necessary). +#if !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +#if !defined(HWY_TARGET_INCLUDE) +#error ">1 target enabled => define HWY_TARGET_INCLUDE before foreach_target.h" +#endif + +#if (HWY_TARGETS & HWY_EMU128) && (HWY_STATIC_TARGET != HWY_EMU128) +#undef HWY_TARGET +#define HWY_TARGET HWY_EMU128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SCALAR) && (HWY_STATIC_TARGET != HWY_SCALAR) +#undef HWY_TARGET +#define HWY_TARGET HWY_SCALAR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON) && (HWY_STATIC_TARGET != HWY_NEON) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_RVV) && (HWY_STATIC_TARGET != HWY_RVV) +#undef HWY_TARGET +#define HWY_TARGET HWY_RVV +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2) && (HWY_STATIC_TARGET != HWY_SVE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE_256) && (HWY_STATIC_TARGET != HWY_SVE_256) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE_256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2_128) && (HWY_STATIC_TARGET != HWY_SVE2_128) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2_128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSSE3) && (HWY_STATIC_TARGET != HWY_SSSE3) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSSE3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSE4) && (HWY_STATIC_TARGET != HWY_SSE4) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX2) && (HWY_STATIC_TARGET != HWY_AVX2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3) && (HWY_STATIC_TARGET != HWY_AVX3) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_DL) && (HWY_STATIC_TARGET != HWY_AVX3_DL) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_DL +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM_EMU256) && (HWY_STATIC_TARGET != HWY_WASM_EMU256) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM_EMU256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM) && (HWY_STATIC_TARGET != HWY_WASM) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC8) && (HWY_STATIC_TARGET != HWY_PPC8) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC8 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#endif // !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +// Now that all but the static target have been generated, re-enable HWY_EXPORT. +#undef HWY_ONCE +#define HWY_ONCE 1 + +// If we re-include once per enabled target, the translation unit's +// implementation would have to be skipped via #if to avoid redefining symbols. +// We instead skip the re-include for HWY_STATIC_TARGET, and generate its +// implementation when resuming compilation of the translation unit. +#undef HWY_TARGET +#define HWY_TARGET HWY_STATIC_TARGET + +#ifdef HWY_ALREADY_INCLUDED +// Revert the previous toggle to prevent redefinitions for the static target. +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif + +// Force re-inclusion of set_macros-inl.h now that HWY_TARGET is restored. +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif +#endif + +#endif // HIGHWAY_HWY_FOREACH_TARGET_H_ diff --git a/third_party/highway/hwy/highway.h b/third_party/highway/hwy/highway.h new file mode 100644 index 0000000000..8a7a7531e5 --- /dev/null +++ b/third_party/highway/hwy/highway.h @@ -0,0 +1,378 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This include guard is checked by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included +// after/outside this include guard. +#ifndef HWY_HIGHWAY_INCLUDED +#define HWY_HIGHWAY_INCLUDED + +// Main header required before using vector types. + +#include "hwy/base.h" +#include "hwy/targets.h" + +namespace hwy { + +// API version (https://semver.org/); keep in sync with CMakeLists.txt. +#define HWY_MAJOR 1 +#define HWY_MINOR 0 +#define HWY_PATCH 3 + +//------------------------------------------------------------------------------ +// Shorthand for tags (defined in shared-inl.h) used to select overloads. +// Note that ScalableTag<T> is preferred over HWY_FULL, and CappedTag<T, N> over +// HWY_CAPPED(T, N). + +// HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of +// registers in the group, and is ignored on targets that do not support groups. +#define HWY_FULL1(T) hwy::HWY_NAMESPACE::ScalableTag<T> +#define HWY_FULL2(T, LMUL) \ + hwy::HWY_NAMESPACE::ScalableTag<T, hwy::CeilLog2(HWY_MAX(0, LMUL))> +#define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3 +// Workaround for MSVC grouping __VA_ARGS__ into a single argument +#define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren +// Trailing comma avoids -pedantic false alarm +#define HWY_CHOOSE_FULL(...) \ + HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, )) +#define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__) + +// Vector of up to MAX_N lanes. It's better to use full vectors where possible. +#define HWY_CAPPED(T, MAX_N) hwy::HWY_NAMESPACE::CappedTag<T, MAX_N> + +//------------------------------------------------------------------------------ +// Export user functions for static/dynamic dispatch + +// Evaluates to 0 inside a translation unit if it is generating anything but the +// static target (the last one if multiple targets are enabled). Used to prevent +// redefinitions of HWY_EXPORT. Unless foreach_target.h is included, we only +// compile once anyway, so this is 1 unless it is or has been included. +#ifndef HWY_ONCE +#define HWY_ONCE 1 +#endif + +// HWY_STATIC_DISPATCH(FUNC_NAME) is the namespace-qualified FUNC_NAME for +// HWY_STATIC_TARGET (the only defined namespace unless HWY_TARGET_INCLUDE is +// defined), and can be used to deduce the return type of Choose*. +#if HWY_STATIC_TARGET == HWY_SCALAR +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SCALAR::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_EMU128 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_EMU128::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_RVV +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_RVV::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM_EMU256 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM_EMU256::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_WASM +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_NEON +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE_256 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE_256::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SVE2_128 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2_128::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_PPC8 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC8::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSSE3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSSE3::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_SSE4 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE4::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX2 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX2::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3 +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3::FUNC_NAME +#elif HWY_STATIC_TARGET == HWY_AVX3_DL +#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_DL::FUNC_NAME +#endif + +// HWY_CHOOSE_*(FUNC_NAME) expands to the function pointer for that target or +// nullptr is that target was not compiled. +#if HWY_TARGETS & HWY_EMU128 +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_EMU128::FUNC_NAME +#elif HWY_TARGETS & HWY_SCALAR +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_SCALAR::FUNC_NAME +#else +// When HWY_SCALAR/HWY_EMU128 are not present and other targets were disabled at +// runtime, fall back to the baseline with HWY_STATIC_DISPATCH(). +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#endif + +#if HWY_TARGETS & HWY_WASM_EMU256 +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) &N_WASM_EMU256::FUNC_NAME +#else +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_WASM +#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME +#else +#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_RVV +#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME +#else +#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_NEON +#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME +#else +#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE +#define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME +#else +#define HWY_CHOOSE_SVE(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE2 +#define HWY_CHOOSE_SVE2(FUNC_NAME) &N_SVE2::FUNC_NAME +#else +#define HWY_CHOOSE_SVE2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE_256 +#define HWY_CHOOSE_SVE_256(FUNC_NAME) &N_SVE_256::FUNC_NAME +#else +#define HWY_CHOOSE_SVE_256(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SVE2_128 +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) &N_SVE2_128::FUNC_NAME +#else +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_PPC8 +#define HWY_CHOOSE_PCC8(FUNC_NAME) &N_PPC8::FUNC_NAME +#else +#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSSE3 +#define HWY_CHOOSE_SSSE3(FUNC_NAME) &N_SSSE3::FUNC_NAME +#else +#define HWY_CHOOSE_SSSE3(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_SSE4 +#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME +#else +#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX2 +#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME +#else +#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3 +#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr +#endif + +#if HWY_TARGETS & HWY_AVX3_DL +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) &N_AVX3_DL::FUNC_NAME +#else +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) nullptr +#endif + +// MSVC 2017 workaround: the non-type template parameter to ChooseAndCall +// apparently cannot be an array. Use a function pointer instead, which has the +// disadvantage that we call the static (not best) target on the first call to +// any HWY_DYNAMIC_DISPATCH. +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915 +#define HWY_DISPATCH_WORKAROUND 1 +#else +#define HWY_DISPATCH_WORKAROUND 0 +#endif + +// Provides a static member function which is what is called during the first +// HWY_DYNAMIC_DISPATCH, where GetIndex is still zero, and instantiations of +// this function are the first entry in the tables created by HWY_EXPORT. +template <typename RetType, typename... Args> +struct FunctionCache { + public: + typedef RetType(FunctionType)(Args...); + +#if HWY_DISPATCH_WORKAROUND + template <FunctionType* const func> + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (*func)(args...); + } +#else + // A template function that when instantiated has the same signature as the + // function being called. This function initializes the bit array of targets + // supported by the current CPU and then calls the appropriate entry within + // the HWY_EXPORT table. Subsequent calls via HWY_DYNAMIC_DISPATCH to any + // exported functions, even those defined by different translation units, + // will dispatch directly to the best available target. + template <FunctionType* const table[]> + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // HWY_DISPATCH_WORKAROUND +}; + +// Used to deduce the template parameters RetType and Args from a function. +template <typename RetType, typename... Args> +FunctionCache<RetType, Args...> DeduceFunctionCache(RetType (*)(Args...)) { + return FunctionCache<RetType, Args...>(); +} + +#define HWY_DISPATCH_TABLE(FUNC_NAME) \ + HWY_CONCAT(FUNC_NAME, HighwayDispatchTable) + +// HWY_EXPORT(FUNC_NAME); expands to a static array that is used by +// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. This +// static array must be defined at the same namespace level as the function +// it is exporting. +// After being exported, it can be called from other parts of the same source +// file using HWY_DYNAMIC_DISPATCH(), in particular from a function wrapper +// like in the following example: +// +// #include "hwy/highway.h" +// HWY_BEFORE_NAMESPACE(); +// namespace skeleton { +// namespace HWY_NAMESPACE { +// +// void MyFunction(int a, char b, const char* c) { ... } +// +// // NOLINTNEXTLINE(google-readability-namespace-comments) +// } // namespace HWY_NAMESPACE +// } // namespace skeleton +// HWY_AFTER_NAMESPACE(); +// +// namespace skeleton { +// HWY_EXPORT(MyFunction); // Defines the dispatch table in this scope. +// +// void MyFunction(int a, char b, const char* c) { +// return HWY_DYNAMIC_DISPATCH(MyFunction)(a, b, c); +// } +// } // namespace skeleton +// + +#if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// Simplified version for IDE or the dynamic dispatch case with only one target. +// This case still uses a table, although of a single element, to provide the +// same compile error conditions as with the dynamic dispatch case when multiple +// targets are being compiled. +#define HWY_EXPORT(FUNC_NAME) \ + HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const \ + HWY_DISPATCH_TABLE(FUNC_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)} +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) + +#else + +// Simplified version for MSVC 2017: function pointer instead of table. +#if HWY_DISPATCH_WORKAROUND + +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the function from HWY_STATIC_TARGET. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH( \ + FUNC_NAME)))::ChooseAndCall<&HWY_STATIC_DISPATCH(FUNC_NAME)>, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } + +#else + +// Dynamic dispatch case with one entry per dynamic target plus the fallback +// target and the initialization wrapper. +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + FUNC_NAME)[HWY_MAX_DYNAMIC_TARGETS + 2] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH( \ + FUNC_NAME)))::ChooseAndCall<HWY_DISPATCH_TABLE(FUNC_NAME)>, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } + +#endif // HWY_DISPATCH_WORKAROUND + +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \ + (*(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()])) + +#endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// DEPRECATED names; please use HWY_HAVE_* instead. +#define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64 +#define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16 +#define HWY_CAP_FLOAT64 HWY_HAVE_FLOAT64 + +} // namespace hwy + +#endif // HWY_HIGHWAY_INCLUDED + +//------------------------------------------------------------------------------ + +// NOTE: the following definitions and ops/*.h depend on HWY_TARGET, so we want +// to include them once per target, which is ensured by the toggle check. +// Because ops/*.h are included under it, they do not need their own guard. +#if defined(HWY_HIGHWAY_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_HIGHWAY_PER_TARGET +#undef HWY_HIGHWAY_PER_TARGET +#else +#define HWY_HIGHWAY_PER_TARGET +#endif + +// These define ops inside namespace hwy::HWY_NAMESPACE. +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 +#include "hwy/ops/x86_128-inl.h" +#elif HWY_TARGET == HWY_AVX2 +#include "hwy/ops/x86_256-inl.h" +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL +#include "hwy/ops/x86_512-inl.h" +#elif HWY_TARGET == HWY_PPC8 +#error "PPC is not yet supported" +#elif HWY_TARGET == HWY_NEON +#include "hwy/ops/arm_neon-inl.h" +#elif HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || \ + HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#include "hwy/ops/arm_sve-inl.h" +#elif HWY_TARGET == HWY_WASM_EMU256 +#include "hwy/ops/wasm_256-inl.h" +#elif HWY_TARGET == HWY_WASM +#include "hwy/ops/wasm_128-inl.h" +#elif HWY_TARGET == HWY_RVV +#include "hwy/ops/rvv-inl.h" +#elif HWY_TARGET == HWY_EMU128 +#include "hwy/ops/emu128-inl.h" +#elif HWY_TARGET == HWY_SCALAR +#include "hwy/ops/scalar-inl.h" +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +#include "hwy/ops/generic_ops-inl.h" + +#endif // HWY_HIGHWAY_PER_TARGET diff --git a/third_party/highway/hwy/highway_export.h b/third_party/highway/hwy/highway_export.h new file mode 100644 index 0000000000..30edc17d01 --- /dev/null +++ b/third_party/highway/hwy/highway_export.h @@ -0,0 +1,74 @@ +// Pseudo-generated file to handle both cmake & bazel build system. + +// Initial generation done using cmake code: +// include(GenerateExportHeader) +// generate_export_header(hwy EXPORT_MACRO_NAME HWY_DLLEXPORT EXPORT_FILE_NAME +// hwy/highway_export.h) +// code reformatted using clang-format --style=Google + +#ifndef HWY_DLLEXPORT_H +#define HWY_DLLEXPORT_H + +#if !defined(HWY_SHARED_DEFINE) +#define HWY_DLLEXPORT +#define HWY_CONTRIB_DLLEXPORT +#define HWY_TEST_DLLEXPORT +#else // !HWY_SHARED_DEFINE + +#ifndef HWY_DLLEXPORT +#if defined(hwy_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllexport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllimport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_EXPORTS) +#endif // HWY_DLLEXPORT + +#ifndef HWY_CONTRIB_DLLEXPORT +#if defined(hwy_contrib_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllexport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_contrib_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllimport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_contrib_EXPORTS) +#endif // HWY_CONTRIB_DLLEXPORT + +#ifndef HWY_TEST_DLLEXPORT +#if defined(hwy_test_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllexport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_test_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllimport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_test_EXPORTS) +#endif // HWY_TEST_DLLEXPORT + +#endif // !HWY_SHARED_DEFINE + +#endif /* HWY_DLLEXPORT_H */ diff --git a/third_party/highway/hwy/highway_test.cc b/third_party/highway/hwy/highway_test.cc new file mode 100644 index 0000000000..d2caec067b --- /dev/null +++ b/third_party/highway/hwy/highway_test.cc @@ -0,0 +1,483 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> // std::fill +#include <bitset> + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "highway_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/nanobenchmark.h" // Unpredictable1 +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <size_t kLimit, typename T> +HWY_NOINLINE void TestCappedLimit(T /* tag */) { + CappedTag<T, kLimit> d; + // Ensure two ops compile + HWY_ASSERT_VEC_EQ(d, Zero(d), Set(d, T{0})); + + // Ensure we do not write more than kLimit lanes + const size_t N = Lanes(d); + if (kLimit < N) { + auto lanes = AllocateAligned<T>(N); + std::fill(lanes.get(), lanes.get() + N, T{0}); + Store(Set(d, T{1}), d, lanes.get()); + for (size_t i = kLimit; i < N; ++i) { + HWY_ASSERT_EQ(lanes[i], T{0}); + } + } +} + +// Adapter for ForAllTypes - we are constructing our own Simd<> and thus do not +// use ForPartialVectors etc. +struct TestCapped { + template <typename T> + void operator()(T t) const { + TestCappedLimit<1>(t); + TestCappedLimit<3>(t); + TestCappedLimit<5>(t); + TestCappedLimit<1ull << 15>(t); + } +}; + +HWY_NOINLINE void TestAllCapped() { ForAllTypes(TestCapped()); } + +// For testing that ForPartialVectors reaches every possible size: +using NumLanesSet = std::bitset<HWY_MAX_BYTES + 1>; + +// Monostate pattern because ForPartialVectors takes a template argument, not a +// functor by reference. +static NumLanesSet* NumLanesForSize(size_t sizeof_t) { + HWY_ASSERT(sizeof_t <= sizeof(uint64_t)); + static NumLanesSet num_lanes[sizeof(uint64_t) + 1]; + return num_lanes + sizeof_t; +} +static size_t* MaxLanesForSize(size_t sizeof_t) { + HWY_ASSERT(sizeof_t <= sizeof(uint64_t)); + static size_t num_lanes[sizeof(uint64_t) + 1] = {0}; + return num_lanes + sizeof_t; +} + +struct TestMaxLanes { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const size_t kMax = MaxLanes(d); // for RVV, includes LMUL + HWY_ASSERT(N <= kMax); + HWY_ASSERT(kMax <= (HWY_MAX_BYTES / sizeof(T))); + + NumLanesForSize(sizeof(T))->set(N); + *MaxLanesForSize(sizeof(T)) = HWY_MAX(*MaxLanesForSize(sizeof(T)), N); + } +}; + +HWY_NOINLINE void TestAllMaxLanes() { + ForAllTypes(ForPartialVectors<TestMaxLanes>()); + + // Ensure ForPartialVectors visited all powers of two [1, N]. + for (size_t sizeof_t : {sizeof(uint8_t), sizeof(uint16_t), sizeof(uint32_t), + sizeof(uint64_t)}) { + const size_t N = *MaxLanesForSize(sizeof_t); + for (size_t i = 1; i <= N; i += i) { + if (!NumLanesForSize(sizeof_t)->test(i)) { + fprintf(stderr, "T=%d: did not visit for N=%d, max=%d\n", + static_cast<int>(sizeof_t), static_cast<int>(i), + static_cast<int>(N)); + HWY_ASSERT(false); + } + } + } +} + +struct TestSet { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Zero + const auto v0 = Zero(d); + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + std::fill(expected.get(), expected.get() + N, T(0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), v0); + + // Set + const auto v2 = Set(d, T(2)); + for (size_t i = 0; i < N; ++i) { + expected[i] = 2; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), v2); + + // Iota + const auto vi = Iota(d, T(5)); + for (size_t i = 0; i < N; ++i) { + expected[i] = T(5 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), vi); + + // Undefined + const auto vu = Undefined(d); + Store(vu, d, expected.get()); + } +}; + +HWY_NOINLINE void TestAllSet() { ForAllTypes(ForPartialVectors<TestSet>()); } + +// Ensures wraparound (mod 2^bits) +struct TestOverflow { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(1)); + const auto vmax = Set(d, LimitsMax<T>()); + const auto vmin = Set(d, LimitsMin<T>()); + // Unsigned underflow / negative -> positive + HWY_ASSERT_VEC_EQ(d, vmax, Sub(vmin, v1)); + // Unsigned overflow / positive -> negative + HWY_ASSERT_VEC_EQ(d, vmin, Add(vmax, v1)); + } +}; + +HWY_NOINLINE void TestAllOverflow() { + ForIntegerTypes(ForPartialVectors<TestOverflow>()); +} + +struct TestClamp { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto v2 = Set(d, 2); + + HWY_ASSERT_VEC_EQ(d, v1, Clamp(v2, v0, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Clamp(v0, v1, v2)); + } +}; + +HWY_NOINLINE void TestAllClamp() { + ForAllTypes(ForPartialVectors<TestClamp>()); +} + +struct TestSignBitInteger { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto all = VecFromMask(d, Eq(v0, v0)); + const auto vs = SignBit(d); + const auto other = Sub(vs, Set(d, 1)); + + // Shifting left by one => overflow, equal zero + HWY_ASSERT_VEC_EQ(d, v0, Add(vs, vs)); + // Verify the lower bits are zero (only +/- and logical ops are available + // for all types) + HWY_ASSERT_VEC_EQ(d, all, Add(vs, other)); + } +}; + +struct TestSignBitFloat { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vs = SignBit(d); + const auto vp = Set(d, 2.25); + const auto vn = Set(d, -2.25); + HWY_ASSERT_VEC_EQ(d, Or(vp, vs), vn); + HWY_ASSERT_VEC_EQ(d, AndNot(vs, vn), vp); + HWY_ASSERT_VEC_EQ(d, v0, vs); + } +}; + +HWY_NOINLINE void TestAllSignBit() { + ForIntegerTypes(ForPartialVectors<TestSignBitInteger>()); + ForFloatTypes(ForPartialVectors<TestSignBitFloat>()); +} + +// inline to work around incorrect SVE codegen (only first 128 bits used). +template <class D, class V> +HWY_INLINE void AssertNaN(D d, VecArg<V> v, const char* file, int line) { + using T = TFromD<D>; + const size_t N = Lanes(d); + if (!AllTrue(d, IsNaN(v))) { + Print(d, "not all NaN", v, 0, N); + Print(d, "mask", VecFromMask(d, IsNaN(v)), 0, N); + const std::string type_name = TypeName(T(), N); + // RVV lacks PRIu64 and MSYS still has problems with %zu, so print bytes to + // avoid truncating doubles. + uint8_t bytes[HWY_MAX(sizeof(T), 8)] = {0}; + const T lane = GetLane(v); + CopyBytes<sizeof(T)>(&lane, bytes); + Abort(file, line, + "Expected %s NaN, got %E (bytes %02x %02x %02x %02x %02x %02x %02x " + "%02x)", + type_name.c_str(), lane, bytes[0], bytes[1], bytes[2], bytes[3], + bytes[4], bytes[5], bytes[6], bytes[7]); + } +} + +#define HWY_ASSERT_NAN(d, v) AssertNaN(d, v, __FILE__, __LINE__) + +struct TestNaN { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Vec<D> v1 = Set(d, static_cast<T>(Unpredictable1())); + const Vec<D> nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + HWY_ASSERT_NAN(d, nan); + + // Arithmetic + HWY_ASSERT_NAN(d, Add(nan, v1)); + HWY_ASSERT_NAN(d, Add(v1, nan)); + HWY_ASSERT_NAN(d, Sub(nan, v1)); + HWY_ASSERT_NAN(d, Sub(v1, nan)); + HWY_ASSERT_NAN(d, Mul(nan, v1)); + HWY_ASSERT_NAN(d, Mul(v1, nan)); + HWY_ASSERT_NAN(d, Div(nan, v1)); + HWY_ASSERT_NAN(d, Div(v1, nan)); + + // FMA + HWY_ASSERT_NAN(d, MulAdd(nan, v1, v1)); + HWY_ASSERT_NAN(d, MulAdd(v1, nan, v1)); + HWY_ASSERT_NAN(d, MulAdd(v1, v1, nan)); + HWY_ASSERT_NAN(d, MulSub(nan, v1, v1)); + HWY_ASSERT_NAN(d, MulSub(v1, nan, v1)); + HWY_ASSERT_NAN(d, MulSub(v1, v1, nan)); + HWY_ASSERT_NAN(d, NegMulAdd(nan, v1, v1)); + HWY_ASSERT_NAN(d, NegMulAdd(v1, nan, v1)); + HWY_ASSERT_NAN(d, NegMulAdd(v1, v1, nan)); + HWY_ASSERT_NAN(d, NegMulSub(nan, v1, v1)); + HWY_ASSERT_NAN(d, NegMulSub(v1, nan, v1)); + HWY_ASSERT_NAN(d, NegMulSub(v1, v1, nan)); + + // Rcp/Sqrt + HWY_ASSERT_NAN(d, Sqrt(nan)); + + // Sign manipulation + HWY_ASSERT_NAN(d, Abs(nan)); + HWY_ASSERT_NAN(d, Neg(nan)); + HWY_ASSERT_NAN(d, CopySign(nan, v1)); + HWY_ASSERT_NAN(d, CopySignToAbs(nan, v1)); + + // Rounding + HWY_ASSERT_NAN(d, Ceil(nan)); + HWY_ASSERT_NAN(d, Floor(nan)); + HWY_ASSERT_NAN(d, Round(nan)); + HWY_ASSERT_NAN(d, Trunc(nan)); + + // Logical (And/AndNot/Xor will clear NaN!) + HWY_ASSERT_NAN(d, Or(nan, v1)); + + // Comparison + HWY_ASSERT(AllFalse(d, Eq(nan, v1))); + HWY_ASSERT(AllFalse(d, Gt(nan, v1))); + HWY_ASSERT(AllFalse(d, Lt(nan, v1))); + HWY_ASSERT(AllFalse(d, Ge(nan, v1))); + HWY_ASSERT(AllFalse(d, Le(nan, v1))); + + // Reduction + HWY_ASSERT_NAN(d, SumOfLanes(d, nan)); +// TODO(janwas): re-enable after QEMU/Spike are fixed +#if HWY_TARGET != HWY_RVV + HWY_ASSERT_NAN(d, MinOfLanes(d, nan)); + HWY_ASSERT_NAN(d, MaxOfLanes(d, nan)); +#endif + + // Min/Max +#if (HWY_ARCH_X86 || HWY_ARCH_WASM) && (HWY_TARGET < HWY_EMU128) + // Native WASM or x86 SIMD return the second operand if any input is NaN. + HWY_ASSERT_VEC_EQ(d, v1, Min(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Max(nan, v1)); + HWY_ASSERT_NAN(d, Min(v1, nan)); + HWY_ASSERT_NAN(d, Max(v1, nan)); +#elif HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7 + // ARMv7 NEON returns NaN if any input is NaN. + HWY_ASSERT_NAN(d, Min(v1, nan)); + HWY_ASSERT_NAN(d, Max(v1, nan)); + HWY_ASSERT_NAN(d, Min(nan, v1)); + HWY_ASSERT_NAN(d, Max(nan, v1)); +#else + // IEEE 754-2019 minimumNumber is defined as the other argument if exactly + // one is NaN, and qNaN if both are. + HWY_ASSERT_VEC_EQ(d, v1, Min(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Max(nan, v1)); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, nan)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, nan)); +#endif + HWY_ASSERT_NAN(d, Min(nan, nan)); + HWY_ASSERT_NAN(d, Max(nan, nan)); + } +}; + +// For functions only available for float32 +struct TestF32NaN { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + HWY_ASSERT_NAN(d, ApproximateReciprocal(nan)); + HWY_ASSERT_NAN(d, ApproximateReciprocalSqrt(nan)); + HWY_ASSERT_NAN(d, AbsDiff(nan, v1)); + HWY_ASSERT_NAN(d, AbsDiff(v1, nan)); + } +}; + +HWY_NOINLINE void TestAllNaN() { + ForFloatTypes(ForPartialVectors<TestNaN>()); + ForPartialVectors<TestF32NaN>()(float()); +} + +struct TestIsNaN { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto inf = IfThenElse(Eq(v1, Set(d, T(1))), Inf(d), v1); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + const auto neg = Set(d, T{-1}); + HWY_ASSERT_NAN(d, nan); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(inf)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(CopySign(inf, neg))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsNaN(nan)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsNaN(CopySign(nan, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(v1)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(Zero(d))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(Set(d, hwy::LowestValue<T>()))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsNaN(Set(d, hwy::HighestValue<T>()))); + } +}; + +HWY_NOINLINE void TestAllIsNaN() { + ForFloatTypes(ForPartialVectors<TestIsNaN>()); +} + +struct TestIsInf { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto inf = IfThenElse(Eq(v1, Set(d, T(1))), Inf(d), v1); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + const auto neg = Set(d, T{-1}); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsInf(inf)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsInf(CopySign(inf, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(nan)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(CopySign(nan, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(v1)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(Zero(d))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(Set(d, hwy::LowestValue<T>()))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsInf(Set(d, hwy::HighestValue<T>()))); + } +}; + +HWY_NOINLINE void TestAllIsInf() { + ForFloatTypes(ForPartialVectors<TestIsInf>()); +} + +struct TestIsFinite { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Set(d, T(Unpredictable1())); + const auto inf = IfThenElse(Eq(v1, Set(d, T(1))), Inf(d), v1); + const auto nan = IfThenElse(Eq(v1, Set(d, T(1))), NaN(d), v1); + const auto neg = Set(d, T{-1}); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(inf)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(CopySign(inf, neg))); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(nan)); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), IsFinite(CopySign(nan, neg))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsFinite(v1)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsFinite(Zero(d))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), IsFinite(Set(d, hwy::LowestValue<T>()))); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), + IsFinite(Set(d, hwy::HighestValue<T>()))); + } +}; + +HWY_NOINLINE void TestAllIsFinite() { + ForFloatTypes(ForPartialVectors<TestIsFinite>()); +} + +struct TestCopyAndAssign { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // copy V + const auto v3 = Iota(d, 3); + auto v3b(v3); + HWY_ASSERT_VEC_EQ(d, v3, v3b); + + // assign V + auto v3c = Undefined(d); + v3c = v3; + HWY_ASSERT_VEC_EQ(d, v3, v3c); + } +}; + +HWY_NOINLINE void TestAllCopyAndAssign() { + ForAllTypes(ForPartialVectors<TestCopyAndAssign>()); +} + +struct TestGetLane { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + HWY_ASSERT_EQ(T(0), GetLane(Zero(d))); + HWY_ASSERT_EQ(T(1), GetLane(Set(d, 1))); + } +}; + +HWY_NOINLINE void TestAllGetLane() { + ForAllTypes(ForPartialVectors<TestGetLane>()); +} + +struct TestDFromV { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + using D0 = DFromV<decltype(v0)>; // not necessarily same as D + const auto v0b = And(v0, Set(D0(), 1)); // but vectors can interoperate + HWY_ASSERT_VEC_EQ(d, v0, v0b); + } +}; + +HWY_NOINLINE void TestAllDFromV() { + ForAllTypes(ForPartialVectors<TestDFromV>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HighwayTest); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllCapped); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllMaxLanes); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSet); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllOverflow); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllClamp); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllSignBit); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllNaN); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllIsNaN); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllIsInf); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllIsFinite); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllCopyAndAssign); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllGetLane); +HWY_EXPORT_AND_TEST_P(HighwayTest, TestAllDFromV); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/hwy.version b/third_party/highway/hwy/hwy.version new file mode 100644 index 0000000000..9ff6be6a2d --- /dev/null +++ b/third_party/highway/hwy/hwy.version @@ -0,0 +1,19 @@ +HWY_0 { + global: + extern "C++" { + *hwy::*; + }; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/third_party/highway/hwy/nanobenchmark.cc b/third_party/highway/hwy/nanobenchmark.cc new file mode 100644 index 0000000000..b4dae93443 --- /dev/null +++ b/third_party/highway/hwy/nanobenchmark.cc @@ -0,0 +1,763 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stddef.h> +#include <stdio.h> +#include <stdlib.h> +#include <time.h> // clock_gettime + +#include <algorithm> // std::sort, std::find_if +#include <array> +#include <atomic> +#include <chrono> //NOLINT +#include <limits> +#include <numeric> // std::iota +#include <random> +#include <string> +#include <utility> // std::pair +#include <vector> + +#if defined(_WIN32) || defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#include <windows.h> +#endif + +#if defined(__APPLE__) +#include <mach/mach.h> +#include <mach/mach_time.h> +#endif + +#if defined(__HAIKU__) +#include <OS.h> +#endif + +#include "hwy/base.h" +#if HWY_ARCH_PPC && defined(__GLIBC__) +#include <sys/platform/ppc.h> // NOLINT __ppc_get_timebase_freq +#elif HWY_ARCH_X86 + +#if HWY_COMPILER_MSVC +#include <intrin.h> +#else +#include <cpuid.h> // NOLINT +#endif // HWY_COMPILER_MSVC + +#endif // HWY_ARCH_X86 + +namespace hwy { +namespace { +namespace timer { + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. +using Ticks = uint64_t; + +// Start/Stop return absolute timestamps and must be placed immediately before +// and after the region to measure. We provide separate Start/Stop functions +// because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final RDTSCP would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional RDTSCP half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than RDTSCP because it does not read TSC_AUX. In summary, +// we define Start = LFENCE/RDTSC/LFENCE; Stop = RDTSCP/LFENCE. +// +// Using Start+Start leads to higher variance and overhead than Stop+Stop. +// However, Stop+Stop includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Start+Stop +// is faster than Start+Start and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +inline Ticks Start() { + Ticks t; +#if HWY_ARCH_PPC && defined(__GLIBC__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = __rdtsc(); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#elif HWY_ARCH_RVV + asm volatile("rdtime %0" : "=r"(t)); +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + t = counter.QuadPart; +#elif defined(__APPLE__) + t = mach_absolute_time(); +#elif defined(__HAIKU__) + t = system_time_nsecs(); // since boot +#else // POSIX + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = static_cast<Ticks>(ts.tv_sec * 1000000000LL + ts.tv_nsec); +#endif + return t; +} + +// WARNING: on x86, caller must check HasRDTSCP before using this! +inline Ticks Stop() { + uint64_t t; +#if HWY_ARCH_PPC && defined(__GLIBC__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + unsigned aux; + t = __rdtscp(&aux); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = Start(); +#endif + return t; +} + +} // namespace timer + +namespace robust_statistics { + +// Sorts integral values in ascending order (e.g. for Mode). About 3x faster +// than std::sort for input distributions with very few unique values. +template <class T> +void CountingSort(T* values, size_t num_values) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair<T, int>; + std::vector<Unique> unique; + for (size_t i = 0; i < num_values; ++i) { + const T value = values[i]; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(value, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* HWY_RESTRICT p = values; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + NANOBENCHMARK_CHECK(p == values + num_values); +} + +// @return i in [idx_begin, idx_begin + half_count) that minimizes +// sorted[i + half_count] - sorted[i]. +template <typename T> +size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin, + const size_t half_count) { + T min_range = std::numeric_limits<T>::max(); + size_t min_idx = 0; + + for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { + NANOBENCHMARK_CHECK(sorted[idx] <= sorted[idx + half_count]); + const T range = sorted[idx + half_count] - sorted[idx]; + if (range < min_range) { + min_range = range; + min_idx = idx; + } + } + + return min_idx; +} + +// Returns an estimate of the mode by calling MinRange on successively +// halved intervals. "sorted" must be in ascending order. This is the +// Half Sample Mode estimator proposed by Bickel in "On a fast, robust +// estimator of the mode", with complexity O(N log N). The mode is less +// affected by outliers in highly-skewed distributions than the median. +// The averaging operation below assumes "T" is an unsigned integer type. +template <typename T> +T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) { + size_t idx_begin = 0; + size_t half_count = num_values / 2; + while (half_count > 1) { + idx_begin = MinRange(sorted, idx_begin, half_count); + half_count >>= 1; + } + + const T x = sorted[idx_begin + 0]; + if (half_count == 0) { + return x; + } + NANOBENCHMARK_CHECK(half_count == 1); + const T average = (x + sorted[idx_begin + 1] + 1) / 2; + return average; +} + +// Returns the mode. Side effect: sorts "values". +template <typename T> +T Mode(T* values, const size_t num_values) { + CountingSort(values, num_values); + return ModeOfSorted(values, num_values); +} + +template <typename T, size_t N> +T Mode(T (&values)[N]) { + return Mode(&values[0], N); +} + +// Returns the median value. Side effect: sorts "values". +template <typename T> +T Median(T* values, const size_t num_values) { + NANOBENCHMARK_CHECK(!values->empty()); + std::sort(values, values + num_values); + const size_t half = num_values / 2; + // Odd count: return middle + if (num_values % 2) { + return values[half]; + } + // Even count: return average of middle two. + return (values[half] + values[half - 1] + 1) / 2; +} + +// Returns a robust measure of variability. +template <typename T> +T MedianAbsoluteDeviation(const T* values, const size_t num_values, + const T median) { + NANOBENCHMARK_CHECK(num_values != 0); + std::vector<T> abs_deviations; + abs_deviations.reserve(num_values); + for (size_t i = 0; i < num_values; ++i) { + const int64_t abs = std::abs(static_cast<int64_t>(values[i]) - + static_cast<int64_t>(median)); + abs_deviations.push_back(static_cast<T>(abs)); + } + return Median(abs_deviations.data(), num_values); +} + +} // namespace robust_statistics +} // namespace +namespace platform { +namespace { + +// Prevents the compiler from eliding the computations that led to "output". +template <class T> +inline void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC == 0 + // Works by indicating to the compiler that "output" is being read and + // modified. The +r constraint avoids unnecessary writes to memory, but only + // works for built-in types (typically FuncOutput). + asm volatile("" : "+r"(output) : : "memory"); +#else + // MSVC does not support inline assembly anymore (and never supported GCC's + // RTL constraints). Self-assignment with #pragma optimize("off") might be + // expected to prevent elision, but it does not with MSVC 2015. Type-punning + // with volatile pointers generates inefficient code on MSVC 2017. + static std::atomic<T> dummy(T{}); + dummy.store(output, std::memory_order_relaxed); +#endif +} + +// Measures the actual current frequency of Ticks. We cannot rely on the nominal +// frequency encoded in x86 BrandString because it is misleading on M1 Rosetta, +// and not reported by AMD. CPUID 0x15 is also not yet widely supported. Also +// used on RISC-V and ARM64. +HWY_MAYBE_UNUSED double MeasureNominalClockRate() { + double max_ticks_per_sec = 0.0; + // Arbitrary, enough to ignore 2 outliers without excessive init time. + for (int rep = 0; rep < 3; ++rep) { + auto time0 = std::chrono::steady_clock::now(); + using Time = decltype(time0); + const timer::Ticks ticks0 = timer::Start(); + const Time time_min = time0 + std::chrono::milliseconds(10); + + Time time1; + timer::Ticks ticks1; + for (;;) { + time1 = std::chrono::steady_clock::now(); + // Ideally this would be Stop, but that requires RDTSCP on x86. To avoid + // another codepath, just use Start instead. now() presumably has its own + // fence-like behavior. + ticks1 = timer::Start(); // Do not use Stop, see comment above + if (time1 >= time_min) break; + } + + const double dticks = static_cast<double>(ticks1 - ticks0); + std::chrono::duration<double, std::ratio<1>> dtime = time1 - time0; + const double ticks_per_sec = dticks / dtime.count(); + max_ticks_per_sec = std::max(max_ticks_per_sec, ticks_per_sec); + } + return max_ticks_per_sec; +} + +#if HWY_ARCH_X86 + +void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif +} + +bool HasRDTSCP() { + uint32_t abcd[4]; + Cpuid(0x80000001U, 0, abcd); // Extended feature flags + return (abcd[3] & (1u << 27)) != 0; // RDTSCP +} + +std::string BrandString() { + char brand_string[49]; + std::array<uint32_t, 4> abcd; + + // Check if brand string is supported (it is on all reasonable Intel/AMD) + Cpuid(0x80000000U, 0, abcd.data()); + if (abcd[0] < 0x80000004U) { + return std::string(); + } + + for (size_t i = 0; i < 3; ++i) { + Cpuid(static_cast<uint32_t>(0x80000002U + i), 0, abcd.data()); + CopyBytes<sizeof(abcd)>(&abcd[0], brand_string + i * 16); // not same size + } + brand_string[48] = 0; + return brand_string; +} + +#endif // HWY_ARCH_X86 + +} // namespace + +HWY_DLLEXPORT double InvariantTicksPerSecond() { +#if HWY_ARCH_PPC && defined(__GLIBC__) + return static_cast<double>(__ppc_get_timebase_freq()); +#elif HWY_ARCH_X86 || HWY_ARCH_RVV || (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) + // We assume the x86 TSC is invariant; it is on all recent Intel/AMD CPUs. + static const double freq = MeasureNominalClockRate(); + return freq; +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER freq; + (void)QueryPerformanceFrequency(&freq); + return static_cast<double>(freq.QuadPart); +#elif defined(__APPLE__) + // https://developer.apple.com/library/mac/qa/qa1398/_index.html + mach_timebase_info_data_t timebase; + (void)mach_timebase_info(&timebase); + return static_cast<double>(timebase.denom) / timebase.numer * 1E9; +#else + return 1E9; // Haiku and clock_gettime return nanoseconds. +#endif +} + +HWY_DLLEXPORT double Now() { + static const double mul = 1.0 / InvariantTicksPerSecond(); + return static_cast<double>(timer::Start()) * mul; +} + +HWY_DLLEXPORT uint64_t TimerResolution() { +#if HWY_ARCH_X86 + bool can_use_stop = platform::HasRDTSCP(); +#else + constexpr bool can_use_stop = true; +#endif + + // Nested loop avoids exceeding stack/L1 capacity. + timer::Ticks repetitions[Params::kTimerSamples]; + for (size_t rep = 0; rep < Params::kTimerSamples; ++rep) { + timer::Ticks samples[Params::kTimerSamples]; + if (can_use_stop) { + for (size_t i = 0; i < Params::kTimerSamples; ++i) { + const timer::Ticks t0 = timer::Start(); + const timer::Ticks t1 = timer::Stop(); // we checked HasRDTSCP above + samples[i] = t1 - t0; + } + } else { + for (size_t i = 0; i < Params::kTimerSamples; ++i) { + const timer::Ticks t0 = timer::Start(); + const timer::Ticks t1 = timer::Start(); // do not use Stop, see above + samples[i] = t1 - t0; + } + } + repetitions[rep] = robust_statistics::Mode(samples); + } + return robust_statistics::Mode(repetitions); +} + +} // namespace platform +namespace { + +static const timer::Ticks timer_resolution = platform::TimerResolution(); + +// Estimates the expected value of "lambda" values with a variable number of +// samples until the variability "rel_mad" is less than "max_rel_mad". +template <class Lambda> +timer::Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad, + const Params& p, const Lambda& lambda) { + // Choose initial samples_per_eval based on a single estimated duration. + timer::Ticks t0 = timer::Start(); + lambda(); + timer::Ticks t1 = timer::Stop(); // Caller checks HasRDTSCP + timer::Ticks est = t1 - t0; + static const double ticks_per_second = platform::InvariantTicksPerSecond(); + const size_t ticks_per_eval = + static_cast<size_t>(ticks_per_second * p.seconds_per_eval); + size_t samples_per_eval = est == 0 + ? p.min_samples_per_eval + : static_cast<size_t>(ticks_per_eval / est); + samples_per_eval = HWY_MAX(samples_per_eval, p.min_samples_per_eval); + + std::vector<timer::Ticks> samples; + samples.reserve(1 + samples_per_eval); + samples.push_back(est); + + // Percentage is too strict for tiny differences, so also allow a small + // absolute "median absolute deviation". + const timer::Ticks max_abs_mad = (timer_resolution + 99) / 100; + *rel_mad = 0.0; // ensure initialized + + for (size_t eval = 0; eval < p.max_evals; ++eval, samples_per_eval *= 2) { + samples.reserve(samples.size() + samples_per_eval); + for (size_t i = 0; i < samples_per_eval; ++i) { + t0 = timer::Start(); + lambda(); + t1 = timer::Stop(); // Caller checks HasRDTSCP + samples.push_back(t1 - t0); + } + + if (samples.size() >= p.min_mode_samples) { + est = robust_statistics::Mode(samples.data(), samples.size()); + } else { + // For "few" (depends also on the variance) samples, Median is safer. + est = robust_statistics::Median(samples.data(), samples.size()); + } + NANOBENCHMARK_CHECK(est != 0); + + // Median absolute deviation (mad) is a robust measure of 'variability'. + const timer::Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation( + samples.data(), samples.size(), est); + *rel_mad = static_cast<double>(abs_mad) / static_cast<double>(est); + + if (*rel_mad <= max_rel_mad || abs_mad <= max_abs_mad) { + if (p.verbose) { + printf("%6" PRIu64 " samples => %5" PRIu64 " (abs_mad=%4" PRIu64 + ", rel_mad=%4.2f%%)\n", + static_cast<uint64_t>(samples.size()), + static_cast<uint64_t>(est), static_cast<uint64_t>(abs_mad), + *rel_mad * 100.0); + } + return est; + } + } + + if (p.verbose) { + printf("WARNING: rel_mad=%4.2f%% still exceeds %4.2f%% after %6" PRIu64 + " samples.\n", + *rel_mad * 100.0, max_rel_mad * 100.0, + static_cast<uint64_t>(samples.size())); + } + return est; +} + +using InputVec = std::vector<FuncInput>; + +// Returns vector of unique input values. +InputVec UniqueInputs(const FuncInput* inputs, const size_t num_inputs) { + InputVec unique(inputs, inputs + num_inputs); + std::sort(unique.begin(), unique.end()); + unique.erase(std::unique(unique.begin(), unique.end()), unique.end()); + return unique; +} + +// Returns how often we need to call func for sufficient precision. +size_t NumSkip(const Func func, const uint8_t* arg, const InputVec& unique, + const Params& p) { + // Min elapsed ticks for any input. + timer::Ticks min_duration = ~timer::Ticks(0); + + for (const FuncInput input : unique) { + double rel_mad; + const timer::Ticks total = SampleUntilStable( + p.target_rel_mad, &rel_mad, p, + [func, arg, input]() { platform::PreventElision(func(arg, input)); }); + min_duration = HWY_MIN(min_duration, total - timer_resolution); + } + + // Number of repetitions required to reach the target resolution. + const size_t max_skip = p.precision_divisor; + // Number of repetitions given the estimated duration. + const size_t num_skip = + min_duration == 0 + ? 0 + : static_cast<size_t>((max_skip + min_duration - 1) / min_duration); + if (p.verbose) { + printf("res=%" PRIu64 " max_skip=%" PRIu64 " min_dur=%" PRIu64 + " num_skip=%" PRIu64 "\n", + static_cast<uint64_t>(timer_resolution), + static_cast<uint64_t>(max_skip), static_cast<uint64_t>(min_duration), + static_cast<uint64_t>(num_skip)); + } + return num_skip; +} + +// Replicates inputs until we can omit "num_skip" occurrences of an input. +InputVec ReplicateInputs(const FuncInput* inputs, const size_t num_inputs, + const size_t num_unique, const size_t num_skip, + const Params& p) { + InputVec full; + if (num_unique == 1) { + full.assign(p.subset_ratio * num_skip, inputs[0]); + return full; + } + + full.reserve(p.subset_ratio * num_skip * num_inputs); + for (size_t i = 0; i < p.subset_ratio * num_skip; ++i) { + full.insert(full.end(), inputs, inputs + num_inputs); + } + std::mt19937 rng; + std::shuffle(full.begin(), full.end(), rng); + return full; +} + +// Copies the "full" to "subset" in the same order, but with "num_skip" +// randomly selected occurrences of "input_to_skip" removed. +void FillSubset(const InputVec& full, const FuncInput input_to_skip, + const size_t num_skip, InputVec* subset) { + const size_t count = + static_cast<size_t>(std::count(full.begin(), full.end(), input_to_skip)); + // Generate num_skip random indices: which occurrence to skip. + std::vector<uint32_t> omit(count); + std::iota(omit.begin(), omit.end(), 0); + // omit[] is the same on every call, but that's OK because they identify the + // Nth instance of input_to_skip, so the position within full[] differs. + std::mt19937 rng; + std::shuffle(omit.begin(), omit.end(), rng); + omit.resize(num_skip); + std::sort(omit.begin(), omit.end()); + + uint32_t occurrence = ~0u; // 0 after preincrement + size_t idx_omit = 0; // cursor within omit[] + size_t idx_subset = 0; // cursor within *subset + for (const FuncInput next : full) { + if (next == input_to_skip) { + ++occurrence; + // Haven't removed enough already + if (idx_omit < num_skip) { + // This one is up for removal + if (occurrence == omit[idx_omit]) { + ++idx_omit; + continue; + } + } + } + if (idx_subset < subset->size()) { + (*subset)[idx_subset++] = next; + } + } + NANOBENCHMARK_CHECK(idx_subset == subset->size()); + NANOBENCHMARK_CHECK(idx_omit == omit.size()); + NANOBENCHMARK_CHECK(occurrence == count - 1); +} + +// Returns total ticks elapsed for all inputs. +timer::Ticks TotalDuration(const Func func, const uint8_t* arg, + const InputVec* inputs, const Params& p, + double* max_rel_mad) { + double rel_mad; + const timer::Ticks duration = + SampleUntilStable(p.target_rel_mad, &rel_mad, p, [func, arg, inputs]() { + for (const FuncInput input : *inputs) { + platform::PreventElision(func(arg, input)); + } + }); + *max_rel_mad = HWY_MAX(*max_rel_mad, rel_mad); + return duration; +} + +// (Nearly) empty Func for measuring timer overhead/resolution. +HWY_NOINLINE FuncOutput EmptyFunc(const void* /*arg*/, const FuncInput input) { + return input; +} + +// Returns overhead of accessing inputs[] and calling a function; this will +// be deducted from future TotalDuration return values. +timer::Ticks Overhead(const uint8_t* arg, const InputVec* inputs, + const Params& p) { + double rel_mad; + // Zero tolerance because repeatability is crucial and EmptyFunc is fast. + return SampleUntilStable(0.0, &rel_mad, p, [arg, inputs]() { + for (const FuncInput input : *inputs) { + platform::PreventElision(EmptyFunc(arg, input)); + } + }); +} + +} // namespace + +HWY_DLLEXPORT int Unpredictable1() { return timer::Start() != ~0ULL; } + +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p) { + NANOBENCHMARK_CHECK(num_inputs != 0); + +#if HWY_ARCH_X86 + if (!platform::HasRDTSCP()) { + fprintf(stderr, "CPU '%s' does not support RDTSCP, skipping benchmark.\n", + platform::BrandString().c_str()); + return 0; + } +#endif + + const InputVec& unique = UniqueInputs(inputs, num_inputs); + + const size_t num_skip = NumSkip(func, arg, unique, p); // never 0 + if (num_skip == 0) return 0; // NumSkip already printed error message + // (slightly less work on x86 to cast from signed integer) + const float mul = 1.0f / static_cast<float>(static_cast<int>(num_skip)); + + const InputVec& full = + ReplicateInputs(inputs, num_inputs, unique.size(), num_skip, p); + InputVec subset(full.size() - num_skip); + + const timer::Ticks overhead = Overhead(arg, &full, p); + const timer::Ticks overhead_skip = Overhead(arg, &subset, p); + if (overhead < overhead_skip) { + fprintf(stderr, "Measurement failed: overhead %" PRIu64 " < %" PRIu64 "\n", + static_cast<uint64_t>(overhead), + static_cast<uint64_t>(overhead_skip)); + return 0; + } + + if (p.verbose) { + printf("#inputs=%5" PRIu64 ",%5" PRIu64 " overhead=%5" PRIu64 ",%5" PRIu64 + "\n", + static_cast<uint64_t>(full.size()), + static_cast<uint64_t>(subset.size()), + static_cast<uint64_t>(overhead), + static_cast<uint64_t>(overhead_skip)); + } + + double max_rel_mad = 0.0; + const timer::Ticks total = TotalDuration(func, arg, &full, p, &max_rel_mad); + + for (size_t i = 0; i < unique.size(); ++i) { + FillSubset(full, unique[i], num_skip, &subset); + const timer::Ticks total_skip = + TotalDuration(func, arg, &subset, p, &max_rel_mad); + + if (total < total_skip) { + fprintf(stderr, "Measurement failed: total %" PRIu64 " < %" PRIu64 "\n", + static_cast<uint64_t>(total), static_cast<uint64_t>(total_skip)); + return 0; + } + + const timer::Ticks duration = + (total - overhead) - (total_skip - overhead_skip); + results[i].input = unique[i]; + results[i].ticks = static_cast<float>(duration) * mul; + results[i].variability = static_cast<float>(max_rel_mad); + } + + return unique.size(); +} + +} // namespace hwy diff --git a/third_party/highway/hwy/nanobenchmark.h b/third_party/highway/hwy/nanobenchmark.h new file mode 100644 index 0000000000..f0910b4b94 --- /dev/null +++ b/third_party/highway/hwy/nanobenchmark.h @@ -0,0 +1,194 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_NANOBENCHMARK_H_ +#define HIGHWAY_HWY_NANOBENCHMARK_H_ + +// Benchmarks functions of a single integer argument with realistic branch +// prediction hit rates. Uses a robust estimator to summarize the measurements. +// The precision is about 0.2%. +// +// Examples: see nanobenchmark_test.cc. +// +// Background: Microbenchmarks such as http://github.com/google/benchmark +// can measure elapsed times on the order of a microsecond. Shorter functions +// are typically measured by repeating them thousands of times and dividing +// the total elapsed time by this count. Unfortunately, repetition (especially +// with the same input parameter!) influences the runtime. In time-critical +// code, it is reasonable to expect warm instruction/data caches and TLBs, +// but a perfect record of which branches will be taken is unrealistic. +// Unless the application also repeatedly invokes the measured function with +// the same parameter, the benchmark is measuring something very different - +// a best-case result, almost as if the parameter were made a compile-time +// constant. This may lead to erroneous conclusions about branch-heavy +// algorithms outperforming branch-free alternatives. +// +// Our approach differs in three ways. Adding fences to the timer functions +// reduces variability due to instruction reordering, improving the timer +// resolution to about 40 CPU cycles. However, shorter functions must still +// be invoked repeatedly. For more realistic branch prediction performance, +// we vary the input parameter according to a user-specified distribution. +// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the +// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the +// central tendency of the measurement samples with the "half sample mode", +// which is more robust to outliers and skewed data than the mean or median. + +#include <stddef.h> +#include <stdint.h> + +#include "hwy/highway_export.h" + +// Enables sanity checks that verify correct operation at the cost of +// longer benchmark runs. +#ifndef NANOBENCHMARK_ENABLE_CHECKS +#define NANOBENCHMARK_ENABLE_CHECKS 0 +#endif + +#define NANOBENCHMARK_CHECK_ALWAYS(condition) \ + while (!(condition)) { \ + fprintf(stderr, "Nanobenchmark check failed at line %d\n", __LINE__); \ + abort(); \ + } + +#if NANOBENCHMARK_ENABLE_CHECKS +#define NANOBENCHMARK_CHECK(condition) NANOBENCHMARK_CHECK_ALWAYS(condition) +#else +#define NANOBENCHMARK_CHECK(condition) +#endif + +namespace hwy { + +namespace platform { + +// Returns tick rate, useful for converting measurements to seconds. Invariant +// means the tick counter frequency is independent of CPU throttling or sleep. +// This call may be expensive, callers should cache the result. +HWY_DLLEXPORT double InvariantTicksPerSecond(); + +// Returns current timestamp [in seconds] relative to an unspecified origin. +// Features: monotonic (no negative elapsed time), steady (unaffected by system +// time changes), high-resolution (on the order of microseconds). +HWY_DLLEXPORT double Now(); + +// Returns ticks elapsed in back to back timer calls, i.e. a function of the +// timer resolution (minimum measurable difference) and overhead. +// This call is expensive, callers should cache the result. +HWY_DLLEXPORT uint64_t TimerResolution(); + +} // namespace platform + +// Returns 1, but without the compiler knowing what the value is. This prevents +// optimizing out code. +HWY_DLLEXPORT int Unpredictable1(); + +// Input influencing the function being measured (e.g. number of bytes to copy). +using FuncInput = size_t; + +// "Proof of work" returned by Func to ensure the compiler does not elide it. +using FuncOutput = uint64_t; + +// Function to measure: either 1) a captureless lambda or function with two +// arguments or 2) a lambda with capture, in which case the first argument +// is reserved for use by MeasureClosure. +using Func = FuncOutput (*)(const void*, FuncInput); + +// Internal parameters that determine precision/resolution/measuring time. +struct Params { + // For measuring timer overhead/resolution. Used in a nested loop => + // quadratic time, acceptable because we know timer overhead is "low". + // constexpr because this is used to define array bounds. + static constexpr size_t kTimerSamples = 256; + + // Best-case precision, expressed as a divisor of the timer resolution. + // Larger => more calls to Func and higher precision. + size_t precision_divisor = 1024; + + // Ratio between full and subset input distribution sizes. Cannot be less + // than 2; larger values increase measurement time but more faithfully + // model the given input distribution. + size_t subset_ratio = 2; + + // Together with the estimated Func duration, determines how many times to + // call Func before checking the sample variability. Larger values increase + // measurement time, memory/cache use and precision. + double seconds_per_eval = 4E-3; + + // The minimum number of samples before estimating the central tendency. + size_t min_samples_per_eval = 7; + + // The mode is better than median for estimating the central tendency of + // skewed/fat-tailed distributions, but it requires sufficient samples + // relative to the width of half-ranges. + size_t min_mode_samples = 64; + + // Maximum permissible variability (= median absolute deviation / center). + double target_rel_mad = 0.002; + + // Abort after this many evals without reaching target_rel_mad. This + // prevents infinite loops. + size_t max_evals = 9; + + // Whether to print additional statistics to stdout. + bool verbose = true; +}; + +// Measurement result for each unique input. +struct Result { + FuncInput input; + + // Robust estimate (mode or median) of duration. + float ticks; + + // Measure of variability (median absolute deviation relative to "ticks"). + float variability; +}; + +// Precisely measures the number of ticks elapsed when calling "func" with the +// given inputs, shuffled to ensure realistic branch prediction hit rates. +// +// "func" returns a 'proof of work' to ensure its computations are not elided. +// "arg" is passed to Func, or reserved for internal use by MeasureClosure. +// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to +// "func". The values should be chosen to maximize coverage of "func". This +// represents a distribution, so a value's frequency should reflect its +// probability in the real application. Order does not matter; for example, a +// uniform distribution over [0, 4) could be represented as {3,0,2,1}. +// Returns how many Result were written to "results": one per unique input, or +// zero if the measurement failed (an error message goes to stderr). +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p = Params()); + +// Calls operator() of the given closure (lambda function). +template <class Closure> +static FuncOutput CallClosure(const Closure* f, const FuncInput input) { + return (*f)(input); +} + +// Same as Measure, except "closure" is typically a lambda function of +// FuncInput -> FuncOutput with a capture list. +template <class Closure> +static inline size_t MeasureClosure(const Closure& closure, + const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()) { + return Measure(reinterpret_cast<Func>(&CallClosure<Closure>), + reinterpret_cast<const uint8_t*>(&closure), inputs, num_inputs, + results, p); +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_NANOBENCHMARK_H_ diff --git a/third_party/highway/hwy/nanobenchmark_test.cc b/third_party/highway/hwy/nanobenchmark_test.cc new file mode 100644 index 0000000000..0d153a14c5 --- /dev/null +++ b/third_party/highway/hwy/nanobenchmark_test.cc @@ -0,0 +1,94 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stdint.h> +#include <stdio.h> + +#include <random> + +#include "hwy/tests/test_util-inl.h" + +namespace hwy { +namespace { + +// Governs duration of test; avoid timeout in debug builds. +#if HWY_IS_DEBUG_BUILD +constexpr size_t kMaxEvals = 3; +#else +constexpr size_t kMaxEvals = 4; +#endif + +FuncOutput Div(const void*, FuncInput in) { + // Here we're measuring the throughput because benchmark invocations are + // independent. Any dividend will do; the divisor is nonzero. + return 0xFFFFF / in; +} + +template <size_t N> +void MeasureDiv(const FuncInput (&inputs)[N]) { + printf("Measuring integer division (output on final two lines)\n"); + Result results[N]; + Params params; + params.max_evals = kMaxEvals; + const size_t num_results = Measure(&Div, nullptr, inputs, N, results, params); + for (size_t i = 0; i < num_results; ++i) { + printf("%5" PRIu64 ": %6.2f ticks; MAD=%4.2f%%\n", + static_cast<uint64_t>(results[i].input), results[i].ticks, + results[i].variability * 100.0); + } +} + +std::mt19937 rng; + +// A function whose runtime depends on rng. +FuncOutput Random(const void* /*arg*/, FuncInput in) { + const size_t r = rng() & 0xF; + FuncOutput ret = static_cast<FuncOutput>(in); + for (size_t i = 0; i < r; ++i) { + ret /= ((rng() & 1) + 2); + } + return ret; +} + +// Ensure the measured variability is high. +template <size_t N> +void MeasureRandom(const FuncInput (&inputs)[N]) { + Result results[N]; + Params p; + p.max_evals = kMaxEvals; + p.verbose = false; + const size_t num_results = Measure(&Random, nullptr, inputs, N, results, p); + for (size_t i = 0; i < num_results; ++i) { + NANOBENCHMARK_CHECK(results[i].variability > 1E-3); + } +} + +TEST(NanobenchmarkTest, RunAll) { + const int unpredictable = Unpredictable1(); // == 1, unknown to compiler. + static const FuncInput inputs[] = {static_cast<FuncInput>(unpredictable) + 2, + static_cast<FuncInput>(unpredictable + 9)}; + + MeasureDiv(inputs); + MeasureRandom(inputs); +} + +} // namespace +} // namespace hwy diff --git a/third_party/highway/hwy/ops/arm_neon-inl.h b/third_party/highway/hwy/ops/arm_neon-inl.h new file mode 100644 index 0000000000..7c3759aa3d --- /dev/null +++ b/third_party/highway/hwy/ops/arm_neon-inl.h @@ -0,0 +1,6810 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit ARM64 NEON vectors and operations. +// External include guard in highway.h - see comment there. + +// ARM NEON intrinsics are documented at: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] + +#include <stddef.h> +#include <stdint.h> + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); + +// Must come after HWY_BEFORE_NAMESPACE so that the intrinsics are compiled with +// the same target attribute as our code, see #834. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#include <arm_neon.h> // NOLINT(build/include_order) +HWY_DIAGNOSTICS(pop) + +// Must come after arm_neon.h. +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { // for code folding and Raw128 + +// Macros used to define single and double function calls for multiple types +// for full and half vectors. These macros are undefined at the end of the file. + +// HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. +#define HWY_NEON_BUILD_TPL_1 +#define HWY_NEON_BUILD_TPL_2 +#define HWY_NEON_BUILD_TPL_3 + +// HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can +// extend it to int32x4x2_t packs. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128<type##_t, size> +#define HWY_NEON_BUILD_RET_2(type, size) Vec128<type##_t, size> +#define HWY_NEON_BUILD_RET_3(type, size) Vec128<type##_t, size> + +// HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128<type##_t, size> a +#define HWY_NEON_BUILD_PARAM_2(type, size) \ + const Vec128<type##_t, size> a, const Vec128<type##_t, size> b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128<type##_t, size> a, const Vec128<type##_t, size> b, \ + const Vec128<type##_t, size> c + +// HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying +// function. +#define HWY_NEON_BUILD_ARG_1 a.raw +#define HWY_NEON_BUILD_ARG_2 a.raw, b.raw +#define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw + +// We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after +// the __VA_ARGS__ have been expanded. This allows "func" to be a macro on +// itself like with some of the library "functions" such as vshlq_u8. For +// example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as +// "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. +// Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro +// expects two arguments. +#define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) + +// Main macro definition that defines a single function for the given type and +// size of vector, using the underlying (prefix##infix##suffix) function and +// the template, return type, parameters and arguments defined by the "args" +// parameters passed here (see HWY_NEON_BUILD_* macros defined before). +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ + HWY_API HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ + name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ + return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ + HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ + } + +// The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function +// called "name" using the set of neon functions starting with the given +// "prefix" for all the variants of certain types, as specified next to each +// macro. For example, the prefix "vsub" can be used to define the operator- +// using args=2. + +// uint8_t +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) + +// int8_t +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) + +// uint16_t +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) + +// int16_t +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) + +// uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) + +// int32_t +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) + +// uint64_t +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) + +// int64_t +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) + +// float +#define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) + +// double +#if HWY_ARCH_ARM_A64 +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) +#endif + +// float and double + +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// Helper macros to define for more than one type. +// uint8_t, uint16_t and uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) + +// int8_t, int16_t and int32_t +#define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) + +// uint8_t, uint16_t, uint32_t and uint64_t +#define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) + +// int8_t, int16_t, int32_t and int64_t +#define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) + +// All int*_t and uint*_t up to 64 +#define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) + +// All previous types. +#define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF81632(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) + +// For eor3q, which is only defined for full vectors. +#define HWY_NEON_DEF_FUNCTION_FULL_UI(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) + +// Emulation of some intrinsics on armv7. +#if HWY_ARCH_ARM_V7 +#define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] +#define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] +#define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] +#define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] +#define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] +#define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] +#define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] +#define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] +#define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] +#define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] +#define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] +#define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] +#define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] +#define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] +#define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] +#define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] +#define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] +#define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] +#define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] +#define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] +#define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] +#define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] +#define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] +#define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] +#define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] +#define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] +#define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] +#define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] +#define vzip1_s8(x, y) vzip_s8(x, y).val[0] +#define vzip1_u8(x, y) vzip_u8(x, y).val[0] +#define vzip1_s16(x, y) vzip_s16(x, y).val[0] +#define vzip1_u16(x, y) vzip_u16(x, y).val[0] +#define vzip1_f32(x, y) vzip_f32(x, y).val[0] +#define vzip1_u32(x, y) vzip_u32(x, y).val[0] +#define vzip1_s32(x, y) vzip_s32(x, y).val[0] +#define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] +#define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] +#define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] +#define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] +#define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] +#define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] +#define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] +#define vzip2_s8(x, y) vzip_s8(x, y).val[1] +#define vzip2_u8(x, y) vzip_u8(x, y).val[1] +#define vzip2_s16(x, y) vzip_s16(x, y).val[1] +#define vzip2_u16(x, y) vzip_u16(x, y).val[1] +#define vzip2_s32(x, y) vzip_s32(x, y).val[1] +#define vzip2_u32(x, y) vzip_u32(x, y).val[1] +#define vzip2_f32(x, y) vzip_f32(x, y).val[1] +#define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] +#define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] +#define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] +#define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] +#define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] +#define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] +#define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] +#endif + +// Wrappers over uint8x16x2_t etc. so we can define StoreInterleaved2 overloads +// for all vector types, even those (bfloat16_t) where the underlying vector is +// the same as others (uint16_t). +template <typename T, size_t N> +struct Tuple2; +template <typename T, size_t N> +struct Tuple3; +template <typename T, size_t N> +struct Tuple4; + +template <> +struct Tuple2<uint8_t, 16> { + uint8x16x2_t raw; +}; +template <size_t N> +struct Tuple2<uint8_t, N> { + uint8x8x2_t raw; +}; +template <> +struct Tuple2<int8_t, 16> { + int8x16x2_t raw; +}; +template <size_t N> +struct Tuple2<int8_t, N> { + int8x8x2_t raw; +}; +template <> +struct Tuple2<uint16_t, 8> { + uint16x8x2_t raw; +}; +template <size_t N> +struct Tuple2<uint16_t, N> { + uint16x4x2_t raw; +}; +template <> +struct Tuple2<int16_t, 8> { + int16x8x2_t raw; +}; +template <size_t N> +struct Tuple2<int16_t, N> { + int16x4x2_t raw; +}; +template <> +struct Tuple2<uint32_t, 4> { + uint32x4x2_t raw; +}; +template <size_t N> +struct Tuple2<uint32_t, N> { + uint32x2x2_t raw; +}; +template <> +struct Tuple2<int32_t, 4> { + int32x4x2_t raw; +}; +template <size_t N> +struct Tuple2<int32_t, N> { + int32x2x2_t raw; +}; +template <> +struct Tuple2<uint64_t, 2> { + uint64x2x2_t raw; +}; +template <size_t N> +struct Tuple2<uint64_t, N> { + uint64x1x2_t raw; +}; +template <> +struct Tuple2<int64_t, 2> { + int64x2x2_t raw; +}; +template <size_t N> +struct Tuple2<int64_t, N> { + int64x1x2_t raw; +}; + +template <> +struct Tuple2<float16_t, 8> { + uint16x8x2_t raw; +}; +template <size_t N> +struct Tuple2<float16_t, N> { + uint16x4x2_t raw; +}; +template <> +struct Tuple2<bfloat16_t, 8> { + uint16x8x2_t raw; +}; +template <size_t N> +struct Tuple2<bfloat16_t, N> { + uint16x4x2_t raw; +}; + +template <> +struct Tuple2<float32_t, 4> { + float32x4x2_t raw; +}; +template <size_t N> +struct Tuple2<float32_t, N> { + float32x2x2_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple2<float64_t, 2> { + float64x2x2_t raw; +}; +template <size_t N> +struct Tuple2<float64_t, N> { + float64x1x2_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <> +struct Tuple3<uint8_t, 16> { + uint8x16x3_t raw; +}; +template <size_t N> +struct Tuple3<uint8_t, N> { + uint8x8x3_t raw; +}; +template <> +struct Tuple3<int8_t, 16> { + int8x16x3_t raw; +}; +template <size_t N> +struct Tuple3<int8_t, N> { + int8x8x3_t raw; +}; +template <> +struct Tuple3<uint16_t, 8> { + uint16x8x3_t raw; +}; +template <size_t N> +struct Tuple3<uint16_t, N> { + uint16x4x3_t raw; +}; +template <> +struct Tuple3<int16_t, 8> { + int16x8x3_t raw; +}; +template <size_t N> +struct Tuple3<int16_t, N> { + int16x4x3_t raw; +}; +template <> +struct Tuple3<uint32_t, 4> { + uint32x4x3_t raw; +}; +template <size_t N> +struct Tuple3<uint32_t, N> { + uint32x2x3_t raw; +}; +template <> +struct Tuple3<int32_t, 4> { + int32x4x3_t raw; +}; +template <size_t N> +struct Tuple3<int32_t, N> { + int32x2x3_t raw; +}; +template <> +struct Tuple3<uint64_t, 2> { + uint64x2x3_t raw; +}; +template <size_t N> +struct Tuple3<uint64_t, N> { + uint64x1x3_t raw; +}; +template <> +struct Tuple3<int64_t, 2> { + int64x2x3_t raw; +}; +template <size_t N> +struct Tuple3<int64_t, N> { + int64x1x3_t raw; +}; + +template <> +struct Tuple3<float16_t, 8> { + uint16x8x3_t raw; +}; +template <size_t N> +struct Tuple3<float16_t, N> { + uint16x4x3_t raw; +}; +template <> +struct Tuple3<bfloat16_t, 8> { + uint16x8x3_t raw; +}; +template <size_t N> +struct Tuple3<bfloat16_t, N> { + uint16x4x3_t raw; +}; + +template <> +struct Tuple3<float32_t, 4> { + float32x4x3_t raw; +}; +template <size_t N> +struct Tuple3<float32_t, N> { + float32x2x3_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple3<float64_t, 2> { + float64x2x3_t raw; +}; +template <size_t N> +struct Tuple3<float64_t, N> { + float64x1x3_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <> +struct Tuple4<uint8_t, 16> { + uint8x16x4_t raw; +}; +template <size_t N> +struct Tuple4<uint8_t, N> { + uint8x8x4_t raw; +}; +template <> +struct Tuple4<int8_t, 16> { + int8x16x4_t raw; +}; +template <size_t N> +struct Tuple4<int8_t, N> { + int8x8x4_t raw; +}; +template <> +struct Tuple4<uint16_t, 8> { + uint16x8x4_t raw; +}; +template <size_t N> +struct Tuple4<uint16_t, N> { + uint16x4x4_t raw; +}; +template <> +struct Tuple4<int16_t, 8> { + int16x8x4_t raw; +}; +template <size_t N> +struct Tuple4<int16_t, N> { + int16x4x4_t raw; +}; +template <> +struct Tuple4<uint32_t, 4> { + uint32x4x4_t raw; +}; +template <size_t N> +struct Tuple4<uint32_t, N> { + uint32x2x4_t raw; +}; +template <> +struct Tuple4<int32_t, 4> { + int32x4x4_t raw; +}; +template <size_t N> +struct Tuple4<int32_t, N> { + int32x2x4_t raw; +}; +template <> +struct Tuple4<uint64_t, 2> { + uint64x2x4_t raw; +}; +template <size_t N> +struct Tuple4<uint64_t, N> { + uint64x1x4_t raw; +}; +template <> +struct Tuple4<int64_t, 2> { + int64x2x4_t raw; +}; +template <size_t N> +struct Tuple4<int64_t, N> { + int64x1x4_t raw; +}; + +template <> +struct Tuple4<float16_t, 8> { + uint16x8x4_t raw; +}; +template <size_t N> +struct Tuple4<float16_t, N> { + uint16x4x4_t raw; +}; +template <> +struct Tuple4<bfloat16_t, 8> { + uint16x8x4_t raw; +}; +template <size_t N> +struct Tuple4<bfloat16_t, N> { + uint16x4x4_t raw; +}; + +template <> +struct Tuple4<float32_t, 4> { + float32x4x4_t raw; +}; +template <size_t N> +struct Tuple4<float32_t, N> { + float32x2x4_t raw; +}; +#if HWY_ARCH_ARM_A64 +template <> +struct Tuple4<float64_t, 2> { + float64x2x4_t raw; +}; +template <size_t N> +struct Tuple4<float64_t, N> { + float64x1x4_t raw; +}; +#endif // HWY_ARCH_ARM_A64 + +template <typename T, size_t N> +struct Raw128; + +// 128 +template <> +struct Raw128<uint8_t, 16> { + using type = uint8x16_t; +}; + +template <> +struct Raw128<uint16_t, 8> { + using type = uint16x8_t; +}; + +template <> +struct Raw128<uint32_t, 4> { + using type = uint32x4_t; +}; + +template <> +struct Raw128<uint64_t, 2> { + using type = uint64x2_t; +}; + +template <> +struct Raw128<int8_t, 16> { + using type = int8x16_t; +}; + +template <> +struct Raw128<int16_t, 8> { + using type = int16x8_t; +}; + +template <> +struct Raw128<int32_t, 4> { + using type = int32x4_t; +}; + +template <> +struct Raw128<int64_t, 2> { + using type = int64x2_t; +}; + +template <> +struct Raw128<float16_t, 8> { + using type = uint16x8_t; +}; + +template <> +struct Raw128<bfloat16_t, 8> { + using type = uint16x8_t; +}; + +template <> +struct Raw128<float, 4> { + using type = float32x4_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128<double, 2> { + using type = float64x2_t; +}; +#endif + +// 64 +template <> +struct Raw128<uint8_t, 8> { + using type = uint8x8_t; +}; + +template <> +struct Raw128<uint16_t, 4> { + using type = uint16x4_t; +}; + +template <> +struct Raw128<uint32_t, 2> { + using type = uint32x2_t; +}; + +template <> +struct Raw128<uint64_t, 1> { + using type = uint64x1_t; +}; + +template <> +struct Raw128<int8_t, 8> { + using type = int8x8_t; +}; + +template <> +struct Raw128<int16_t, 4> { + using type = int16x4_t; +}; + +template <> +struct Raw128<int32_t, 2> { + using type = int32x2_t; +}; + +template <> +struct Raw128<int64_t, 1> { + using type = int64x1_t; +}; + +template <> +struct Raw128<float16_t, 4> { + using type = uint16x4_t; +}; + +template <> +struct Raw128<bfloat16_t, 4> { + using type = uint16x4_t; +}; + +template <> +struct Raw128<float, 2> { + using type = float32x2_t; +}; + +#if HWY_ARCH_ARM_A64 +template <> +struct Raw128<double, 1> { + using type = float64x1_t; +}; +#endif + +// 32 (same as 64) +template <> +struct Raw128<uint8_t, 4> : public Raw128<uint8_t, 8> {}; + +template <> +struct Raw128<uint16_t, 2> : public Raw128<uint16_t, 4> {}; + +template <> +struct Raw128<uint32_t, 1> : public Raw128<uint32_t, 2> {}; + +template <> +struct Raw128<int8_t, 4> : public Raw128<int8_t, 8> {}; + +template <> +struct Raw128<int16_t, 2> : public Raw128<int16_t, 4> {}; + +template <> +struct Raw128<int32_t, 1> : public Raw128<int32_t, 2> {}; + +template <> +struct Raw128<float16_t, 2> : public Raw128<float16_t, 4> {}; + +template <> +struct Raw128<bfloat16_t, 2> : public Raw128<bfloat16_t, 4> {}; + +template <> +struct Raw128<float, 1> : public Raw128<float, 2> {}; + +// 16 (same as 64) +template <> +struct Raw128<uint8_t, 2> : public Raw128<uint8_t, 8> {}; + +template <> +struct Raw128<uint16_t, 1> : public Raw128<uint16_t, 4> {}; + +template <> +struct Raw128<int8_t, 2> : public Raw128<int8_t, 8> {}; + +template <> +struct Raw128<int16_t, 1> : public Raw128<int16_t, 4> {}; + +template <> +struct Raw128<float16_t, 1> : public Raw128<float16_t, 4> {}; + +template <> +struct Raw128<bfloat16_t, 1> : public Raw128<bfloat16_t, 4> {}; + +// 8 (same as 64) +template <> +struct Raw128<uint8_t, 1> : public Raw128<uint8_t, 8> {}; + +template <> +struct Raw128<int8_t, 1> : public Raw128<int8_t, 8> {}; + +} // namespace detail + +template <typename T, size_t N = 16 / sizeof(T)> +class Vec128 { + using Raw = typename detail::Raw128<T, N>::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() {} + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template <typename T> +using Vec64 = Vec128<T, 8 / sizeof(T)>; + +template <typename T> +using Vec32 = Vec128<T, 4 / sizeof(T)>; + +// FF..FF or 0. +template <typename T, size_t N = 16 / sizeof(T)> +class Mask128 { + // ARM C Language Extensions return and expect unsigned type. + using Raw = typename detail::Raw128<MakeUnsigned<T>, N>::type; + + public: + HWY_INLINE Mask128() {} + Mask128(const Mask128&) = default; + Mask128& operator=(const Mask128&) = default; + HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} + + Raw raw; +}; + +template <typename T> +using Mask64 = Mask128<T, 8 / sizeof(T)>; + +template <class V> +using DFromV = Simd<typename V::PrivateT, V::kPrivateN, 0>; + +template <class V> +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +// Converts from Vec128<T, N> to Vec128<uint8_t, N * sizeof(T)> using the +// vreinterpret*_u8_*() set of functions. +#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ + Vec128<uint8_t, size * sizeof(type##_t)> +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128<type##_t, size> v +#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw + +// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. +template <size_t N> +HWY_INLINE Vec128<uint8_t, N> BitCastToByte(Vec128<uint8_t, N> v) { + return v; +} + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) + +// Special cases for [b]float16_t, which have the same Raw as uint16_t. +template <size_t N> +HWY_INLINE Vec128<uint8_t, N * 2> BitCastToByte(Vec128<float16_t, N> v) { + return BitCastToByte(Vec128<uint16_t, N>(v.raw)); +} +template <size_t N> +HWY_INLINE Vec128<uint8_t, N * 2> BitCastToByte(Vec128<bfloat16_t, N> v) { + return BitCastToByte(Vec128<uint16_t, N>(v.raw)); +} + +#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 + +template <size_t N> +HWY_INLINE Vec128<uint8_t, N> BitCastFromByte(Simd<uint8_t, N, 0> /* tag */, + Vec128<uint8_t, N> v) { + return v; +} + +// 64-bit or less: + +template <size_t N, HWY_IF_LE64(int8_t, N)> +HWY_INLINE Vec128<int8_t, N> BitCastFromByte(Simd<int8_t, N, 0> /* tag */, + Vec128<uint8_t, N> v) { + return Vec128<int8_t, N>(vreinterpret_s8_u8(v.raw)); +} +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_INLINE Vec128<uint16_t, N> BitCastFromByte(Simd<uint16_t, N, 0> /* tag */, + Vec128<uint8_t, N * 2> v) { + return Vec128<uint16_t, N>(vreinterpret_u16_u8(v.raw)); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_INLINE Vec128<int16_t, N> BitCastFromByte(Simd<int16_t, N, 0> /* tag */, + Vec128<uint8_t, N * 2> v) { + return Vec128<int16_t, N>(vreinterpret_s16_u8(v.raw)); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_INLINE Vec128<uint32_t, N> BitCastFromByte(Simd<uint32_t, N, 0> /* tag */, + Vec128<uint8_t, N * 4> v) { + return Vec128<uint32_t, N>(vreinterpret_u32_u8(v.raw)); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_INLINE Vec128<int32_t, N> BitCastFromByte(Simd<int32_t, N, 0> /* tag */, + Vec128<uint8_t, N * 4> v) { + return Vec128<int32_t, N>(vreinterpret_s32_u8(v.raw)); +} +template <size_t N, HWY_IF_LE64(float, N)> +HWY_INLINE Vec128<float, N> BitCastFromByte(Simd<float, N, 0> /* tag */, + Vec128<uint8_t, N * 4> v) { + return Vec128<float, N>(vreinterpret_f32_u8(v.raw)); +} +HWY_INLINE Vec64<uint64_t> BitCastFromByte(Full64<uint64_t> /* tag */, + Vec128<uint8_t, 1 * 8> v) { + return Vec64<uint64_t>(vreinterpret_u64_u8(v.raw)); +} +HWY_INLINE Vec64<int64_t> BitCastFromByte(Full64<int64_t> /* tag */, + Vec128<uint8_t, 1 * 8> v) { + return Vec64<int64_t>(vreinterpret_s64_u8(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec64<double> BitCastFromByte(Full64<double> /* tag */, + Vec128<uint8_t, 1 * 8> v) { + return Vec64<double>(vreinterpret_f64_u8(v.raw)); +} +#endif + +// 128-bit full: + +HWY_INLINE Vec128<int8_t> BitCastFromByte(Full128<int8_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<int8_t>(vreinterpretq_s8_u8(v.raw)); +} +HWY_INLINE Vec128<uint16_t> BitCastFromByte(Full128<uint16_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<uint16_t>(vreinterpretq_u16_u8(v.raw)); +} +HWY_INLINE Vec128<int16_t> BitCastFromByte(Full128<int16_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<int16_t>(vreinterpretq_s16_u8(v.raw)); +} +HWY_INLINE Vec128<uint32_t> BitCastFromByte(Full128<uint32_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<uint32_t>(vreinterpretq_u32_u8(v.raw)); +} +HWY_INLINE Vec128<int32_t> BitCastFromByte(Full128<int32_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<int32_t>(vreinterpretq_s32_u8(v.raw)); +} +HWY_INLINE Vec128<float> BitCastFromByte(Full128<float> /* tag */, + Vec128<uint8_t> v) { + return Vec128<float>(vreinterpretq_f32_u8(v.raw)); +} +HWY_INLINE Vec128<uint64_t> BitCastFromByte(Full128<uint64_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<uint64_t>(vreinterpretq_u64_u8(v.raw)); +} +HWY_INLINE Vec128<int64_t> BitCastFromByte(Full128<int64_t> /* tag */, + Vec128<uint8_t> v) { + return Vec128<int64_t>(vreinterpretq_s64_u8(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_INLINE Vec128<double> BitCastFromByte(Full128<double> /* tag */, + Vec128<uint8_t> v) { + return Vec128<double>(vreinterpretq_f64_u8(v.raw)); +} +#endif + +// Special cases for [b]float16_t, which have the same Raw as uint16_t. +template <size_t N> +HWY_INLINE Vec128<float16_t, N> BitCastFromByte(Simd<float16_t, N, 0> /* tag */, + Vec128<uint8_t, N * 2> v) { + return Vec128<float16_t, N>(BitCastFromByte(Simd<uint16_t, N, 0>(), v).raw); +} +template <size_t N> +HWY_INLINE Vec128<bfloat16_t, N> BitCastFromByte( + Simd<bfloat16_t, N, 0> /* tag */, Vec128<uint8_t, N * 2> v) { + return Vec128<bfloat16_t, N>(BitCastFromByte(Simd<uint16_t, N, 0>(), v).raw); +} + +} // namespace detail + +template <typename T, size_t N, typename FromT> +HWY_API Vec128<T, N> BitCast(Simd<T, N, 0> d, + Vec128<FromT, N * sizeof(T) / sizeof(FromT)> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns a vector with all lanes set to "t". +#define HWY_NEON_BUILD_TPL_HWY_SET1 +#define HWY_NEON_BUILD_RET_HWY_SET1(type, size) Vec128<type##_t, size> +#define HWY_NEON_BUILD_PARAM_HWY_SET1(type, size) \ + Simd<type##_t, size, 0> /* tag */, const type##_t t +#define HWY_NEON_BUILD_ARG_HWY_SET1 t + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(Set, vdup, _n_, HWY_SET1) + +#undef HWY_NEON_BUILD_TPL_HWY_SET1 +#undef HWY_NEON_BUILD_RET_HWY_SET1 +#undef HWY_NEON_BUILD_PARAM_HWY_SET1 +#undef HWY_NEON_BUILD_ARG_HWY_SET1 + +// Returns an all-zero vector. +template <typename T, size_t N> +HWY_API Vec128<T, N> Zero(Simd<T, N, 0> d) { + return Set(d, 0); +} + +template <size_t N> +HWY_API Vec128<bfloat16_t, N> Zero(Simd<bfloat16_t, N, 0> /* tag */) { + return Vec128<bfloat16_t, N>(Zero(Simd<uint16_t, N, 0>()).raw); +} + +template <class D> +using VFromD = decltype(Zero(D())); + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// Returns a vector with uninitialized elements. +template <typename T, size_t N> +HWY_API Vec128<T, N> Undefined(Simd<T, N, 0> /*d*/) { + typename detail::Raw128<T, N>::type a; + return Vec128<T, N>(a); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template <typename T, size_t N, typename T2> +Vec128<T, N> Iota(const Simd<T, N, 0> d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag<T>(), static_cast<T>(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ GetLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_GET template <size_t kLane> +#define HWY_NEON_BUILD_RET_HWY_GET(type, size) type##_t +#define HWY_NEON_BUILD_PARAM_HWY_GET(type, size) Vec128<type##_t, size> v +#define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) + +#undef HWY_NEON_BUILD_TPL_HWY_GET +#undef HWY_NEON_BUILD_RET_HWY_GET +#undef HWY_NEON_BUILD_PARAM_HWY_GET +#undef HWY_NEON_BUILD_ARG_HWY_GET + +} // namespace detail + +template <class V> +HWY_API TFromV<V> GetLane(const V v) { + return detail::GetLane<0>(v); +} + +// ------------------------------ ExtractLane + +// Requires one overload per vector length because GetLane<3> is a compile error +// if v is a uint32x2_t. +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 1> v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::GetLane<0>(v); +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 2> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 4> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 8> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 16> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + case 8: + return detail::GetLane<8>(v); + case 9: + return detail::GetLane<9>(v); + case 10: + return detail::GetLane<10>(v); + case 11: + return detail::GetLane<11>(v); + case 12: + return detail::GetLane<12>(v); + case 13: + return detail::GetLane<13>(v); + case 14: + return detail::GetLane<14>(v); + case 15: + return detail::GetLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_INSERT template <size_t kLane> +#define HWY_NEON_BUILD_RET_HWY_INSERT(type, size) Vec128<type##_t, size> +#define HWY_NEON_BUILD_PARAM_HWY_INSERT(type, size) \ + Vec128<type##_t, size> v, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) + +#undef HWY_NEON_BUILD_TPL_HWY_INSERT +#undef HWY_NEON_BUILD_RET_HWY_INSERT +#undef HWY_NEON_BUILD_PARAM_HWY_INSERT +#undef HWY_NEON_BUILD_ARG_HWY_INSERT + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error. + +template <typename T> +HWY_API Vec128<T, 1> InsertLane(const Vec128<T, 1> v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV<decltype(v)>(), t); +} + +template <typename T> +HWY_API Vec128<T, 2> InsertLane(const Vec128<T, 2> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 4> InsertLane(const Vec128<T, 4> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 8> InsertLane(const Vec128<T, 8> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 16> InsertLane(const Vec128<T, 16> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2) + +// ------------------------------ Subtraction +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2) + +// ------------------------------ SumsOf8 + +HWY_API Vec128<uint64_t> SumsOf8(const Vec128<uint8_t> v) { + return Vec128<uint64_t>(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); +} +HWY_API Vec64<uint64_t> SumsOf8(const Vec64<uint8_t> v) { + return Vec64<uint64_t>(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); +} + +// ------------------------------ SaturatedAdd +// Only defined for uint8_t, uint16_t and their signed versions, as in other +// architectures. + +// Returns a + b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedAdd, vqadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedAdd, vqadd, _, 2) + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_16(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(SaturatedSub, vqsub, _, 2) + +// Not part of API, used in implementation. +namespace detail { +HWY_NEON_DEF_FUNCTION_UINT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_64(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_32(SaturatedSub, vqsub, _, 2) +HWY_NEON_DEF_FUNCTION_INT_64(SaturatedSub, vqsub, _, 2) +} // namespace detail + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 +HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2) + +// ------------------------------ Neg + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below + +HWY_API Vec64<int64_t> Neg(const Vec64<int64_t> v) { +#if HWY_ARCH_ARM_A64 + return Vec64<int64_t>(vneg_s64(v.raw)); +#else + return Zero(Full64<int64_t>()) - v; +#endif +} + +HWY_API Vec128<int64_t> Neg(const Vec128<int64_t> v) { +#if HWY_ARCH_ARM_A64 + return Vec128<int64_t>(vnegq_s64(v.raw)); +#else + return Zero(Full128<int64_t>()) - v; +#endif +} + +// ------------------------------ ShiftLeft + +// Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + template <int kBits> \ + HWY_API Vec128<type##_t, size> name(const Vec128<type##_t, size> v) { \ + return kBits == 0 ? v \ + : Vec128<type##_t, size>(HWY_NEON_EVAL( \ + prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ + } + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) + +HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ RotateRight (ShiftRight, Or) + +template <int kBits, size_t N> +HWY_API Vec128<uint32_t, N> RotateRight(const Vec128<uint32_t, N> v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(31, 32 - kBits)>(v)); +} + +template <int kBits, size_t N> +HWY_API Vec128<uint64_t, N> RotateRight(const Vec128<uint64_t, N> v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(63, 64 - kBits)>(v)); +} + +// NOTE: vxarq_u64 can be applied to uint64_t, but we do not yet have a +// mechanism for checking for extensions to ARMv8. + +// ------------------------------ Shl + +HWY_API Vec128<uint8_t> operator<<(const Vec128<uint8_t> v, + const Vec128<uint8_t> bits) { + return Vec128<uint8_t>(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); +} +template <size_t N, HWY_IF_LE64(uint8_t, N)> +HWY_API Vec128<uint8_t, N> operator<<(const Vec128<uint8_t, N> v, + const Vec128<uint8_t, N> bits) { + return Vec128<uint8_t, N>(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); +} + +HWY_API Vec128<uint16_t> operator<<(const Vec128<uint16_t> v, + const Vec128<uint16_t> bits) { + return Vec128<uint16_t>(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); +} +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> operator<<(const Vec128<uint16_t, N> v, + const Vec128<uint16_t, N> bits) { + return Vec128<uint16_t, N>(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); +} + +HWY_API Vec128<uint32_t> operator<<(const Vec128<uint32_t> v, + const Vec128<uint32_t> bits) { + return Vec128<uint32_t>(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint32_t, N> operator<<(const Vec128<uint32_t, N> v, + const Vec128<uint32_t, N> bits) { + return Vec128<uint32_t, N>(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); +} + +HWY_API Vec128<uint64_t> operator<<(const Vec128<uint64_t> v, + const Vec128<uint64_t> bits) { + return Vec128<uint64_t>(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); +} +HWY_API Vec64<uint64_t> operator<<(const Vec64<uint64_t> v, + const Vec64<uint64_t> bits) { + return Vec64<uint64_t>(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +} + +HWY_API Vec128<int8_t> operator<<(const Vec128<int8_t> v, + const Vec128<int8_t> bits) { + return Vec128<int8_t>(vshlq_s8(v.raw, bits.raw)); +} +template <size_t N, HWY_IF_LE64(int8_t, N)> +HWY_API Vec128<int8_t, N> operator<<(const Vec128<int8_t, N> v, + const Vec128<int8_t, N> bits) { + return Vec128<int8_t, N>(vshl_s8(v.raw, bits.raw)); +} + +HWY_API Vec128<int16_t> operator<<(const Vec128<int16_t> v, + const Vec128<int16_t> bits) { + return Vec128<int16_t>(vshlq_s16(v.raw, bits.raw)); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> operator<<(const Vec128<int16_t, N> v, + const Vec128<int16_t, N> bits) { + return Vec128<int16_t, N>(vshl_s16(v.raw, bits.raw)); +} + +HWY_API Vec128<int32_t> operator<<(const Vec128<int32_t> v, + const Vec128<int32_t> bits) { + return Vec128<int32_t>(vshlq_s32(v.raw, bits.raw)); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> operator<<(const Vec128<int32_t, N> v, + const Vec128<int32_t, N> bits) { + return Vec128<int32_t, N>(vshl_s32(v.raw, bits.raw)); +} + +HWY_API Vec128<int64_t> operator<<(const Vec128<int64_t> v, + const Vec128<int64_t> bits) { + return Vec128<int64_t>(vshlq_s64(v.raw, bits.raw)); +} +HWY_API Vec64<int64_t> operator<<(const Vec64<int64_t> v, + const Vec64<int64_t> bits) { + return Vec64<int64_t>(vshl_s64(v.raw, bits.raw)); +} + +// ------------------------------ Shr (Neg) + +HWY_API Vec128<uint8_t> operator>>(const Vec128<uint8_t> v, + const Vec128<uint8_t> bits) { + const int8x16_t neg_bits = Neg(BitCast(Full128<int8_t>(), bits)).raw; + return Vec128<uint8_t>(vshlq_u8(v.raw, neg_bits)); +} +template <size_t N, HWY_IF_LE64(uint8_t, N)> +HWY_API Vec128<uint8_t, N> operator>>(const Vec128<uint8_t, N> v, + const Vec128<uint8_t, N> bits) { + const int8x8_t neg_bits = Neg(BitCast(Simd<int8_t, N, 0>(), bits)).raw; + return Vec128<uint8_t, N>(vshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128<uint16_t> operator>>(const Vec128<uint16_t> v, + const Vec128<uint16_t> bits) { + const int16x8_t neg_bits = Neg(BitCast(Full128<int16_t>(), bits)).raw; + return Vec128<uint16_t>(vshlq_u16(v.raw, neg_bits)); +} +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> operator>>(const Vec128<uint16_t, N> v, + const Vec128<uint16_t, N> bits) { + const int16x4_t neg_bits = Neg(BitCast(Simd<int16_t, N, 0>(), bits)).raw; + return Vec128<uint16_t, N>(vshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128<uint32_t> operator>>(const Vec128<uint32_t> v, + const Vec128<uint32_t> bits) { + const int32x4_t neg_bits = Neg(BitCast(Full128<int32_t>(), bits)).raw; + return Vec128<uint32_t>(vshlq_u32(v.raw, neg_bits)); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint32_t, N> operator>>(const Vec128<uint32_t, N> v, + const Vec128<uint32_t, N> bits) { + const int32x2_t neg_bits = Neg(BitCast(Simd<int32_t, N, 0>(), bits)).raw; + return Vec128<uint32_t, N>(vshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128<uint64_t> operator>>(const Vec128<uint64_t> v, + const Vec128<uint64_t> bits) { + const int64x2_t neg_bits = Neg(BitCast(Full128<int64_t>(), bits)).raw; + return Vec128<uint64_t>(vshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64<uint64_t> operator>>(const Vec64<uint64_t> v, + const Vec64<uint64_t> bits) { + const int64x1_t neg_bits = Neg(BitCast(Full64<int64_t>(), bits)).raw; + return Vec64<uint64_t>(vshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128<int8_t> operator>>(const Vec128<int8_t> v, + const Vec128<int8_t> bits) { + return Vec128<int8_t>(vshlq_s8(v.raw, Neg(bits).raw)); +} +template <size_t N, HWY_IF_LE64(int8_t, N)> +HWY_API Vec128<int8_t, N> operator>>(const Vec128<int8_t, N> v, + const Vec128<int8_t, N> bits) { + return Vec128<int8_t, N>(vshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128<int16_t> operator>>(const Vec128<int16_t> v, + const Vec128<int16_t> bits) { + return Vec128<int16_t>(vshlq_s16(v.raw, Neg(bits).raw)); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> operator>>(const Vec128<int16_t, N> v, + const Vec128<int16_t, N> bits) { + return Vec128<int16_t, N>(vshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128<int32_t> operator>>(const Vec128<int32_t> v, + const Vec128<int32_t> bits) { + return Vec128<int32_t>(vshlq_s32(v.raw, Neg(bits).raw)); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> operator>>(const Vec128<int32_t, N> v, + const Vec128<int32_t, N> bits) { + return Vec128<int32_t, N>(vshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128<int64_t> operator>>(const Vec128<int64_t> v, + const Vec128<int64_t> bits) { + return Vec128<int64_t>(vshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64<int64_t> operator>>(const Vec64<int64_t> v, + const Vec64<int64_t> bits) { + return Vec64<int64_t>(vshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ ShiftLeftSame (Shl) + +template <typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftSame(const Vec128<T, N> v, int bits) { + return v << Set(Simd<T, N, 0>(), static_cast<T>(bits)); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightSame(const Vec128<T, N> v, int bits) { + return v >> Set(Simd<T, N, 0>(), static_cast<T>(bits)); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec128<uint16_t> operator*(const Vec128<uint16_t> a, + const Vec128<uint16_t> b) { + return Vec128<uint16_t>(vmulq_u16(a.raw, b.raw)); +} +HWY_API Vec128<uint32_t> operator*(const Vec128<uint32_t> a, + const Vec128<uint32_t> b) { + return Vec128<uint32_t>(vmulq_u32(a.raw, b.raw)); +} + +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> operator*(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>(vmul_u16(a.raw, b.raw)); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint32_t, N> operator*(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>(vmul_u32(a.raw, b.raw)); +} + +// Signed +HWY_API Vec128<int16_t> operator*(const Vec128<int16_t> a, + const Vec128<int16_t> b) { + return Vec128<int16_t>(vmulq_s16(a.raw, b.raw)); +} +HWY_API Vec128<int32_t> operator*(const Vec128<int32_t> a, + const Vec128<int32_t> b) { + return Vec128<int32_t>(vmulq_s32(a.raw, b.raw)); +} + +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<int16_t, N> operator*(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>(vmul_s16(a.raw, b.raw)); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> operator*(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>(vmul_s32(a.raw, b.raw)); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec128<int16_t> MulHigh(const Vec128<int16_t> a, + const Vec128<int16_t> b) { + int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); +#if HWY_ARCH_ARM_A64 + int32x4_t rhi = vmull_high_s16(a.raw, b.raw); +#else + int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); +#endif + return Vec128<int16_t>( + vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); +} +HWY_API Vec128<uint16_t> MulHigh(const Vec128<uint16_t> a, + const Vec128<uint16_t> b) { + uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); +#if HWY_ARCH_ARM_A64 + uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); +#else + uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); +#endif + return Vec128<uint16_t>( + vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); +} + +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> MulHigh(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); + return Vec128<int16_t, N>(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); +} +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MulHigh(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); + return Vec128<uint16_t, N>(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); +} + +HWY_API Vec128<int16_t> MulFixedPoint15(Vec128<int16_t> a, Vec128<int16_t> b) { + return Vec128<int16_t>(vqrdmulhq_s16(a.raw, b.raw)); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> MulFixedPoint15(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Vec128<int16_t, N>(vqrdmulh_s16(a.raw, b.raw)); +} + +// ------------------------------ Floating-point mul / div + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) + +// Approximate reciprocal +HWY_API Vec128<float> ApproximateReciprocal(const Vec128<float> v) { + return Vec128<float>(vrecpeq_f32(v.raw)); +} +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocal(const Vec128<float, N> v) { + return Vec128<float, N>(vrecpe_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) +#else +// Not defined on armv7: approximate +namespace detail { + +HWY_INLINE Vec128<float> ReciprocalNewtonRaphsonStep( + const Vec128<float> recip, const Vec128<float> divisor) { + return Vec128<float>(vrecpsq_f32(recip.raw, divisor.raw)); +} +template <size_t N> +HWY_INLINE Vec128<float, N> ReciprocalNewtonRaphsonStep( + const Vec128<float, N> recip, Vec128<float, N> divisor) { + return Vec128<float, N>(vrecps_f32(recip.raw, divisor.raw)); +} + +} // namespace detail + +template <size_t N> +HWY_API Vec128<float, N> operator/(const Vec128<float, N> a, + const Vec128<float, N> b) { + auto x = ApproximateReciprocal(b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + return a * x; +} +#endif + +// ------------------------------ Absolute value of difference. + +HWY_API Vec128<float> AbsDiff(const Vec128<float> a, const Vec128<float> b) { + return Vec128<float>(vabdq_f32(a.raw, b.raw)); +} +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> AbsDiff(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>(vabd_f32(a.raw, b.raw)); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns add + mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> MulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { + return Vec128<float, N>(vfma_f32(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128<float> MulAdd(const Vec128<float> mul, const Vec128<float> x, + const Vec128<float> add) { + return Vec128<float>(vfmaq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template <size_t N> +HWY_API Vec128<float, N> MulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { + return mul * x + add; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_API Vec64<double> MulAdd(const Vec64<double> mul, const Vec64<double> x, + const Vec64<double> add) { + return Vec64<double>(vfma_f64(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128<double> MulAdd(const Vec128<double> mul, const Vec128<double> x, + const Vec128<double> add) { + return Vec128<double>(vfmaq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns add - mul * x +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> NegMulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { + return Vec128<float, N>(vfms_f32(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128<float> NegMulAdd(const Vec128<float> mul, const Vec128<float> x, + const Vec128<float> add) { + return Vec128<float>(vfmsq_f32(add.raw, mul.raw, x.raw)); +} +#else +// Emulate FMA for floats. +template <size_t N> +HWY_API Vec128<float, N> NegMulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { + return add - mul * x; +} +#endif + +#if HWY_ARCH_ARM_A64 +HWY_API Vec64<double> NegMulAdd(const Vec64<double> mul, const Vec64<double> x, + const Vec64<double> add) { + return Vec64<double>(vfms_f64(add.raw, mul.raw, x.raw)); +} +HWY_API Vec128<double> NegMulAdd(const Vec128<double> mul, + const Vec128<double> x, + const Vec128<double> add) { + return Vec128<double>(vfmsq_f64(add.raw, mul.raw, x.raw)); +} +#endif + +// Returns mul * x - sub +template <size_t N> +HWY_API Vec128<float, N> MulSub(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> sub) { + return MulAdd(mul, x, Neg(sub)); +} + +// Returns -mul * x - sub +template <size_t N> +HWY_API Vec128<float, N> NegMulSub(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> sub) { + return Neg(MulAdd(mul, x, sub)); +} + +#if HWY_ARCH_ARM_A64 +template <size_t N> +HWY_API Vec128<double, N> MulSub(const Vec128<double, N> mul, + const Vec128<double, N> x, + const Vec128<double, N> sub) { + return MulAdd(mul, x, Neg(sub)); +} +template <size_t N> +HWY_API Vec128<double, N> NegMulSub(const Vec128<double, N> mul, + const Vec128<double, N> x, + const Vec128<double, N> sub) { + return Neg(MulAdd(mul, x, sub)); +} +#endif + +// ------------------------------ Floating-point square root (IfThenZeroElse) + +// Approximate reciprocal square root +HWY_API Vec128<float> ApproximateReciprocalSqrt(const Vec128<float> v) { + return Vec128<float>(vrsqrteq_f32(v.raw)); +} +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocalSqrt(const Vec128<float, N> v) { + return Vec128<float, N>(vrsqrte_f32(v.raw)); +} + +// Full precision square root +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) +#else +namespace detail { + +HWY_INLINE Vec128<float> ReciprocalSqrtStep(const Vec128<float> root, + const Vec128<float> recip) { + return Vec128<float>(vrsqrtsq_f32(root.raw, recip.raw)); +} +template <size_t N> +HWY_INLINE Vec128<float, N> ReciprocalSqrtStep(const Vec128<float, N> root, + Vec128<float, N> recip) { + return Vec128<float, N>(vrsqrts_f32(root.raw, recip.raw)); +} + +} // namespace detail + +// Not defined on armv7: approximate +template <size_t N> +HWY_API Vec128<float, N> Sqrt(const Vec128<float, N> v) { + auto recip = ApproximateReciprocalSqrt(v); + + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + + const auto root = v * recip; + return IfThenZeroElse(v == Zero(Simd<float, N, 0>()), root); +} +#endif + +// ================================================== LOGICAL + +// ------------------------------ Not + +// There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. +template <typename T> +HWY_API Vec128<T> Not(const Vec128<T> v) { + const Full128<T> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Vec128<uint8_t>(vmvnq_u8(BitCast(d8, v).raw))); +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> Not(const Vec128<T, N> v) { + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + using V8 = decltype(Zero(d8)); + return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); +} + +// ------------------------------ And +HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) + +// Uses the u32/64 defined above. +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> And(const Vec128<T, N> a, const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, BitCast(du, a) & BitCast(du, b)); +} + +// ------------------------------ AndNot + +namespace detail { +// reversed_andnot returns a & ~b. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2) +} // namespace detail + +// Returns ~not_mask & mask. +template <typename T, size_t N, HWY_IF_NOT_FLOAT(T)> +HWY_API Vec128<T, N> AndNot(const Vec128<T, N> not_mask, + const Vec128<T, N> mask) { + return detail::reversed_andnot(mask, not_mask); +} + +// Uses the u32/64 defined above. +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> AndNot(const Vec128<T, N> not_mask, + const Vec128<T, N> mask) { + const DFromV<decltype(mask)> d; + const RebindToUnsigned<decltype(d)> du; + VFromD<decltype(du)> ret = + detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(d, ret); +} + +// ------------------------------ Or + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) + +// Uses the u32/64 defined above. +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> Or(const Vec128<T, N> a, const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, BitCast(du, a) | BitCast(du, b)); +} + +// ------------------------------ Xor + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) + +// Uses the u32/64 defined above. +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> Xor(const Vec128<T, N> a, const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); +} + +// ------------------------------ Xor3 +#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3) +HWY_NEON_DEF_FUNCTION_FULL_UI(Xor3, veor3, _, 3) + +// Half vectors are not natively supported. Two Xor are likely more efficient +// than Combine to 128-bit. +template <typename T, size_t N, HWY_IF_LE64(T, N), HWY_IF_NOT_FLOAT(T)> +HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) { + return Xor(x1, Xor(x2, x3)); +} + +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> Xor3(const Vec128<T, N> x1, const Vec128<T, N> x2, + const Vec128<T, N> x3) { + const DFromV<decltype(x1)> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#else +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) { + return Xor(x1, Xor(x2, x3)); +} +#endif + +// ------------------------------ Or3 + +template <typename T, size_t N> +HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template <typename T, size_t N> +HWY_API Vec128<T, N> OrAnd(Vec128<T, N> o, Vec128<T, N> a1, Vec128<T, N> a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator&(const Vec128<T, N> a, const Vec128<T, N> b) { + return And(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator|(const Vec128<T, N> a, const Vec128<T, N> b) { + return Or(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator^(const Vec128<T, N> a, const Vec128<T, N> b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template <typename T> +HWY_INLINE Vec128<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec128<T> v) { + const Full128<uint8_t> d8; + return Vec128<T>(vcntq_u8(BitCast(d8, v).raw)); +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128<T, N> v) { + const Simd<uint8_t, N, 0> d8; + return Vec128<T, N>(vcnt_u8(BitCast(d8, v).raw)); +} + +// ARM lacks popcount for lane sizes > 1, so take pairwise sums of the bytes. +template <typename T> +HWY_INLINE Vec128<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec128<T> v) { + const Full128<uint8_t> d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128<T>(vpaddlq_u8(bytes)); +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128<T, N> v) { + const Repartition<uint8_t, Simd<T, N, 0>> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128<T, N>(vpaddl_u8(bytes)); +} + +template <typename T> +HWY_INLINE Vec128<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec128<T> v) { + const Full128<uint8_t> d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128<T>(vpaddlq_u16(vpaddlq_u8(bytes))); +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128<T, N> v) { + const Repartition<uint8_t, Simd<T, N, 0>> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128<T, N>(vpaddl_u16(vpaddl_u8(bytes))); +} + +template <typename T> +HWY_INLINE Vec128<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec128<T> v) { + const Full128<uint8_t> d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128<T>(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(bytes)))); +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128<T, N> v) { + const Repartition<uint8_t, Simd<T, N, 0>> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128<T, N>(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); +} + +} // namespace detail + +template <typename T, size_t N, HWY_IF_NOT_FLOAT(T)> +HWY_API Vec128<T, N> PopulationCount(Vec128<T, N> v) { + return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); +} + +// ================================================== SIGN + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec128<int8_t> Abs(const Vec128<int8_t> v) { + return Vec128<int8_t>(vabsq_s8(v.raw)); +} +HWY_API Vec128<int16_t> Abs(const Vec128<int16_t> v) { + return Vec128<int16_t>(vabsq_s16(v.raw)); +} +HWY_API Vec128<int32_t> Abs(const Vec128<int32_t> v) { + return Vec128<int32_t>(vabsq_s32(v.raw)); +} +// i64 is implemented after BroadcastSignBit. +HWY_API Vec128<float> Abs(const Vec128<float> v) { + return Vec128<float>(vabsq_f32(v.raw)); +} + +template <size_t N, HWY_IF_LE64(int8_t, N)> +HWY_API Vec128<int8_t, N> Abs(const Vec128<int8_t, N> v) { + return Vec128<int8_t, N>(vabs_s8(v.raw)); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> Abs(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>(vabs_s16(v.raw)); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> Abs(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>(vabs_s32(v.raw)); +} +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> Abs(const Vec128<float, N> v) { + return Vec128<float, N>(vabs_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 +HWY_API Vec128<double> Abs(const Vec128<double> v) { + return Vec128<double>(vabsq_f64(v.raw)); +} + +HWY_API Vec64<double> Abs(const Vec64<double> v) { + return Vec64<double>(vabs_f64(v.raw)); +} +#endif + +// ------------------------------ CopySign + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySign(const Vec128<T, N> magn, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd<T, N, 0>()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySignToAbs(const Vec128<T, N> abs, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd<T, N, 0>()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template <typename T, size_t N, HWY_IF_SIGNED(T)> +HWY_API Vec128<T, N> BroadcastSignBit(const Vec128<T, N> v) { + return ShiftRight<sizeof(T) * 8 - 1>(v); +} + +// ================================================== MASK + +// ------------------------------ To/from vector + +// Mask and Vec have the same representation (true = FF..FF). +template <typename T, size_t N> +HWY_API Mask128<T, N> MaskFromVec(const Vec128<T, N> v) { + const Simd<MakeUnsigned<T>, N, 0> du; + return Mask128<T, N>(BitCast(du, v).raw); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> VecFromMask(Simd<T, N, 0> d, const Mask128<T, N> v) { + return BitCast(d, Vec128<MakeUnsigned<T>, N>(v.raw)); +} + +// ------------------------------ RebindMask + +template <typename TFrom, typename TTo, size_t N> +HWY_API Mask128<TTo, N> RebindMask(Simd<TTo, N, 0> dto, Mask128<TFrom, N> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(dto, VecFromMask(Simd<TFrom, N, 0>(), m))); +} + +// ------------------------------ IfThenElse(mask, yes, no) = mask ? b : a. + +#define HWY_NEON_BUILD_TPL_HWY_IF +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128<type##_t, size> +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128<type##_t, size> mask, const Vec128<type##_t, size> yes, \ + const Vec128<type##_t, size> no +#define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) + +#undef HWY_NEON_BUILD_TPL_HWY_IF +#undef HWY_NEON_BUILD_RET_HWY_IF +#undef HWY_NEON_BUILD_PARAM_HWY_IF +#undef HWY_NEON_BUILD_ARG_HWY_IF + +// mask ? yes : 0 +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElseZero(const Mask128<T, N> mask, + const Vec128<T, N> yes) { + return yes & VecFromMask(Simd<T, N, 0>(), mask); +} + +// mask ? 0 : no +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenZeroElse(const Mask128<T, N> mask, + const Vec128<T, N> no) { + return AndNot(VecFromMask(Simd<T, N, 0>(), mask), no); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes, + Vec128<T, N> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + const Simd<T, N, 0> d; + const RebindToSigned<decltype(d)> di; + + Mask128<T, N> m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ZeroIfNegative(Vec128<T, N> v) { + const auto zero = Zero(Simd<T, N, 0>()); + return Max(zero, v); +} + +// ------------------------------ Mask logical + +template <typename T, size_t N> +HWY_API Mask128<T, N> Not(const Mask128<T, N> m) { + return MaskFromVec(Not(VecFromMask(Simd<T, N, 0>(), m))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> And(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> AndNot(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Or(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Xor(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// ------------------------------ Shuffle2301 (for i64 compares) + +// Swap 32-bit halves in 64-bits +HWY_API Vec64<uint32_t> Shuffle2301(const Vec64<uint32_t> v) { + return Vec64<uint32_t>(vrev64_u32(v.raw)); +} +HWY_API Vec64<int32_t> Shuffle2301(const Vec64<int32_t> v) { + return Vec64<int32_t>(vrev64_s32(v.raw)); +} +HWY_API Vec64<float> Shuffle2301(const Vec64<float> v) { + return Vec64<float>(vrev64_f32(v.raw)); +} +HWY_API Vec128<uint32_t> Shuffle2301(const Vec128<uint32_t> v) { + return Vec128<uint32_t>(vrev64q_u32(v.raw)); +} +HWY_API Vec128<int32_t> Shuffle2301(const Vec128<int32_t> v) { + return Vec128<int32_t>(vrev64q_s32(v.raw)); +} +HWY_API Vec128<float> Shuffle2301(const Vec128<float> v) { + return Vec128<float>(vrev64q_f32(v.raw)); +} + +#define HWY_NEON_BUILD_TPL_HWY_COMPARE +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128<type##_t, size> +#define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ + const Vec128<type##_t, size> a, const Vec128<type##_t, size> b +#define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw + +// ------------------------------ Equality +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) +#else +// No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) +#endif + +// ------------------------------ Strict inequality (signed, float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<, vclt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<, vclt, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) + +// ------------------------------ Weak inequality (float) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) + +#undef HWY_NEON_BUILD_TPL_HWY_COMPARE +#undef HWY_NEON_BUILD_RET_HWY_COMPARE +#undef HWY_NEON_BUILD_PARAM_HWY_COMPARE +#undef HWY_NEON_BUILD_ARG_HWY_COMPARE + +// ------------------------------ ARMv7 i64 compare (Shuffle2301, Eq) + +#if HWY_ARCH_ARM_V7 + +template <size_t N> +HWY_API Mask128<int64_t, N> operator==(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + const Simd<int32_t, N * 2, 0> d32; + const Simd<int64_t, N, 0> d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +template <size_t N> +HWY_API Mask128<uint64_t, N> operator==(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + const Simd<uint32_t, N * 2, 0> d32; + const Simd<uint64_t, N, 0> d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +HWY_API Mask128<int64_t> operator<(const Vec128<int64_t> a, + const Vec128<int64_t> b) { + const int64x2_t sub = vqsubq_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128<int64_t>(sub))); +} +HWY_API Mask128<int64_t, 1> operator<(const Vec64<int64_t> a, + const Vec64<int64_t> b) { + const int64x1_t sub = vqsub_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec64<int64_t>(sub))); +} + +template <size_t N> +HWY_API Mask128<uint64_t, N> operator<(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + const DFromV<decltype(a)> du; + const RebindToSigned<decltype(du)> di; + const Vec128<uint64_t, N> msb = AndNot(a, b) | AndNot(a ^ b, a - b); + return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); +} + +#endif + +// ------------------------------ operator!= (operator==) + +// Customize HWY_NEON_DEF_FUNCTION to call 2 functions. +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +// This cannot have _any_ template argument (in x86_128 we can at least have N +// as an argument), otherwise it is not more specialized than rewritten +// operator== in C++20, leading to compile errors. +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_API Mask128<type##_t, size> name(Vec128<type##_t, size> a, \ + Vec128<type##_t, size> b) { \ + return Not(a == b); \ + } + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator!=, ignored, ignored, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ Reversed comparisons + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator>(Vec128<T, N> a, Vec128<T, N> b) { + return operator<(b, a); +} +template <typename T, size_t N> +HWY_API Mask128<T, N> operator>=(Vec128<T, N> a, Vec128<T, N> b) { + return operator<=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template <typename T, size_t N> +HWY_API Mask128<T, N> FirstN(const Simd<T, N, 0> d, size_t num) { + const RebindToSigned<decltype(d)> di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast<MakeSigned<T>>(num))); +} + +// ------------------------------ TestBit (Eq) + +#define HWY_NEON_BUILD_TPL_HWY_TESTBIT +#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128<type##_t, size> +#define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ + Vec128<type##_t, size> v, Vec128<type##_t, size> bit +#define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(TestBit, vtst, _, HWY_TESTBIT) +#else +// No 64-bit versions on armv7 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) + +template <size_t N> +HWY_API Mask128<uint64_t, N> TestBit(Vec128<uint64_t, N> v, + Vec128<uint64_t, N> bit) { + return (v & bit) == bit; +} +template <size_t N> +HWY_API Mask128<int64_t, N> TestBit(Vec128<int64_t, N> v, + Vec128<int64_t, N> bit) { + return (v & bit) == bit; +} + +#endif +#undef HWY_NEON_BUILD_TPL_HWY_TESTBIT +#undef HWY_NEON_BUILD_RET_HWY_TESTBIT +#undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT +#undef HWY_NEON_BUILD_ARG_HWY_TESTBIT + +// ------------------------------ Abs i64 (IfThenElse, BroadcastSignBit) +HWY_API Vec128<int64_t> Abs(const Vec128<int64_t> v) { +#if HWY_ARCH_ARM_A64 + return Vec128<int64_t>(vabsq_s64(v.raw)); +#else + const auto zero = Zero(Full128<int64_t>()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} +HWY_API Vec64<int64_t> Abs(const Vec64<int64_t> v) { +#if HWY_ARCH_ARM_A64 + return Vec64<int64_t>(vabs_s64(v.raw)); +#else + const auto zero = Zero(Full64<int64_t>()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +// ------------------------------ Min (IfThenElse, BroadcastSignBit) + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) + +template <size_t N> +HWY_API Vec128<uint64_t, N> Min(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const DFromV<decltype(a)> du; + const RebindToSigned<decltype(du)> di; + return BitCast(du, BitCast(di, a) - BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) + +template <size_t N> +HWY_API Vec128<int64_t, N> Min(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const Vec128<int64_t, N> sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); +#endif +} + +// Float: IEEE minimumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vminnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) +#endif + +// ------------------------------ Max (IfThenElse, BroadcastSignBit) + +// Unsigned (no u64) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) + +template <size_t N> +HWY_API Vec128<uint64_t, N> Max(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const DFromV<decltype(a)> du; + const RebindToSigned<decltype(du)> di; + return BitCast(du, BitCast(di, b) + BitCast(di, detail::SaturatedSub(a, b))); +#endif +} + +// Signed (no i64) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) + +template <size_t N> +HWY_API Vec128<int64_t, N> Max(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const Vec128<int64_t, N> sign = detail::SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); +#endif +} + +// Float: IEEE maximumNumber on v8, otherwise NaN if any is NaN. +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmaxnm, _, 2) +#else +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) +#endif + +// ================================================== MEMORY + +// ------------------------------ Load 128 + +HWY_API Vec128<uint8_t> LoadU(Full128<uint8_t> /* tag */, + const uint8_t* HWY_RESTRICT unaligned) { + return Vec128<uint8_t>(vld1q_u8(unaligned)); +} +HWY_API Vec128<uint16_t> LoadU(Full128<uint16_t> /* tag */, + const uint16_t* HWY_RESTRICT unaligned) { + return Vec128<uint16_t>(vld1q_u16(unaligned)); +} +HWY_API Vec128<uint32_t> LoadU(Full128<uint32_t> /* tag */, + const uint32_t* HWY_RESTRICT unaligned) { + return Vec128<uint32_t>(vld1q_u32(unaligned)); +} +HWY_API Vec128<uint64_t> LoadU(Full128<uint64_t> /* tag */, + const uint64_t* HWY_RESTRICT unaligned) { + return Vec128<uint64_t>(vld1q_u64(unaligned)); +} +HWY_API Vec128<int8_t> LoadU(Full128<int8_t> /* tag */, + const int8_t* HWY_RESTRICT unaligned) { + return Vec128<int8_t>(vld1q_s8(unaligned)); +} +HWY_API Vec128<int16_t> LoadU(Full128<int16_t> /* tag */, + const int16_t* HWY_RESTRICT unaligned) { + return Vec128<int16_t>(vld1q_s16(unaligned)); +} +HWY_API Vec128<int32_t> LoadU(Full128<int32_t> /* tag */, + const int32_t* HWY_RESTRICT unaligned) { + return Vec128<int32_t>(vld1q_s32(unaligned)); +} +HWY_API Vec128<int64_t> LoadU(Full128<int64_t> /* tag */, + const int64_t* HWY_RESTRICT unaligned) { + return Vec128<int64_t>(vld1q_s64(unaligned)); +} +HWY_API Vec128<float> LoadU(Full128<float> /* tag */, + const float* HWY_RESTRICT unaligned) { + return Vec128<float>(vld1q_f32(unaligned)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec128<double> LoadU(Full128<double> /* tag */, + const double* HWY_RESTRICT unaligned) { + return Vec128<double>(vld1q_f64(unaligned)); +} +#endif + +// ------------------------------ Load 64 + +HWY_API Vec64<uint8_t> LoadU(Full64<uint8_t> /* tag */, + const uint8_t* HWY_RESTRICT p) { + return Vec64<uint8_t>(vld1_u8(p)); +} +HWY_API Vec64<uint16_t> LoadU(Full64<uint16_t> /* tag */, + const uint16_t* HWY_RESTRICT p) { + return Vec64<uint16_t>(vld1_u16(p)); +} +HWY_API Vec64<uint32_t> LoadU(Full64<uint32_t> /* tag */, + const uint32_t* HWY_RESTRICT p) { + return Vec64<uint32_t>(vld1_u32(p)); +} +HWY_API Vec64<uint64_t> LoadU(Full64<uint64_t> /* tag */, + const uint64_t* HWY_RESTRICT p) { + return Vec64<uint64_t>(vld1_u64(p)); +} +HWY_API Vec64<int8_t> LoadU(Full64<int8_t> /* tag */, + const int8_t* HWY_RESTRICT p) { + return Vec64<int8_t>(vld1_s8(p)); +} +HWY_API Vec64<int16_t> LoadU(Full64<int16_t> /* tag */, + const int16_t* HWY_RESTRICT p) { + return Vec64<int16_t>(vld1_s16(p)); +} +HWY_API Vec64<int32_t> LoadU(Full64<int32_t> /* tag */, + const int32_t* HWY_RESTRICT p) { + return Vec64<int32_t>(vld1_s32(p)); +} +HWY_API Vec64<int64_t> LoadU(Full64<int64_t> /* tag */, + const int64_t* HWY_RESTRICT p) { + return Vec64<int64_t>(vld1_s64(p)); +} +HWY_API Vec64<float> LoadU(Full64<float> /* tag */, + const float* HWY_RESTRICT p) { + return Vec64<float>(vld1_f32(p)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64<double> LoadU(Full64<double> /* tag */, + const double* HWY_RESTRICT p) { + return Vec64<double>(vld1_f64(p)); +} +#endif +// ------------------------------ Load 32 + +// Actual 32-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +HWY_API Vec32<uint32_t> LoadU(Full32<uint32_t> /*tag*/, + const uint32_t* HWY_RESTRICT p) { + return Vec32<uint32_t>(vld1_dup_u32(p)); +} +HWY_API Vec32<int32_t> LoadU(Full32<int32_t> /*tag*/, + const int32_t* HWY_RESTRICT p) { + return Vec32<int32_t>(vld1_dup_s32(p)); +} +HWY_API Vec32<float> LoadU(Full32<float> /*tag*/, const float* HWY_RESTRICT p) { + return Vec32<float>(vld1_dup_f32(p)); +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x6)> // 1 or 2 bytes +HWY_API Vec32<T> LoadU(Full32<T> d, const T* HWY_RESTRICT p) { + const Repartition<uint32_t, decltype(d)> d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} + +// ------------------------------ Load 16 + +// Actual 16-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +HWY_API Vec128<uint16_t, 1> LoadU(Simd<uint16_t, 1, 0> /*tag*/, + const uint16_t* HWY_RESTRICT p) { + return Vec128<uint16_t, 1>(vld1_dup_u16(p)); +} +HWY_API Vec128<int16_t, 1> LoadU(Simd<int16_t, 1, 0> /*tag*/, + const int16_t* HWY_RESTRICT p) { + return Vec128<int16_t, 1>(vld1_dup_s16(p)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 2> LoadU(Simd<T, 2, 0> d, const T* HWY_RESTRICT p) { + const Repartition<uint16_t, decltype(d)> d16; + uint16_t buf; + CopyBytes<2>(p, &buf); + return BitCast(d, LoadU(d16, &buf)); +} + +// ------------------------------ Load 8 + +HWY_API Vec128<uint8_t, 1> LoadU(Simd<uint8_t, 1, 0>, + const uint8_t* HWY_RESTRICT p) { + return Vec128<uint8_t, 1>(vld1_dup_u8(p)); +} + +HWY_API Vec128<int8_t, 1> LoadU(Simd<int8_t, 1, 0>, + const int8_t* HWY_RESTRICT p) { + return Vec128<int8_t, 1>(vld1_dup_s8(p)); +} + +// [b]float16_t use the same Raw as uint16_t, so forward to that. +template <size_t N> +HWY_API Vec128<float16_t, N> LoadU(Simd<float16_t, N, 0> d, + const float16_t* HWY_RESTRICT p) { + const RebindToUnsigned<decltype(d)> du16; + const auto pu16 = reinterpret_cast<const uint16_t*>(p); + return Vec128<float16_t, N>(LoadU(du16, pu16).raw); +} +template <size_t N> +HWY_API Vec128<bfloat16_t, N> LoadU(Simd<bfloat16_t, N, 0> d, + const bfloat16_t* HWY_RESTRICT p) { + const RebindToUnsigned<decltype(d)> du16; + const auto pu16 = reinterpret_cast<const uint16_t*>(p); + return Vec128<bfloat16_t, N>(LoadU(du16, pu16).raw); +} + +// On ARM, Load is the same as LoadU. +template <typename T, size_t N> +HWY_API Vec128<T, N> Load(Simd<T, N, 0> d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> LoadDup128(Simd<T, N, 0> d, + const T* const HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store 128 + +HWY_API void StoreU(const Vec128<uint8_t> v, Full128<uint8_t> /* tag */, + uint8_t* HWY_RESTRICT unaligned) { + vst1q_u8(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<uint16_t> v, Full128<uint16_t> /* tag */, + uint16_t* HWY_RESTRICT unaligned) { + vst1q_u16(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<uint32_t> v, Full128<uint32_t> /* tag */, + uint32_t* HWY_RESTRICT unaligned) { + vst1q_u32(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<uint64_t> v, Full128<uint64_t> /* tag */, + uint64_t* HWY_RESTRICT unaligned) { + vst1q_u64(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<int8_t> v, Full128<int8_t> /* tag */, + int8_t* HWY_RESTRICT unaligned) { + vst1q_s8(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<int16_t> v, Full128<int16_t> /* tag */, + int16_t* HWY_RESTRICT unaligned) { + vst1q_s16(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<int32_t> v, Full128<int32_t> /* tag */, + int32_t* HWY_RESTRICT unaligned) { + vst1q_s32(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<int64_t> v, Full128<int64_t> /* tag */, + int64_t* HWY_RESTRICT unaligned) { + vst1q_s64(unaligned, v.raw); +} +HWY_API void StoreU(const Vec128<float> v, Full128<float> /* tag */, + float* HWY_RESTRICT unaligned) { + vst1q_f32(unaligned, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_API void StoreU(const Vec128<double> v, Full128<double> /* tag */, + double* HWY_RESTRICT unaligned) { + vst1q_f64(unaligned, v.raw); +} +#endif + +// ------------------------------ Store 64 + +HWY_API void StoreU(const Vec64<uint8_t> v, Full64<uint8_t> /* tag */, + uint8_t* HWY_RESTRICT p) { + vst1_u8(p, v.raw); +} +HWY_API void StoreU(const Vec64<uint16_t> v, Full64<uint16_t> /* tag */, + uint16_t* HWY_RESTRICT p) { + vst1_u16(p, v.raw); +} +HWY_API void StoreU(const Vec64<uint32_t> v, Full64<uint32_t> /* tag */, + uint32_t* HWY_RESTRICT p) { + vst1_u32(p, v.raw); +} +HWY_API void StoreU(const Vec64<uint64_t> v, Full64<uint64_t> /* tag */, + uint64_t* HWY_RESTRICT p) { + vst1_u64(p, v.raw); +} +HWY_API void StoreU(const Vec64<int8_t> v, Full64<int8_t> /* tag */, + int8_t* HWY_RESTRICT p) { + vst1_s8(p, v.raw); +} +HWY_API void StoreU(const Vec64<int16_t> v, Full64<int16_t> /* tag */, + int16_t* HWY_RESTRICT p) { + vst1_s16(p, v.raw); +} +HWY_API void StoreU(const Vec64<int32_t> v, Full64<int32_t> /* tag */, + int32_t* HWY_RESTRICT p) { + vst1_s32(p, v.raw); +} +HWY_API void StoreU(const Vec64<int64_t> v, Full64<int64_t> /* tag */, + int64_t* HWY_RESTRICT p) { + vst1_s64(p, v.raw); +} +HWY_API void StoreU(const Vec64<float> v, Full64<float> /* tag */, + float* HWY_RESTRICT p) { + vst1_f32(p, v.raw); +} +#if HWY_ARCH_ARM_A64 +HWY_API void StoreU(const Vec64<double> v, Full64<double> /* tag */, + double* HWY_RESTRICT p) { + vst1_f64(p, v.raw); +} +#endif + +// ------------------------------ Store 32 + +HWY_API void StoreU(const Vec32<uint32_t> v, Full32<uint32_t>, + uint32_t* HWY_RESTRICT p) { + vst1_lane_u32(p, v.raw, 0); +} +HWY_API void StoreU(const Vec32<int32_t> v, Full32<int32_t>, + int32_t* HWY_RESTRICT p) { + vst1_lane_s32(p, v.raw, 0); +} +HWY_API void StoreU(const Vec32<float> v, Full32<float>, + float* HWY_RESTRICT p) { + vst1_lane_f32(p, v.raw, 0); +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x6)> // 1 or 2 bytes +HWY_API void StoreU(const Vec32<T> v, Full32<T> d, T* HWY_RESTRICT p) { + const Repartition<uint32_t, decltype(d)> d32; + const uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} + +// ------------------------------ Store 16 + +HWY_API void StoreU(const Vec128<uint16_t, 1> v, Simd<uint16_t, 1, 0>, + uint16_t* HWY_RESTRICT p) { + vst1_lane_u16(p, v.raw, 0); +} +HWY_API void StoreU(const Vec128<int16_t, 1> v, Simd<int16_t, 1, 0>, + int16_t* HWY_RESTRICT p) { + vst1_lane_s16(p, v.raw, 0); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API void StoreU(const Vec128<T, 2> v, Simd<T, 2, 0> d, T* HWY_RESTRICT p) { + const Repartition<uint16_t, decltype(d)> d16; + const uint16_t buf = GetLane(BitCast(d16, v)); + CopyBytes<2>(&buf, p); +} + +// ------------------------------ Store 8 + +HWY_API void StoreU(const Vec128<uint8_t, 1> v, Simd<uint8_t, 1, 0>, + uint8_t* HWY_RESTRICT p) { + vst1_lane_u8(p, v.raw, 0); +} +HWY_API void StoreU(const Vec128<int8_t, 1> v, Simd<int8_t, 1, 0>, + int8_t* HWY_RESTRICT p) { + vst1_lane_s8(p, v.raw, 0); +} + +// [b]float16_t use the same Raw as uint16_t, so forward to that. +template <size_t N> +HWY_API void StoreU(Vec128<float16_t, N> v, Simd<float16_t, N, 0> d, + float16_t* HWY_RESTRICT p) { + const RebindToUnsigned<decltype(d)> du16; + const auto pu16 = reinterpret_cast<uint16_t*>(p); + return StoreU(Vec128<uint16_t, N>(v.raw), du16, pu16); +} +template <size_t N> +HWY_API void StoreU(Vec128<bfloat16_t, N> v, Simd<bfloat16_t, N, 0> d, + bfloat16_t* HWY_RESTRICT p) { + const RebindToUnsigned<decltype(d)> du16; + const auto pu16 = reinterpret_cast<uint16_t*>(p); + return StoreU(Vec128<uint16_t, N>(v.raw), du16, pu16); +} + +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// On ARM, Store is the same as StoreU. +template <typename T, size_t N> +HWY_API void Store(Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT aligned) { + StoreU(v, d, aligned); +} + +HWY_DIAGNOSTICS(pop) + +template <typename T, size_t N> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, Simd<T, N, 0> d, + T* HWY_RESTRICT p) { + // Treat as unsigned so that we correctly support float16. + const RebindToUnsigned<decltype(d)> du; + const auto blended = + IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); + StoreU(BitCast(d, blended), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template <typename T, size_t N> +HWY_API void Stream(const Vec128<T, N> v, Simd<T, N, 0> d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend to full vector. +HWY_API Vec128<uint16_t> PromoteTo(Full128<uint16_t> /* tag */, + const Vec64<uint8_t> v) { + return Vec128<uint16_t>(vmovl_u8(v.raw)); +} +HWY_API Vec128<uint32_t> PromoteTo(Full128<uint32_t> /* tag */, + const Vec32<uint8_t> v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128<uint32_t>(vmovl_u16(vget_low_u16(a))); +} +HWY_API Vec128<uint32_t> PromoteTo(Full128<uint32_t> /* tag */, + const Vec64<uint16_t> v) { + return Vec128<uint32_t>(vmovl_u16(v.raw)); +} +HWY_API Vec128<uint64_t> PromoteTo(Full128<uint64_t> /* tag */, + const Vec64<uint32_t> v) { + return Vec128<uint64_t>(vmovl_u32(v.raw)); +} +HWY_API Vec128<int16_t> PromoteTo(Full128<int16_t> d, const Vec64<uint8_t> v) { + return BitCast(d, Vec128<uint16_t>(vmovl_u8(v.raw))); +} +HWY_API Vec128<int32_t> PromoteTo(Full128<int32_t> d, const Vec32<uint8_t> v) { + uint16x8_t a = vmovl_u8(v.raw); + return BitCast(d, Vec128<uint32_t>(vmovl_u16(vget_low_u16(a)))); +} +HWY_API Vec128<int32_t> PromoteTo(Full128<int32_t> d, const Vec64<uint16_t> v) { + return BitCast(d, Vec128<uint32_t>(vmovl_u16(v.raw))); +} + +// Unsigned: zero-extend to half vector. +template <size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> PromoteTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + return Vec128<uint16_t, N>(vget_low_u16(vmovl_u8(v.raw))); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint32_t, N> PromoteTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128<uint32_t, N>(vget_low_u32(vmovl_u16(vget_low_u16(a)))); +} +template <size_t N> +HWY_API Vec128<uint32_t, N> PromoteTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + return Vec128<uint32_t, N>(vget_low_u32(vmovl_u16(v.raw))); +} +template <size_t N, HWY_IF_LE64(uint64_t, N)> +HWY_API Vec128<uint64_t, N> PromoteTo(Simd<uint64_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + return Vec128<uint64_t, N>(vget_low_u64(vmovl_u32(v.raw))); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> PromoteTo(Simd<int16_t, N, 0> d, + const Vec128<uint8_t, N> v) { + return BitCast(d, Vec128<uint16_t, N>(vget_low_u16(vmovl_u8(v.raw)))); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + uint16x8_t a = vmovl_u8(v.raw); + uint32x4_t b = vmovl_u16(vget_low_u16(a)); + return Vec128<int32_t, N>(vget_low_s32(vreinterpretq_s32_u32(b))); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + uint32x4_t a = vmovl_u16(v.raw); + return Vec128<int32_t, N>(vget_low_s32(vreinterpretq_s32_u32(a))); +} + +// Signed: replicate sign bit to full vector. +HWY_API Vec128<int16_t> PromoteTo(Full128<int16_t> /* tag */, + const Vec64<int8_t> v) { + return Vec128<int16_t>(vmovl_s8(v.raw)); +} +HWY_API Vec128<int32_t> PromoteTo(Full128<int32_t> /* tag */, + const Vec32<int8_t> v) { + int16x8_t a = vmovl_s8(v.raw); + return Vec128<int32_t>(vmovl_s16(vget_low_s16(a))); +} +HWY_API Vec128<int32_t> PromoteTo(Full128<int32_t> /* tag */, + const Vec64<int16_t> v) { + return Vec128<int32_t>(vmovl_s16(v.raw)); +} +HWY_API Vec128<int64_t> PromoteTo(Full128<int64_t> /* tag */, + const Vec64<int32_t> v) { + return Vec128<int64_t>(vmovl_s32(v.raw)); +} + +// Signed: replicate sign bit to half vector. +template <size_t N> +HWY_API Vec128<int16_t, N> PromoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<int8_t, N> v) { + return Vec128<int16_t, N>(vget_low_s16(vmovl_s8(v.raw))); +} +template <size_t N> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<int8_t, N> v) { + int16x8_t a = vmovl_s8(v.raw); + int32x4_t b = vmovl_s16(vget_low_s16(a)); + return Vec128<int32_t, N>(vget_low_s32(b)); +} +template <size_t N> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<int32_t, N>(vget_low_s32(vmovl_s16(v.raw))); +} +template <size_t N> +HWY_API Vec128<int64_t, N> PromoteTo(Simd<int64_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<int64_t, N>(vget_low_s64(vmovl_s32(v.raw))); +} + +#if __ARM_FP & 2 + +HWY_API Vec128<float> PromoteTo(Full128<float> /* tag */, + const Vec128<float16_t, 4> v) { + const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); + return Vec128<float>(f32); +} +template <size_t N> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> /* tag */, + const Vec128<float16_t, N> v) { + const float32x4_t f32 = vcvt_f32_f16(vreinterpret_f16_u16(v.raw)); + return Vec128<float, N>(vget_low_f32(f32)); +} + +#else + +template <size_t N> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> df32, + const Vec128<float16_t, N> v) { + const RebindToSigned<decltype(df32)> di32; + const RebindToUnsigned<decltype(df32)> du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128<uint16_t, N>{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +#endif + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128<double> PromoteTo(Full128<double> /* tag */, + const Vec64<float> v) { + return Vec128<double>(vcvt_f64_f32(v.raw)); +} + +HWY_API Vec64<double> PromoteTo(Full64<double> /* tag */, + const Vec32<float> v) { + return Vec64<double>(vget_low_f64(vcvt_f64_f32(v.raw))); +} + +HWY_API Vec128<double> PromoteTo(Full128<double> /* tag */, + const Vec64<int32_t> v) { + const int64x2_t i64 = vmovl_s32(v.raw); + return Vec128<double>(vcvtq_f64_s64(i64)); +} + +HWY_API Vec64<double> PromoteTo(Full64<double> /* tag */, + const Vec32<int32_t> v) { + const int64x1_t i64 = vget_low_s64(vmovl_s32(v.raw)); + return Vec64<double>(vcvt_f64_s64(i64)); +} + +#endif + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +// From full vector to half or quarter +HWY_API Vec64<uint16_t> DemoteTo(Full64<uint16_t> /* tag */, + const Vec128<int32_t> v) { + return Vec64<uint16_t>(vqmovun_s32(v.raw)); +} +HWY_API Vec64<int16_t> DemoteTo(Full64<int16_t> /* tag */, + const Vec128<int32_t> v) { + return Vec64<int16_t>(vqmovn_s32(v.raw)); +} +HWY_API Vec32<uint8_t> DemoteTo(Full32<uint8_t> /* tag */, + const Vec128<int32_t> v) { + const uint16x4_t a = vqmovun_s32(v.raw); + return Vec32<uint8_t>(vqmovn_u16(vcombine_u16(a, a))); +} +HWY_API Vec64<uint8_t> DemoteTo(Full64<uint8_t> /* tag */, + const Vec128<int16_t> v) { + return Vec64<uint8_t>(vqmovun_s16(v.raw)); +} +HWY_API Vec32<int8_t> DemoteTo(Full32<int8_t> /* tag */, + const Vec128<int32_t> v) { + const int16x4_t a = vqmovn_s32(v.raw); + return Vec32<int8_t>(vqmovn_s16(vcombine_s16(a, a))); +} +HWY_API Vec64<int8_t> DemoteTo(Full64<int8_t> /* tag */, + const Vec128<int16_t> v) { + return Vec64<int8_t>(vqmovn_s16(v.raw)); +} + +// From half vector to partial half +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<uint16_t, N> DemoteTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<uint16_t, N>(vqmovun_s32(vcombine_s32(v.raw, v.raw))); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int16_t, N> DemoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<int16_t, N>(vqmovn_s32(vcombine_s32(v.raw, v.raw))); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); + return Vec128<uint8_t, N>(vqmovn_u16(vcombine_u16(a, a))); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<uint8_t, N>(vqmovun_s16(vcombine_s16(v.raw, v.raw))); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); + return Vec128<int8_t, N>(vqmovn_s16(vcombine_s16(a, a))); +} +template <size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<int8_t, N>(vqmovn_s16(vcombine_s16(v.raw, v.raw))); +} + +#if __ARM_FP & 2 + +HWY_API Vec128<float16_t, 4> DemoteTo(Full64<float16_t> /* tag */, + const Vec128<float> v) { + return Vec128<float16_t, 4>{vreinterpret_u16_f16(vcvt_f16_f32(v.raw))}; +} +template <size_t N> +HWY_API Vec128<float16_t, N> DemoteTo(Simd<float16_t, N, 0> /* tag */, + const Vec128<float, N> v) { + const float16x4_t f16 = vcvt_f16_f32(vcombine_f32(v.raw, v.raw)); + return Vec128<float16_t, N>(vreinterpret_u16_f16(f16)); +} + +#else + +template <size_t N> +HWY_API Vec128<float16_t, N> DemoteTo(Simd<float16_t, N, 0> df16, + const Vec128<float, N> v) { + const RebindToUnsigned<decltype(df16)> du16; + const Rebind<uint32_t, decltype(du16)> du; + const RebindToSigned<decltype(du)> di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128<float16_t, N>(DemoteTo(du16, bits16).raw); +} + +#endif + +template <size_t N> +HWY_API Vec128<bfloat16_t, N> DemoteTo(Simd<bfloat16_t, N, 0> dbf16, + const Vec128<float, N> v) { + const Rebind<int32_t, decltype(dbf16)> di32; + const Rebind<uint32_t, decltype(dbf16)> du32; // for logical shift right + const Rebind<uint16_t, decltype(dbf16)> du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec64<float> DemoteTo(Full64<float> /* tag */, const Vec128<double> v) { + return Vec64<float>(vcvt_f32_f64(v.raw)); +} +HWY_API Vec32<float> DemoteTo(Full32<float> /* tag */, const Vec64<double> v) { + return Vec32<float>(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +} + +HWY_API Vec64<int32_t> DemoteTo(Full64<int32_t> /* tag */, + const Vec128<double> v) { + const int64x2_t i64 = vcvtq_s64_f64(v.raw); + return Vec64<int32_t>(vqmovn_s64(i64)); +} +HWY_API Vec32<int32_t> DemoteTo(Full32<int32_t> /* tag */, + const Vec64<double> v) { + const int64x1_t i64 = vcvt_s64_f64(v.raw); + // There is no i64x1 -> i32x1 narrow, so expand to int64x2_t first. + const int64x2_t i64x2 = vcombine_s64(i64, i64); + return Vec32<int32_t>(vqmovn_s64(i64x2)); +} + +#endif + +HWY_API Vec32<uint8_t> U8FromU32(const Vec128<uint32_t> v) { + const uint8x16_t org_v = detail::BitCastToByte(v).raw; + const uint8x16_t w = vuzp1q_u8(org_v, org_v); + return Vec32<uint8_t>(vget_low_u8(vuzp1q_u8(w, w))); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint8_t, N> U8FromU32(const Vec128<uint32_t, N> v) { + const uint8x8_t org_v = detail::BitCastToByte(v).raw; + const uint8x8_t w = vuzp1_u8(org_v, org_v); + return Vec128<uint8_t, N>(vuzp1_u8(w, w)); +} + +// In the following DemoteTo functions, |b| is purposely undefined. +// The value a needs to be extended to 128 bits so that vqmovn can be +// used and |b| is undefined so that no extra overhead is introduced. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") + +template <size_t N> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int32_t> v) { + Vec128<uint16_t, N> a = DemoteTo(Simd<uint16_t, N, 0>(), v); + Vec128<uint16_t, N> b; + uint16x8_t c = vcombine_u16(a.raw, b.raw); + return Vec128<uint8_t, N>(vqmovn_u16(c)); +} + +template <size_t N> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int32_t> v) { + Vec128<int16_t, N> a = DemoteTo(Simd<int16_t, N, 0>(), v); + Vec128<int16_t, N> b; + int16x8_t c = vcombine_s16(a.raw, b.raw); + return Vec128<int8_t, N>(vqmovn_s16(c)); +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Convert integer <=> floating-point + +HWY_API Vec128<float> ConvertTo(Full128<float> /* tag */, + const Vec128<int32_t> v) { + return Vec128<float>(vcvtq_f32_s32(v.raw)); +} +template <size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<float, N> ConvertTo(Simd<float, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<float, N>(vcvt_f32_s32(v.raw)); +} + +HWY_API Vec128<float> ConvertTo(Full128<float> /* tag */, + const Vec128<uint32_t> v) { + return Vec128<float>(vcvtq_f32_u32(v.raw)); +} +template <size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<float, N> ConvertTo(Simd<float, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + return Vec128<float, N>(vcvt_f32_u32(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_API Vec128<int32_t> ConvertTo(Full128<int32_t> /* tag */, + const Vec128<float> v) { + return Vec128<int32_t>(vcvtq_s32_f32(v.raw)); +} +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<int32_t, N> ConvertTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<float, N> v) { + return Vec128<int32_t, N>(vcvt_s32_f32(v.raw)); +} + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128<double> ConvertTo(Full128<double> /* tag */, + const Vec128<int64_t> v) { + return Vec128<double>(vcvtq_f64_s64(v.raw)); +} +HWY_API Vec64<double> ConvertTo(Full64<double> /* tag */, + const Vec64<int64_t> v) { + return Vec64<double>(vcvt_f64_s64(v.raw)); +} + +HWY_API Vec128<double> ConvertTo(Full128<double> /* tag */, + const Vec128<uint64_t> v) { + return Vec128<double>(vcvtq_f64_u64(v.raw)); +} +HWY_API Vec64<double> ConvertTo(Full64<double> /* tag */, + const Vec64<uint64_t> v) { + return Vec64<double>(vcvt_f64_u64(v.raw)); +} + +// Truncates (rounds toward zero). +HWY_API Vec128<int64_t> ConvertTo(Full128<int64_t> /* tag */, + const Vec128<double> v) { + return Vec128<int64_t>(vcvtq_s64_f64(v.raw)); +} +HWY_API Vec64<int64_t> ConvertTo(Full64<int64_t> /* tag */, + const Vec64<double> v) { + return Vec64<int64_t>(vcvt_s64_f64(v.raw)); +} + +#endif + +// ------------------------------ Round (IfThenElse, mask, logical) + +#if HWY_ARCH_ARM_A64 +// Toward nearest integer +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) + +// Toward zero, aka truncate +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) + +// Toward +infinity, aka ceiling +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) + +// Toward -infinity, aka floor +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) +#else + +// ------------------------------ Trunc + +// ARMv7 only supports truncation to integer. We can either convert back to +// float (3 floating-point and 2 logic operations) or manipulate the binary32 +// representation, clearing the lowest 23-exp mantissa bits. This requires 9 +// integer operations and 3 constants, which is likely more expensive. + +namespace detail { + +// The original value is already the desired result if NaN or the magnitude is +// large (i.e. the value is already an integer). +template <size_t N> +HWY_INLINE Mask128<float, N> UseInt(const Vec128<float, N> v) { + return Abs(v) < Set(Simd<float, N, 0>(), MantissaEnd<float>()); +} + +} // namespace detail + +template <size_t N> +HWY_API Vec128<float, N> Trunc(const Vec128<float, N> v) { + const DFromV<decltype(v)> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), int_f, v); +} + +template <size_t N> +HWY_API Vec128<float, N> Round(const Vec128<float, N> v) { + const DFromV<decltype(v)> df; + + // ARMv7 also lacks a native NearestInt, but we can instead rely on rounding + // (we assume the current mode is nearest-even) after addition with a large + // value such that no mantissa bits remain. We may need a compiler flag for + // precise floating-point to prevent this from being "optimized" out. + const auto max = Set(df, MantissaEnd<float>()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +template <size_t N> +HWY_API Vec128<float, N> Ceil(const Vec128<float, N> v) { + const DFromV<decltype(v)> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +template <size_t N> +HWY_API Vec128<float, N> Floor(const Vec128<float, N> v) { + const DFromV<decltype(v)> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#endif + +// ------------------------------ NearestInt (Round) + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128<int32_t> NearestInt(const Vec128<float> v) { + return Vec128<int32_t>(vcvtnq_s32_f32(v.raw)); +} +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<int32_t, N> NearestInt(const Vec128<float, N> v) { + return Vec128<int32_t, N>(vcvtn_s32_f32(v.raw)); +} + +#else + +template <size_t N> +HWY_API Vec128<int32_t, N> NearestInt(const Vec128<float, N> v) { + const RebindToSigned<DFromV<decltype(v)>> di; + return ConvertTo(di, Round(v)); +} + +#endif + +// ------------------------------ Floating-point classification +template <typename T, size_t N> +HWY_API Mask128<T, N> IsNaN(const Vec128<T, N> v) { + return v != v; +} + +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Mask128<T, N> IsInf(const Vec128<T, N> v) { + const Simd<T, N, 0> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2<T>()))); +} + +// Returns whether normal/subnormal/zero. +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Mask128<T, N> IsFinite(const Vec128<T, N> v) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + const VFromD<decltype(du)> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField<T>()))); +} + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +// <= 64 bit: just return different type +template <typename T, size_t N, HWY_IF_LE64(uint8_t, N)> +HWY_API Vec128<T, N / 2> LowerHalf(const Vec128<T, N> v) { + return Vec128<T, N / 2>(v.raw); +} + +HWY_API Vec64<uint8_t> LowerHalf(const Vec128<uint8_t> v) { + return Vec64<uint8_t>(vget_low_u8(v.raw)); +} +HWY_API Vec64<uint16_t> LowerHalf(const Vec128<uint16_t> v) { + return Vec64<uint16_t>(vget_low_u16(v.raw)); +} +HWY_API Vec64<uint32_t> LowerHalf(const Vec128<uint32_t> v) { + return Vec64<uint32_t>(vget_low_u32(v.raw)); +} +HWY_API Vec64<uint64_t> LowerHalf(const Vec128<uint64_t> v) { + return Vec64<uint64_t>(vget_low_u64(v.raw)); +} +HWY_API Vec64<int8_t> LowerHalf(const Vec128<int8_t> v) { + return Vec64<int8_t>(vget_low_s8(v.raw)); +} +HWY_API Vec64<int16_t> LowerHalf(const Vec128<int16_t> v) { + return Vec64<int16_t>(vget_low_s16(v.raw)); +} +HWY_API Vec64<int32_t> LowerHalf(const Vec128<int32_t> v) { + return Vec64<int32_t>(vget_low_s32(v.raw)); +} +HWY_API Vec64<int64_t> LowerHalf(const Vec128<int64_t> v) { + return Vec64<int64_t>(vget_low_s64(v.raw)); +} +HWY_API Vec64<float> LowerHalf(const Vec128<float> v) { + return Vec64<float>(vget_low_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64<double> LowerHalf(const Vec128<double> v) { + return Vec64<double>(vget_low_f64(v.raw)); +} +#endif +HWY_API Vec64<bfloat16_t> LowerHalf(const Vec128<bfloat16_t> v) { + const Full128<uint16_t> du; + const Full64<bfloat16_t> dbh; + return BitCast(dbh, LowerHalf(BitCast(du, v))); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Simd<T, N / 2, 0> /* tag */, + Vec128<T, N> v) { + return LowerHalf(v); +} + +// ------------------------------ CombineShiftRightBytes + +// 128-bit +template <int kBytes, typename T, class V128 = Vec128<T>> +HWY_API V128 CombineShiftRightBytes(Full128<T> d, V128 hi, V128 lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); + const Repartition<uint8_t, decltype(d)> d8; + uint8x16_t v8 = vextq_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, Vec128<uint8_t>(v8)); +} + +// 64-bit +template <int kBytes, typename T> +HWY_API Vec64<T> CombineShiftRightBytes(Full64<T> d, Vec64<T> hi, Vec64<T> lo) { + static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); + const Repartition<uint8_t, decltype(d)> d8; + uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, VFromD<decltype(d8)>(v8)); +} + +// <= 32-bit defined after ShiftLeftBytes. + +// ------------------------------ Shift vector by constant #bytes + +namespace detail { + +// Partially specialize because kBytes = 0 and >= size are compile errors; +// callers replace the latter with 0xFF for easier specialization. +template <int kBytes> +struct ShiftLeftBytesT { + // Full + template <class T> + HWY_INLINE Vec128<T> operator()(const Vec128<T> v) { + const Full128<T> d; + return CombineShiftRightBytes<16 - kBytes>(d, v, Zero(d)); + } + + // Partial + template <class T, size_t N, HWY_IF_LE64(T, N)> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> v) { + // Expand to 64-bit so we only use the native EXT instruction. + const Full64<T> d64; + const auto zero64 = Zero(d64); + const decltype(zero64) v64(v.raw); + return Vec128<T, N>( + CombineShiftRightBytes<8 - kBytes>(d64, v64, zero64).raw); + } +}; +template <> +struct ShiftLeftBytesT<0> { + template <class T, size_t N> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> v) { + return v; + } +}; +template <> +struct ShiftLeftBytesT<0xFF> { + template <class T, size_t N> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> /* v */) { + return Zero(Simd<T, N, 0>()); + } +}; + +template <int kBytes> +struct ShiftRightBytesT { + template <class T, size_t N> + HWY_INLINE Vec128<T, N> operator()(Vec128<T, N> v) { + const Simd<T, N, 0> d; + // For < 64-bit vectors, zero undefined lanes so we shift in zeros. + if (N * sizeof(T) < 8) { + constexpr size_t kReg = N * sizeof(T) == 16 ? 16 : 8; + const Simd<T, kReg / sizeof(T), 0> dreg; + v = Vec128<T, N>( + IfThenElseZero(FirstN(dreg, N), VFromD<decltype(dreg)>(v.raw)).raw); + } + return CombineShiftRightBytes<kBytes>(d, Zero(d), v); + } +}; +template <> +struct ShiftRightBytesT<0> { + template <class T, size_t N> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> v) { + return v; + } +}; +template <> +struct ShiftRightBytesT<0xFF> { + template <class T, size_t N> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> /* v */) { + return Zero(Simd<T, N, 0>()); + } +}; + +} // namespace detail + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + return detail::ShiftLeftBytesT < kBytes >= N * sizeof(T) ? 0xFF + : kBytes > ()(v); +} + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(const Vec128<T, N> v) { + return ShiftLeftBytes<kBytes>(Simd<T, N, 0>(), v); +} + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(const Vec128<T, N> v) { + return ShiftLeftLanes<kLanes>(Simd<T, N, 0>(), v); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + return detail::ShiftRightBytesT < kBytes >= N * sizeof(T) ? 0xFF + : kBytes > ()(v); +} + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// Calls ShiftLeftBytes +template <int kBytes, typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API Vec128<T, N> CombineShiftRightBytes(Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition<uint8_t, decltype(d)> d8; + const Full64<uint8_t> d_full8; + const Repartition<T, decltype(d_full8)> d_full; + using V64 = VFromD<decltype(d_full8)>; + const V64 hi64(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V64 lo64 = ShiftLeftBytes<8 - kSize>(V64(BitCast(d8, lo).raw)); + const V64 r = CombineShiftRightBytes<8 - kSize + kBytes>(d_full8, hi64, lo64); + // After casting to full 64-bit vector of correct type, shrink to 32-bit + return Vec128<T, N>(BitCast(d_full, r).raw); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input +HWY_API Vec64<uint8_t> UpperHalf(Full64<uint8_t> /* tag */, + const Vec128<uint8_t> v) { + return Vec64<uint8_t>(vget_high_u8(v.raw)); +} +HWY_API Vec64<uint16_t> UpperHalf(Full64<uint16_t> /* tag */, + const Vec128<uint16_t> v) { + return Vec64<uint16_t>(vget_high_u16(v.raw)); +} +HWY_API Vec64<uint32_t> UpperHalf(Full64<uint32_t> /* tag */, + const Vec128<uint32_t> v) { + return Vec64<uint32_t>(vget_high_u32(v.raw)); +} +HWY_API Vec64<uint64_t> UpperHalf(Full64<uint64_t> /* tag */, + const Vec128<uint64_t> v) { + return Vec64<uint64_t>(vget_high_u64(v.raw)); +} +HWY_API Vec64<int8_t> UpperHalf(Full64<int8_t> /* tag */, + const Vec128<int8_t> v) { + return Vec64<int8_t>(vget_high_s8(v.raw)); +} +HWY_API Vec64<int16_t> UpperHalf(Full64<int16_t> /* tag */, + const Vec128<int16_t> v) { + return Vec64<int16_t>(vget_high_s16(v.raw)); +} +HWY_API Vec64<int32_t> UpperHalf(Full64<int32_t> /* tag */, + const Vec128<int32_t> v) { + return Vec64<int32_t>(vget_high_s32(v.raw)); +} +HWY_API Vec64<int64_t> UpperHalf(Full64<int64_t> /* tag */, + const Vec128<int64_t> v) { + return Vec64<int64_t>(vget_high_s64(v.raw)); +} +HWY_API Vec64<float> UpperHalf(Full64<float> /* tag */, const Vec128<float> v) { + return Vec64<float>(vget_high_f32(v.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec64<double> UpperHalf(Full64<double> /* tag */, + const Vec128<double> v) { + return Vec64<double>(vget_high_f64(v.raw)); +} +#endif + +HWY_API Vec64<bfloat16_t> UpperHalf(Full64<bfloat16_t> dbh, + const Vec128<bfloat16_t> v) { + const RebindToUnsigned<decltype(dbh)> duh; + const Twice<decltype(duh)> du; + return BitCast(dbh, UpperHalf(duh, BitCast(du, v))); +} + +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, (N + 1) / 2> UpperHalf(Half<Simd<T, N, 0>> /* tag */, + Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes<N * sizeof(T) / 2>(du, vu)); + return Vec128<T, (N + 1) / 2>(upper.raw); +} + +// ------------------------------ Broadcast/splat any lane + +#if HWY_ARCH_ARM_A64 +// Unsigned +template <int kLane> +HWY_API Vec128<uint16_t> Broadcast(const Vec128<uint16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128<uint16_t>(vdupq_laneq_u16(v.raw, kLane)); +} +template <int kLane, size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> Broadcast(const Vec128<uint16_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<uint16_t, N>(vdup_lane_u16(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<uint32_t> Broadcast(const Vec128<uint32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128<uint32_t>(vdupq_laneq_u32(v.raw, kLane)); +} +template <int kLane, size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint32_t, N> Broadcast(const Vec128<uint32_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<uint32_t, N>(vdup_lane_u32(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<uint64_t> Broadcast(const Vec128<uint64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128<uint64_t>(vdupq_laneq_u64(v.raw, kLane)); +} +// Vec64<uint64_t> is defined below. + +// Signed +template <int kLane> +HWY_API Vec128<int16_t> Broadcast(const Vec128<int16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128<int16_t>(vdupq_laneq_s16(v.raw, kLane)); +} +template <int kLane, size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> Broadcast(const Vec128<int16_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<int16_t, N>(vdup_lane_s16(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<int32_t> Broadcast(const Vec128<int32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128<int32_t>(vdupq_laneq_s32(v.raw, kLane)); +} +template <int kLane, size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> Broadcast(const Vec128<int32_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<int32_t, N>(vdup_lane_s32(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<int64_t> Broadcast(const Vec128<int64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128<int64_t>(vdupq_laneq_s64(v.raw, kLane)); +} +// Vec64<int64_t> is defined below. + +// Float +template <int kLane> +HWY_API Vec128<float> Broadcast(const Vec128<float> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128<float>(vdupq_laneq_f32(v.raw, kLane)); +} +template <int kLane, size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> Broadcast(const Vec128<float, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<float, N>(vdup_lane_f32(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<double> Broadcast(const Vec128<double> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128<double>(vdupq_laneq_f64(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec64<double> Broadcast(const Vec64<double> v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +#else +// No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. + +// Unsigned +template <int kLane> +HWY_API Vec128<uint16_t> Broadcast(const Vec128<uint16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128<uint16_t>(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); +} +template <int kLane, size_t N, HWY_IF_LE64(uint16_t, N)> +HWY_API Vec128<uint16_t, N> Broadcast(const Vec128<uint16_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<uint16_t, N>(vdup_lane_u16(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<uint32_t> Broadcast(const Vec128<uint32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128<uint32_t>(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); +} +template <int kLane, size_t N, HWY_IF_LE64(uint32_t, N)> +HWY_API Vec128<uint32_t, N> Broadcast(const Vec128<uint32_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<uint32_t, N>(vdup_lane_u32(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<uint64_t> Broadcast(const Vec128<uint64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128<uint64_t>(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); +} +// Vec64<uint64_t> is defined below. + +// Signed +template <int kLane> +HWY_API Vec128<int16_t> Broadcast(const Vec128<int16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128<int16_t>(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); +} +template <int kLane, size_t N, HWY_IF_LE64(int16_t, N)> +HWY_API Vec128<int16_t, N> Broadcast(const Vec128<int16_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<int16_t, N>(vdup_lane_s16(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<int32_t> Broadcast(const Vec128<int32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128<int32_t>(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); +} +template <int kLane, size_t N, HWY_IF_LE64(int32_t, N)> +HWY_API Vec128<int32_t, N> Broadcast(const Vec128<int32_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<int32_t, N>(vdup_lane_s32(v.raw, kLane)); +} +template <int kLane> +HWY_API Vec128<int64_t> Broadcast(const Vec128<int64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128<int64_t>(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); +} +// Vec64<int64_t> is defined below. + +// Float +template <int kLane> +HWY_API Vec128<float> Broadcast(const Vec128<float> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128<float>(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); +} +template <int kLane, size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> Broadcast(const Vec128<float, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<float, N>(vdup_lane_f32(v.raw, kLane)); +} + +#endif + +template <int kLane> +HWY_API Vec64<uint64_t> Broadcast(const Vec64<uint64_t> v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} +template <int kLane> +HWY_API Vec64<int64_t> Broadcast(const Vec64<int64_t> v) { + static_assert(0 <= kLane && kLane < 1, "Invalid lane"); + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template <typename T, size_t N> +struct Indices128 { + typename detail::Raw128<T, N>::type raw; +}; + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N)> +HWY_API Indices128<T, N> IndicesFromVec(Simd<T, N, 0> d, Vec128<TI, N> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind<TI, decltype(d)> di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast<TI>(N))))); +#endif + + const Repartition<uint8_t, decltype(d)> d8; + using V8 = VFromD<decltype(d8)>; + const Repartition<uint16_t, decltype(d)> d16; + + // Broadcast each lane index to all bytes of T and shift to bytes + static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); + if (sizeof(T) == 4) { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); + return Indices128<T, N>{BitCast(d, sum).raw}; + } else { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + const V8 sum = Add(byte_indices, Load(d8, kByteOffsets)); + return Indices128<T, N>{BitCast(d, sum).raw}; + } +} + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N)> +HWY_API Indices128<T, N> SetTableIndices(Simd<T, N, 0> d, const TI* idx) { + const Rebind<TI, decltype(d)> di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> TableLookupLanes(Vec128<T, N> v, Indices128<T, N> idx) { + const DFromV<decltype(v)> d; + const RebindToSigned<decltype(d)> di; + return BitCast( + d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128<T, N>{idx.raw}))); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template <typename T> +HWY_API Vec128<T, 1> Reverse(Simd<T, 1, 0> /* tag */, const Vec128<T, 1> v) { + return v; +} + +// Two lanes: shuffle +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, 2> Reverse(Simd<T, 2, 0> /* tag */, const Vec128<T, 2> v) { + return Vec128<T, 2>(Shuffle2301(v)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> Reverse(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> Reverse(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle0123(v); +} + +// 16-bit +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RepartitionToWide<RebindToUnsigned<decltype(d)>> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2), HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Vec128<uint16_t, N>(vrev32_u16(BitCast(du, v).raw))); +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T> Reverse2(Full128<T> d, const Vec128<T> v) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Vec128<uint16_t>(vrev32q_u16(BitCast(du, v).raw))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4), HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Vec128<uint32_t, N>(vrev64_u32(BitCast(du, v).raw))); +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> Reverse2(Full128<T> d, const Vec128<T> v) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Vec128<uint32_t>(vrev64q_u32(BitCast(du, v).raw))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2), HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Vec128<uint16_t, N>(vrev64_u16(BitCast(du, v).raw))); +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T> Reverse4(Full128<T> d, const Vec128<T> v) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Vec128<uint16_t>(vrev64q_u16(BitCast(du, v).raw))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle0123(v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> /* tag */, const Vec128<T, N>) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0> d, const Vec128<T, N> v) { + return Reverse(d, v); +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0>, const Vec128<T, N>) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ Other shuffles (TableLookupBytes) + +// Notation: let Vec128<int32_t> have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 64-bit halves +template <typename T> +HWY_API Vec128<T> Shuffle1032(const Vec128<T> v) { + return CombineShiftRightBytes<8>(Full128<T>(), v, v); +} +template <typename T> +HWY_API Vec128<T> Shuffle01(const Vec128<T> v) { + return CombineShiftRightBytes<8>(Full128<T>(), v, v); +} + +// Rotate right 32 bits +template <typename T> +HWY_API Vec128<T> Shuffle0321(const Vec128<T> v) { + return CombineShiftRightBytes<4>(Full128<T>(), v, v); +} + +// Rotate left 32 bits +template <typename T> +HWY_API Vec128<T> Shuffle2103(const Vec128<T> v) { + return CombineShiftRightBytes<12>(Full128<T>(), v, v); +} + +// Reverse +template <typename T> +HWY_API Vec128<T> Shuffle0123(const Vec128<T> v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveLower, vzip1, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveLower, vzip1, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_API Vec128<uint64_t> InterleaveLower(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + return Vec128<uint64_t>(vzip1q_u64(a.raw, b.raw)); +} +HWY_API Vec128<int64_t> InterleaveLower(const Vec128<int64_t> a, + const Vec128<int64_t> b) { + return Vec128<int64_t>(vzip1q_s64(a.raw, b.raw)); +} +HWY_API Vec128<double> InterleaveLower(const Vec128<double> a, + const Vec128<double> b) { + return Vec128<double>(vzip1q_f64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_API Vec128<uint64_t> InterleaveLower(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + return CombineShiftRightBytes<8>(Full128<uint64_t>(), b, Shuffle01(a)); +} +HWY_API Vec128<int64_t> InterleaveLower(const Vec128<int64_t> a, + const Vec128<int64_t> b) { + return CombineShiftRightBytes<8>(Full128<int64_t>(), b, Shuffle01(a)); +} +#endif + +// Floats +HWY_API Vec128<float> InterleaveLower(const Vec128<float> a, + const Vec128<float> b) { + return Vec128<float>(vzip1q_f32(a.raw, b.raw)); +} +template <size_t N, HWY_IF_LE64(float, N)> +HWY_API Vec128<float, N> InterleaveLower(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>(vzip1_f32(a.raw, b.raw)); +} + +// < 64 bit parts +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API Vec128<T, N> InterleaveLower(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>(InterleaveLower(Vec64<T>(a.raw), Vec64<T>(b.raw)).raw); +} + +// Additional overload for the optional Simd<> tag. +template <typename T, size_t N, class V = Vec128<T, N>> +HWY_API V InterleaveLower(Simd<T, N, 0> /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { +HWY_NEON_DEF_FUNCTION_INT_8_16_32(InterleaveUpper, vzip2, _, 2) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(InterleaveUpper, vzip2, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_API Vec128<uint64_t> InterleaveUpper(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + return Vec128<uint64_t>(vzip2q_u64(a.raw, b.raw)); +} +HWY_API Vec128<int64_t> InterleaveUpper(Vec128<int64_t> a, Vec128<int64_t> b) { + return Vec128<int64_t>(vzip2q_s64(a.raw, b.raw)); +} +HWY_API Vec128<double> InterleaveUpper(Vec128<double> a, Vec128<double> b) { + return Vec128<double>(vzip2q_f64(a.raw, b.raw)); +} +#else +// ARMv7 emulation. +HWY_API Vec128<uint64_t> InterleaveUpper(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + return CombineShiftRightBytes<8>(Full128<uint64_t>(), Shuffle01(b), a); +} +HWY_API Vec128<int64_t> InterleaveUpper(Vec128<int64_t> a, Vec128<int64_t> b) { + return CombineShiftRightBytes<8>(Full128<int64_t>(), Shuffle01(b), a); +} +#endif + +HWY_API Vec128<float> InterleaveUpper(Vec128<float> a, Vec128<float> b) { + return Vec128<float>(vzip2q_f32(a.raw, b.raw)); +} +HWY_API Vec64<float> InterleaveUpper(const Vec64<float> a, + const Vec64<float> b) { + return Vec64<float>(vzip2_f32(a.raw, b.raw)); +} + +} // namespace detail + +// Full register +template <typename T, size_t N, HWY_IF_GE64(T, N), class V = Vec128<T, N>> +HWY_API V InterleaveUpper(Simd<T, N, 0> /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template <typename T, size_t N, HWY_IF_LE32(T, N), class V = Vec128<T, N>> +HWY_API V InterleaveUpper(Simd<T, N, 0> d, V a, V b) { + const Half<decltype(d)> d2; + return InterleaveLower(d, V(UpperHalf(d2, a).raw), V(UpperHalf(d2, b).raw)); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template <size_t N> +HWY_API Vec128<float, N> ReorderWidenMulAccumulate(Simd<float, N, 0> df32, + Vec128<bfloat16_t, 2 * N> a, + Vec128<bfloat16_t, 2 * N> b, + const Vec128<float, N> sum0, + Vec128<float, N>& sum1) { + const Rebind<uint32_t, decltype(df32)> du32; + using VU32 = VFromD<decltype(du32)>; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Avoid ZipLower/Upper so this also works on big-endian systems. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +HWY_API Vec128<int32_t> ReorderWidenMulAccumulate(Full128<int32_t> /*d32*/, + Vec128<int16_t> a, + Vec128<int16_t> b, + const Vec128<int32_t> sum0, + Vec128<int32_t>& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128<int32_t>(vmlal_high_s16(sum1.raw, a.raw, b.raw)); +#else + const Full64<int16_t> dh; + sum1 = Vec128<int32_t>( + vmlal_s16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128<int32_t>( + vmlal_s16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +HWY_API Vec64<int32_t> ReorderWidenMulAccumulate(Full64<int32_t> d32, + Vec64<int16_t> a, + Vec64<int16_t> b, + const Vec64<int32_t> sum0, + Vec64<int32_t>& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128<int32_t> mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64<int32_t> mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +HWY_API Vec32<int32_t> ReorderWidenMulAccumulate(Full32<int32_t> d32, + Vec32<int16_t> a, + Vec32<int16_t> b, + const Vec32<int32_t> sum0, + Vec32<int32_t>& sum1) { + const Vec128<int32_t> mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64<int32_t> mul_10(LowerHalf(mul_xx10)); + const Vec32<int32_t> mul0 = LowerHalf(d32, mul_10); + const Vec32<int32_t> mul1 = UpperHalf(d32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// Full result +HWY_API Vec128<uint8_t> Combine(Full128<uint8_t> /* tag */, Vec64<uint8_t> hi, + Vec64<uint8_t> lo) { + return Vec128<uint8_t>(vcombine_u8(lo.raw, hi.raw)); +} +HWY_API Vec128<uint16_t> Combine(Full128<uint16_t> /* tag */, + Vec64<uint16_t> hi, Vec64<uint16_t> lo) { + return Vec128<uint16_t>(vcombine_u16(lo.raw, hi.raw)); +} +HWY_API Vec128<uint32_t> Combine(Full128<uint32_t> /* tag */, + Vec64<uint32_t> hi, Vec64<uint32_t> lo) { + return Vec128<uint32_t>(vcombine_u32(lo.raw, hi.raw)); +} +HWY_API Vec128<uint64_t> Combine(Full128<uint64_t> /* tag */, + Vec64<uint64_t> hi, Vec64<uint64_t> lo) { + return Vec128<uint64_t>(vcombine_u64(lo.raw, hi.raw)); +} + +HWY_API Vec128<int8_t> Combine(Full128<int8_t> /* tag */, Vec64<int8_t> hi, + Vec64<int8_t> lo) { + return Vec128<int8_t>(vcombine_s8(lo.raw, hi.raw)); +} +HWY_API Vec128<int16_t> Combine(Full128<int16_t> /* tag */, Vec64<int16_t> hi, + Vec64<int16_t> lo) { + return Vec128<int16_t>(vcombine_s16(lo.raw, hi.raw)); +} +HWY_API Vec128<int32_t> Combine(Full128<int32_t> /* tag */, Vec64<int32_t> hi, + Vec64<int32_t> lo) { + return Vec128<int32_t>(vcombine_s32(lo.raw, hi.raw)); +} +HWY_API Vec128<int64_t> Combine(Full128<int64_t> /* tag */, Vec64<int64_t> hi, + Vec64<int64_t> lo) { + return Vec128<int64_t>(vcombine_s64(lo.raw, hi.raw)); +} + +HWY_API Vec128<float> Combine(Full128<float> /* tag */, Vec64<float> hi, + Vec64<float> lo) { + return Vec128<float>(vcombine_f32(lo.raw, hi.raw)); +} +#if HWY_ARCH_ARM_A64 +HWY_API Vec128<double> Combine(Full128<double> /* tag */, Vec64<double> hi, + Vec64<double> lo) { + return Vec128<double>(vcombine_f64(lo.raw, hi.raw)); +} +#endif + +// < 64bit input, <= 64 bit result +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> Combine(Simd<T, N, 0> d, Vec128<T, N / 2> hi, + Vec128<T, N / 2> lo) { + // First double N (only lower halves will be used). + const Vec128<T, N> hi2(hi.raw); + const Vec128<T, N> lo2(lo.raw); + // Repartition to two unsigned lanes (each the size of the valid input). + const Simd<UnsignedFromSize<N * sizeof(T) / 2>, 2, 0> du; + return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); +} + +// ------------------------------ RearrangeToOddPlusEven (Combine) + +template <size_t N> +HWY_API Vec128<float, N> RearrangeToOddPlusEven(const Vec128<float, N> sum0, + const Vec128<float, N> sum1) { + return Add(sum0, sum1); +} + +HWY_API Vec128<int32_t> RearrangeToOddPlusEven(const Vec128<int32_t> sum0, + const Vec128<int32_t> sum1) { +// vmlal_s16 multiplied the lower half into sum0 and upper into sum1. +#if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want + return Vec128<int32_t>(vpaddq_s32(sum0.raw, sum1.raw)); +#else + const Full128<int32_t> d; + const Half<decltype(d)> d64; + const Vec64<int32_t> hi( + vpadd_s32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw)); + const Vec64<int32_t> lo( + vpadd_s32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw)); + return Combine(Full128<int32_t>(), hi, lo); +#endif +} + +HWY_API Vec64<int32_t> RearrangeToOddPlusEven(const Vec64<int32_t> sum0, + const Vec64<int32_t> sum1) { + // vmlal_s16 multiplied the lower half into sum0 and upper into sum1. + return Vec64<int32_t>(vpadd_s32(sum0.raw, sum1.raw)); +} + +HWY_API Vec32<int32_t> RearrangeToOddPlusEven(const Vec32<int32_t> sum0, + const Vec32<int32_t> sum1) { + // Only one widened sum per register, so add them for sum of odd and even. + return sum0 + sum1; +} + +// ------------------------------ ZeroExtendVector (Combine) + +template <typename T, size_t N> +HWY_API Vec128<T, N> ZeroExtendVector(Simd<T, N, 0> d, Vec128<T, N / 2> lo) { + return Combine(d, Zero(Half<decltype(d)>()), lo); +} + +// ------------------------------ ConcatLowerLower + +// 64 or 128-bit input: just interleave +template <typename T, size_t N, HWY_IF_GE64(T, N)> +HWY_API Vec128<T, N> ConcatLowerLower(const Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition<UnsignedFromSize<N * sizeof(T) / 2>, decltype(d)> du; + return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); +} + +namespace detail { +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF81632(InterleaveOdd, vtrn2, _, 2) +#else + +// vtrn returns a struct with even and odd result. +#define HWY_NEON_BUILD_TPL_HWY_TRN +#define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t +// Pass raw args so we can accept uint16x2 args, for which there is no +// corresponding uint16x2x2 return type. +#define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ + Raw128<TYPE##_t, size>::type a, Raw128<TYPE##_t, size>::type b +#define HWY_NEON_BUILD_ARG_HWY_TRN a, b + +// Cannot use UINT8 etc. type macros because the x2_t tuples are only defined +// for full and half vectors. +HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) +#endif +} // namespace detail + +// <= 32-bit input/output +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API Vec128<T, N> ConcatLowerLower(const Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition<UnsignedFromSize<N * sizeof(T) / 2>, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD<decltype(du)>; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[0])); +#endif +} + +// ------------------------------ ConcatUpperUpper + +// 64 or 128-bit input: just interleave +template <typename T, size_t N, HWY_IF_GE64(T, N)> +HWY_API Vec128<T, N> ConcatUpperUpper(const Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition<UnsignedFromSize<N * sizeof(T) / 2>, decltype(d)> du; + return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); +} + +// <= 32-bit input/output +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API Vec128<T, N> ConcatUpperUpper(const Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition<UnsignedFromSize<N * sizeof(T) / 2>, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD<decltype(du)>; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[1])); +#endif +} + +// ------------------------------ ConcatLowerUpper (ShiftLeftBytes) + +// 64 or 128-bit input: extract from concatenated +template <typename T, size_t N, HWY_IF_GE64(T, N)> +HWY_API Vec128<T, N> ConcatLowerUpper(const Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + return CombineShiftRightBytes<N * sizeof(T) / 2>(d, hi, lo); +} + +// <= 32-bit input/output +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API Vec128<T, N> ConcatLowerUpper(const Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + constexpr size_t kSize = N * sizeof(T); + const Repartition<uint8_t, decltype(d)> d8; + const Full64<uint8_t> d8x8; + const Full64<T> d64; + using V8x8 = VFromD<decltype(d8x8)>; + const V8x8 hi8x8(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V8x8 lo8x8 = ShiftLeftBytes<8 - kSize>(V8x8(BitCast(d8, lo).raw)); + const V8x8 r = CombineShiftRightBytes<8 - kSize / 2>(d8x8, hi8x8, lo8x8); + // Back to original lane type, then shrink N. + return Vec128<T, N>(BitCast(d64, r).raw); +} + +// ------------------------------ ConcatUpperLower + +// Works for all N. +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatUpperLower(Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd (InterleaveUpper) + +namespace detail { +// There is no vuzpq_u64. +HWY_NEON_DEF_FUNCTION_UIF81632(ConcatEven, vuzp1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF81632(ConcatOdd, vuzp2, _, 2) +} // namespace detail + +// Full/half vector +template <typename T, size_t N, + hwy::EnableIf<N != 2 && sizeof(T) * N >= 8>* = nullptr> +HWY_API Vec128<T, N> ConcatOdd(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + return detail::ConcatOdd(lo, hi); +} + +// 8-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> ConcatOdd(Simd<T, 4, 0> d, Vec128<T, 4> hi, + Vec128<T, 4> lo) { + const Twice<decltype(d)> d2; + const Repartition<uint16_t, decltype(d2)> dw2; + const VFromD<decltype(d2)> hi2(hi.raw); + const VFromD<decltype(d2)> lo2(lo.raw); + const VFromD<decltype(dw2)> Hx1Lx1 = BitCast(dw2, ConcatOdd(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec128<T, 4>(BitCast(d2, ConcatEven(dw2, Hx1Lx1, Hx1Lx1)).raw); +} + +// Any type x2 +template <typename T> +HWY_API Vec128<T, 2> ConcatOdd(Simd<T, 2, 0> d, Vec128<T, 2> hi, + Vec128<T, 2> lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// Full/half vector +template <typename T, size_t N, + hwy::EnableIf<N != 2 && sizeof(T) * N >= 8>* = nullptr> +HWY_API Vec128<T, N> ConcatEven(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + return detail::ConcatEven(lo, hi); +} + +// 8-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> ConcatEven(Simd<T, 4, 0> d, Vec128<T, 4> hi, + Vec128<T, 4> lo) { + const Twice<decltype(d)> d2; + const Repartition<uint16_t, decltype(d2)> dw2; + const VFromD<decltype(d2)> hi2(hi.raw); + const VFromD<decltype(d2)> lo2(lo.raw); + const VFromD<decltype(dw2)> Hx0Lx0 = BitCast(dw2, ConcatEven(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec128<T, 4>(BitCast(d2, ConcatEven(dw2, Hx0Lx0, Hx0Lx0)).raw); +} + +// Any type x2 +template <typename T> +HWY_API Vec128<T, 2> ConcatEven(Simd<T, 2, 0> d, Vec128<T, 2> hi, + Vec128<T, 2> lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> DupEven(Vec128<T, N> v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(v, v); +#else + return Vec128<T, N>(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); +#endif +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> DupEven(const Vec128<T, N> v) { + return InterleaveLower(Simd<T, N, 0>(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> DupOdd(Vec128<T, N> v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(v, v); +#else + return Vec128<T, N>(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); +#endif +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> DupOdd(const Vec128<T, N> v) { + return InterleaveUpper(Simd<T, N, 0>(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) { + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + alignas(16) constexpr uint8_t kBytes[16] = { + ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, + ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, + ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, + ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, + ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, + ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, + ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, + ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, + }; + const auto vec = BitCast(d, Load(d8, kBytes)); + return IfThenElse(MaskFromVec(vec), b, a); +} + +// ------------------------------ OddEvenBlocks +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEvenBlocks(Vec128<T, N> /* odd */, Vec128<T, N> even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T, size_t N> +HWY_API Vec128<T, N> SwapAdjacentBlocks(Vec128<T, N> v) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template <typename T> +HWY_API Vec128<T> ReverseBlocks(Full128<T> /* tag */, const Vec128<T> v) { + return v; +} + +// ------------------------------ ReorderDemote2To (OddEven) + +template <size_t N> +HWY_API Vec128<bfloat16_t, 2 * N> ReorderDemote2To( + Simd<bfloat16_t, 2 * N, 0> dbf16, Vec128<float, N> a, Vec128<float, N> b) { + const RebindToUnsigned<decltype(dbf16)> du16; + const Repartition<uint32_t, decltype(dbf16)> du32; + const Vec128<uint32_t, N> b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec128<int16_t> ReorderDemote2To(Full128<int16_t> d16, + Vec128<int32_t> a, Vec128<int32_t> b) { + const Vec64<int16_t> a16(vqmovn_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128<int16_t>(vqmovn_high_s32(a16.raw, b.raw)); +#else + const Vec64<int16_t> b16(vqmovn_s32(b.raw)); + return Combine(d16, a16, b16); +#endif +} + +HWY_API Vec64<int16_t> ReorderDemote2To(Full64<int16_t> /*d16*/, + Vec64<int32_t> a, Vec64<int32_t> b) { + const Full128<int32_t> d32; + const Vec128<int32_t> ab = Combine(d32, a, b); + return Vec64<int16_t>(vqmovn_s32(ab.raw)); +} + +HWY_API Vec32<int16_t> ReorderDemote2To(Full32<int16_t> /*d16*/, + Vec32<int32_t> a, Vec32<int32_t> b) { + const Full128<int32_t> d32; + const Vec64<int32_t> ab(vzip1_s32(a.raw, b.raw)); + return Vec32<int16_t>(vqmovn_s32(Combine(d32, ab, ab).raw)); +} + +// ================================================== CRYPTO + +#if defined(__ARM_FEATURE_AES) || \ + (HWY_HAVE_RUNTIME_DISPATCH && HWY_ARCH_ARM_A64) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128<uint8_t> AESRound(Vec128<uint8_t> state, + Vec128<uint8_t> round_key) { + // NOTE: it is important that AESE and AESMC be consecutive instructions so + // they can be fused. AESE includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128<uint8_t>(vaesmcq_u8(vaeseq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128<uint8_t> AESLastRound(Vec128<uint8_t> state, + Vec128<uint8_t> round_key) { + return Vec128<uint8_t>(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128<uint64_t> CLMulLower(Vec128<uint64_t> a, Vec128<uint64_t> b) { + return Vec128<uint64_t>((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); +} + +HWY_API Vec128<uint64_t> CLMulUpper(Vec128<uint64_t> a, Vec128<uint64_t> b) { + return Vec128<uint64_t>( + (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); +} + +#endif // __ARM_FEATURE_AES + +// ================================================== MISC + +template <size_t N> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> df32, + const Vec128<bfloat16_t, N> v) { + const Rebind<uint16_t, decltype(df32)> du16; + const RebindToSigned<decltype(df32)> di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Truncations + +template <typename From, typename To, HWY_IF_UNSIGNED(From), + HWY_IF_UNSIGNED(To), + hwy::EnableIf<(sizeof(To) < sizeof(From))>* = nullptr> +HWY_API Vec128<To, 1> TruncateTo(Simd<To, 1, 0> /* tag */, + const Vec128<From, 1> v) { + const Repartition<To, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + return Vec128<To, 1>{v1.raw}; +} + +HWY_API Vec128<uint8_t, 2> TruncateTo(Simd<uint8_t, 2, 0> /* tag */, + const Vec128<uint64_t, 2> v) { + const Repartition<uint8_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + const auto v4 = detail::ConcatEven(v3, v3); + return LowerHalf(LowerHalf(LowerHalf(v4))); +} + +HWY_API Vec32<uint16_t> TruncateTo(Simd<uint16_t, 2, 0> /* tag */, + const Vec128<uint64_t, 2> v) { + const Repartition<uint16_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +HWY_API Vec64<uint32_t> TruncateTo(Simd<uint32_t, 2, 0> /* tag */, + const Vec128<uint64_t, 2> v) { + const Repartition<uint32_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + const Repartition<uint8_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint16_t, N> TruncateTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + const Repartition<uint16_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + const Repartition<uint8_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +// ------------------------------ MulEven (ConcatEven) + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128<int64_t> MulEven(Vec128<int32_t> a, Vec128<int32_t> b) { + const Full128<int32_t> d; + int32x4_t a_packed = ConcatEven(d, a, a).raw; + int32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128<int64_t>( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128<uint64_t> MulEven(Vec128<uint32_t> a, Vec128<uint32_t> b) { + const Full128<uint32_t> d; + uint32x4_t a_packed = ConcatEven(d, a, a).raw; + uint32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128<uint64_t>( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template <size_t N> +HWY_API Vec128<int64_t, (N + 1) / 2> MulEven(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + const DFromV<decltype(a)> d; + int32x2_t a_packed = ConcatEven(d, a, a).raw; + int32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128<int64_t, (N + 1) / 2>( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template <size_t N> +HWY_API Vec128<uint64_t, (N + 1) / 2> MulEven(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + const DFromV<decltype(a)> d; + uint32x2_t a_packed = ConcatEven(d, a, a).raw; + uint32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128<uint64_t, (N + 1) / 2>( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +HWY_INLINE Vec128<uint64_t> MulEven(Vec128<uint64_t> a, Vec128<uint64_t> b) { + uint64_t hi; + uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 0), vgetq_lane_u64(b.raw, 0), &hi); + return Vec128<uint64_t>(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +} + +HWY_INLINE Vec128<uint64_t> MulOdd(Vec128<uint64_t> a, Vec128<uint64_t> b) { + uint64_t hi; + uint64_t lo = Mul128(vgetq_lane_u64(a.raw, 1), vgetq_lane_u64(b.raw, 1), &hi); + return Vec128<uint64_t>(vsetq_lane_u64(hi, vdupq_n_u64(lo), 1)); +} + +// ------------------------------ TableLookupBytes (Combine, LowerHalf) + +// Both full +template <typename T, typename TI> +HWY_API Vec128<TI> TableLookupBytes(const Vec128<T> bytes, + const Vec128<TI> from) { + const Full128<TI> d; + const Repartition<uint8_t, decltype(d)> d8; +#if HWY_ARCH_ARM_A64 + return BitCast(d, Vec128<uint8_t>(vqtbl1q_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +#else + uint8x16_t table0 = BitCast(d8, bytes).raw; + uint8x8x2_t table; + table.val[0] = vget_low_u8(table0); + table.val[1] = vget_high_u8(table0); + uint8x16_t idx = BitCast(d8, from).raw; + uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); + return BitCast(d, Vec128<uint8_t>(vcombine_u8(low, hi))); +#endif +} + +// Partial index vector +template <typename T, typename TI, size_t NI, HWY_IF_LE64(TI, NI)> +HWY_API Vec128<TI, NI> TableLookupBytes(const Vec128<T> bytes, + const Vec128<TI, NI> from) { + const Full128<TI> d_full; + const Vec64<TI> from64(from.raw); + const auto idx_full = Combine(d_full, from64, from64); + const auto out_full = TableLookupBytes(bytes, idx_full); + return Vec128<TI, NI>(LowerHalf(Half<decltype(d_full)>(), out_full).raw); +} + +// Partial table vector +template <typename T, size_t N, typename TI, HWY_IF_LE64(T, N)> +HWY_API Vec128<TI> TableLookupBytes(const Vec128<T, N> bytes, + const Vec128<TI> from) { + const Full128<T> d_full; + return TableLookupBytes(Combine(d_full, bytes, bytes), from); +} + +// Partial both +template <typename T, size_t N, typename TI, size_t NI, HWY_IF_LE64(T, N), + HWY_IF_LE64(TI, NI)> +HWY_API VFromD<Repartition<T, Simd<TI, NI, 0>>> TableLookupBytes( + Vec128<T, N> bytes, Vec128<TI, NI> from) { + const Simd<T, N, 0> d; + const Simd<TI, NI, 0> d_idx; + const Repartition<uint8_t, decltype(d_idx)> d_idx8; + // uint8x8 + const auto bytes8 = BitCast(Repartition<uint8_t, decltype(d)>(), bytes); + const auto from8 = BitCast(d_idx8, from); + const VFromD<decltype(d_idx8)> v8(vtbl1_u8(bytes8.raw, from8.raw)); + return BitCast(d_idx, v8); +} + +// For all vector widths; ARM anyway zeroes if >= 0x10. +template <class V, class VI> +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Scatter (Store) + +template <typename T, size_t N, typename Offset, HWY_IF_LE128(T, N)> +HWY_API void ScatterOffset(Vec128<T, N> v, Simd<T, N, 0> d, + T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind<Offset, decltype(d)>(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast<uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template <typename T, size_t N, typename Index, HWY_IF_LE128(T, N)> +HWY_API void ScatterIndex(Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind<Index, decltype(d)>(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template <typename T, size_t N, typename Offset> +HWY_API Vec128<T, N> GatherOffset(const Simd<T, N, 0> d, + const T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind<Offset, decltype(d)>(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast<const uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template <typename T, size_t N, typename Index> +HWY_API Vec128<T, N> GatherIndex(const Simd<T, N, 0> d, + const T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind<Index, decltype(d)>(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template <typename T> +HWY_INLINE Vec128<T, 1> SumOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} +template <typename T> +HWY_INLINE Vec128<T, 1> MinOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} +template <typename T> +HWY_INLINE Vec128<T, 1> MaxOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} + +// full vectors +#if HWY_ARCH_ARM_A64 +#define HWY_NEON_BUILD_RET_REDUCTION(type, size) Vec128<type##_t, size> +#define HWY_NEON_DEF_REDUCTION(type, size, name, prefix, infix, suffix, dup) \ + HWY_API HWY_NEON_BUILD_RET_REDUCTION(type, size) \ + name(hwy::SizeTag<sizeof(type##_t)>, const Vec128<type##_t, size> v) { \ + return HWY_NEON_BUILD_RET_REDUCTION( \ + type, size)(dup##suffix(HWY_NEON_EVAL(prefix##infix##suffix, v.raw))); \ + } + +#define HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint8, 8, name, prefix, _, u8, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(uint8, 16, name, prefix##q, _, u8, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(uint16, 4, name, prefix, _, u16, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(uint16, 8, name, prefix##q, _, u16, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(uint32, 2, name, prefix, _, u32, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(uint32, 4, name, prefix##q, _, u32, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int8, 8, name, prefix, _, s8, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(int8, 16, name, prefix##q, _, s8, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int16, 4, name, prefix, _, s16, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(int16, 8, name, prefix##q, _, s16, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int32, 2, name, prefix, _, s32, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(int32, 4, name, prefix##q, _, s32, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(float32, 2, name, prefix, _, f32, vdup_n_) \ + HWY_NEON_DEF_REDUCTION(float32, 4, name, prefix##q, _, f32, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(float64, 2, name, prefix##q, _, f64, vdupq_n_) + +HWY_NEON_DEF_REDUCTION_CORE_TYPES(MinOfLanes, vminv) +HWY_NEON_DEF_REDUCTION_CORE_TYPES(MaxOfLanes, vmaxv) + +// u64/s64 don't have horizontal min/max for some reason, but do have add. +#define HWY_NEON_DEF_REDUCTION_ALL_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint64, 2, name, prefix##q, _, u64, vdupq_n_) \ + HWY_NEON_DEF_REDUCTION(int64, 2, name, prefix##q, _, s64, vdupq_n_) + +HWY_NEON_DEF_REDUCTION_ALL_TYPES(SumOfLanes, vaddv) + +#undef HWY_NEON_DEF_REDUCTION_ALL_TYPES +#undef HWY_NEON_DEF_REDUCTION_CORE_TYPES +#undef HWY_NEON_DEF_REDUCTION +#undef HWY_NEON_BUILD_RET_REDUCTION + +// Need some fallback implementations for [ui]64x2 and [ui]16x2. +#define HWY_IF_SUM_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 1 << 2) +#define HWY_IF_MINMAX_REDUCTION(T) \ + HWY_IF_LANE_SIZE_ONE_OF(T, (1 << 8) | (1 << 2)) + +#else +// u32/i32/f32: N=2 +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<T, 2> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return v10 + Shuffle2301(v10); +} +template <typename T> +HWY_INLINE Vec128<T, 2> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return Min(v10, Shuffle2301(v10)); +} +template <typename T> +HWY_INLINE Vec128<T, 2> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return Max(v10, Shuffle2301(v10)); +} + +// ARMv7 version for everything except doubles. +HWY_INLINE Vec128<uint32_t> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<uint32_t> v) { + uint32x4x2_t v0 = vuzpq_u32(v.raw, v.raw); + uint32x4_t c0 = vaddq_u32(v0.val[0], v0.val[1]); + uint32x4x2_t v1 = vuzpq_u32(c0, c0); + return Vec128<uint32_t>(vaddq_u32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128<int32_t> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<int32_t> v) { + int32x4x2_t v0 = vuzpq_s32(v.raw, v.raw); + int32x4_t c0 = vaddq_s32(v0.val[0], v0.val[1]); + int32x4x2_t v1 = vuzpq_s32(c0, c0); + return Vec128<int32_t>(vaddq_s32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128<float> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<float> v) { + float32x4x2_t v0 = vuzpq_f32(v.raw, v.raw); + float32x4_t c0 = vaddq_f32(v0.val[0], v0.val[1]); + float32x4x2_t v1 = vuzpq_f32(c0, c0); + return Vec128<float>(vaddq_f32(v1.val[0], v1.val[1])); +} +HWY_INLINE Vec128<uint64_t> SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<uint64_t> v) { + return v + Shuffle01(v); +} +HWY_INLINE Vec128<int64_t> SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<int64_t> v) { + return v + Shuffle01(v); +} + +template <typename T> +HWY_INLINE Vec128<T> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = Min(v3210, v1032); + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template <typename T> +HWY_INLINE Vec128<T> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = Max(v3210, v1032); + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +#define HWY_NEON_BUILD_TYPE_T(type, size) type##x##size##_t +#define HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) Vec128<type##_t, size> +#define HWY_NEON_DEF_PAIRWISE_REDUCTION(type, size, name, prefix, suffix) \ + HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) \ + name(hwy::SizeTag<sizeof(type##_t)>, const Vec128<type##_t, size> v) { \ + HWY_NEON_BUILD_TYPE_T(type, size) tmp = prefix##_##suffix(v.raw, v.raw); \ + if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ + type, size)(HWY_NEON_EVAL(vdup##_lane_##suffix, tmp, 0)); \ + } +#define HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(type, size, half, name, prefix, \ + suffix) \ + HWY_API HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION(type, size) \ + name(hwy::SizeTag<sizeof(type##_t)>, const Vec128<type##_t, size> v) { \ + HWY_NEON_BUILD_TYPE_T(type, half) tmp; \ + tmp = prefix##_##suffix(vget_high_##suffix(v.raw), \ + vget_low_##suffix(v.raw)); \ + if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + if ((size / 8) > 1) tmp = prefix##_##suffix(tmp, tmp); \ + tmp = vdup_lane_##suffix(tmp, 0); \ + return HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION( \ + type, size)(HWY_NEON_EVAL(vcombine_##suffix, tmp, tmp)); \ + } + +#define HWY_NEON_DEF_PAIRWISE_REDUCTIONS(name, prefix) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint16, 4, name, prefix, u16) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(uint8, 8, name, prefix, u8) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int16, 4, name, prefix, s16) \ + HWY_NEON_DEF_PAIRWISE_REDUCTION(int8, 8, name, prefix, s8) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint16, 8, 4, name, prefix, u16) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint8, 16, 8, name, prefix, u8) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int16, 8, 4, name, prefix, s16) \ + HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int8, 16, 8, name, prefix, s8) + +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(SumOfLanes, vpadd) +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(MinOfLanes, vpmin) +HWY_NEON_DEF_PAIRWISE_REDUCTIONS(MaxOfLanes, vpmax) + +#undef HWY_NEON_DEF_PAIRWISE_REDUCTIONS +#undef HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION +#undef HWY_NEON_DEF_PAIRWISE_REDUCTION +#undef HWY_NEON_BUILD_RET_PAIRWISE_REDUCTION +#undef HWY_NEON_BUILD_TYPE_T + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// Need fallback min/max implementations for [ui]64x2. +#define HWY_IF_SUM_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 0) +#define HWY_IF_MINMAX_REDUCTION(T) HWY_IF_LANE_SIZE_ONE_OF(T, 1 << 8) + +#endif + +// [ui]16/[ui]64: N=2 -- special case for pairs of very small or large lanes +template <typename T, HWY_IF_SUM_REDUCTION(T)> +HWY_API Vec128<T, 2> SumOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 2> v10) { + return v10 + Reverse2(Simd<T, 2, 0>(), v10); +} +template <typename T, HWY_IF_MINMAX_REDUCTION(T)> +HWY_API Vec128<T, 2> MinOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 2> v10) { + return Min(v10, Reverse2(Simd<T, 2, 0>(), v10)); +} +template <typename T, HWY_IF_MINMAX_REDUCTION(T)> +HWY_API Vec128<T, 2> MaxOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 2> v10) { + return Max(v10, Reverse2(Simd<T, 2, 0>(), v10)); +} + +#undef HWY_IF_SUM_REDUCTION +#undef HWY_IF_MINMAX_REDUCTION + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> SumOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::SumOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MinOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::MinOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MaxOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::MaxOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// Helper function to set 64 bits and potentially return a smaller vector. The +// overload is required to call the q vs non-q intrinsics. Note that 8-bit +// LoadMaskBits only requires 16 bits, but 64 avoids casting. +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE Vec128<T, N> Set64(Simd<T, N, 0> /* tag */, uint64_t mask_bits) { + const auto v64 = Vec64<uint64_t>(vdup_n_u64(mask_bits)); + return Vec128<T, N>(BitCast(Full64<T>(), v64).raw); +} +template <typename T> +HWY_INLINE Vec128<T> Set64(Full128<T> d, uint64_t mask_bits) { + return BitCast(d, Vec128<uint64_t>(vdupq_n_u64(mask_bits))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const auto vmask_bits = Set64(du, mask_bits); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vmask_bits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast<uint16_t>(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast<uint32_t>(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns mask[i]? 0xF : 0 in each nibble. This is more efficient than +// BitsFromMask for use in (partial) CountTrue, FindFirstTrue and AllFalse. +template <typename T> +HWY_INLINE uint64_t NibblesFromMask(const Full128<T> d, Mask128<T> mask) { + const Full128<uint16_t> du16; + const Vec128<uint16_t> vu16 = BitCast(du16, VecFromMask(d, mask)); + const Vec64<uint8_t> nib(vshrn_n_u16(vu16.raw, 4)); + return GetLane(BitCast(Full64<uint64_t>(), nib)); +} + +template <typename T> +HWY_INLINE uint64_t NibblesFromMask(const Full64<T> d, Mask64<T> mask) { + // There is no vshrn_n_u16 for uint16x4, so zero-extend. + const Twice<decltype(d)> d2; + const Vec128<T> v128 = ZeroExtendVector(d2, VecFromMask(d, mask)); + // No need to mask, upper half is zero thanks to ZeroExtendVector. + return NibblesFromMask(d2, MaskFromVec(v128)); +} + +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_INLINE uint64_t NibblesFromMask(Simd<T, N, 0> /*d*/, Mask128<T, N> mask) { + const Mask64<T> mask64(mask.raw); + const uint64_t nib = NibblesFromMask(Full64<T>(), mask64); + // Clear nibbles from upper half of 64-bits + constexpr size_t kBytes = sizeof(T) * N; + return nib & ((1ull << (kBytes * 4)) - 1); +} + +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128<T> mask) { + alignas(16) constexpr uint8_t kSliceLanes[16] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, + }; + const Full128<uint8_t> du; + const Vec128<uint8_t> values = + BitCast(du, VecFromMask(Full128<T>(), mask)) & Load(du, kSliceLanes); + +#if HWY_ARCH_ARM_A64 + // Can't vaddv - we need two separate bytes (16 bits). + const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); + const uint8x8_t x4 = vpadd_u8(x2, x2); + const uint8x8_t x8 = vpadd_u8(x4, x4); + return vget_lane_u64(vreinterpret_u64_u8(x8), 0); +#else + // Don't have vpaddq, so keep doubling lane size. + const uint16x8_t x2 = vpaddlq_u8(values.raw); + const uint32x4_t x4 = vpaddlq_u16(x2); + const uint64x2_t x8 = vpaddlq_u32(x4); + return (vgetq_lane_u64(x8, 1) << 8) | vgetq_lane_u64(x8, 0); +#endif +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128<T, N> mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const Vec128<uint8_t, N> slice(Load(Full64<uint8_t>(), kSliceLanes).raw); + const Vec128<uint8_t, N> values = BitCast(du, VecFromMask(d, mask)) & slice; + +#if HWY_ARCH_ARM_A64 + return vaddv_u8(values.raw); +#else + const uint16x4_t x2 = vpaddl_u8(values.raw); + const uint32x2_t x4 = vpaddl_u16(x2); + const uint64x1_t x8 = vpaddl_u32(x4); + return vget_lane_u64(x8, 0); +#endif +} + +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128<T> mask) { + alignas(16) constexpr uint16_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const Full128<T> d; + const Full128<uint16_t> du; + const Vec128<uint16_t> values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u16(values.raw); +#else + const uint32x4_t x2 = vpaddlq_u16(values.raw); + const uint64x2_t x4 = vpaddlq_u32(x2); + return vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1); +#endif +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128<T, N> mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const Vec128<uint16_t, N> slice(Load(Full64<uint16_t>(), kSliceLanes).raw); + const Vec128<uint16_t, N> values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u16(values.raw); +#else + const uint32x2_t x2 = vpaddl_u16(values.raw); + const uint64x1_t x4 = vpaddl_u32(x2); + return vget_lane_u64(x4, 0); +#endif +} + +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128<T> mask) { + alignas(16) constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; + const Full128<T> d; + const Full128<uint32_t> du; + const Vec128<uint32_t> values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u32(values.raw); +#else + const uint64x2_t x2 = vpaddlq_u32(values.raw); + return vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1); +#endif +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128<T, N> mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) constexpr uint32_t kSliceLanes[2] = {1, 2}; + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const Vec128<uint32_t, N> slice(Load(Full64<uint32_t>(), kSliceLanes).raw); + const Vec128<uint32_t, N> values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return vaddv_u32(values.raw); +#else + const uint64x1_t x2 = vpaddl_u32(values.raw); + return vget_lane_u64(x2, 0); +#endif +} + +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, const Mask128<T> m) { + alignas(16) constexpr uint64_t kSliceLanes[2] = {1, 2}; + const Full128<T> d; + const Full128<uint64_t> du; + const Vec128<uint64_t> values = + BitCast(du, VecFromMask(d, m)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return vaddvq_u64(values.raw); +#else + return vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1); +#endif +} + +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128<T, 1> m) { + const Full64<T> d; + const Full64<uint64_t> du; + const Vec64<uint64_t> values = BitCast(du, VecFromMask(d, m)) & Set(du, 1); + return vget_lane_u64(values.raw, 0); +} + +// Returns the lowest N for the BitsFromMask result. +template <typename T, size_t N> +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) >= 8) ? bits : (bits & ((1ull << N) - 1)); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(const Mask128<T, N> mask) { + return OnlyActive<T, N>(BitsFromMask(hwy::SizeTag<sizeof(T)>(), mask)); +} + +// Returns number of lanes whose mask is set. +// +// Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op +// ("vsubv"). ANDing with 1 would work but requires a constant. Negating also +// changes each lane to 1 (if mask set) or 0. +// NOTE: PopCount also operates on vectors, so we still have to do horizontal +// sums separately. We specialize CountTrue for full vectors (negating instead +// of PopCount because it avoids an extra shift), and use PopCount of +// NibblesFromMask for partial vectors. + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, const Mask128<T> mask) { + const Full128<int8_t> di; + const int8x16_t ones = + vnegq_s8(BitCast(di, VecFromMask(Full128<T>(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast<size_t>(vaddvq_s8(ones)); +#else + const int16x8_t x2 = vpaddlq_s8(ones); + const int32x4_t x4 = vpaddlq_s16(x2); + const int64x2_t x8 = vpaddlq_s32(x4); + return static_cast<size_t>(vgetq_lane_s64(x8, 0) + vgetq_lane_s64(x8, 1)); +#endif +} +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, const Mask128<T> mask) { + const Full128<int16_t> di; + const int16x8_t ones = + vnegq_s16(BitCast(di, VecFromMask(Full128<T>(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast<size_t>(vaddvq_s16(ones)); +#else + const int32x4_t x2 = vpaddlq_s16(ones); + const int64x2_t x4 = vpaddlq_s32(x2); + return static_cast<size_t>(vgetq_lane_s64(x4, 0) + vgetq_lane_s64(x4, 1)); +#endif +} + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128<T> mask) { + const Full128<int32_t> di; + const int32x4_t ones = + vnegq_s32(BitCast(di, VecFromMask(Full128<T>(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast<size_t>(vaddvq_s32(ones)); +#else + const int64x2_t x2 = vpaddlq_s32(ones); + return static_cast<size_t>(vgetq_lane_s64(x2, 0) + vgetq_lane_s64(x2, 1)); +#endif +} + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128<T> mask) { +#if HWY_ARCH_ARM_A64 + const Full128<int64_t> di; + const int64x2_t ones = + vnegq_s64(BitCast(di, VecFromMask(Full128<T>(), mask)).raw); + return static_cast<size_t>(vaddvq_s64(ones)); +#else + const Full128<uint64_t> du; + const auto mask_u = VecFromMask(du, RebindMask(du, mask)); + const uint64x2_t ones = vshrq_n_u64(mask_u.raw, 63); + return static_cast<size_t>(vgetq_lane_u64(ones, 0) + vgetq_lane_u64(ones, 1)); +#endif +} + +} // namespace detail + +// Full +template <typename T> +HWY_API size_t CountTrue(Full128<T> /* tag */, const Mask128<T> mask) { + return detail::CountTrue(hwy::SizeTag<sizeof(T)>(), mask); +} + +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API size_t CountTrue(Simd<T, N, 0> d, const Mask128<T, N> mask) { + constexpr int kDiv = 4 * sizeof(T); + return PopCount(detail::NibblesFromMask(d, mask)) / kDiv; +} + +template <typename T, size_t N> +HWY_API size_t FindKnownFirstTrue(const Simd<T, N, 0> d, + const Mask128<T, N> mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(T); + return Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv; +} + +template <typename T, size_t N> +HWY_API intptr_t FindFirstTrue(const Simd<T, N, 0> d, + const Mask128<T, N> mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr int kDiv = 4 * sizeof(T); + return static_cast<intptr_t>(Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv); +} + +// `p` points to at least 8 writable bytes. +template <typename T, size_t N> +HWY_API size_t StoreMaskBits(Simd<T, N, 0> /* tag */, const Mask128<T, N> mask, + uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(&mask_bits, bits); + return kNumBytes; +} + +template <typename T, size_t N> +HWY_API bool AllFalse(const Simd<T, N, 0> d, const Mask128<T, N> m) { + return detail::NibblesFromMask(d, m) == 0; +} + +// Full +template <typename T> +HWY_API bool AllTrue(const Full128<T> d, const Mask128<T> m) { + return detail::NibblesFromMask(d, m) == ~0ull; +} +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API bool AllTrue(const Simd<T, N, 0> d, const Mask128<T, N> m) { + constexpr size_t kBytes = sizeof(T) * N; + return detail::NibblesFromMask(d, m) == (1ull << (kBytes * 4)) - 1; +} + +// ------------------------------ Compress + +template <typename T> +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +namespace detail { + +// Load 8 bytes, replicate into upper half so ZipLower can use the lower half. +HWY_INLINE Vec128<uint8_t> Load8Bytes(Full128<uint8_t> /*d*/, + const uint8_t* bytes) { + return Vec128<uint8_t>(vreinterpretq_u8_u64( + vld1q_dup_u64(reinterpret_cast<const uint64_t*>(bytes)))); +} + +// Load 8 bytes and return half-reg with N <= 8 bytes. +template <size_t N, HWY_IF_LE64(uint8_t, N)> +HWY_INLINE Vec128<uint8_t, N> Load8Bytes(Simd<uint8_t, N, 0> d, + const uint8_t* bytes) { + return Load(d, bytes); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IdxFromBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + const Simd<uint16_t, N, 0> du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128<uint8_t, 2 * N> byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128<uint16_t, N> pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IdxFromNotBits(hwy::SizeTag<2> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + const Simd<uint16_t, N, 0> du; + + // ARM does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128<uint8_t, 2 * N> byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128<uint16_t, N> pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IdxFromBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IdxFromNotBits(hwy::SizeTag<4> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IdxFromBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IdxFromNotBits(hwy::SizeTag<8> /*tag*/, + const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#endif + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Compress(Vec128<T, N> v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromBits<T, N>(hwy::SizeTag<sizeof(T)>(), mask_bits); + using D = Simd<T, N, 0>; + const RebindToSigned<D> di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> CompressNot(Vec128<T, N> v, const uint64_t mask_bits) { + const auto idx = + detail::IdxFromNotBits<T, N>(hwy::SizeTag<sizeof(T)>(), mask_bits); + using D = Simd<T, N, 0>; + const RebindToSigned<D> di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> Compress(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +// Two lanes: conditional swap +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Compress(Vec128<T, N> v, const Mask128<T, N> mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Simd<T, N, 0> d; + const Vec128<T, N> m = VecFromMask(d, mask); + const Vec128<T, N> maskL = DupEven(m); + const Vec128<T, N> maskH = DupOdd(m); + const Vec128<T, N> swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API Vec128<T, N> Compress(Vec128<T, N> v, const Mask128<T, N> mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> CompressNot(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +// Two lanes: conditional swap +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128<T> d; + const Vec128<T> m = VecFromMask(d, mask); + const Vec128<T> maskL = DupEven(m); + const Vec128<T> maskH = DupOdd(m); + const Vec128<T> swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API Vec128<T, N> CompressNot(Vec128<T, N> v, Mask128<T, N> mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNot(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128<uint64_t> CompressBlocksNot(Vec128<uint64_t> v, + Mask128<uint64_t> /* m */) { + return v; +} + +// ------------------------------ CompressBits + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_INLINE Vec128<T, N> CompressBits(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressStore(Vec128<T, N> v, const Mask128<T, N> mask, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; // so we can support fp16/bf16 + using TU = TFromD<decltype(du)>; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Mask128<T, N> store_mask = RebindMask(d, FirstN(du, count)); + const Vec128<TU, N> compressed = detail::Compress(BitCast(du, v), mask_bits); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_LOAD_INT(T, N) HWY_IF_GE64(T, N) +#define HWY_NEON_DEF_FUNCTION_LOAD_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64 +#define HWY_IF_LOAD_INT(T, N) \ + hwy::EnableIf<N * sizeof(T) >= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +// Must return raw tuple because Tuple2 lack a ctor, and we cannot use +// brace-initialization in HWY_NEON_DEF_FUNCTION because some functions return +// void. +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple2<type##_t, size>().raw) +// Tuple tag arg allows overloading (cannot just overload on return type) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple2<type##_t, size> +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved2, vld2, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple3<type##_t, size>().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple3<type##_t, size> +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved3, vld3, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple4<type##_t, size>().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const type##_t *from, Tuple4<type##_t, size> +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#undef HWY_NEON_DEF_FUNCTION_LOAD_INT +#undef HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#undef HWY_NEON_BUILD_ARG_HWY_LOAD_INT +} // namespace detail + +template <typename T, size_t N, HWY_IF_LOAD_INT(T, N)> +HWY_API void LoadInterleaved2(Simd<T, N, 0> /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128<T, N>& v0, + Vec128<T, N>& v1) { + auto raw = detail::LoadInterleaved2(unaligned, detail::Tuple2<T, N>()); + v0 = Vec128<T, N>(raw.val[0]); + v1 = Vec128<T, N>(raw.val[1]); +} + +// <= 32 bits: avoid loading more than N bytes by copying to buffer +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void LoadInterleaved2(Simd<T, N, 0> /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128<T, N>& v0, + Vec128<T, N>& v1) { + // The smallest vector registers are 64-bits and we want space for two. + alignas(16) T buf[2 * 8 / sizeof(T)] = {}; + CopyBytes<N * 2 * sizeof(T)>(unaligned, buf); + auto raw = detail::LoadInterleaved2(buf, detail::Tuple2<T, N>()); + v0 = Vec128<T, N>(raw.val[0]); + v1 = Vec128<T, N>(raw.val[1]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void LoadInterleaved2(Full128<T> d, T* HWY_RESTRICT unaligned, + Vec128<T>& v0, Vec128<T>& v1) { + const Half<decltype(d)> dh; + VFromD<decltype(dh)> v00, v10, v01, v11; + LoadInterleaved2(dh, unaligned, v00, v10); + LoadInterleaved2(dh, unaligned + 2, v01, v11); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved3 + +template <typename T, size_t N, HWY_IF_LOAD_INT(T, N)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128<T, N>& v0, + Vec128<T, N>& v1, Vec128<T, N>& v2) { + auto raw = detail::LoadInterleaved3(unaligned, detail::Tuple3<T, N>()); + v0 = Vec128<T, N>(raw.val[0]); + v1 = Vec128<T, N>(raw.val[1]); + v2 = Vec128<T, N>(raw.val[2]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128<T, N>& v0, + Vec128<T, N>& v1, Vec128<T, N>& v2) { + // The smallest vector registers are 64-bits and we want space for three. + alignas(16) T buf[3 * 8 / sizeof(T)] = {}; + CopyBytes<N * 3 * sizeof(T)>(unaligned, buf); + auto raw = detail::LoadInterleaved3(buf, detail::Tuple3<T, N>()); + v0 = Vec128<T, N>(raw.val[0]); + v1 = Vec128<T, N>(raw.val[1]); + v2 = Vec128<T, N>(raw.val[2]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void LoadInterleaved3(Full128<T> d, const T* HWY_RESTRICT unaligned, + Vec128<T>& v0, Vec128<T>& v1, Vec128<T>& v2) { + const Half<decltype(d)> dh; + VFromD<decltype(dh)> v00, v10, v20, v01, v11, v21; + LoadInterleaved3(dh, unaligned, v00, v10, v20); + LoadInterleaved3(dh, unaligned + 3, v01, v11, v21); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved4 + +template <typename T, size_t N, HWY_IF_LOAD_INT(T, N)> +HWY_API void LoadInterleaved4(Simd<T, N, 0> /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128<T, N>& v0, + Vec128<T, N>& v1, Vec128<T, N>& v2, + Vec128<T, N>& v3) { + auto raw = detail::LoadInterleaved4(unaligned, detail::Tuple4<T, N>()); + v0 = Vec128<T, N>(raw.val[0]); + v1 = Vec128<T, N>(raw.val[1]); + v2 = Vec128<T, N>(raw.val[2]); + v3 = Vec128<T, N>(raw.val[3]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void LoadInterleaved4(Simd<T, N, 0> /*tag*/, + const T* HWY_RESTRICT unaligned, Vec128<T, N>& v0, + Vec128<T, N>& v1, Vec128<T, N>& v2, + Vec128<T, N>& v3) { + alignas(16) T buf[4 * 8 / sizeof(T)] = {}; + CopyBytes<N * 4 * sizeof(T)>(unaligned, buf); + auto raw = detail::LoadInterleaved4(buf, detail::Tuple4<T, N>()); + v0 = Vec128<T, N>(raw.val[0]); + v1 = Vec128<T, N>(raw.val[1]); + v2 = Vec128<T, N>(raw.val[2]); + v3 = Vec128<T, N>(raw.val[3]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void LoadInterleaved4(Full128<T> d, const T* HWY_RESTRICT unaligned, + Vec128<T>& v0, Vec128<T>& v1, Vec128<T>& v2, + Vec128<T>& v3) { + const Half<decltype(d)> dh; + VFromD<decltype(dh)> v00, v10, v20, v30, v01, v11, v21, v31; + LoadInterleaved4(dh, unaligned, v00, v10, v20, v30); + LoadInterleaved4(dh, unaligned + 4, v01, v11, v21, v31); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); + v3 = Combine(d, v31, v30); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_LOAD_INT + +// ------------------------------ StoreInterleaved2 + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_STORE_INT +#define HWY_NEON_BUILD_RET_HWY_STORE_INT(type, size) void +#define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_STORE_INT(T, N) HWY_IF_GE64(T, N) +#define HWY_NEON_DEF_FUNCTION_STORE_INT HWY_NEON_DEF_FUNCTION_ALL_TYPES +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64 +#define HWY_IF_STORE_INT(T, N) \ + hwy::EnableIf<N * sizeof(T) >= 8 && (N == 1 || sizeof(T) < 8)>* = nullptr +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple2<type##_t, size> tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved2, vst2, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple3<type##_t, size> tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved3, vst3, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple4<type##_t, size> tup, type##_t *to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved4, vst4, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#undef HWY_NEON_DEF_FUNCTION_STORE_INT +#undef HWY_NEON_BUILD_TPL_HWY_STORE_INT +#undef HWY_NEON_BUILD_RET_HWY_STORE_INT +#undef HWY_NEON_BUILD_ARG_HWY_STORE_INT +} // namespace detail + +template <typename T, size_t N, HWY_IF_STORE_INT(T, N)> +HWY_API void StoreInterleaved2(const Vec128<T, N> v0, const Vec128<T, N> v1, + Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple2<T, N> tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void StoreInterleaved2(const Vec128<T, N> v0, const Vec128<T, N> v1, + Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[2 * 8 / sizeof(T)]; + detail::Tuple2<T, N> tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, buf); + CopyBytes<N * 2 * sizeof(T)>(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void StoreInterleaved2(const Vec128<T> v0, const Vec128<T> v1, + Full128<T> d, T* HWY_RESTRICT unaligned) { + const Half<decltype(d)> dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, unaligned); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, unaligned + 2); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved3 + +template <typename T, size_t N, HWY_IF_STORE_INT(T, N)> +HWY_API void StoreInterleaved3(const Vec128<T, N> v0, const Vec128<T, N> v1, + const Vec128<T, N> v2, Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple3<T, N> tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void StoreInterleaved3(const Vec128<T, N> v0, const Vec128<T, N> v1, + const Vec128<T, N> v2, Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[3 * 8 / sizeof(T)]; + detail::Tuple3<T, N> tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, buf); + CopyBytes<N * 3 * sizeof(T)>(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void StoreInterleaved3(const Vec128<T> v0, const Vec128<T> v1, + const Vec128<T> v2, Full128<T> d, + T* HWY_RESTRICT unaligned) { + const Half<decltype(d)> dh; + StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, + unaligned); + StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, + unaligned + 3); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved4 + +template <typename T, size_t N, HWY_IF_STORE_INT(T, N)> +HWY_API void StoreInterleaved4(const Vec128<T, N> v0, const Vec128<T, N> v1, + const Vec128<T, N> v2, const Vec128<T, N> v3, + Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + detail::Tuple4<T, N> tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, unaligned); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void StoreInterleaved4(const Vec128<T, N> v0, const Vec128<T, N> v1, + const Vec128<T, N> v2, const Vec128<T, N> v3, + Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[4 * 8 / sizeof(T)]; + detail::Tuple4<T, N> tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, buf); + CopyBytes<N * 4 * sizeof(T)>(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void StoreInterleaved4(const Vec128<T> v0, const Vec128<T> v1, + const Vec128<T> v2, const Vec128<T> v3, + Full128<T> d, T* HWY_RESTRICT unaligned) { + const Half<decltype(d)> dh; + StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), + LowerHalf(dh, v3), dh, unaligned); + StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), + UpperHalf(dh, v3), dh, unaligned + 4); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_STORE_INT + +// ------------------------------ Lt128 + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Lt128(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + static_assert(!IsSigned<T>() && sizeof(T) == 8, "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128<T, N> eqHL = Eq(a, b); + const Vec128<T, N> ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const Vec128<T, N> ltLx = DupEven(ltHL); + const Vec128<T, N> outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Lt128Upper(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + const Vec128<T, N> ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Eq128(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + static_assert(!IsSigned<T>() && sizeof(T) == 8, "T must be u64"); + const Vec128<T, N> eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Eq128Upper(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + const Vec128<T, N> eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Ne128(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + static_assert(!IsSigned<T>() && sizeof(T) == 8, "T must be u64"); + const Vec128<T, N> neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Ne128Upper(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + const Vec128<T, N> neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template <class D> +HWY_INLINE VFromD<D> Min128(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Max128(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Min128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Max128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +namespace detail { // for code folding +#if HWY_ARCH_ARM_V7 +#undef vuzp1_s8 +#undef vuzp1_u8 +#undef vuzp1_s16 +#undef vuzp1_u16 +#undef vuzp1_s32 +#undef vuzp1_u32 +#undef vuzp1_f32 +#undef vuzp1q_s8 +#undef vuzp1q_u8 +#undef vuzp1q_s16 +#undef vuzp1q_u16 +#undef vuzp1q_s32 +#undef vuzp1q_u32 +#undef vuzp1q_f32 +#undef vuzp2_s8 +#undef vuzp2_u8 +#undef vuzp2_s16 +#undef vuzp2_u16 +#undef vuzp2_s32 +#undef vuzp2_u32 +#undef vuzp2_f32 +#undef vuzp2q_s8 +#undef vuzp2q_u8 +#undef vuzp2q_s16 +#undef vuzp2q_u16 +#undef vuzp2q_s32 +#undef vuzp2q_u32 +#undef vuzp2q_f32 +#undef vzip1_s8 +#undef vzip1_u8 +#undef vzip1_s16 +#undef vzip1_u16 +#undef vzip1_s32 +#undef vzip1_u32 +#undef vzip1_f32 +#undef vzip1q_s8 +#undef vzip1q_u8 +#undef vzip1q_s16 +#undef vzip1q_u16 +#undef vzip1q_s32 +#undef vzip1q_u32 +#undef vzip1q_f32 +#undef vzip2_s8 +#undef vzip2_u8 +#undef vzip2_s16 +#undef vzip2_u16 +#undef vzip2_s32 +#undef vzip2_u32 +#undef vzip2_f32 +#undef vzip2q_s8 +#undef vzip2q_u8 +#undef vzip2q_s16 +#undef vzip2q_u16 +#undef vzip2q_s32 +#undef vzip2q_u32 +#undef vzip2q_f32 +#endif + +#undef HWY_NEON_BUILD_ARG_1 +#undef HWY_NEON_BUILD_ARG_2 +#undef HWY_NEON_BUILD_ARG_3 +#undef HWY_NEON_BUILD_PARAM_1 +#undef HWY_NEON_BUILD_PARAM_2 +#undef HWY_NEON_BUILD_PARAM_3 +#undef HWY_NEON_BUILD_RET_1 +#undef HWY_NEON_BUILD_RET_2 +#undef HWY_NEON_BUILD_RET_3 +#undef HWY_NEON_BUILD_TPL_1 +#undef HWY_NEON_BUILD_TPL_2 +#undef HWY_NEON_BUILD_TPL_3 +#undef HWY_NEON_DEF_FUNCTION +#undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS +#undef HWY_NEON_DEF_FUNCTION_ALL_TYPES +#undef HWY_NEON_DEF_FUNCTION_FLOAT_64 +#undef HWY_NEON_DEF_FUNCTION_FULL_UI +#undef HWY_NEON_DEF_FUNCTION_INT_16 +#undef HWY_NEON_DEF_FUNCTION_INT_32 +#undef HWY_NEON_DEF_FUNCTION_INT_8 +#undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_INTS +#undef HWY_NEON_DEF_FUNCTION_INTS_UINTS +#undef HWY_NEON_DEF_FUNCTION_TPL +#undef HWY_NEON_DEF_FUNCTION_UIF81632 +#undef HWY_NEON_DEF_FUNCTION_UINT_16 +#undef HWY_NEON_DEF_FUNCTION_UINT_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_8 +#undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINTS +#undef HWY_NEON_EVAL +} // namespace detail + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/arm_sve-inl.h b/third_party/highway/hwy/ops/arm_sve-inl.h new file mode 100644 index 0000000000..5b83017172 --- /dev/null +++ b/third_party/highway/hwy/ops/arm_sve-inl.h @@ -0,0 +1,3186 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ARM SVE[2] vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include <arm_sve.h> +#include <stddef.h> +#include <stdint.h> + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +// If running on hardware whose vector length is known to be a power of two, we +// can skip fixups for non-power of two sizes. +#undef HWY_SVE_IS_POW2 +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +#define HWY_SVE_IS_POW2 1 +#else +#define HWY_SVE_IS_POW2 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <class V> +struct DFromV_t {}; // specialized in macros +template <class V> +using DFromV = typename DFromV_t<RemoveConst<V>>::type; + +template <class V> +using TFromV = TFromD<DFromV<V>>; + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// Unsigned: +#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 64, 32, NAME, OP) + +// Signed: +#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) + +// Float: +#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 16, 16, NAME, OP) +#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 64, 32, NAME, OP) + +// For all element sizes: +#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories for a given element size: +#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// Commonly used type categories: +#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +// Assemble types for use in x-macros +#define HWY_SVE_T(BASE, BITS) BASE##BITS##_t +#define HWY_SVE_D(BASE, BITS, N, POW2) Simd<HWY_SVE_T(BASE, BITS), N, POW2> +#define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t + +} // namespace detail + +#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <> \ + struct DFromV_t<HWY_SVE_V(BASE, BITS)> { \ + using type = ScalableTag<HWY_SVE_T(BASE, BITS)>; \ + }; + +HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) +#undef HWY_SPECIALIZE + +// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX +// instructions, and we anyway only use it when the predicate is ptrue. + +// vector = f(vector), e.g. Not +#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +// vector = f(vector, scalar), e.g. detail::AddN +#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS(a, b, c); \ + } + +// ------------------------------ Lanes + +namespace detail { + +// Returns actual lanes of a hardware vector without rounding to a power of two. +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<1> /* tag */) { + return svcntb_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<2> /* tag */) { + return svcnth_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<4> /* tag */) { + return svcntw_pat(SV_ALL); +} +HWY_INLINE size_t AllHardwareLanes(hwy::SizeTag<8> /* tag */) { + return svcntd_pat(SV_ALL); +} + +// All-true mask from a macro +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) + +#if HWY_SVE_IS_POW2 +#define HWY_SVE_PTRUE(BITS) HWY_SVE_ALL_PTRUE(BITS) +#else +#define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) + +// Returns actual lanes of a hardware vector, rounded down to a power of two. +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE size_t HardwareLanes() { + return svcntb_pat(SV_POW2); +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE size_t HardwareLanes() { + return svcnth_pat(SV_POW2); +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE size_t HardwareLanes() { + return svcntw_pat(SV_POW2); +} +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE size_t HardwareLanes() { + return svcntd_pat(SV_POW2); +} + +#endif // HWY_SVE_IS_POW2 + +} // namespace detail + +// Returns actual number of lanes after capping by N and shifting. May return 0 +// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). +#if HWY_TARGET == HWY_SVE_256 +template <typename T, size_t N, int kPow2> +HWY_API constexpr size_t Lanes(Simd<T, N, kPow2> /* d */) { + return HWY_MIN(detail::ScaleByPower(32 / sizeof(T), kPow2), N); +} +#elif HWY_TARGET == HWY_SVE2_128 +template <typename T, size_t N, int kPow2> +HWY_API constexpr size_t Lanes(Simd<T, N, kPow2> /* d */) { + return HWY_MIN(detail::ScaleByPower(16 / sizeof(T), kPow2), N); +} +#else +template <typename T, size_t N, int kPow2> +HWY_API size_t Lanes(Simd<T, N, kPow2> d) { + const size_t actual = detail::HardwareLanes<T>(); + // Common case of full vectors: avoid any extra instructions. + if (detail::IsFull(d)) return actual; + return HWY_MIN(detail::ScaleByPower(actual, kPow2), N); +} +#endif // HWY_TARGET + +// ================================================== MASK INIT + +// One mask bit per byte; only the one belonging to the lowest byte is valid. + +// ------------------------------ FirstN +#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ + const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ + return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast<uint32_t>(limit)); \ + } +HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) +#undef HWY_SVE_FIRSTN + +template <class D> +using MFromD = decltype(FirstN(D(), 0)); + +namespace detail { + +#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_PTRUE(BITS); \ + } \ + template <size_t N, int kPow2> \ + HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_ALL_PTRUE(BITS); \ + } + +HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true +#undef HWY_SVE_WRAP_PTRUE + +HWY_API svbool_t PFalse() { return svpfalse_b(); } + +// Returns all-true if d is HWY_FULL or FirstN(N) after capping N. +// +// This is used in functions that load/store memory; other functions (e.g. +// arithmetic) can ignore d and use PTrue instead. +template <class D> +svbool_t MakeMask(D d) { + return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d)); +} + +} // namespace detail + +// ================================================== INIT + +// ------------------------------ Set +// vector = f(d, scalar), e.g. Set +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) arg) { \ + return sv##OP##_##CHAR##BITS(arg); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) +#undef HWY_SVE_SET + +// Required for Zero and VFromD +template <size_t N, int kPow2> +svuint16_t Set(Simd<bfloat16_t, N, kPow2> d, bfloat16_t arg) { + return Set(RebindToUnsigned<decltype(d)>(), arg.bits); +} + +template <class D> +using VFromD = decltype(Set(D(), TFromD<D>())); + +// ------------------------------ Zero + +template <class D> +VFromD<D> Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) + +// ------------------------------ BitCast + +namespace detail { + +// u8: no change +#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } + +// All other types +#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u8_##CHAR##BITS(v); \ + } \ + template <size_t N, int kPow2> \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ + return sv##OP##_##CHAR##BITS##_u8(v); \ + } + +HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) +HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) + +#undef HWY_SVE_CAST_NOP +#undef HWY_SVE_CAST + +template <size_t N, int kPow2> +HWY_INLINE svuint16_t BitCastFromByte(Simd<bfloat16_t, N, kPow2> /* d */, + svuint8_t v) { + return BitCastFromByte(Simd<uint16_t, N, kPow2>(), v); +} + +} // namespace detail + +template <class D, class FromV> +HWY_API VFromD<D> BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ================================================== LOGICAL + +// detail::*N() functions accept a scalar argument to avoid extra Set(). + +// ------------------------------ Not +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT + +// ------------------------------ And + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V And(const V a, const V b) { + const DFromV<V> df; + const RebindToUnsigned<decltype(df)> du; + return BitCast(df, And(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Or + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V Or(const V a, const V b) { + const DFromV<V> df; + const RebindToUnsigned<decltype(df)> du; + return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V Xor(const V a, const V b) { + const DFromV<V> df; + const RebindToUnsigned<decltype(df)> du; + return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ AndNot + +namespace detail { +#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) +#undef HWY_SVE_RETV_ARGPVN_SWAP +} // namespace detail + +#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) +#undef HWY_SVE_RETV_ARGPVV_SWAP + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V AndNot(const V a, const V b) { + const DFromV<V> df; + const RebindToUnsigned<decltype(df)> du; + return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor3 + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V Xor3(const V x1, const V x2, const V x3) { + const DFromV<V> df; + const RebindToUnsigned<decltype(df)> du; + return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#else +template <class V> +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} +#endif + +// ------------------------------ Or3 +template <class V> +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template <class V> +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +// Need to return original type instead of unsigned. +#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return BitCast(DFromV<decltype(v)>(), \ + sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) +#undef HWY_SVE_POPCNT + +// ================================================== SIGN + +// ------------------------------ Neg +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) + +// ------------------------------ Abs +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) + +// ------------------------------ CopySign[ToAbs] + +template <class V> +HWY_API V CopySign(const V magn, const V sign) { + const auto msb = SignBit(DFromV<V>()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template <class V> +HWY_API V CopySignToAbs(const V abs, const V sign) { + const auto msb = SignBit(DFromV<V>()); + return Or(abs, And(msb, sign)); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) + +// ------------------------------ Sub + +namespace detail { +// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. +#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) +#undef HWY_SVE_RETV_ARGPVN_MASK +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) + +// ------------------------------ SumsOf8 +HWY_API svuint64_t SumsOf8(const svuint8_t v) { + const ScalableTag<uint32_t> du32; + const ScalableTag<uint64_t> du64; + const svbool_t pg = detail::PTrue(du64); + + const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); + // Compute pairwise sum of u32 and extend to u64. + // TODO(janwas): on SVE2, we can instead use svaddp. + const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) + const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); + return Add(hi, lo); +} + +// ------------------------------ SaturatedAdd + +HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) + +// ------------------------------ SaturatedSub + +HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) + +// ------------------------------ AbsDiff +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) + +// ------------------------------ ShiftLeft[Same] + +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <int kBits> \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, bits); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) + +// ------------------------------ ShiftRight[Same] + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) + +#undef HWY_SVE_SHIFT_N + +// ------------------------------ RotateRight + +// TODO(janwas): svxar on SVE2 +template <int kBits, class V> +HWY_API V RotateRight(const V v) { + constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v)); +} + +// ------------------------------ Shl/r + +#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToUnsigned<DFromV<decltype(v)>> du; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + BitCast(du, bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) + +#undef HWY_SVE_SHIFT + +// ------------------------------ Min/Max + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) +} // namespace detail + +// ------------------------------ Mul +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, Mul, mul) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_RETV_ARGPVV, Mul, mul) + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// ------------------------------ MulHigh +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +// Not part of API, used internally: +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) +HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) + +// ------------------------------ MulFixedPoint15 +HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + return svqrdmulh_s16(a, b); +#else + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + + const svuint16_t lo = BitCast(du, Mul(a, b)); + const svint16_t hi = MulHigh(a, b); + // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must + // carry that into the result. Instead isolate the top two bits because only + // they can influence the result. + const svuint16_t lo_top2 = ShiftRight<14>(lo); + // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. + const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); + return Add(Add(hi, hi), BitCast(d, rounding)); +#endif +} + +// ------------------------------ Div +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) + +// ------------------------------ ApproximateReciprocal +HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) + +// ------------------------------ Sqrt +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) + +// ------------------------------ ApproximateReciprocalSqrt +HWY_SVE_FOREACH_F32(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) + +// ------------------------------ MulAdd +#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulAdd, mad) + +// ------------------------------ NegMulAdd +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulAdd, msb) + +// ------------------------------ MulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) + +// ------------------------------ NegMulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) + +#undef HWY_SVE_FMA + +// ------------------------------ Round etc. + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) + +// ================================================== MASK + +// ------------------------------ RebindMask +template <class D, typename MFrom> +HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { + return mask; +} + +// ------------------------------ Mask logical + +HWY_API svbool_t Not(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Per ARM, such bits are ignored. + return svnot_b_z(HWY_SVE_PTRUE(8), m); +} +HWY_API svbool_t And(svbool_t a, svbool_t b) { + return svand_b_z(b, b, a); // same order as AndNot for consistency +} +HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { + return svbic_b_z(b, b, a); // reversed order like NEON +} +HWY_API svbool_t Or(svbool_t a, svbool_t b) { + return svsel_b(a, a, b); // a ? true : b +} +HWY_API svbool_t Xor(svbool_t a, svbool_t b) { + return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. +} + +HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { + return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. +} + +// ------------------------------ CountTrue + +#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ + return sv##OP##_b##BITS(detail::MakeMask(d), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) +#undef HWY_SVE_COUNT_TRUE + +// For 16-bit Compress: full vector, not limited to SV_POW2. +namespace detail { + +#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ + return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) +#undef HWY_SVE_COUNT_TRUE_FULL + +} // namespace detail + +// ------------------------------ AllFalse +template <class D> +HWY_API bool AllFalse(D d, svbool_t m) { + return !svptest_any(detail::MakeMask(d), m); +} + +// ------------------------------ AllTrue +template <class D> +HWY_API bool AllTrue(D d, svbool_t m) { + return CountTrue(d, m) == Lanes(d); +} + +// ------------------------------ FindFirstTrue +template <class D> +HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast<intptr_t>( + CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); +} + +// ------------------------------ FindKnownFirstTrue +template <class D> +HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { + return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); +} + +// ------------------------------ IfThenElse +#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(m, yes, no); \ + } + +HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +#undef HWY_SVE_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template <class V> +HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV<V>())); +} + +// ------------------------------ IfThenZeroElse +template <class V> +HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { + return IfThenElse(mask, Zero(DFromV<V>()), no); +} + +// ================================================== COMPARE + +// mask = f(vector, vector) +#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } + +// ------------------------------ Eq +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) +} // namespace detail + +// ------------------------------ Ne +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) +} // namespace detail + +// ------------------------------ Lt +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) +} // namespace detail + +// ------------------------------ Le +HWY_SVE_FOREACH_F(HWY_SVE_COMPARE, Le, cmple) + +#undef HWY_SVE_COMPARE +#undef HWY_SVE_COMPARE_N + +// ------------------------------ Gt/Ge (swapped order) +template <class V> +HWY_API svbool_t Gt(const V a, const V b) { + return Lt(b, a); +} +template <class V> +HWY_API svbool_t Ge(const V a, const V b) { + return Le(b, a); +} + +// ------------------------------ TestBit +template <class V> +HWY_API svbool_t TestBit(const V a, const V bit) { + return detail::NeN(And(a, bit), 0); +} + +// ------------------------------ MaskFromVec (Ne) +template <class V> +HWY_API svbool_t MaskFromVec(const V v) { + return detail::NeN(v, static_cast<TFromV<V>>(0)); +} + +// ------------------------------ VecFromMask +template <class D> +HWY_API VFromD<D> VecFromMask(const D d, svbool_t mask) { + const RebindToSigned<D> di; + // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which + // requires an extra instruction plus M0 pipeline. + return BitCast(d, IfThenElseZero(mask, Set(di, -1))); +} + +// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +#define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ + HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(yes, no, mask); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) +#undef HWY_SVE_IF_VEC + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + const DFromV<V> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast( + d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#else + +template <class V> +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +// ------------------------------ Floating-point classification (Ne) + +template <class V> +HWY_API svbool_t IsNaN(const V v) { + return Ne(v, v); // could also use cmpuo +} + +template <class V> +HWY_API svbool_t IsInf(const V v) { + using T = TFromV<V>; + const DFromV<decltype(v)> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqN(Add(vi, vi), hwy::MaxExponentTimes2<T>())); +} + +// Returns whether normal/subnormal/zero. +template <class V> +HWY_API svbool_t IsFinite(const V v) { + using T = TFromV<V>; + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + const VFromD<decltype(du)> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>())); +} + +// ================================================== MEMORY + +// ------------------------------ Load/MaskedLoad/LoadDup128/Store/Stream + +#define HWY_SVE_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(detail::MakeMask(d), p); \ + } + +#define HWY_SVE_MASKED_LOAD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return sv##OP##_##CHAR##BITS(m, p); \ + } + +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), p); \ + } + +#define HWY_SVE_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), p, v); \ + } + +#define HWY_SVE_BLENDED_STORE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(m, p, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_LOAD, Load, ld1) +HWY_SVE_FOREACH(HWY_SVE_MASKED_LOAD, MaskedLoad, ld1) +HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDup128, ld1rq) +HWY_SVE_FOREACH(HWY_SVE_STORE, Store, st1) +HWY_SVE_FOREACH(HWY_SVE_STORE, Stream, stnt1) +HWY_SVE_FOREACH(HWY_SVE_BLENDED_STORE, BlendedStore, st1) + +#undef HWY_SVE_LOAD +#undef HWY_SVE_MASKED_LOAD +#undef HWY_SVE_LOAD_DUP128 +#undef HWY_SVE_STORE +#undef HWY_SVE_BLENDED_STORE + +// BF16 is the same as svuint16_t because BF16 is optional before v8.6. +template <size_t N, int kPow2> +HWY_API svuint16_t Load(Simd<bfloat16_t, N, kPow2> d, + const bfloat16_t* HWY_RESTRICT p) { + return Load(RebindToUnsigned<decltype(d)>(), + reinterpret_cast<const uint16_t * HWY_RESTRICT>(p)); +} + +template <size_t N, int kPow2> +HWY_API void Store(svuint16_t v, Simd<bfloat16_t, N, kPow2> d, + bfloat16_t* HWY_RESTRICT p) { + Store(v, RebindToUnsigned<decltype(d)>(), + reinterpret_cast<uint16_t * HWY_RESTRICT>(p)); +} + +// ------------------------------ Load/StoreU + +// SVE only requires lane alignment, not natural alignment of the entire +// vector. +template <class D> +HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) { + return Load(d, p); +} + +template <class V, class D> +HWY_API void StoreU(const V v, D d, TFromD<D>* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ ScatterOffset/Index + +#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ + v); \ + } + +#define HWY_SVE_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME( \ + HWY_SVE_V(BASE, BITS) v, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, HWY_SVE_V(int, BITS) index) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, index, v); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_INDEX, ScatterIndex, st1_scatter) +#undef HWY_SVE_SCATTER_OFFSET +#undef HWY_SVE_SCATTER_INDEX + +// ------------------------------ GatherOffset/Index + +#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ + offset); \ + } +#define HWY_SVE_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) index) { \ + return sv##OP##_s##BITS##index_##CHAR##BITS(detail::MakeMask(d), base, \ + index); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather) +#undef HWY_SVE_GATHER_OFFSET +#undef HWY_SVE_GATHER_INDEX + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ + const sv##BASE##BITS##x2_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget2(tuple, 0); \ + v1 = svget2(tuple, 1); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) + +#undef HWY_SVE_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2) { \ + const sv##BASE##BITS##x3_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget3(tuple, 0); \ + v1 = svget3(tuple, 1); \ + v2 = svget3(tuple, 2); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) + +#undef HWY_SVE_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ + const sv##BASE##BITS##x4_t tuple = \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned); \ + v0 = svget4(tuple, 0); \ + v1 = svget4(tuple, 1); \ + v2 = svget4(tuple, 2); \ + v3 = svget4(tuple, 3); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) + +#undef HWY_SVE_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x2_t tuple = svcreate2##_##CHAR##BITS(v0, v1); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, tuple); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) + +#undef HWY_SVE_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x3_t triple = svcreate3##_##CHAR##BITS(v0, v1, v2); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, triple); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) + +#undef HWY_SVE_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + const sv##BASE##BITS##x4_t quad = \ + svcreate4##_##CHAR##BITS(v0, v1, v2, v3); \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, quad); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) + +#undef HWY_SVE_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// Same sign +#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) + +// 2x +template <size_t N, int kPow2> +HWY_API svuint32_t PromoteTo(Simd<uint32_t, N, kPow2> dto, svuint8_t vfrom) { + const RepartitionToWide<DFromV<decltype(vfrom)>> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template <size_t N, int kPow2> +HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svint8_t vfrom) { + const RepartitionToWide<DFromV<decltype(vfrom)>> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} + +// Sign change +template <size_t N, int kPow2> +HWY_API svint16_t PromoteTo(Simd<int16_t, N, kPow2> dto, svuint8_t vfrom) { + const RebindToUnsigned<decltype(dto)> du; + return BitCast(dto, PromoteTo(du, vfrom)); +} +template <size_t N, int kPow2> +HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svuint16_t vfrom) { + const RebindToUnsigned<decltype(dto)> du; + return BitCast(dto, PromoteTo(du, vfrom)); +} +template <size_t N, int kPow2> +HWY_API svint32_t PromoteTo(Simd<int32_t, N, kPow2> dto, svuint8_t vfrom) { + const Repartition<uint16_t, DFromV<decltype(vfrom)>> du16; + const Repartition<int16_t, decltype(du16)> di16; + return PromoteTo(dto, BitCast(di16, PromoteTo(du16, vfrom))); +} + +// ------------------------------ PromoteTo F + +// Unlike Highway's ZipLower, this returns the same type. +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) +} // namespace detail + +template <size_t N, int kPow2> +HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f32_f16_x(detail::PTrue(Simd<float16_t, N, kPow2>()), vv); +} + +template <size_t N, int kPow2> +HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f32_x(detail::PTrue(Simd<float32_t, N, kPow2>()), vv); +} + +template <size_t N, int kPow2> +HWY_API svfloat64_t PromoteTo(Simd<float64_t, N, kPow2> /* d */, + const svint32_t v) { + const svint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_s32_x(detail::PTrue(Simd<int32_t, N, kPow2>()), vv); +} + +// For 16-bit Compress +namespace detail { +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +#undef HWY_SVE_PROMOTE_TO + +template <size_t N, int kPow2> +HWY_API svfloat32_t PromoteUpperTo(Simd<float, N, kPow2> df, svfloat16_t v) { + const RebindToUnsigned<decltype(df)> du; + const RepartitionToNarrow<decltype(du)> dn; + return BitCast(df, PromoteUpperTo(du, BitCast(dn, v))); +} + +} // namespace detail + +// ------------------------------ DemoteTo U + +namespace detail { + +// Saturates unsigned vectors to half/quarter-width TN. +template <typename TN, class VU> +VU SaturateU(VU v) { + return detail::MinN(v, static_cast<TFromV<VU>>(LimitsMax<TN>())); +} + +// Saturates unsigned vectors to half/quarter-width TN. +template <typename TN, class VI> +VI SaturateI(VI v) { + return detail::MinN(detail::MaxN(v, LimitsMin<TN>()), LimitsMax<TN>()); +} + +} // namespace detail + +template <size_t N, int kPow2> +HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint16_t v) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + using TN = TFromD<decltype(dn)>; + // First clamp negative numbers to zero and cast to unsigned. + const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint8_t vn = BitCast(dn, detail::SaturateU<TN>(clamped)); + return svuzp1_u8(vn, vn); +} + +template <size_t N, int kPow2> +HWY_API svuint16_t DemoteTo(Simd<uint16_t, N, kPow2> dn, const svint32_t v) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + using TN = TFromD<decltype(dn)>; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint16_t vn = BitCast(dn, detail::SaturateU<TN>(clamped)); + return svuzp1_u16(vn, vn); +} + +template <size_t N, int kPow2> +HWY_API svuint8_t DemoteTo(Simd<uint8_t, N, kPow2> dn, const svint32_t v) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const RepartitionToNarrow<decltype(du)> d2; + using TN = TFromD<decltype(dn)>; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and quarter the width. + const svuint16_t cast16 = BitCast(d2, detail::SaturateU<TN>(clamped)); + const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); + return svuzp1_u8(x2, x2); +} + +HWY_API svuint8_t U8FromU32(const svuint32_t v) { + const DFromV<svuint32_t> du32; + const RepartitionToNarrow<decltype(du32)> du16; + const RepartitionToNarrow<decltype(du16)> du8; + + const svuint16_t cast16 = BitCast(du16, v); + const svuint16_t x2 = svuzp1_u16(cast16, cast16); + const svuint8_t cast8 = BitCast(du8, x2); + return svuzp1_u8(cast8, cast8); +} + +// ------------------------------ Truncations + +template <size_t N, int kPow2> +HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, + const svuint64_t v) { + const DFromV<svuint8_t> d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + const svuint8_t v3 = svuzp1_u8(v2, v2); + return svuzp1_u8(v3, v3); +} + +template <size_t N, int kPow2> +HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */, + const svuint64_t v) { + const DFromV<svuint16_t> d; + const svuint16_t v1 = BitCast(d, v); + const svuint16_t v2 = svuzp1_u16(v1, v1); + return svuzp1_u16(v2, v2); +} + +template <size_t N, int kPow2> +HWY_API svuint32_t TruncateTo(Simd<uint32_t, N, kPow2> /* tag */, + const svuint64_t v) { + const DFromV<svuint32_t> d; + const svuint32_t v1 = BitCast(d, v); + return svuzp1_u32(v1, v1); +} + +template <size_t N, int kPow2> +HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, + const svuint32_t v) { + const DFromV<svuint8_t> d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + return svuzp1_u8(v2, v2); +} + +template <size_t N, int kPow2> +HWY_API svuint16_t TruncateTo(Simd<uint16_t, N, kPow2> /* tag */, + const svuint32_t v) { + const DFromV<svuint16_t> d; + const svuint16_t v1 = BitCast(d, v); + return svuzp1_u16(v1, v1); +} + +template <size_t N, int kPow2> +HWY_API svuint8_t TruncateTo(Simd<uint8_t, N, kPow2> /* tag */, + const svuint16_t v) { + const DFromV<svuint8_t> d; + const svuint8_t v1 = BitCast(d, v); + return svuzp1_u8(v1, v1); +} + +// ------------------------------ DemoteTo I + +template <size_t N, int kPow2> +HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint16_t v) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); +#else + using TN = TFromD<decltype(dn)>; + const svint8_t vn = BitCast(dn, detail::SaturateI<TN>(v)); +#endif + return svuzp1_s8(vn, vn); +} + +template <size_t N, int kPow2> +HWY_API svint16_t DemoteTo(Simd<int16_t, N, kPow2> dn, const svint32_t v) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); +#else + using TN = TFromD<decltype(dn)>; + const svint16_t vn = BitCast(dn, detail::SaturateI<TN>(v)); +#endif + return svuzp1_s16(vn, vn); +} + +template <size_t N, int kPow2> +HWY_API svint8_t DemoteTo(Simd<int8_t, N, kPow2> dn, const svint32_t v) { + const RepartitionToWide<decltype(dn)> d2; +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); +#else + using TN = TFromD<decltype(dn)>; + const svint16_t cast16 = BitCast(d2, detail::SaturateI<TN>(v)); +#endif + const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); + return BitCast(dn, svuzp1_s8(v2, v2)); +} + +// ------------------------------ ConcatEven/ConcatOdd + +// WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the +// full vector length, not rounded down to a power of two as we require). +namespace detail { + +#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) +#endif +#undef HWY_SVE_CONCAT_EVERY_SECOND + +// Used to slide up / shift whole register left; mask indicates which range +// to take from lo, and the rest is filled from hi starting at its lowest. +#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) +#undef HWY_SVE_SPLICE + +} // namespace detail + +template <class D> +HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) { +#if HWY_SVE_IS_POW2 + (void)d; + return detail::ConcatOddFull(hi, lo); +#else + const VFromD<D> hi_odd = detail::ConcatOddFull(hi, hi); + const VFromD<D> lo_odd = detail::ConcatOddFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +#endif +} + +template <class D> +HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) { +#if HWY_SVE_IS_POW2 + (void)d; + return detail::ConcatEvenFull(hi, lo); +#else + const VFromD<D> hi_odd = detail::ConcatEvenFull(hi, hi); + const VFromD<D> lo_odd = detail::ConcatEvenFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +#endif +} + +// ------------------------------ DemoteTo F + +template <size_t N, int kPow2> +HWY_API svfloat16_t DemoteTo(Simd<float16_t, N, kPow2> d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template <size_t N, int kPow2> +HWY_API svuint16_t DemoteTo(Simd<bfloat16_t, N, kPow2> /* d */, svfloat32_t v) { + const svuint16_t in_even = BitCast(ScalableTag<uint16_t>(), v); + return detail::ConcatOddFull(in_even, in_even); // lower half +} + +template <size_t N, int kPow2> +HWY_API svfloat32_t DemoteTo(Simd<float32_t, N, kPow2> d, const svfloat64_t v) { + const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template <size_t N, int kPow2> +HWY_API svint32_t DemoteTo(Simd<int32_t, N, kPow2> d, const svfloat64_t v) { + const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +// ------------------------------ ConvertTo F + +#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* signed integers */ \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* unsigned integers */ \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Truncates (rounds toward zero). */ \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } + +// API only requires f32 but we provide f64 for use by Iota. +HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) +#undef HWY_SVE_CONVERT + +// ------------------------------ NearestInt (Round, ConvertTo) +template <class VF, class DI = RebindToSigned<DFromV<VF>>> +HWY_API VFromD<DI> NearestInt(VF v) { + // No single instruction, round then truncate. + return ConvertTo(DI(), Round(v)); +} + +// ------------------------------ Iota (Add, ConvertTo) + +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) first) { \ + return sv##OP##_##CHAR##BITS(first, 1); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) +#undef HWY_SVE_IOTA + +template <class D, HWY_IF_FLOAT_D(D)> +HWY_API VFromD<D> Iota(const D d, TFromD<D> first) { + const RebindToSigned<D> di; + return detail::AddN(ConvertTo(d, Iota(di, 0)), first); +} + +// ------------------------------ InterleaveLower + +template <class D, class V> +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch"); +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipLowerSame(a, b); +#else + // Move lower halves of blocks to lower half of vector. + const Repartition<uint64_t, decltype(d)> d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatEvenFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +template <class V> +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV<V>(), a, b); +} + +// ------------------------------ InterleaveUpper + +// Only use zip2 if vector are a powers of two, otherwise getting the actual +// "upper half" requires MaskUpperHalf. +#if HWY_TARGET == HWY_SVE2_128 +namespace detail { +// Unlike Highway's ZipUpper, this returns the same type. +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) +} // namespace detail +#endif + +// Full vector: guaranteed to have at least one block +template <class D, class V = VFromD<D>, + hwy::EnableIf<detail::IsFull(D())>* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipUpperSame(a, b); +#else + // Move upper halves of blocks to lower half of vector. + const Repartition<uint64_t, decltype(d)> d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatOddFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +// Capped/fraction: need runtime check +template <class D, class V = VFromD<D>, + hwy::EnableIf<!detail::IsFull(D())>* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { + // Less than one block: treat as capped + if (Lanes(d) * sizeof(TFromD<D>) < 16) { + const Half<decltype(d)> d2; + return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); + } + return InterleaveUpper(DFromV<V>(), a, b); +} + +// ================================================== COMBINE + +namespace detail { + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template <class D, HWY_IF_LANE_SIZE_D(D, 1)> +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 32: + return svptrue_pat_b8(SV_VL16); + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + default: + return svptrue_pat_b8(SV_VL1); + } +} +template <class D, HWY_IF_LANE_SIZE_D(D, 2)> +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 16: + return svptrue_pat_b16(SV_VL8); + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + default: + return svptrue_pat_b16(SV_VL1); + } +} +template <class D, HWY_IF_LANE_SIZE_D(D, 4)> +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 8: + return svptrue_pat_b32(SV_VL4); + case 4: + return svptrue_pat_b32(SV_VL2); + default: + return svptrue_pat_b32(SV_VL1); + } +} +template <class D, HWY_IF_LANE_SIZE_D(D, 8)> +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 4: + return svptrue_pat_b64(SV_VL2); + default: + return svptrue_pat_b64(SV_VL1); + } +} +#endif +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template <class D, HWY_IF_LANE_SIZE_D(D, 1)> +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b8(SV_VL1); + } +} +template <class D, HWY_IF_LANE_SIZE_D(D, 2)> +svbool_t MaskLowerHalf(D d) { + switch (Lanes(d)) { + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b16(SV_VL1); + } +} +template <class D, HWY_IF_LANE_SIZE_D(D, 4)> +svbool_t MaskLowerHalf(D d) { + return svptrue_pat_b32(Lanes(d) == 4 ? SV_VL2 : SV_VL1); +} +template <class D, HWY_IF_LANE_SIZE_D(D, 8)> +svbool_t MaskLowerHalf(D /*d*/) { + return svptrue_pat_b64(SV_VL1); +} +#endif // HWY_TARGET == HWY_SVE2_128 +#if HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128 +template <class D> +svbool_t MaskLowerHalf(D d) { + return FirstN(d, Lanes(d) / 2); +} +#endif + +template <class D> +svbool_t MaskUpperHalf(D d) { + // TODO(janwas): WHILEGE on pow2 SVE2 + if (HWY_SVE_IS_POW2 && IsFull(d)) { + return Not(MaskLowerHalf(d)); + } + + // For Splice to work as intended, make sure bits above Lanes(d) are zero. + return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); +} + +// Right-shift vector pair by constexpr; can be used to slide down (=N) or up +// (=Lanes()-N). +#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t kIndex> \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ + } +HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) +#undef HWY_SVE_EXT + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template <class D, class V> +HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { + return IfThenElse(detail::MaskLowerHalf(d), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template <class D, class V> +HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatEvenBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition<uint64_t, D> du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); +#endif + } + return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); +} + +// ------------------------------ ConcatLowerUpper +template <class D, class V> +HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes + if (detail::IsFull(d)) { + return detail::Ext<Lanes(d) / 2>(hi, lo); + } +#endif + return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); +} + +// ------------------------------ ConcatUpperUpper +template <class D, class V> +HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatOddBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition<uint64_t, D> du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); +#endif + } + const svbool_t mask_upper = detail::MaskUpperHalf(d); + const V lo_upper = detail::Splice(lo, lo, mask_upper); + return IfThenElse(mask_upper, hi, lo_upper); +} + +// ------------------------------ Combine +template <class D, class V2> +HWY_API VFromD<D> Combine(const D d, const V2 hi, const V2 lo) { + return ConcatLowerLower(d, hi, lo); +} + +// ------------------------------ ZeroExtendVector +template <class D, class V> +HWY_API V ZeroExtendVector(const D d, const V lo) { + return Combine(d, Zero(Half<D>()), lo); +} + +// ------------------------------ Lower/UpperHalf + +template <class D2, class V> +HWY_API V LowerHalf(D2 /* tag */, const V v) { + return v; +} + +template <class V> +HWY_API V LowerHalf(const V v) { + return v; +} + +template <class DH, class V> +HWY_API V UpperHalf(const DH dh, const V v) { + const Twice<decltype(dh)> d; + // Cast so that we support bfloat16_t. + const RebindToUnsigned<decltype(d)> du; + const VFromD<decltype(du)> vu = BitCast(du, v); +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 // constexpr Lanes + return BitCast(d, detail::Ext<Lanes(dh)>(vu, vu)); +#else + const MFromD<decltype(du)> mask = detail::MaskUpperHalf(du); + return BitCast(d, detail::Splice(vu, vu, mask)); +#endif +} + +// ================================================== REDUCE + +// These return T, whereas the Highway op returns a broadcasted vector. +namespace detail { +#define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ + using T = HWY_SVE_T(BASE, BITS); \ + using TU = MakeUnsigned<T>; \ + constexpr uint64_t kMask = LimitsMax<TU>(); \ + return static_cast<T>(static_cast<TU>( \ + static_cast<uint64_t>(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ + } + +#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(pg, v); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) +// NaN if all are +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) + +#undef HWY_SVE_REDUCE +#undef HWY_SVE_REDUCE_ADD +} // namespace detail + +template <class D, class V> +V SumOfLanes(D d, V v) { + return Set(d, detail::SumOfLanesM(detail::MakeMask(d), v)); +} + +template <class D, class V> +V MinOfLanes(D d, V v) { + return Set(d, detail::MinOfLanesM(detail::MakeMask(d), v)); +} + +template <class D, class V> +V MaxOfLanes(D d, V v) { + return Set(d, detail::MaxOfLanesM(detail::MakeMask(d), v)); +} + + +// ================================================== SWIZZLE + +// ------------------------------ GetLane + +namespace detail { +#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_T(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) +#undef HWY_SVE_GET_LANE +} // namespace detail + +template <class V> +HWY_API TFromV<V> GetLane(V v) { + return detail::GetLaneM(v, detail::PFalse()); +} + +// ------------------------------ ExtractLane +template <class V> +HWY_API TFromV<V> ExtractLane(V v, size_t i) { + return detail::GetLaneM(v, FirstN(DFromV<V>(), i)); +} + +// ------------------------------ InsertLane (IfThenElse) +template <class V> +HWY_API V InsertLane(const V v, size_t i, TFromV<V> t) { + const DFromV<V> d; + const auto is_i = detail::EqN(Iota(d, 0), static_cast<TFromV<V>>(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ DupEven + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) +} // namespace detail + +template <class V> +HWY_API V DupEven(const V v) { + return detail::InterleaveEven(v, v); +} + +// ------------------------------ DupOdd + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) +} // namespace detail + +template <class V> +HWY_API V DupOdd(const V v) { + return detail::InterleaveOdd(v, v); +} + +// ------------------------------ OddEven + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +#define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ + return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) +#undef HWY_SVE_ODD_EVEN + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V OddEven(const V odd, const V even) { + const DFromV<V> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); +} + +#else + +template <class V> +HWY_API V OddEven(const V odd, const V even) { + const auto odd_in_even = detail::Ext<1>(odd, odd); + return detail::InterleaveEven(even, odd_in_even); +} + +#endif // HWY_TARGET + +// ------------------------------ OddEvenBlocks +template <class V> +HWY_API V OddEvenBlocks(const V odd, const V even) { + const DFromV<V> d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperLower(d, odd, even); +#elif HWY_TARGET == HWY_SVE2_128 + (void)odd; + (void)d; + return even; +#else + const RebindToUnsigned<decltype(d)> du; + using TU = TFromD<decltype(du)>; + constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); + const auto idx_block = ShiftRight<kShift>(Iota(du, 0)); + const auto lsb = detail::AndN(idx_block, static_cast<TU>(1)); + const svbool_t is_even = detail::EqN(lsb, static_cast<TU>(0)); + return IfThenElse(is_even, even, odd); +#endif +} + +// ------------------------------ TableLookupLanes + +template <class D, class VI> +HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) { + using TI = TFromV<VI>; + static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index/lane size mismatch"); + const RebindToUnsigned<D> du; + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllTrue(du, detail::LtN(indices, static_cast<TI>(Lanes(d))))); +#else + (void)d; +#endif + return indices; +} + +template <class D, typename TI> +HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx)); +} + +// <32bit are not part of Highway API, but used in Broadcast. +#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(v, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) +#undef HWY_SVE_TABLE + +// ------------------------------ SwapAdjacentBlocks (TableLookupLanes) + +namespace detail { + +template <typename T, size_t N, int kPow2> +constexpr size_t LanesPerBlock(Simd<T, N, kPow2> /* tag */) { + // We might have a capped vector smaller than a block, so honor that. + return HWY_MIN(16 / sizeof(T), detail::ScaleByPower(N, kPow2)); +} + +} // namespace detail + +template <class V> +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV<V> d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerUpper(d, v, v); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#else + const RebindToUnsigned<decltype(d)> du; + constexpr auto kLanesPerBlock = + static_cast<TFromD<decltype(du)>>(detail::LanesPerBlock(d)); + const VFromD<decltype(du)> idx = detail::XorN(Iota(du, 0), kLanesPerBlock); + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ Reverse + +namespace detail { + +#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) +#undef HWY_SVE_REVERSE + +} // namespace detail + +template <class D, class V> +HWY_API V Reverse(D d, V v) { + using T = TFromD<D>; + const auto reversed = detail::ReverseFull(v); + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; + // Shift right to remove extra (non-pow2 and remainder) lanes. + // TODO(janwas): on SVE2, use WHILEGE. + // Avoids FirstN truncating to the return vector size. Must also avoid Not + // because that is limited to SV_POW2. + const ScalableTag<T> dfull; + const svbool_t all_true = detail::AllPTrue(dfull); + const size_t all_lanes = detail::AllHardwareLanes(hwy::SizeTag<sizeof(T)>()); + const svbool_t mask = + svnot_b_z(all_true, FirstN(dfull, all_lanes - Lanes(d))); + return detail::Splice(reversed, reversed, mask); +} + +// ------------------------------ Reverse2 + +template <class D, HWY_IF_LANE_SIZE_D(D, 2)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { + const RebindToUnsigned<decltype(d)> du; + const RepartitionToWide<decltype(du)> dw; + return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 4)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { + const RebindToUnsigned<decltype(d)> du; + const RepartitionToWide<decltype(du)> dw; + return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { // 3210 +#if HWY_TARGET == HWY_SVE2_128 + if (detail::IsFull(d)) { + return detail::Ext<1>(v, v); + } +#endif + (void)d; + const auto odd_in_even = detail::Ext<1>(v, v); // x321 + return detail::InterleaveEven(odd_in_even, v); // 2301 +} +// ------------------------------ Reverse4 (TableLookupLanes) +template <class D> +HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { + if (HWY_TARGET == HWY_SVE_256 && sizeof(TFromD<D>) == 8 && + detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned<decltype(d)> du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse8 (TableLookupLanes) +template <class D> +HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { + const RebindToUnsigned<decltype(d)> du; + const auto idx = detail::XorN(Iota(du, 0), 7); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Compress (PromoteTo) + +template <typename T> +struct CompressIsPartition { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + // Optimization for 64-bit lanes (could also be applied to 32-bit, but that + // requires a larger table). + enum { value = (sizeof(T) == 8) }; +#else + enum { value = 0 }; +#endif // HWY_TARGET == HWY_SVE_256 +}; + +#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) +HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) +#else +HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) +#endif +#undef HWY_SVE_COMPRESS + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template <class V, HWY_IF_LANE_SIZE_V(V, 8)> +HWY_API V Compress(V v, svbool_t mask) { + const DFromV<V> d; + const RebindToUnsigned<decltype(d)> du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompress64x4Tables + 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, + 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, + 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +} + +#endif // HWY_TARGET == HWY_SVE_256 +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template <class V, HWY_IF_LANE_SIZE_V(V, 8)> +HWY_API V Compress(V v, svbool_t mask) { + // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 + // unchanged and map everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(maskLL, mask)); +} + +#endif // HWY_TARGET == HWY_SVE2_128 + +template <class V, HWY_IF_LANE_SIZE_V(V, 2)> +HWY_API V Compress(V v, svbool_t mask16) { + static_assert(!IsSame<V, svfloat16_t>(), "Must use overload"); + const DFromV<V> d16; + + // Promote vector and mask to 32-bit + const RepartitionToWide<decltype(d16)> dw; + const auto v32L = PromoteTo(dw, v); + const auto v32H = detail::PromoteUpperTo(dw, v); + const svbool_t mask32L = svunpklo_b(mask16); + const svbool_t mask32H = svunpkhi_b(mask16); + + const auto compressedL = Compress(v32L, mask32L); + const auto compressedH = Compress(v32H, mask32H); + + // Demote to 16-bit (already in range) - separately so we can splice + const V evenL = BitCast(d16, compressedL); + const V evenH = BitCast(d16, compressedH); + const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half + const V v16H = detail::ConcatEvenFull(evenH, evenH); + + // We need to combine two vectors of non-constexpr length, so the only option + // is Splice, which requires us to synthesize a mask. NOTE: this function uses + // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. + const size_t countL = detail::CountTrueFull(dw, mask32L); + const auto compressed_maskL = FirstN(d16, countL); + return detail::Splice(v16H, v16L, compressed_maskL); +} + +// Must treat float16_t as integers so we can ConcatEven. +HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { + const DFromV<decltype(v)> df; + const RebindToSigned<decltype(df)> di; + return BitCast(df, Compress(BitCast(di, v), mask16)); +} + +// ------------------------------ CompressNot + +// 2 or 4 bytes +template <class V, typename T = TFromV<V>, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API V CompressNot(V v, const svbool_t mask) { + return Compress(v, Not(mask)); +} + +template <class V, HWY_IF_LANE_SIZE_V(V, 8)> +HWY_API V CompressNot(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map + // 01 to 10, and everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(mask, maskLL)); +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + const DFromV<V> d; + const RebindToUnsigned<decltype(d)> du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompressNot64x4Tables + 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, + 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, + 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif // HWY_TARGET == HWY_SVE_256 + + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 + (void)mask; + return v; +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + uint64_t bits = 0; // predicate reg is 32-bit + CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient + // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. + const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); + // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. + alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, + 0, 1, 2, 3, 0, 1, 2, 3}; + const ScalableTag<uint64_t> d; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif + + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template <class V, class D, HWY_IF_NOT_LANE_SIZE_D(D, 1)> +HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, + TFromD<D>* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template <class V, class D, HWY_IF_NOT_LANE_SIZE_D(D, 1)> +HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, + TFromD<D>* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + const svbool_t store_mask = FirstN(d, count); + BlendedStore(Compress(v, mask), store_mask, d, unaligned); + return count; +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes + +// Prevent accidentally using these for 128-bit vectors - should not be +// necessary. +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template <class D, class V> +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned<TFromD<D>>; + return detail::AndNotN(static_cast<T>(LanesPerBlock(d) - 1), iota0); +} + +template <size_t kLanes, class D, HWY_IF_LANE_SIZE_D(D, 1)> +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned<decltype(d)> du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint8_t idx_mod = + svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, + 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, + 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, + 15 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template <size_t kLanes, class D, HWY_IF_LANE_SIZE_D(D, 2)> +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned<decltype(d)> du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint16_t idx_mod = + svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template <size_t kLanes, class D, HWY_IF_LANE_SIZE_D(D, 4)> +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned<decltype(d)> du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint32_t idx_mod = + svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template <size_t kLanes, class D, HWY_IF_LANE_SIZE_D(D, 8)> +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned<decltype(d)> du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint64_t idx_mod = + svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +template <size_t kBytes, class D, class V = VFromD<D>> +HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { + const Repartition<uint8_t, decltype(d)> d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, detail::Ext<kBytes>(hi8, lo8)); +#else + const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); + const auto lo_down = detail::Ext<kBytes>(lo8, lo8); + const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +#endif +} + +// ------------------------------ Shuffle2301 +template <class V> +HWY_API V Shuffle2301(const V v) { + const DFromV<V> d; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + return Reverse2(d, v); +} + +// ------------------------------ Shuffle2103 +template <class V> +HWY_API V Shuffle2103(const V v) { + const DFromV<V> d; + const Repartition<uint8_t, decltype(d)> d8; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0321 +template <class V> +HWY_API V Shuffle0321(const V v) { + const DFromV<V> d; + const Repartition<uint8_t, decltype(d)> d8; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); +} + +// ------------------------------ Shuffle1032 +template <class V> +HWY_API V Shuffle1032(const V v) { + const DFromV<V> d; + const Repartition<uint8_t, decltype(d)> d8; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle01 +template <class V> +HWY_API V Shuffle01(const V v) { + const DFromV<V> d; + const Repartition<uint8_t, decltype(d)> d8; + static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0123 +template <class V> +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template <class D, class V = VFromD<D>> +HWY_API V ReverseBlocks(D d, V v) { +#if HWY_TARGET == HWY_SVE_256 + if (detail::IsFull(d)) { + return SwapAdjacentBlocks(v); + } else if (detail::IsFull(Twice<D>())) { + return v; + } +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#endif + const Repartition<uint64_t, D> du64; + return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); +} + +// ------------------------------ TableLookupBytes + +template <class V, class VI> +HWY_API VI TableLookupBytes(const V v, const VI idx) { + const DFromV<VI> d; + const Repartition<uint8_t, decltype(d)> du8; +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); +#else + const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); + const auto idx8 = Add(BitCast(du8, idx), offsets128); + return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); +#endif +} + +template <class V, class VI> +HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { + const DFromV<VI> d; + // Mask size must match vector type, so cast everything to this type. + const Repartition<int8_t, decltype(d)> di8; + + auto idx8 = BitCast(di8, idx); + const auto msb = detail::LtN(idx8, 0); + + const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Broadcast + +#if HWY_TARGET == HWY_SVE2_128 +namespace detail { +#define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <int kLane> \ + HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v, kLane); \ + } + +HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) +#undef HWY_SVE_BROADCAST +} // namespace detail +#endif + +template <int kLane, class V> +HWY_API V Broadcast(const V v) { + const DFromV<V> d; + const RebindToUnsigned<decltype(d)> du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); +#if HWY_TARGET == HWY_SVE2_128 + return detail::BroadcastLane<kLane>(v); +#else + auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); + if (kLane != 0) { + idx = detail::AddN(idx, kLane); + } + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ ShiftLeftLanes + +template <size_t kLanes, class D, class V = VFromD<D>> +HWY_API V ShiftLeftLanes(D d, const V v) { + const auto zero = Zero(d); + const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); +#if HWY_TARGET == HWY_SVE2_128 + return shifted; +#else + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + return IfThenElse(detail::FirstNPerBlock<kLanes>(d), zero, shifted); +#endif +} + +template <size_t kLanes, class V> +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes<kLanes>(DFromV<V>(), v); +} + +// ------------------------------ ShiftRightLanes +template <size_t kLanes, class D, class V = VFromD<D>> +HWY_API V ShiftRightLanes(D d, V v) { + // For capped/fractional vectors, clear upper lanes so we shift in zeros. + if (!detail::IsFull(d)) { + v = IfThenElseZero(detail::MakeMask(d), v); + } + +#if HWY_TARGET == HWY_SVE2_128 + return detail::Ext<kLanes>(Zero(d), v); +#else + const auto shifted = detail::Ext<kLanes>(v, v); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const svbool_t mask = detail::FirstNPerBlock<kLanesPerBlock - kLanes>(d); + return IfThenElseZero(mask, shifted); +#endif +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, class D, class V = VFromD<D>> +HWY_API V ShiftLeftBytes(const D d, const V v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v))); +} + +template <int kBytes, class V> +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes<kBytes>(DFromV<V>(), v); +} + +// ------------------------------ ShiftRightBytes +template <int kBytes, class D, class V = VFromD<D>> +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v))); +} + +// ------------------------------ ZipLower + +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow<DW> dn; + static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipLower(const V a, const V b) { + return BitCast(DW(), InterleaveLower(D(), a, b)); +} + +// ------------------------------ ZipUpper +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow<DW> dn; + static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ PromoteTo bfloat16 (ZipLower) +template <size_t N, int kPow2> +HWY_API svfloat32_t PromoteTo(Simd<float32_t, N, kPow2> df32, + const svuint16_t v) { + return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), v)); +} + +// ------------------------------ ReorderDemote2To (OddEven) + +template <size_t N, int kPow2> +HWY_API svuint16_t ReorderDemote2To(Simd<bfloat16_t, N, kPow2> dbf16, + svfloat32_t a, svfloat32_t b) { + const RebindToUnsigned<decltype(dbf16)> du16; + const Repartition<uint32_t, decltype(dbf16)> du32; + const svuint32_t b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +template <size_t N, int kPow2> +HWY_API svint16_t ReorderDemote2To(Simd<int16_t, N, kPow2> d16, svint32_t a, + svint32_t b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + (void)d16; + const svint16_t a_in_even = svqxtnb_s32(a); + return svqxtnt_s32(a_in_even, b); +#else + const Half<decltype(d16)> dh; + const svint16_t a16 = BitCast(dh, detail::SaturateI<int16_t>(a)); + const svint16_t b16 = BitCast(dh, detail::SaturateI<int16_t>(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +// ------------------------------ ZeroIfNegative (Lt, IfThenElse) +template <class V> +HWY_API V ZeroIfNegative(const V v) { + return IfThenZeroElse(detail::LtN(v, 0), v); +} + +// ------------------------------ BroadcastSignBit (ShiftRight) +template <class V> +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template <class V> +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float"); + const DFromV<V> d; + const RebindToSigned<decltype(d)> di; + + const svbool_t m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ AverageRound (ShiftRight) + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +#else +template <class V> +V AverageRound(const V a, const V b) { + return ShiftRight<1>(detail::AddN(Add(a, b), 1)); +} +#endif // HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <class D, HWY_IF_LANE_SIZE_D(D, 1)> +HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned<D> du; + const svuint8_t iota = Iota(du, 0); + + // Load correct number of bytes (bits/8) with 7 zeros after each. + const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 2)> +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned<D> du; + const Repartition<uint8_t, D> du8; + + // There may be up to 128 bits; avoid reading past the end. + const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); + + // Replicate bytes 16x such that each lane contains the bit that governs it. + const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, rep16), bit); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 4)> +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned<D> du; + const Repartition<uint8_t, D> du8; + + // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, + // so we can skip computing the actual length (Lanes(du)+7)/8. + const svuint8_t bytes = svld1(FirstN(du8, 8), bits); + + // Replicate bytes 32x such that each lane contains the bit that governs it. + const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); + + // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. + const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); + + return TestBit(BitCast(du, rep32), bit); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 8)> +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned<D> du; + + // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. + // The "at least 8 byte" guarantee in quick_reference ensures this is safe. + uint32_t mask_bits; + CopyBytes<4>(bits, &mask_bits); // copy from bytes + const auto vbits = Set(du, mask_bits); + + // 2 ^ {0,1, .., 31}, will not have more lanes than that. + const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); + + return TestBit(vbits, bit); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. +template <class T, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return svdup_n_u8_z(m, 1); +} +template <class T, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag<uint8_t> d8; + const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); + return detail::ConcatEvenFull(b16, b16); // lower half +} +template <class T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return U8FromU32(svdup_n_u32_z(m, 1)); +} +template <class T, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag<uint32_t> d32; + const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); + return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half +} + +// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. +HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { + const ScalableTag<uint8_t> d8; + const ScalableTag<uint16_t> d16; + const ScalableTag<uint32_t> d32; + const ScalableTag<uint64_t> d64; + // TODO(janwas): could use SVE2 BDEP, but it's optional. + x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); + x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); + x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); + return BitCast(d64, x); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +// TODO(janwas): specialize for HWY_SVE_256 +template <class D> +HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { + svuint64_t bits_in_u64 = + detail::BitsFromBool(detail::BoolFromMask<TFromD<D>>(m)); + + const size_t num_bits = Lanes(d); + const size_t num_bytes = (num_bits + 8 - 1) / 8; // Round up, see below + + // Truncate each u64 to 8 bits and store to u8. + svst1b_u64(FirstN(ScalableTag<uint64_t>(), num_bytes), bits, bits_in_u64); + + // Non-full byte, need to clear the undefined upper bits. Can happen for + // capped/fractional vectors or large T and small hardware vectors. + if (num_bits < 8) { + const int mask = static_cast<int>((1ull << num_bits) - 1); + bits[0] = static_cast<uint8_t>(bits[0] & mask); + } + // Else: we wrote full bytes because num_bits is a power of two >= 8. + + return num_bytes; +} + +// ------------------------------ CompressBits (LoadMaskBits) +template <class V, class D = DFromV<V>, HWY_IF_NOT_LANE_SIZE_D(D, 1)> +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(D(), bits)); +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) +template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 1)> +HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD<D>* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ MulEven (InterleaveEven) + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +namespace detail { +#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +#undef HWY_SVE_MUL_EVEN +} // namespace detail +#endif + +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> MulEven(const V a, const V b) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + return BitCast(DW(), detail::MulEvenNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveEven(lo, hi)); +#endif +} + +HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template <size_t N, int kPow2> +HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd<float, N, kPow2> df32, + svuint16_t a, svuint16_t b, + const svfloat32_t sum0, + svfloat32_t& sum1) { + // TODO(janwas): svbfmlalb_f32 if __ARM_FEATURE_SVE_BF16. + const RebindToUnsigned<decltype(df32)> du32; + // Using shift/and instead of Zip leads to the odd/even order that + // RearrangeToOddPlusEven prefers. + using VU32 = VFromD<decltype(du32)>; + const VU32 odd = Set(du32, 0xFFFF0000u); + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +template <size_t N, int kPow2> +HWY_API svint32_t ReorderWidenMulAccumulate(Simd<int32_t, N, kPow2> d32, + svint16_t a, svint16_t b, + const svint32_t sum0, + svint32_t& sum1) { +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 + (void)d32; + sum1 = svmlalt_s32(sum1, a, b); + return svmlalb_s32(sum0, a, b); +#else + const svbool_t pg = detail::PTrue(d32); + // Shifting extracts the odd lanes as RearrangeToOddPlusEven prefers. + // Fortunately SVE has sign-extension for the even lanes. + const svint32_t ae = svexth_s32_x(pg, BitCast(d32, a)); + const svint32_t be = svexth_s32_x(pg, BitCast(d32, b)); + const svint32_t ao = ShiftRight<16>(BitCast(d32, a)); + const svint32_t bo = ShiftRight<16>(BitCast(d32, b)); + sum1 = svmla_s32_x(pg, sum1, ao, bo); + return svmla_s32_x(pg, sum0, ae, be); +#endif +} + +// ------------------------------ RearrangeToOddPlusEven +template <class VW> +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes. + return Add(sum0, sum1); +} + +// ------------------------------ AESRound / CLMul + +#if defined(__ARM_FEATURE_SVE2_AES) || \ + ((HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128) && \ + HWY_HAVE_RUNTIME_DISPATCH) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { + // It is not clear whether E and MC fuse like they did on NEON. + const svuint8_t zero = svdup_n_u8(0); + return Xor(svaesmc_u8(svaese_u8(state, zero)), round_key); +} + +HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { + return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); +} + +HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { + return svpmullb_pair(a, b); +} + +HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { + return svpmullt_pair(a, b); +} + +#endif // __ARM_FEATURE_SVE2_AES + +// ------------------------------ Lt128 + +namespace detail { +#define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <size_t N, int kPow2> \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ + return sv##OP##_b##BITS(m, m); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool +#undef HWY_SVE_DUP + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template <class D> +HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. + const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); + // Duplicate upper lane into lower. + return DupOdd(ltHx); +} +#endif +} // namespace detail + +template <class D> +HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Lt128Vec(d, a, b)); +#else + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + const svbool_t ltHL = Lt(a, b); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); + // Duplicate upper lane into lower. + return detail::DupOddB(d, ltHx); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Lt128Upper + +template <class D> +HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t ltHL = Lt(a, b); + return detail::DupOddB(d, ltHL); +} + +// ------------------------------ Eq128, Ne128 + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +namespace detail { + +template <class D> +HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); + // Duplicate upper and lower. + const svuint64_t eqHH = DupOdd(eqHL); + const svuint64_t eqLL = DupEven(eqHL); + return And(eqLL, eqHH); +} + +template <class D> +HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t neHL = VecFromMask(d, Ne(a, b)); + // Duplicate upper and lower. + const svuint64_t neHH = DupOdd(neHL); + const svuint64_t neLL = DupEven(neHL); + return Or(neLL, neHH); +} + +} // namespace detail +#endif + +template <class D> +HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Eq128Vec(d, a, b)); +#else + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t eqHL = Eq(a, b); + const svbool_t eqHH = detail::DupOddB(d, eqHL); + const svbool_t eqLL = detail::DupEvenB(d, eqHL); + return And(eqLL, eqHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +template <class D> +HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Ne128Vec(d, a, b)); +#else + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t neHL = Ne(a, b); + const svbool_t neHH = detail::DupOddB(d, neHL); + const svbool_t neLL = detail::DupEvenB(d, neHL); + return Or(neLL, neHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Eq128Upper, Ne128Upper + +template <class D> +HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t eqHL = Eq(a, b); + return detail::DupOddB(d, eqHL); +} + +template <class D> +HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const svbool_t neHL = Ne(a, b); + return detail::DupOddB(d, neHL); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template <class D> +HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +#else + return IfThenElse(Lt128(d, a, b), a, b); +#endif +} + +template <class D> +HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +#else + return IfThenElse(Lt128(d, b, a), a, b); +#endif +} + +template <class D> +HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template <class D> +HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_IF_FLOAT_V +#undef HWY_IF_LANE_SIZE_V +#undef HWY_SVE_ALL_PTRUE +#undef HWY_SVE_D +#undef HWY_SVE_FOREACH +#undef HWY_SVE_FOREACH_F +#undef HWY_SVE_FOREACH_F16 +#undef HWY_SVE_FOREACH_F32 +#undef HWY_SVE_FOREACH_F64 +#undef HWY_SVE_FOREACH_I +#undef HWY_SVE_FOREACH_I08 +#undef HWY_SVE_FOREACH_I16 +#undef HWY_SVE_FOREACH_I32 +#undef HWY_SVE_FOREACH_I64 +#undef HWY_SVE_FOREACH_IF +#undef HWY_SVE_FOREACH_U +#undef HWY_SVE_FOREACH_U08 +#undef HWY_SVE_FOREACH_U16 +#undef HWY_SVE_FOREACH_U32 +#undef HWY_SVE_FOREACH_U64 +#undef HWY_SVE_FOREACH_UI +#undef HWY_SVE_FOREACH_UI08 +#undef HWY_SVE_FOREACH_UI16 +#undef HWY_SVE_FOREACH_UI32 +#undef HWY_SVE_FOREACH_UI64 +#undef HWY_SVE_FOREACH_UIF3264 +#undef HWY_SVE_PTRUE +#undef HWY_SVE_RETV_ARGPV +#undef HWY_SVE_RETV_ARGPVN +#undef HWY_SVE_RETV_ARGPVV +#undef HWY_SVE_RETV_ARGV +#undef HWY_SVE_RETV_ARGVN +#undef HWY_SVE_RETV_ARGVV +#undef HWY_SVE_RETV_ARGVVV +#undef HWY_SVE_T +#undef HWY_SVE_UNDEFINED +#undef HWY_SVE_V + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/emu128-inl.h b/third_party/highway/hwy/ops/emu128-inl.h new file mode 100644 index 0000000000..7fb934def0 --- /dev/null +++ b/third_party/highway/hwy/ops/emu128-inl.h @@ -0,0 +1,2503 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include <stddef.h> +#include <stdint.h> +#include <cmath> // std::abs, std::isnan + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename T> +using Full128 = Simd<T, 16 / sizeof(T), 0>; + +// (Wrapper class required for overloading comparison operators.) +template <typename T, size_t N = 16 / sizeof(T)> +struct Vec128 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() = default; + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h + // relies on this for LoadInterleaved*. CAVEAT: this method of padding + // prevents using range for, especially in SumOfLanes, where it would be + // incorrect. Moving padding to another field would require handling the case + // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. + T raw[16 / sizeof(T)] = {}; +}; + +// 0 or FF..FF, same size as Vec128. +template <typename T, size_t N = 16 / sizeof(T)> +struct Mask128 { + using Raw = hwy::MakeUnsigned<T>; + static HWY_INLINE Raw FromBool(bool b) { + return b ? static_cast<Raw>(~Raw{0}) : 0; + } + + // Must match the size of Vec128. + Raw bits[16 / sizeof(T)] = {}; +}; + +template <class V> +using DFromV = Simd<typename V::PrivateT, V::kPrivateN, 0>; + +template <class V> +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template <typename T, size_t N, typename FromT, size_t FromN> +HWY_API Vec128<T, N> BitCast(Simd<T, N, 0> /* tag */, Vec128<FromT, FromN> v) { + Vec128<T, N> to; + CopySameSize(&v, &to); + return to; +} + +// ------------------------------ Set + +template <typename T, size_t N> +HWY_API Vec128<T, N> Zero(Simd<T, N, 0> /* tag */) { + Vec128<T, N> v; + ZeroBytes<sizeof(T) * N>(v.raw); + return v; +} + +template <class D> +using VFromD = decltype(Zero(D())); + +template <typename T, size_t N, typename T2> +HWY_API Vec128<T, N> Set(Simd<T, N, 0> /* tag */, const T2 t) { + Vec128<T, N> v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(t); + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Undefined(Simd<T, N, 0> d) { + return Zero(d); +} + +template <typename T, size_t N, typename T2> +HWY_API Vec128<T, N> Iota(const Simd<T, N, 0> /* tag */, T2 first) { + Vec128<T, N> v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = + AddWithWraparound(hwy::IsFloatTag<T>(), static_cast<T>(first), i); + } + return v; +} + +// ================================================== LOGICAL + +// ------------------------------ Not +template <typename T, size_t N> +HWY_API Vec128<T, N> Not(const Vec128<T, N> v) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + using TU = TFromD<decltype(du)>; + VFromD<decltype(du)> vu = BitCast(du, v); + for (size_t i = 0; i < N; ++i) { + vu.raw[i] = static_cast<TU>(~vu.raw[i]); + } + return BitCast(d, vu); +} + +// ------------------------------ And +template <typename T, size_t N> +HWY_API Vec128<T, N> And(const Vec128<T, N> a, const Vec128<T, N> b) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] &= bu.raw[i]; + } + return BitCast(d, au); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> operator&(const Vec128<T, N> a, const Vec128<T, N> b) { + return And(a, b); +} + +// ------------------------------ AndNot +template <typename T, size_t N> +HWY_API Vec128<T, N> AndNot(const Vec128<T, N> a, const Vec128<T, N> b) { + return And(Not(a), b); +} + +// ------------------------------ Or +template <typename T, size_t N> +HWY_API Vec128<T, N> Or(const Vec128<T, N> a, const Vec128<T, N> b) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] |= bu.raw[i]; + } + return BitCast(d, au); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> operator|(const Vec128<T, N> a, const Vec128<T, N> b) { + return Or(a, b); +} + +// ------------------------------ Xor +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor(const Vec128<T, N> a, const Vec128<T, N> b) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] ^= bu.raw[i]; + } + return BitCast(d, au); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> operator^(const Vec128<T, N> a, const Vec128<T, N> b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 + +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template <typename T, size_t N> +HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template <typename T, size_t N> +HWY_API Vec128<T, N> OrAnd(const Vec128<T, N> o, const Vec128<T, N> a1, + const Vec128<T, N> a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template <typename T, size_t N> +HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +// ------------------------------ CopySign +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySign(const Vec128<T, N> magn, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + const auto msb = SignBit(Simd<T, N, 0>()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySignToAbs(const Vec128<T, N> abs, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Simd<T, N, 0>()), sign)); +} + +// ------------------------------ BroadcastSignBit +template <typename T, size_t N> +HWY_API Vec128<T, N> BroadcastSignBit(Vec128<T, N> v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[i] < 0 ? T(-1) : T(0); + } + return v; +} + +// ------------------------------ Mask + +template <typename TFrom, typename TTo, size_t N> +HWY_API Mask128<TTo, N> RebindMask(Simd<TTo, N, 0> /*tag*/, + Mask128<TFrom, N> mask) { + Mask128<TTo, N> to; + CopySameSize(&mask, &to); + return to; +} + +// v must be 0 or FF..FF. +template <typename T, size_t N> +HWY_API Mask128<T, N> MaskFromVec(const Vec128<T, N> v) { + Mask128<T, N> mask; + CopySameSize(&v, &mask); + return mask; +} + +template <typename T, size_t N> +Vec128<T, N> VecFromMask(const Mask128<T, N> mask) { + Vec128<T, N> v; + CopySameSize(&mask, &v); + return v; +} + +template <typename T, size_t N> +Vec128<T, N> VecFromMask(Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + return VecFromMask(mask); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> FirstN(Simd<T, N, 0> /*tag*/, size_t n) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(i < n); + } + return m; +} + +// Returns mask ? yes : no. +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElse(const Mask128<T, N> mask, + const Vec128<T, N> yes, const Vec128<T, N> no) { + return IfVecThenElse(VecFromMask(mask), yes, no); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElseZero(const Mask128<T, N> mask, + const Vec128<T, N> yes) { + return IfVecThenElse(VecFromMask(mask), yes, Zero(Simd<T, N, 0>())); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenZeroElse(const Mask128<T, N> mask, + const Vec128<T, N> no) { + return IfVecThenElse(VecFromMask(mask), Zero(Simd<T, N, 0>()), no); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes, + Vec128<T, N> no) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[i] < 0 ? yes.raw[i] : no.raw[i]; + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ZeroIfNegative(const Vec128<T, N> v) { + return IfNegativeThenElse(v, Zero(Simd<T, N, 0>()), v); +} + +// ------------------------------ Mask logical + +template <typename T, size_t N> +HWY_API Mask128<T, N> Not(const Mask128<T, N> m) { + return MaskFromVec(Not(VecFromMask(Simd<T, N, 0>(), m))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> And(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> AndNot(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Or(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Xor(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template <int kBits, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeft(Vec128<T, N> v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast<hwy::MakeUnsigned<T>>(v.raw[i]) << kBits; + v.raw[i] = static_cast<T>(shifted); + } + return v; +} + +template <int kBits, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRight(Vec128<T, N> v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(v.raw[i] >> kBits); + } +#else + if (IsSigned<T>()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned<T>; + for (size_t i = 0; i < N; ++i) { + const TU shifted = static_cast<TU>(static_cast<TU>(v.raw[i]) >> kBits); + const TU sign = v.raw[i] < 0 ? static_cast<TU>(~TU{0}) : 0; + const size_t sign_shift = + static_cast<size_t>(static_cast<int>(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast<TU>(sign << sign_shift); + v.raw[i] = static_cast<T>(shifted | upper); + } + } else { // T is unsigned + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(v.raw[i] >> kBits); + } + } +#endif + return v; +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template <int kBits> +struct RotateRight { + template <typename T, size_t N> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> v) const { + return Or(ShiftRight<kBits>(v), ShiftLeft<sizeof(T) * 8 - kBits>(v)); + } +}; + +template <> +struct RotateRight<0> { + template <typename T, size_t N> + HWY_INLINE Vec128<T, N> operator()(const Vec128<T, N> v) const { + return v; + } +}; + +} // namespace detail + +template <int kBits, typename T, size_t N> +HWY_API Vec128<T, N> RotateRight(const Vec128<T, N> v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight<kBits>()(v); +} + +// ------------------------------ ShiftLeftSame + +template <typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftSame(Vec128<T, N> v, int bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast<hwy::MakeUnsigned<T>>(v.raw[i]) << bits; + v.raw[i] = static_cast<T>(shifted); + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightSame(Vec128<T, N> v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(v.raw[i] >> bits); + } +#else + if (IsSigned<T>()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned<T>; + for (size_t i = 0; i < N; ++i) { + const TU shifted = static_cast<TU>(static_cast<TU>(v.raw[i]) >> bits); + const TU sign = v.raw[i] < 0 ? static_cast<TU>(~TU{0}) : 0; + const size_t sign_shift = + static_cast<size_t>(static_cast<int>(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast<TU>(sign << sign_shift); + v.raw[i] = static_cast<T>(shifted | upper); + } + } else { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(v.raw[i] >> bits); // unsigned, logical shift + } + } +#endif + return v; +} + +// ------------------------------ Shl + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, const Vec128<T, N> bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast<hwy::MakeUnsigned<T>>(v.raw[i]) + << bits.raw[i]; + v.raw[i] = static_cast<T>(shifted); + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator>>(Vec128<T, N> v, const Vec128<T, N> bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(v.raw[i] >> bits.raw[i]); + } +#else + if (IsSigned<T>()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned<T>; + for (size_t i = 0; i < N; ++i) { + const TU shifted = + static_cast<TU>(static_cast<TU>(v.raw[i]) >> bits.raw[i]); + const TU sign = v.raw[i] < 0 ? static_cast<TU>(~TU{0}) : 0; + const size_t sign_shift = static_cast<size_t>( + static_cast<int>(sizeof(TU)) * 8 - 1 - bits.raw[i]); + const TU upper = static_cast<TU>(sign << sign_shift); + v.raw[i] = static_cast<T>(shifted | upper); + } + } else { // T is unsigned + for (size_t i = 0; i < N; ++i) { + v.raw[i] = static_cast<T>(v.raw[i] >> bits.raw[i]); + } + } +#endif + return v; +} + +// ================================================== ARITHMETIC + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Add(hwy::NonFloatTag /*tag*/, Vec128<T, N> a, + Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast<uint64_t>(a.raw[i]); + const uint64_t b64 = static_cast<uint64_t>(b.raw[i]); + a.raw[i] = static_cast<T>((a64 + b64) & static_cast<uint64_t>(~T(0))); + } + return a; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Sub(hwy::NonFloatTag /*tag*/, Vec128<T, N> a, + Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast<uint64_t>(a.raw[i]); + const uint64_t b64 = static_cast<uint64_t>(b.raw[i]); + a.raw[i] = static_cast<T>((a64 - b64) & static_cast<uint64_t>(~T(0))); + } + return a; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Add(hwy::FloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] += b.raw[i]; + } + return a; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Sub(hwy::FloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] -= b.raw[i]; + } + return a; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator-(Vec128<T, N> a, const Vec128<T, N> b) { + return detail::Sub(hwy::IsFloatTag<T>(), a, b); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> operator+(Vec128<T, N> a, const Vec128<T, N> b) { + return detail::Add(hwy::IsFloatTag<T>(), a, b); +} + +// ------------------------------ SumsOf8 + +template <size_t N> +HWY_API Vec128<uint64_t, (N + 7) / 8> SumsOf8(const Vec128<uint8_t, N> v) { + Vec128<uint64_t, (N + 7) / 8> sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +// ------------------------------ SaturatedAdd +template <typename T, size_t N> +HWY_API Vec128<T, N> SaturatedAdd(Vec128<T, N> a, const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<T>( + HWY_MIN(HWY_MAX(hwy::LowestValue<T>(), a.raw[i] + b.raw[i]), + hwy::HighestValue<T>())); + } + return a; +} + +// ------------------------------ SaturatedSub +template <typename T, size_t N> +HWY_API Vec128<T, N> SaturatedSub(Vec128<T, N> a, const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<T>( + HWY_MIN(HWY_MAX(hwy::LowestValue<T>(), a.raw[i] - b.raw[i]), + hwy::HighestValue<T>())); + } + return a; +} + +// ------------------------------ AverageRound +template <typename T, size_t N> +HWY_API Vec128<T, N> AverageRound(Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(!IsSigned<T>(), "Only for unsigned"); + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<T>((a.raw[i] + b.raw[i] + 1) / 2); + } + return a; +} + +// ------------------------------ Abs + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Abs(SignedTag /*tag*/, Vec128<T, N> a) { + for (size_t i = 0; i < N; ++i) { + const T s = a.raw[i]; + const T min = hwy::LimitsMin<T>(); + a.raw[i] = static_cast<T>((s >= 0 || s == min) ? a.raw[i] : -s); + } + return a; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Abs(hwy::FloatTag /*tag*/, Vec128<T, N> v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = std::abs(v.raw[i]); + } + return v; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> Abs(Vec128<T, N> a) { + return detail::Abs(hwy::TypeTag<T>(), a); +} + +// ------------------------------ Min/Max + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Min(hwy::NonFloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + return a; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Max(hwy::NonFloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + return a; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Min(hwy::FloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + if (std::isnan(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (std::isnan(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + } + return a; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Max(hwy::FloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + if (std::isnan(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (std::isnan(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + } + return a; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> Min(Vec128<T, N> a, const Vec128<T, N> b) { + return detail::Min(hwy::IsFloatTag<T>(), a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Max(Vec128<T, N> a, const Vec128<T, N> b) { + return detail::Max(hwy::IsFloatTag<T>(), a, b); +} + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T, size_t N> +HWY_API Vec128<T, N> Neg(hwy::NonFloatTag /*tag*/, Vec128<T, N> v) { + return Zero(Simd<T, N, 0>()) - v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Neg(hwy::FloatTag /*tag*/, Vec128<T, N> v) { + return Xor(v, SignBit(Simd<T, N, 0>())); +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> Neg(Vec128<T, N> v) { + return detail::Neg(hwy::IsFloatTag<T>(), v); +} + +// ------------------------------ Mul/Div + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Mul(hwy::FloatTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] *= b.raw[i]; + } + return a; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Mul(SignedTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<T>(static_cast<uint64_t>(a.raw[i]) * + static_cast<uint64_t>(b.raw[i])); + } + return a; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Mul(UnsignedTag /*tag*/, Vec128<T, N> a, + const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<T>(static_cast<uint64_t>(a.raw[i]) * + static_cast<uint64_t>(b.raw[i])); + } + return a; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator*(Vec128<T, N> a, const Vec128<T, N> b) { + return detail::Mul(hwy::TypeTag<T>(), a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator/(Vec128<T, N> a, const Vec128<T, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] /= b.raw[i]; + } + return a; +} + +// Returns the upper 16 bits of a * b in each lane. +template <size_t N> +HWY_API Vec128<int16_t, N> MulHigh(Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<int16_t>((int32_t{a.raw[i]} * b.raw[i]) >> 16); + } + return a; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> MulHigh(Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + for (size_t i = 0; i < N; ++i) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the + // result is the same but this way it is also defined. + a.raw[i] = static_cast<uint16_t>( + (static_cast<uint32_t>(a.raw[i]) * static_cast<uint32_t>(b.raw[i])) >> + 16); + } + return a; +} + +template <size_t N> +HWY_API Vec128<int16_t, N> MulFixedPoint15(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast<int16_t>((2 * a.raw[i] * b.raw[i] + 32768) >> 16); + } + return a; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +template <size_t N> +HWY_API Vec128<int64_t, (N + 1) / 2> MulEven(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + Vec128<int64_t, (N + 1) / 2> mul; + for (size_t i = 0; i < N; i += 2) { + const int64_t a64 = a.raw[i]; + mul.raw[i / 2] = a64 * b.raw[i]; + } + return mul; +} +template <size_t N> +HWY_API Vec128<uint64_t, (N + 1) / 2> MulEven(Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + Vec128<uint64_t, (N + 1) / 2> mul; + for (size_t i = 0; i < N; i += 2) { + const uint64_t a64 = a.raw[i]; + mul.raw[i / 2] = a64 * b.raw[i]; + } + return mul; +} + +template <size_t N> +HWY_API Vec128<int64_t, (N + 1) / 2> MulOdd(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + Vec128<int64_t, (N + 1) / 2> mul; + for (size_t i = 0; i < N; i += 2) { + const int64_t a64 = a.raw[i + 1]; + mul.raw[i / 2] = a64 * b.raw[i + 1]; + } + return mul; +} +template <size_t N> +HWY_API Vec128<uint64_t, (N + 1) / 2> MulOdd(Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + Vec128<uint64_t, (N + 1) / 2> mul; + for (size_t i = 0; i < N; i += 2) { + const uint64_t a64 = a.raw[i + 1]; + mul.raw[i / 2] = a64 * b.raw[i + 1]; + } + return mul; +} + +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocal(Vec128<float, N> v) { + for (size_t i = 0; i < N; ++i) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The result is arbitrary. + v.raw[i] = (std::abs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; + } + return v; +} + +template <size_t N> +HWY_API Vec128<float, N> AbsDiff(Vec128<float, N> a, const Vec128<float, N> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template <typename T, size_t N> +HWY_API Vec128<T, N> MulAdd(Vec128<T, N> mul, const Vec128<T, N> x, + const Vec128<T, N> add) { + return mul * x + add; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> NegMulAdd(Vec128<T, N> mul, const Vec128<T, N> x, + const Vec128<T, N> add) { + return add - mul * x; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> MulSub(Vec128<T, N> mul, const Vec128<T, N> x, + const Vec128<T, N> sub) { + return mul * x - sub; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> NegMulSub(Vec128<T, N> mul, const Vec128<T, N> x, + const Vec128<T, N> sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocalSqrt(Vec128<float, N> v) { + for (size_t i = 0; i < N; ++i) { + const float half = v.raw[i] * 0.5f; + uint32_t bits; + CopySameSize(&v.raw[i], &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &v.raw[i]); + // One Newton-Raphson iteration + v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Sqrt(Vec128<T, N> v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = std::sqrt(v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point rounding + +template <typename T, size_t N> +HWY_API Vec128<T, N> Round(Vec128<T, N> v) { + using TI = MakeSigned<T>; + const Vec128<T, N> a = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(a.raw[i] < MantissaEnd<T>())) { // Huge or NaN + continue; + } + const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast<TI>(v.raw[i] + bias); + if (rounded == 0) { + v.raw[i] = v.raw[i] < 0 ? T{-0} : T{0}; + continue; + } + const T rounded_f = static_cast<T>(rounded); + // Round to even + if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + v.raw[i] = static_cast<T>(rounded - (v.raw[i] < T(0) ? -1 : 1)); + continue; + } + v.raw[i] = rounded_f; + } + return v; +} + +// Round-to-nearest even. +template <size_t N> +HWY_API Vec128<int32_t, N> NearestInt(const Vec128<float, N> v) { + using T = float; + using TI = int32_t; + + const Vec128<float, N> abs = Abs(v); + Vec128<int32_t, N> ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = std::signbit(v.raw[i]); + + if (!(abs.raw[i] < MantissaEnd<T>())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs.raw[i] <= static_cast<T>(LimitsMax<TI>()))) { + ret.raw[i] = signbit ? LimitsMin<TI>() : LimitsMax<TI>(); + continue; + } + ret.raw[i] = static_cast<TI>(v.raw[i]); + continue; + } + const T bias = v.raw[i] < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast<TI>(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = static_cast<T>(rounded); + // Round to even + if ((rounded & 1) && std::abs(rounded_f - v.raw[i]) == T(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Trunc(Vec128<T, N> v) { + using TI = MakeSigned<T>; + const Vec128<T, N> abs = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(abs.raw[i] <= MantissaEnd<T>())) { // Huge or NaN + continue; + } + const TI truncated = static_cast<TI>(v.raw[i]); + if (truncated == 0) { + v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; + continue; + } + v.raw[i] = static_cast<T>(truncated); + } + return v; +} + +// Toward +infinity, aka ceiling +template <typename Float, size_t N> +Vec128<Float, N> Ceil(Vec128<Float, N> v) { + constexpr int kMantissaBits = MantissaBits<Float>(); + using Bits = MakeUnsigned<Float>; + const Bits kExponentMask = MaxExponentField<Float>(); + const Bits kMantissaMask = MantissaMask<Float>(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool positive = v.raw[i] > Float(0.0); + + Bits bits; + CopySameSize(&v.raw[i], &bits); + + const int exponent = + static_cast<int>(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => 0 or 1. + if (exponent < 0) { + v.raw[i] = positive ? Float{1} : Float{-0.0}; + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &v.raw[i]); + } + return v; +} + +// Toward -infinity, aka floor +template <typename Float, size_t N> +Vec128<Float, N> Floor(Vec128<Float, N> v) { + constexpr int kMantissaBits = MantissaBits<Float>(); + using Bits = MakeUnsigned<Float>; + const Bits kExponentMask = MaxExponentField<Float>(); + const Bits kMantissaMask = MantissaMask<Float>(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool negative = v.raw[i] < Float(0.0); + + Bits bits; + CopySameSize(&v.raw[i], &bits); + + const int exponent = + static_cast<int>(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => -1 or 0. + if (exponent < 0) { + v.raw[i] = negative ? Float(-1.0) : Float(0.0); + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point classification + +template <typename T, size_t N> +HWY_API Mask128<T, N> IsNaN(const Vec128<T, N> v) { + Mask128<T, N> ret; + for (size_t i = 0; i < N; ++i) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned<T> bits; + CopySameSize(&v.raw[i], &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + ret.bits[i] = Mask128<T, N>::FromBool(bits > ExponentMask<T>()); + } + return ret; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> IsInf(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2<T>()))); +} + +// Returns whether normal/subnormal/zero. +template <typename T, size_t N> +HWY_API Mask128<T, N> IsFinite(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + using VI = VFromD<decltype(di)>; + using VU = VFromD<decltype(du)>; + const VU vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VI exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField<T>()))); +} + +// ================================================== COMPARE + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator==(const Vec128<T, N> a, const Vec128<T, N> b) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(a.raw[i] == b.raw[i]); + } + return m; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator!=(const Vec128<T, N> a, const Vec128<T, N> b) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(a.raw[i] != b.raw[i]); + } + return m; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> TestBit(const Vec128<T, N> v, const Vec128<T, N> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator<(const Vec128<T, N> a, const Vec128<T, N> b) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(a.raw[i] < b.raw[i]); + } + return m; +} +template <typename T, size_t N> +HWY_API Mask128<T, N> operator>(const Vec128<T, N> a, const Vec128<T, N> b) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(a.raw[i] > b.raw[i]); + } + return m; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator<=(const Vec128<T, N> a, const Vec128<T, N> b) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(a.raw[i] <= b.raw[i]); + } + return m; +} +template <typename T, size_t N> +HWY_API Mask128<T, N> operator>=(const Vec128<T, N> a, const Vec128<T, N> b) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128<T, N>::FromBool(a.raw[i] >= b.raw[i]); + } + return m; +} + +// ------------------------------ Lt128 + +// Only makes sense for full vectors of u64. +HWY_API Mask128<uint64_t> Lt128(Simd<uint64_t, 2, 0> /* tag */, + Vec128<uint64_t> a, const Vec128<uint64_t> b) { + const bool lt = + (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); + Mask128<uint64_t> ret; + ret.bits[0] = ret.bits[1] = Mask128<uint64_t>::FromBool(lt); + return ret; +} + +HWY_API Mask128<uint64_t> Lt128Upper(Simd<uint64_t, 2, 0> /* tag */, + Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + const bool lt = a.raw[1] < b.raw[1]; + Mask128<uint64_t> ret; + ret.bits[0] = ret.bits[1] = Mask128<uint64_t>::FromBool(lt); + return ret; +} + +// ------------------------------ Eq128 + +// Only makes sense for full vectors of u64. +HWY_API Mask128<uint64_t> Eq128(Simd<uint64_t, 2, 0> /* tag */, + Vec128<uint64_t> a, const Vec128<uint64_t> b) { + const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; + Mask128<uint64_t> ret; + ret.bits[0] = ret.bits[1] = Mask128<uint64_t>::FromBool(eq); + return ret; +} + +HWY_API Mask128<uint64_t> Ne128(Simd<uint64_t, 2, 0> /* tag */, + Vec128<uint64_t> a, const Vec128<uint64_t> b) { + const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; + Mask128<uint64_t> ret; + ret.bits[0] = ret.bits[1] = Mask128<uint64_t>::FromBool(ne); + return ret; +} + +HWY_API Mask128<uint64_t> Eq128Upper(Simd<uint64_t, 2, 0> /* tag */, + Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + const bool eq = a.raw[1] == b.raw[1]; + Mask128<uint64_t> ret; + ret.bits[0] = ret.bits[1] = Mask128<uint64_t>::FromBool(eq); + return ret; +} + +HWY_API Mask128<uint64_t> Ne128Upper(Simd<uint64_t, 2, 0> /* tag */, + Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + const bool ne = a.raw[1] != b.raw[1]; + Mask128<uint64_t> ret; + ret.bits[0] = ret.bits[1] = Mask128<uint64_t>::FromBool(ne); + return ret; +} + +// ------------------------------ Min128, Max128 (Lt128) + +template <class D, class V = VFromD<D>> +HWY_API V Min128(D d, const V a, const V b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template <class D, class V = VFromD<D>> +HWY_API V Max128(D d, const V a, const V b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template <class D, class V = VFromD<D>> +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template <class D, class V = VFromD<D>> +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template <typename T, size_t N> +HWY_API Vec128<T, N> Load(Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT aligned) { + Vec128<T, N> v; + CopyBytes<sizeof(T) * N>(aligned, v.raw); // copy from array + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> LoadU(Simd<T, N, 0> d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template <typename T, size_t N> +HWY_API Vec128<T, N> LoadDup128(Simd<T, N, 0> d, + const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template <typename T, size_t N> +HWY_API void Store(const Vec128<T, N> v, Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT aligned) { + CopyBytes<sizeof(T) * N>(v.raw, aligned); // copy to array +} + +template <typename T, size_t N> +HWY_API void StoreU(const Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template <typename T, size_t N> +HWY_API void BlendedStore(const Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + for (size_t i = 0; i < N; ++i) { + if (m.bits[i]) p[i] = v.raw[i]; + } +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +// We implement those here because scalar code is likely faster than emulation +// via shuffles. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template <typename T, size_t N> +HWY_API void LoadInterleaved2(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + Vec128<T, N>& v0, Vec128<T, N>& v1) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); +} + +template <typename T, size_t N> +HWY_API void LoadInterleaved3(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + Vec128<T, N>& v0, Vec128<T, N>& v1, + Vec128<T, N>& v2) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + alignas(16) T buf2[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); +} + +template <typename T, size_t N> +HWY_API void LoadInterleaved4(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + Vec128<T, N>& v0, Vec128<T, N>& v1, + Vec128<T, N>& v2, Vec128<T, N>& v3) { + alignas(16) T buf0[N]; + alignas(16) T buf1[N]; + alignas(16) T buf2[N]; + alignas(16) T buf3[N]; + for (size_t i = 0; i < N; ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + buf3[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); + v3 = Load(d, buf3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template <typename T, size_t N> +HWY_API void StoreInterleaved2(const Vec128<T, N> v0, const Vec128<T, N> v1, + Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + } +} + +template <typename T, size_t N> +HWY_API void StoreInterleaved3(const Vec128<T, N> v0, const Vec128<T, N> v1, + const Vec128<T, N> v2, Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + } +} + +template <typename T, size_t N> +HWY_API void StoreInterleaved4(const Vec128<T, N> v0, const Vec128<T, N> v1, + const Vec128<T, N> v2, const Vec128<T, N> v3, + Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < N; ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + *unaligned++ = v3.raw[i]; + } +} + +// ------------------------------ Stream + +template <typename T, size_t N> +HWY_API void Stream(const Vec128<T, N> v, Simd<T, N, 0> d, + T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template <typename T, size_t N, typename Offset> +HWY_API void ScatterOffset(Vec128<T, N> v, Simd<T, N, 0> /* tag */, T* base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + for (size_t i = 0; i < N; ++i) { + uint8_t* const base8 = reinterpret_cast<uint8_t*>(base) + offset.raw[i]; + CopyBytes<sizeof(T)>(&v.raw[i], base8); // copy to bytes + } +} + +template <typename T, size_t N, typename Index> +HWY_API void ScatterIndex(Vec128<T, N> v, Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT base, const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + for (size_t i = 0; i < N; ++i) { + base[index.raw[i]] = v.raw[i]; + } +} + +// ------------------------------ Gather + +template <typename T, size_t N, typename Offset> +HWY_API Vec128<T, N> GatherOffset(Simd<T, N, 0> /* tag */, const T* base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + Vec128<T, N> v; + for (size_t i = 0; i < N; ++i) { + const uint8_t* base8 = + reinterpret_cast<const uint8_t*>(base) + offset.raw[i]; + CopyBytes<sizeof(T)>(base8, &v.raw[i]); // copy from bytes + } + return v; +} + +template <typename T, size_t N, typename Index> +HWY_API Vec128<T, N> GatherIndex(Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + Vec128<T, N> v; + for (size_t i = 0; i < N; ++i) { + v.raw[i] = base[index.raw[i]]; + } + return v; +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template <typename FromT, typename ToT, size_t N> +HWY_API Vec128<ToT, N> PromoteTo(Simd<ToT, N, 0> /* tag */, + Vec128<FromT, N> from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + Vec128<ToT, N> ret; + for (size_t i = 0; i < N; ++i) { + // For bits Y > X, floatX->floatY and intX->intY are always representable. + ret.raw[i] = static_cast<ToT>(from.raw[i]); + } + return ret; +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +template <size_t N> +HWY_API Vec128<float, N> DemoteTo(Simd<float, N, 0> /* tag */, + Vec128<double, N> from) { + Vec128<float, N> ret; + for (size_t i = 0; i < N; ++i) { + // Prevent ubsan errors when converting float to narrower integer/float + if (std::isinf(from.raw[i]) || + std::fabs(from.raw[i]) > static_cast<double>(HighestValue<float>())) { + ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue<float>() + : HighestValue<float>(); + continue; + } + ret.raw[i] = static_cast<float>(from.raw[i]); + } + return ret; +} +template <size_t N> +HWY_API Vec128<int32_t, N> DemoteTo(Simd<int32_t, N, 0> /* tag */, + Vec128<double, N> from) { + Vec128<int32_t, N> ret; + for (size_t i = 0; i < N; ++i) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (std::isinf(from.raw[i]) || + std::fabs(from.raw[i]) > static_cast<double>(HighestValue<int32_t>())) { + ret.raw[i] = std::signbit(from.raw[i]) ? LowestValue<int32_t>() + : HighestValue<int32_t>(); + continue; + } + ret.raw[i] = static_cast<int32_t>(from.raw[i]); + } + return ret; +} + +template <typename FromT, typename ToT, size_t N> +HWY_API Vec128<ToT, N> DemoteTo(Simd<ToT, N, 0> /* tag */, + Vec128<FromT, N> from) { + static_assert(!IsFloat<FromT>(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + Vec128<ToT, N> ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw[i] = + HWY_MIN(HWY_MAX(LimitsMin<ToT>(), from.raw[i]), LimitsMax<ToT>()); + ret.raw[i] = static_cast<ToT>(from.raw[i]); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<bfloat16_t, 2 * N> ReorderDemote2To( + Simd<bfloat16_t, 2 * N, 0> dbf16, Vec128<float, N> a, Vec128<float, N> b) { + const Repartition<uint32_t, decltype(dbf16)> du32; + const Vec128<uint32_t, N> b_in_lower = ShiftRight<16>(BitCast(du32, b)); + // Avoid OddEven - we want the upper half of `a` even on big-endian systems. + const Vec128<uint32_t, N> a_mask = Set(du32, 0xFFFF0000); + return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); +} + +template <size_t N> +HWY_API Vec128<int16_t, 2 * N> ReorderDemote2To(Simd<int16_t, 2 * N, 0> /*d16*/, + Vec128<int32_t, N> a, + Vec128<int32_t, N> b) { + const int16_t min = LimitsMin<int16_t>(); + const int16_t max = LimitsMax<int16_t>(); + Vec128<int16_t, 2 * N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<int16_t>(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); + } + for (size_t i = 0; i < N; ++i) { + ret.raw[N + i] = static_cast<int16_t>(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); + } + return ret; +} + +namespace detail { + +HWY_INLINE void StoreU16ToF16(const uint16_t val, + hwy::float16_t* HWY_RESTRICT to) { + CopySameSize(&val, to); +} + +HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { + uint16_t bits16; + CopySameSize(from, &bits16); + return bits16; +} + +} // namespace detail + +template <size_t N> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> /* tag */, + const Vec128<float16_t, N> v) { + Vec128<float, N> ret; + for (size_t i = 0; i < N; ++i) { + const uint16_t bits16 = detail::U16FromF16(&v.raw[i]); + const uint32_t sign = static_cast<uint32_t>(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast<float>(mantissa) * (1.0f / 1024)); + ret.raw[i] = sign ? -subnormal : subnormal; + continue; + } + + // Normalized: convert the representation directly (faster than + // ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + CopySameSize(&bits32, &ret.raw[i]); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> /* tag */, + const Vec128<bfloat16_t, N> v) { + Vec128<float, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = F32FromBF16(v.raw[i]); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<float16_t, N> DemoteTo(Simd<float16_t, N, 0> /* tag */, + const Vec128<float, N> v) { + Vec128<float16_t, N> ret; + for (size_t i = 0; i < N; ++i) { + uint32_t bits32; + CopySameSize(&v.raw[i], &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast<int32_t>(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + if (exp < -24) { + ZeroBytes<sizeof(uint16_t)>(&ret.raw[i]); + continue; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast<uint32_t>(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast<uint32_t>((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast<uint32_t>(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast<uint16_t>(bits16); // big-endian safe + detail::StoreU16ToF16(narrowed, &ret.raw[i]); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<bfloat16_t, N> DemoteTo(Simd<bfloat16_t, N, 0> /* tag */, + const Vec128<float, N> v) { + Vec128<bfloat16_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = BF16FromF32(v.raw[i]); + } + return ret; +} + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename FromT, typename ToT, size_t N> +HWY_API Vec128<ToT, N> ConvertTo(hwy::FloatTag /*tag*/, + Simd<ToT, N, 0> /* tag */, + Vec128<FromT, N> from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + Vec128<ToT, N> ret; + for (size_t i = 0; i < N; ++i) { + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax<ToT> in FromT, so use double. + const double f = static_cast<double>(from.raw[i]); + if (std::isinf(from.raw[i]) || + std::fabs(f) > static_cast<double>(LimitsMax<ToT>())) { + ret.raw[i] = + std::signbit(from.raw[i]) ? LimitsMin<ToT>() : LimitsMax<ToT>(); + continue; + } + ret.raw[i] = static_cast<ToT>(from.raw[i]); + } + return ret; +} + +template <typename FromT, typename ToT, size_t N> +HWY_API Vec128<ToT, N> ConvertTo(hwy::NonFloatTag /*tag*/, + Simd<ToT, N, 0> /* tag */, + Vec128<FromT, N> from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + Vec128<ToT, N> ret; + for (size_t i = 0; i < N; ++i) { + // int## -> float##: no check needed + ret.raw[i] = static_cast<ToT>(from.raw[i]); + } + return ret; +} + +} // namespace detail + +template <typename FromT, typename ToT, size_t N> +HWY_API Vec128<ToT, N> ConvertTo(Simd<ToT, N, 0> d, Vec128<FromT, N> from) { + return detail::ConvertTo(hwy::IsFloatTag<FromT>(), d, from); +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> U8FromU32(const Vec128<uint32_t, N> v) { + return DemoteTo(Simd<uint8_t, N, 0>(), v); +} + +// ------------------------------ Truncations + +template <size_t N> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint64_t, N> v) { + Vec128<uint8_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<uint8_t>(v.raw[i] & 0xFF); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<uint16_t, N> TruncateTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint64_t, N> v) { + Vec128<uint16_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<uint16_t>(v.raw[i] & 0xFFFF); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<uint32_t, N> TruncateTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint64_t, N> v) { + Vec128<uint32_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<uint32_t>(v.raw[i] & 0xFFFFFFFFu); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + Vec128<uint8_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<uint8_t>(v.raw[i] & 0xFF); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<uint16_t, N> TruncateTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + Vec128<uint16_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<uint16_t>(v.raw[i] & 0xFFFF); + } + return ret; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + Vec128<uint8_t, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast<uint8_t>(v.raw[i] & 0xFF); + } + return ret; +} + +// ================================================== COMBINE + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Vec128<T, N> v) { + Vec128<T, N / 2> ret; + CopyBytes<N / 2 * sizeof(T)>(v.raw, ret.raw); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Simd<T, N / 2, 0> /* tag */, + Vec128<T, N> v) { + return LowerHalf(v); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> UpperHalf(Simd<T, N / 2, 0> /* tag */, + Vec128<T, N> v) { + Vec128<T, N / 2> ret; + CopyBytes<N / 2 * sizeof(T)>(&v.raw[N / 2], ret.raw); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ZeroExtendVector(Simd<T, N, 0> /* tag */, + Vec128<T, N / 2> v) { + Vec128<T, N> ret; + CopyBytes<N / 2 * sizeof(T)>(v.raw, ret.raw); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Combine(Simd<T, N, 0> /* tag */, Vec128<T, N / 2> hi_half, + Vec128<T, N / 2> lo_half) { + Vec128<T, N> ret; + CopyBytes<N / 2 * sizeof(T)>(lo_half.raw, &ret.raw[0]); + CopyBytes<N / 2 * sizeof(T)>(hi_half.raw, &ret.raw[N / 2]); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatLowerLower(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + Vec128<T, N> ret; + CopyBytes<N / 2 * sizeof(T)>(lo.raw, &ret.raw[0]); + CopyBytes<N / 2 * sizeof(T)>(hi.raw, &ret.raw[N / 2]); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatUpperUpper(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + Vec128<T, N> ret; + CopyBytes<N / 2 * sizeof(T)>(&lo.raw[N / 2], &ret.raw[0]); + CopyBytes<N / 2 * sizeof(T)>(&hi.raw[N / 2], &ret.raw[N / 2]); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatLowerUpper(Simd<T, N, 0> /* tag */, + const Vec128<T, N> hi, + const Vec128<T, N> lo) { + Vec128<T, N> ret; + CopyBytes<N / 2 * sizeof(T)>(&lo.raw[N / 2], &ret.raw[0]); + CopyBytes<N / 2 * sizeof(T)>(hi.raw, &ret.raw[N / 2]); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatUpperLower(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + Vec128<T, N> ret; + CopyBytes<N / 2 * sizeof(T)>(lo.raw, &ret.raw[0]); + CopyBytes<N / 2 * sizeof(T)>(&hi.raw[N / 2], &ret.raw[N / 2]); + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatEven(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + Vec128<T, N> ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[i] = lo.raw[2 * i]; + } + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[N / 2 + i] = hi.raw[2 * i]; + } + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatOdd(Simd<T, N, 0> /* tag */, Vec128<T, N> hi, + Vec128<T, N> lo) { + Vec128<T, N> ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[i] = lo.raw[2 * i + 1]; + } + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[N / 2 + i] = hi.raw[2 * i + 1]; + } + return ret; +} + +// ------------------------------ CombineShiftRightBytes + +template <int kBytes, typename T, size_t N, class V = Vec128<T, N>> +HWY_API V CombineShiftRightBytes(Simd<T, N, 0> /* tag */, V hi, V lo) { + V ret; + const uint8_t* HWY_RESTRICT lo8 = + reinterpret_cast<const uint8_t * HWY_RESTRICT>(lo.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast<uint8_t * HWY_RESTRICT>(ret.raw); + CopyBytes<sizeof(T) * N - kBytes>(lo8 + kBytes, ret8); + CopyBytes<kBytes>(hi.raw, ret8 + sizeof(T) * N - kBytes); + return ret; +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + Vec128<T, N> ret; + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast<uint8_t * HWY_RESTRICT>(ret.raw); + ZeroBytes<kBytes>(ret8); + CopyBytes<sizeof(T) * N - kBytes>(v.raw, ret8 + kBytes); + return ret; +} + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(const Vec128<T, N> v) { + return ShiftLeftBytes<kBytes>(DFromV<decltype(v)>(), v); +} + +// ------------------------------ ShiftLeftLanes + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(const Vec128<T, N> v) { + return ShiftLeftLanes<kLanes>(DFromV<decltype(v)>(), v); +} + +// ------------------------------ ShiftRightBytes +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + Vec128<T, N> ret; + const uint8_t* HWY_RESTRICT v8 = + reinterpret_cast<const uint8_t * HWY_RESTRICT>(v.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast<uint8_t * HWY_RESTRICT>(ret.raw); + CopyBytes<sizeof(T) * N - kBytes>(v8 + kBytes, ret8); + ZeroBytes<kBytes>(ret8 + sizeof(T) * N - kBytes); + return ret; +} + +// ------------------------------ ShiftRightLanes +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// ================================================== SWIZZLE + +template <typename T, size_t N> +HWY_API T GetLane(const Vec128<T, N> v) { + return v.raw[0]; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> InsertLane(Vec128<T, N> v, size_t i, T t) { + v.raw[i] = t; + return v; +} + +template <typename T, size_t N> +HWY_API T ExtractLane(const Vec128<T, N> v, size_t i) { + return v.raw[i]; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> DupEven(Vec128<T, N> v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i + 1] = v.raw[i]; + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> DupOdd(Vec128<T, N> v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i] = v.raw[i + 1]; + } + return v; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEven(Vec128<T, N> odd, Vec128<T, N> even) { + for (size_t i = 0; i < N; i += 2) { + odd.raw[i] = even.raw[i]; + } + return odd; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEvenBlocks(Vec128<T, N> /* odd */, Vec128<T, N> even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T, size_t N> +HWY_API Vec128<T, N> SwapAdjacentBlocks(Vec128<T, N> v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template <typename T, size_t N> +struct Indices128 { + MakeSigned<T> raw[N]; +}; + +template <typename T, size_t N, typename TI> +HWY_API Indices128<T, N> IndicesFromVec(Simd<T, N, 0>, Vec128<TI, N> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + Indices128<T, N> ret; + CopyBytes<N * sizeof(T)>(vec.raw, ret.raw); + return ret; +} + +template <typename T, size_t N, typename TI> +HWY_API Indices128<T, N> SetTableIndices(Simd<T, N, 0> d, const TI* idx) { + return IndicesFromVec(d, LoadU(Simd<TI, N, 0>(), idx)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> TableLookupLanes(const Vec128<T, N> v, + const Indices128<T, N> idx) { + Vec128<T, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[idx.raw[i]]; + } + return ret; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template <typename T, size_t N> +HWY_API Vec128<T, N> ReverseBlocks(Simd<T, N, 0> /* tag */, + const Vec128<T, N> v) { + return v; +} + +// ------------------------------ Reverse + +template <typename T, size_t N> +HWY_API Vec128<T, N> Reverse(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + Vec128<T, N> ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[N - 1 - i]; + } + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + Vec128<T, N> ret; + for (size_t i = 0; i < N; i += 2) { + ret.raw[i + 0] = v.raw[i + 1]; + ret.raw[i + 1] = v.raw[i + 0]; + } + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + Vec128<T, N> ret; + for (size_t i = 0; i < N; i += 4) { + ret.raw[i + 0] = v.raw[i + 3]; + ret.raw[i + 1] = v.raw[i + 2]; + ret.raw[i + 2] = v.raw[i + 1]; + ret.raw[i + 3] = v.raw[i + 0]; + } + return ret; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + Vec128<T, N> ret; + for (size_t i = 0; i < N; i += 8) { + ret.raw[i + 0] = v.raw[i + 7]; + ret.raw[i + 1] = v.raw[i + 6]; + ret.raw[i + 2] = v.raw[i + 5]; + ret.raw[i + 3] = v.raw[i + 4]; + ret.raw[i + 4] = v.raw[i + 3]; + ret.raw[i + 5] = v.raw[i + 2]; + ret.raw[i + 6] = v.raw[i + 1]; + ret.raw[i + 7] = v.raw[i + 0]; + } + return ret; +} + +// ================================================== BLOCKWISE + +// ------------------------------ Shuffle* + +// Swap 32-bit halves in 64-bit halves. +template <typename T, size_t N> +HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Reverse2(DFromV<decltype(v)>(), v); +} + +// Swap 64-bit halves +template <typename T> +HWY_API Vec128<T> Shuffle1032(const Vec128<T> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + Vec128<T> ret; + ret.raw[3] = v.raw[1]; + ret.raw[2] = v.raw[0]; + ret.raw[1] = v.raw[3]; + ret.raw[0] = v.raw[2]; + return ret; +} +template <typename T> +HWY_API Vec128<T> Shuffle01(const Vec128<T> v) { + static_assert(sizeof(T) == 8, "Only for 64-bit"); + return Reverse2(DFromV<decltype(v)>(), v); +} + +// Rotate right 32 bits +template <typename T> +HWY_API Vec128<T> Shuffle0321(const Vec128<T> v) { + Vec128<T> ret; + ret.raw[3] = v.raw[0]; + ret.raw[2] = v.raw[3]; + ret.raw[1] = v.raw[2]; + ret.raw[0] = v.raw[1]; + return ret; +} + +// Rotate left 32 bits +template <typename T> +HWY_API Vec128<T> Shuffle2103(const Vec128<T> v) { + Vec128<T> ret; + ret.raw[3] = v.raw[2]; + ret.raw[2] = v.raw[1]; + ret.raw[1] = v.raw[0]; + ret.raw[0] = v.raw[3]; + return ret; +} + +template <typename T> +HWY_API Vec128<T> Shuffle0123(const Vec128<T> v) { + return Reverse4(DFromV<decltype(v)>(), v); +} + +// ------------------------------ Broadcast/splat any lane + +template <int kLane, typename T, size_t N> +HWY_API Vec128<T, N> Broadcast(Vec128<T, N> v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[kLane]; + } + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template <typename T, size_t N, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytes(const Vec128<T, N> v, + const Vec128<TI, NI> indices) { + const uint8_t* HWY_RESTRICT v_bytes = + reinterpret_cast<const uint8_t * HWY_RESTRICT>(v.raw); + const uint8_t* HWY_RESTRICT idx_bytes = + reinterpret_cast<const uint8_t*>(indices.raw); + Vec128<TI, NI> ret; + uint8_t* HWY_RESTRICT ret_bytes = + reinterpret_cast<uint8_t * HWY_RESTRICT>(ret.raw); + for (size_t i = 0; i < NI * sizeof(TI); ++i) { + const size_t idx = idx_bytes[i]; + // Avoid out of bounds reads. + ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; + } + return ret; +} + +template <typename T, size_t N, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytesOr0(const Vec128<T, N> v, + const Vec128<TI, NI> indices) { + // Same as TableLookupBytes, which already returns 0 if out of bounds. + return TableLookupBytes(v, indices); +} + +// ------------------------------ InterleaveLower/InterleaveUpper + +template <typename T, size_t N> +HWY_API Vec128<T, N> InterleaveLower(const Vec128<T, N> a, + const Vec128<T, N> b) { + Vec128<T, N> ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[i]; + ret.raw[2 * i + 1] = b.raw[i]; + } + return ret; +} + +// Additional overload for the optional tag (also for 256/512). +template <class V> +HWY_API V InterleaveLower(DFromV<V> /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> InterleaveUpper(Simd<T, N, 0> /* tag */, + const Vec128<T, N> a, + const Vec128<T, N> b) { + Vec128<T, N> ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[N / 2 + i]; + ret.raw[2 * i + 1] = b.raw[N / 2 + i]; + } + return ret; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== MASK + +template <typename T, size_t N> +HWY_API bool AllFalse(Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + typename Mask128<T, N>::Raw or_sum = 0; + for (size_t i = 0; i < N; ++i) { + or_sum |= mask.bits[i]; + } + return or_sum == 0; +} + +template <typename T, size_t N> +HWY_API bool AllTrue(Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + constexpr uint64_t kAll = LimitsMax<typename Mask128<T, N>::Raw>(); + uint64_t and_sum = kAll; + for (size_t i = 0; i < N; ++i) { + and_sum &= mask.bits[i]; + } + return and_sum == kAll; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T, size_t N> +HWY_API Mask128<T, N> LoadMaskBits(Simd<T, N, 0> /* tag */, + const uint8_t* HWY_RESTRICT bits) { + Mask128<T, N> m; + for (size_t i = 0; i < N; ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + m.bits[i] = Mask128<T, N>::FromBool((bits[idx_byte] & bit) != 0); + } + return m; +} + +// `p` points to at least 8 writable bytes. +template <typename T, size_t N> +HWY_API size_t StoreMaskBits(Simd<T, N, 0> /* tag */, const Mask128<T, N> mask, + uint8_t* bits) { + bits[0] = 0; + if (N > 8) bits[1] = 0; // N <= 16, so max two bytes + for (size_t i = 0; i < N; ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + if (mask.bits[i]) { + bits[idx_byte] = static_cast<uint8_t>(bits[idx_byte] | bit); + } + } + return N > 8 ? 2 : 1; +} + +template <typename T, size_t N> +HWY_API size_t CountTrue(Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + size_t count = 0; + for (size_t i = 0; i < N; ++i) { + count += mask.bits[i] != 0; + } + return count; +} + +template <typename T, size_t N> +HWY_API size_t FindKnownFirstTrue(Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i] != 0) return i; + } + HWY_DASSERT(false); + return 0; +} + +template <typename T, size_t N> +HWY_API intptr_t FindFirstTrue(Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i] != 0) return static_cast<intptr_t>(i); + } + return intptr_t{-1}; +} + +// ------------------------------ Compress + +template <typename T> +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +template <typename T, size_t N> +HWY_API Vec128<T, N> Compress(Vec128<T, N> v, const Mask128<T, N> mask) { + size_t count = 0; + Vec128<T, N> ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressNot +template <typename T, size_t N> +HWY_API Vec128<T, N> CompressNot(Vec128<T, N> v, const Mask128<T, N> mask) { + size_t count = 0; + Vec128<T, N> ret; + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128<uint64_t> CompressBlocksNot(Vec128<uint64_t> v, + Mask128<uint64_t> /* m */) { + return v; +} + +// ------------------------------ CompressBits +template <typename T, size_t N> +HWY_API Vec128<T, N> CompressBits(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Simd<T, N, 0>(), bits)); +} + +// ------------------------------ CompressStore +template <typename T, size_t N> +HWY_API size_t CompressStore(Vec128<T, N> v, const Mask128<T, N> mask, + Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT unaligned) { + size_t count = 0; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + unaligned[count++] = v.raw[i]; + } + } + return count; +} + +// ------------------------------ CompressBlendedStore +template <typename T, size_t N> +HWY_API size_t CompressBlendedStore(Vec128<T, N> v, const Mask128<T, N> mask, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + return CompressStore(v, mask, d, unaligned); +} + +// ------------------------------ CompressBitsStore +template <typename T, size_t N> +HWY_API size_t CompressBitsStore(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const Mask128<T, N> mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template <size_t N> +HWY_API Vec128<float, N> ReorderWidenMulAccumulate(Simd<float, N, 0> df32, + Vec128<bfloat16_t, 2 * N> a, + Vec128<bfloat16_t, 2 * N> b, + const Vec128<float, N> sum0, + Vec128<float, N>& sum1) { + const Rebind<uint32_t, decltype(df32)> du32; + using VU32 = VFromD<decltype(du32)>; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Avoid ZipLower/Upper so this also works on big-endian systems. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +template <size_t N> +HWY_API Vec128<int32_t, N> ReorderWidenMulAccumulate( + Simd<int32_t, N, 0> d32, Vec128<int16_t, 2 * N> a, Vec128<int16_t, 2 * N> b, + const Vec128<int32_t, N> sum0, Vec128<int32_t, N>& sum1) { + using VI32 = VFromD<decltype(d32)>; + // Manual sign extension requires two shifts for even lanes. + const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); + const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); + const VI32 ao = ShiftRight<16>(BitCast(d32, a)); + const VI32 bo = ShiftRight<16>(BitCast(d32, b)); + sum1 = Add(Mul(ao, bo), sum1); + return Add(Mul(ae, be), sum0); +} + +// ------------------------------ RearrangeToOddPlusEven +template <class VW> +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== REDUCTIONS + +template <typename T, size_t N> +HWY_API Vec128<T, N> SumOfLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + T sum = T{0}; + for (size_t i = 0; i < N; ++i) { + sum += v.raw[i]; + } + return Set(d, sum); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MinOfLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + T min = HighestValue<T>(); + for (size_t i = 0; i < N; ++i) { + min = HWY_MIN(min, v.raw[i]); + } + return Set(d, min); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MaxOfLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + T max = LowestValue<T>(); + for (size_t i = 0; i < N; ++i) { + max = HWY_MAX(max, v.raw[i]); + } + return Set(d, max); +} + +// ================================================== OPS WITH DEPENDENCIES + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128<uint64_t> MulEven(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128<uint64_t>(), mul); +} + +HWY_INLINE Vec128<uint64_t> MulOdd(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + alignas(16) uint64_t mul[2]; + const Half<Full128<uint64_t>> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128<uint64_t>(), mul); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/generic_ops-inl.h b/third_party/highway/hwy/ops/generic_ops-inl.h new file mode 100644 index 0000000000..5898518467 --- /dev/null +++ b/third_party/highway/hwy/ops/generic_ops-inl.h @@ -0,0 +1,1560 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-independent types/functions defined after target-specific ops. + +#include "hwy/base.h" + +// Define detail::Shuffle1230 etc, but only when viewing the current header; +// normally this is included via highway.h, which includes ops/*.h. +#if HWY_IDE && !defined(HWY_HIGHWAY_INCLUDED) +#include "hwy/ops/emu128-inl.h" +#endif // HWY_IDE + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The lane type of a vector type, e.g. float for Vec<ScalableTag<float>>. +template <class V> +using LaneType = decltype(GetLane(V())); + +// Vector type, e.g. Vec128<float> for CappedTag<float, 4>. Useful as the return +// type of functions that do not take a vector argument, or as an argument type +// if the function only has a template argument for D, or for explicit type +// names instead of auto. This may be a built-in type. +template <class D> +using Vec = decltype(Zero(D())); + +// Mask type. Useful as the return type of functions that do not take a mask +// argument, or as an argument type if the function only has a template argument +// for D, or for explicit type names instead of auto. +template <class D> +using Mask = decltype(MaskFromVec(Zero(D()))); + +// Returns the closest value to v within [lo, hi]. +template <class V> +HWY_API V Clamp(const V v, const V lo, const V hi) { + return Min(Max(lo, v), hi); +} + +// CombineShiftRightBytes (and -Lanes) are not available for the scalar target, +// and RVV has its own implementation of -Lanes. +#if HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV + +template <size_t kLanes, class D, class V = VFromD<D>> +HWY_API V CombineShiftRightLanes(D d, const V hi, const V lo) { + constexpr size_t kBytes = kLanes * sizeof(LaneType<V>); + static_assert(kBytes < 16, "Shift count is per-block"); + return CombineShiftRightBytes<kBytes>(d, hi, lo); +} + +#endif + +// Returns lanes with the most significant bit set and all other bits zero. +template <class D> +HWY_API Vec<D> SignBit(D d) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Set(du, SignMask<TFromD<D>>())); +} + +// Returns quiet NaN. +template <class D> +HWY_API Vec<D> NaN(D d) { + const RebindToSigned<D> di; + // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus + // mantissa MSB (to indicate quiet) would be sufficient. + return BitCast(d, Set(di, LimitsMax<TFromD<decltype(di)>>())); +} + +// Returns positive infinity. +template <class D> +HWY_API Vec<D> Inf(D d) { + const RebindToUnsigned<D> du; + using T = TFromD<D>; + using TU = TFromD<decltype(du)>; + const TU max_x2 = static_cast<TU>(MaxExponentTimes2<T>()); + return BitCast(d, Set(du, max_x2 >> 1)); +} + +// ------------------------------ SafeFillN + +template <class D, typename T = TFromD<D>> +HWY_API void SafeFillN(const size_t num, const T value, D d, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = value; + } +#else + BlendedStore(Set(d, value), FirstN(d, num), d, to); +#endif +} + +// ------------------------------ SafeCopyN + +template <class D, typename T = TFromD<D>> +HWY_API void SafeCopyN(const size_t num, D d, const T* HWY_RESTRICT from, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = from[i]; + } +#else + const Mask<D> mask = FirstN(d, num); + BlendedStore(MaskedLoad(mask, d, from), mask, d, to); +#endif +} + +// "Include guard": skip if native instructions are available. The generic +// implementation is currently shared between x86_* and wasm_*, and is too large +// to duplicate. + +#if (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// ------------------------------ LoadInterleaved2 + +template <typename T, size_t N, class V> +HWY_API void LoadInterleaved2(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1) { + const V A = LoadU(d, unaligned + 0 * N); // v1[1] v0[1] v1[0] v0[0] + const V B = LoadU(d, unaligned + 1 * N); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +template <typename T, class V> +HWY_API void LoadInterleaved2(Simd<T, 1, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template <typename T, size_t N, class V, HWY_IF_LE128(T, N)> +HWY_API void LoadTransposedBlocks3(Simd<T, N, 0> d, + const T* HWY_RESTRICT unaligned, V& A, V& B, + V& C) { + A = LoadU(d, unaligned + 0 * N); + B = LoadU(d, unaligned + 1 * N); + C = LoadU(d, unaligned + 2 * N); +} + +} // namespace detail + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 16)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned<decltype(d)> du; + // Compact notation so these fit on one line: 12 := v1[2]. + V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 + V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 + V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + alignas(16) constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, 9, 12, 15, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, Z, Z, Z, 2, 5, + 8, 11, 14, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, Z, 1, 4, 7, 10, 13}; + alignas(16) constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, 10, 13, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, Z, Z, 0, 3, 6, + 9, 12, 15, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, Z, 2, 5, 8, 11, 14}; + alignas(16) constexpr uint8_t kIdx_v2A[16] = {2, 5, 8, 11, 14, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2B[16] = {Z, Z, Z, Z, Z, 1, 4, 7, + 10, 13, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, Z, Z, Z, + Z, Z, 0, 3, 6, 9, 12, 15}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 8-bit lanes x8 +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 1), + HWY_IF_LANES_PER_BLOCK(T, N, 8)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned<decltype(d)> du; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + alignas(16) constexpr uint8_t kIdx_v0A[16] = {0, 3, 6, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0B[16] = {Z, Z, Z, 1, 4, 7, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v0C[16] = {Z, Z, Z, Z, Z, Z, 2, 5}; + alignas(16) constexpr uint8_t kIdx_v1A[16] = {1, 4, 7, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1B[16] = {Z, Z, Z, 2, 5, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v1C[16] = {Z, Z, Z, Z, Z, 0, 3, 6}; + alignas(16) constexpr uint8_t kIdx_v2A[16] = {2, 5, Z, Z, Z, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2B[16] = {Z, Z, 0, 3, 6, Z, Z, Z}; + alignas(16) constexpr uint8_t kIdx_v2C[16] = {Z, Z, Z, Z, Z, 1, 4, 7}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 16-bit lanes x8 +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 2), + HWY_IF_LANES_PER_BLOCK(T, N, 8)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + const RebindToUnsigned<decltype(d)> du; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. Same as above, + // but each element of the array contains two byte indices for a lane. + constexpr uint16_t Z = 0x8080; + alignas(16) constexpr uint16_t kIdx_v0A[8] = {0x0100, 0x0706, 0x0D0C, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v0B[8] = {Z, Z, Z, 0x0302, + 0x0908, 0x0F0E, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v0C[8] = {Z, Z, Z, Z, + Z, Z, 0x0504, 0x0B0A}; + alignas(16) constexpr uint16_t kIdx_v1A[8] = {0x0302, 0x0908, 0x0F0E, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v1B[8] = {Z, Z, Z, 0x0504, + 0x0B0A, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v1C[8] = {Z, Z, Z, Z, + Z, 0x0100, 0x0706, 0x0D0C}; + alignas(16) constexpr uint16_t kIdx_v2A[8] = {0x0504, 0x0B0A, Z, Z, + Z, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v2B[8] = {Z, Z, 0x0100, 0x0706, + 0x0D0C, Z, Z, Z}; + alignas(16) constexpr uint16_t kIdx_v2C[8] = {Z, Z, Z, Z, + Z, 0x0302, 0x0908, 0x0F0E}; + const V v0L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v0A))); + const V v0M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v0B))); + const V v0U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v0C))); + const V v1L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v1A))); + const V v1M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v1B))); + const V v1U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v1C))); + const V v2L = BitCast(d, TableLookupBytesOr0(A, LoadDup128(du, kIdx_v2A))); + const V v2M = BitCast(d, TableLookupBytesOr0(B, LoadDup128(du, kIdx_v2B))); + const V v2U = BitCast(d, TableLookupBytesOr0(C, LoadDup128(du, kIdx_v2C))); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 4)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + V A; // v0[1] v2[0] v1[0] v0[0] + V B; // v1[2] v0[2] v2[1] v1[1] + V C; // v2[3] v1[3] v0[3] v2[2] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + + const V vxx_02_03_xx = OddEven(C, B); + v0 = detail::Shuffle1230(A, vxx_02_03_xx); + + // Shuffle2301 takes the upper/lower halves of the output from one input, so + // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use + // OddEven because it may have higher throughput than Shuffle. + const V vxx_xx_10_11 = OddEven(A, B); + const V v12_13_xx_xx = OddEven(B, C); + v1 = detail::Shuffle2301(vxx_xx_10_11, v12_13_xx_xx); + + const V vxx_20_21_xx = OddEven(B, A); + v2 = detail::Shuffle3012(vxx_20_21_xx, C); +} + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 2)> +HWY_API void LoadInterleaved3(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + V A; // v1[0] v0[0] + V B; // v0[1] v2[0] + V C; // v2[1] v1[1] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + v0 = OddEven(B, A); + v1 = CombineShiftRightBytes<sizeof(T)>(d, C, A); + v2 = OddEven(C, B); +} + +template <typename T, class V> +HWY_API void LoadInterleaved3(Simd<T, 1, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +// ------------------------------ LoadInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template <typename T, size_t N, class V, HWY_IF_LE128(T, N)> +HWY_API void LoadTransposedBlocks4(Simd<T, N, 0> d, + const T* HWY_RESTRICT unaligned, V& A, V& B, + V& C, V& D) { + A = LoadU(d, unaligned + 0 * N); + B = LoadU(d, unaligned + 1 * N); + C = LoadU(d, unaligned + 2 * N); + D = LoadU(d, unaligned + 3 * N); +} + +} // namespace detail + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 16)> +HWY_API void LoadInterleaved4(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + const Repartition<uint64_t, decltype(d)> d64; + using V64 = VFromD<decltype(d64)>; + // 16 lanes per block; the lowest four blocks are at the bottom of A,B,C,D. + // Here int[i] means the four interleaved values of the i-th 4-tuple and + // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). + V A; // int[13..10] int[3..0] + V B; // int[17..14] int[7..4] + V C; // int[1b..18] int[b..8] + V D; // int[1f..1c] int[f..c] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + + // For brevity, the comments only list the lower block (upper = lower + 0x10) + const V v5140 = InterleaveLower(d, A, B); // int[5,1,4,0] + const V vd9c8 = InterleaveLower(d, C, D); // int[d,9,c,8] + const V v7362 = InterleaveUpper(d, A, B); // int[7,3,6,2] + const V vfbea = InterleaveUpper(d, C, D); // int[f,b,e,a] + + const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] + const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] + const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] + const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] + + const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] + const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] + const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] + const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] + + v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); + v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); + v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); + v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +} + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 8)> +HWY_API void LoadInterleaved4(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + // In the last step, we interleave by half of the block size, which is usually + // 8 bytes but half that for 8-bit x8 vectors. + using TW = hwy::UnsignedFromSize<sizeof(T) * N == 8 ? 4 : 8>; + const Repartition<TW, decltype(d)> dw; + using VW = VFromD<decltype(dw)>; + + // (Comments are for 256-bit vectors.) + // 8 lanes per block; the lowest four blocks are at the bottom of A,B,C,D. + V A; // v3210[9]v3210[8] v3210[1]v3210[0] + V B; // v3210[b]v3210[a] v3210[3]v3210[2] + V C; // v3210[d]v3210[c] v3210[5]v3210[4] + V D; // v3210[f]v3210[e] v3210[7]v3210[6] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + + const V va820 = InterleaveLower(d, A, B); // v3210[a,8] v3210[2,0] + const V vec64 = InterleaveLower(d, C, D); // v3210[e,c] v3210[6,4] + const V vb931 = InterleaveUpper(d, A, B); // v3210[b,9] v3210[3,1] + const V vfd75 = InterleaveUpper(d, C, D); // v3210[f,d] v3210[7,5] + + const VW v10_b830 = // v10[b..8] v10[3..0] + BitCast(dw, InterleaveLower(d, va820, vb931)); + const VW v10_fc74 = // v10[f..c] v10[7..4] + BitCast(dw, InterleaveLower(d, vec64, vfd75)); + const VW v32_b830 = // v32[b..8] v32[3..0] + BitCast(dw, InterleaveUpper(d, va820, vb931)); + const VW v32_fc74 = // v32[f..c] v32[7..4] + BitCast(dw, InterleaveUpper(d, vec64, vfd75)); + + v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); + v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); + v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); + v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +} + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 4)> +HWY_API void LoadInterleaved4(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + V A; // v3210[4] v3210[0] + V B; // v3210[5] v3210[1] + V C; // v3210[6] v3210[2] + V D; // v3210[7] v3210[3] + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + const V v10_ev = InterleaveLower(d, A, C); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] + const V v10_od = InterleaveLower(d, B, D); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] + const V v32_ev = InterleaveUpper(d, A, C); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] + const V v32_od = InterleaveUpper(d, B, D); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] + + v0 = InterleaveLower(d, v10_ev, v10_od); + v1 = InterleaveUpper(d, v10_ev, v10_od); + v2 = InterleaveLower(d, v32_ev, v32_od); + v3 = InterleaveUpper(d, v32_ev, v32_od); +} + +template <typename T, size_t N, class V, HWY_IF_LANES_PER_BLOCK(T, N, 2)> +HWY_API void LoadInterleaved4(Simd<T, N, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + V A, B, C, D; + detail::LoadTransposedBlocks4(d, unaligned, A, B, C, D); + v0 = InterleaveLower(d, A, C); + v1 = InterleaveUpper(d, A, C); + v2 = InterleaveLower(d, B, D); + v3 = InterleaveUpper(d, B, D); +} + +// Any T x1 +template <typename T, class V> +HWY_API void LoadInterleaved4(Simd<T, 1, 0> d, const T* HWY_RESTRICT unaligned, + V& v0, V& v1, V& v2, V& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template <typename T, size_t N, class V, HWY_IF_LE128(T, N)> +HWY_API void StoreTransposedBlocks2(const V A, const V B, Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); +} + +} // namespace detail + +// >= 128 bit vector +template <typename T, size_t N, class V, HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved2(const V v0, const V v1, Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] + const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[N/2] v0[N/2] + detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +} + +// <= 64 bits +template <class V, typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API void StoreInterleaved2(const V part0, const V part1, Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + const Twice<decltype(d)> d2; + const auto v0 = ZeroExtendVector(d2, part0); + const auto v1 = ZeroExtendVector(d2, part1); + const auto v10 = InterleaveLower(d2, v0, v1); + StoreU(v10, d2, unaligned); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template <typename T, size_t N, class V, HWY_IF_LE128(T, N)> +HWY_API void StoreTransposedBlocks3(const V A, const V B, const V C, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); + StoreU(C, d, unaligned + 2 * N); +} + +} // namespace detail + +// >= 128-bit vector, 8-bit lanes +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 1), + HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; + using TU = TFromD<decltype(du)>; + const auto k5 = Set(du, TU{5}); + const auto k6 = Set(du, TU{6}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = LoadDup128(du, tbl_v0); + const auto shuf_A1 = LoadDup128(du, tbl_v1); // cannot reuse shuf_A0 (has 5) + const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const V A = BitCast(d, A0 | A1 | A2); + + // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. + const auto shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 + const auto shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B = BitCast(d, B0 | B1 | B2); + + // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] + const auto shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. + const auto shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. + const auto shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A + const auto C0 = TableLookupBytesOr0(v0, shuf_C0); + const auto C1 = TableLookupBytesOr0(v1, shuf_C1); + const auto C2 = TableLookupBytesOr0(v2, shuf_C2); + const V C = BitCast(d, C0 | C1 | C2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 16-bit lanes +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 2), + HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const Repartition<uint8_t, decltype(d)> du8; + const auto k2 = Set(du8, uint8_t{2 * sizeof(T)}); + const auto k3 = Set(du8, uint8_t{3 * sizeof(T)}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. Note that these are byte + // indices for 16-bit lanes. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A1 = LoadDup128(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const auto shuf_A2 = LoadDup128(du8, tbl_v2); // ..1..0.. + + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); + const V A = BitCast(d, A0 | A1 | A2); + + // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] + const auto shuf_B0 = shuf_A1 + k3; // 5..4..3. + const auto shuf_B1 = shuf_A2 + k3; // ..4..3.. + const auto shuf_B2 = shuf_A0 + k2; // .4..3..2 + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B = BitCast(d, B0 | B1 | B2); + + // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_C0 = shuf_B1 + k3; // ..7..6.. + const auto shuf_C1 = shuf_B2 + k3; // .7..6..5 + const auto shuf_C2 = shuf_B0 + k2; // 7..6..5. + const auto C0 = TableLookupBytesOr0(v0, shuf_C0); + const auto C1 = TableLookupBytesOr0(v1, shuf_C1); + const auto C2 = TableLookupBytesOr0(v2, shuf_C2); + const V C = BitCast(d, C0 | C1 | C2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 32-bit lanes +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 4), + HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const RepartitionToWide<decltype(d)> dw; + + const V v10_v00 = InterleaveLower(d, v0, v1); + const V v01_v20 = OddEven(v0, v2); + // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) + const V A = BitCast( + d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); + + const V v1_321 = ShiftRightLanes<1>(d, v1); + const V v0_32 = ShiftRightLanes<2>(d, v0); + const V v21_v11 = OddEven(v2, v1_321); + const V v12_v02 = OddEven(v1_321, v0_32); + // B: v1[2],v0[2], v2[1],v1[1] + const V B = BitCast( + d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); + + // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. + const V v23_v13 = OddEven(v2, v1_321); + const V v03_v22 = OddEven(v0, v2); + // C: v2[3],v1[3],v0[3], v2[2] + const V C = BitCast( + d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 8), + HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved3(const V v0, const V v1, const V v2, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const V A = InterleaveLower(d, v0, v1); + const V B = OddEven(v0, v2); + const V C = InterleaveUpper(d, v1, v2); + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// 64-bit vector, 8-bit lanes +template <class V, typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API void StoreInterleaved3(const V part0, const V part1, const V part2, + Full64<T> d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors for the shuffles and first result. + const Full128<uint8_t> du; + const Full128<T> d_full; + const auto k5 = Set(du, uint8_t{5}); + const auto k6 = Set(du, uint8_t{6}); + + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = Load(du, tbl_v0); + const auto shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) + const auto shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned + 0 * N); + + // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const auto shuf_B0 = shuf_A2 + k6; // ..7..6.. + const auto shuf_B1 = shuf_A0 + k5; // .7..6..5 + const auto shuf_B2 = shuf_A1 + k5; // 7..6..5. + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const V B{(B0 | B1 | B2).raw}; + StoreU(B, d, unaligned + 1 * N); +} + +// 64-bit vector, 16-bit lanes +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API void StoreInterleaved3(const Vec64<T> part0, const Vec64<T> part1, + const Vec64<T> part2, Full64<T> dh, + T* HWY_RESTRICT unaligned) { + const Full128<T> d; + const Full128<uint8_t> du8; + constexpr size_t N = 16 / sizeof(T); + const auto k2 = Set(du8, uint8_t{2 * sizeof(T)}); + const auto k3 = Set(du8, uint8_t{3 * sizeof(T)}); + + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + + // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A1 = Load(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const auto shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const auto shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); + const Vec128<T> A = BitCast(d, A0 | A1 | A2); + StoreU(A, d, unaligned + 0 * N); + + // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] + const auto shuf_B0 = shuf_A1 + k3; // ..3. + const auto shuf_B1 = shuf_A2 + k3; // .3.. + const auto shuf_B2 = shuf_A0 + k2; // 3..2 + const auto B0 = TableLookupBytesOr0(v0, shuf_B0); + const auto B1 = TableLookupBytesOr0(v1, shuf_B1); + const auto B2 = TableLookupBytesOr0(v2, shuf_B2); + const Vec128<T> B = BitCast(d, B0 | B1 | B2); + StoreU(Vec64<T>{B.raw}, dh, unaligned + 1 * N); +} + +// 64-bit vector, 32-bit lanes +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API void StoreInterleaved3(const Vec64<T> v0, const Vec64<T> v1, + const Vec64<T> v2, Full64<T> d, + T* HWY_RESTRICT unaligned) { + // (same code as 128-bit vector, 64-bit lanes) + constexpr size_t N = 2; + const Vec64<T> v10_v00 = InterleaveLower(d, v0, v1); + const Vec64<T> v01_v20 = OddEven(v0, v2); + const Vec64<T> v21_v11 = InterleaveUpper(d, v1, v2); + StoreU(v10_v00, d, unaligned + 0 * N); + StoreU(v01_v20, d, unaligned + 1 * N); + StoreU(v21_v11, d, unaligned + 2 * N); +} + +// 64-bit lanes are handled by the N=1 case below. + +// <= 32-bit vector, 8-bit lanes +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1), HWY_IF_LE32(T, N)> +HWY_API void StoreInterleaved3(const Vec128<T, N> part0, + const Vec128<T, N> part1, + const Vec128<T, N> part2, Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128<uint8_t> du; + const Full128<T> d_full; + + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, + 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A0 = Load(du, tbl_v0); + const auto shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); + const auto shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. + const Vec128<T> A = BitCast(d_full, A0 | A1 | A2); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(A, d_full, buf); + CopyBytes<N * 3 * sizeof(T)>(buf, unaligned); +} + +// 32-bit vector, 16-bit lanes +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API void StoreInterleaved3(const Vec128<T, 2> part0, + const Vec128<T, 2> part1, + const Vec128<T, 2> part2, Simd<T, 2, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 4 / sizeof(T); + // Use full vectors for the shuffles and result. + const Full128<uint8_t> du8; + const Full128<T> d_full; + + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const auto shuf_A2 = // ..1..0.. + Load(du8, tbl_v2); + const auto shuf_A1 = // ...1..0. + CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); + const auto shuf_A0 = // ....1..0 + CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); + const auto A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 + const auto A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. + const auto A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(A, d_full, buf); + CopyBytes<N * 3 * sizeof(T)>(buf, unaligned); +} + +// Single-element vector, any lane size: just store directly +template <typename T> +HWY_API void StoreInterleaved3(const Vec128<T, 1> v0, const Vec128<T, 1> v1, + const Vec128<T, 1> v2, Simd<T, 1, 0> d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +// ------------------------------ StoreInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template <typename T, size_t N, class V, HWY_IF_LE128(T, N)> +HWY_API void StoreTransposedBlocks4(const V A, const V B, const V C, const V D, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + StoreU(A, d, unaligned + 0 * N); + StoreU(B, d, unaligned + 1 * N); + StoreU(C, d, unaligned + 2 * N); + StoreU(D, d, unaligned + 3 * N); +} + +} // namespace detail + +// >= 128-bit vector, 8..32-bit lanes +template <typename T, size_t N, class V, HWY_IF_NOT_LANE_SIZE(T, 8), + HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved4(const V v0, const V v1, const V v2, const V v3, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const RepartitionToWide<decltype(d)> dw; + const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32L = ZipLower(dw, v2, v3); + const auto v10U = ZipUpper(dw, v0, v1); + const auto v32U = ZipUpper(dw, v2, v3); + // The interleaved vectors are A, B, C, D. + const auto A = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 + const auto B = BitCast(d, InterleaveUpper(dw, v10L, v32L)); + const auto C = BitCast(d, InterleaveLower(dw, v10U, v32U)); + const auto D = BitCast(d, InterleaveUpper(dw, v10U, v32U)); + detail::StoreTransposedBlocks4(A, B, C, D, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template <typename T, size_t N, class V, HWY_IF_LANE_SIZE(T, 8), + HWY_IF_GE128(T, N)> +HWY_API void StoreInterleaved4(const V v0, const V v1, const V v2, const V v3, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + // The interleaved vectors are A, B, C, D. + const auto A = InterleaveLower(d, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d, v2, v3); + const auto C = InterleaveUpper(d, v0, v1); + const auto D = InterleaveUpper(d, v2, v3); + detail::StoreTransposedBlocks4(A, B, C, D, d, unaligned); +} + +// 64-bit vector, 8..32-bit lanes +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 8)> +HWY_API void StoreInterleaved4(const Vec64<T> part0, const Vec64<T> part1, + const Vec64<T> part2, const Vec64<T> part3, + Full64<T> /*tag*/, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors to reduce the number of stores. + const Full128<T> d_full; + const RepartitionToWide<decltype(d_full)> dw; + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + const Vec128<T> v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); + const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); + StoreU(A, d_full, unaligned + 0 * N); + StoreU(B, d_full, unaligned + 1 * N); +} + +// 64-bit vector, 64-bit lane +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void StoreInterleaved4(const Vec64<T> part0, const Vec64<T> part1, + const Vec64<T> part2, const Vec64<T> part3, + Full64<T> /*tag*/, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 16 / sizeof(T); + // Use full vectors to reduce the number of stores. + const Full128<T> d_full; + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + const Vec128<T> v3{part3.raw}; + const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d_full, v2, v3); + StoreU(A, d_full, unaligned + 0 * N); + StoreU(B, d_full, unaligned + 1 * N); +} + +// <= 32-bit vectors +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void StoreInterleaved4(const Vec128<T, N> part0, + const Vec128<T, N> part1, + const Vec128<T, N> part2, + const Vec128<T, N> part3, Simd<T, N, 0> /*tag*/, + T* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128<T> d_full; + const RepartitionToWide<decltype(d_full)> dw; + const Vec128<T> v0{part0.raw}; + const Vec128<T> v1{part1.raw}; + const Vec128<T> v2{part2.raw}; + const Vec128<T> v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); + alignas(16) T buf[16 / sizeof(T)]; + StoreU(v3210, d_full, buf); + CopyBytes<4 * N * sizeof(T)>(buf, unaligned); +} + +#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED + +// ------------------------------ AESRound + +// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +// Define for white-box testing, even if native instructions are available. +namespace detail { + +// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with +// Vector Permute Instructions" and the accompanying assembly language +// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: +// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . +// +// A brute-force 256 byte table lookup can also be made constant-time, and +// possibly competitive on NEON, but this is more performance-portable +// especially for x86 and large vectors. +template <class V> // u8 +HWY_INLINE V SubBytes(V state) { + const DFromV<V> du; + const auto mask = Set(du, uint8_t{0xF}); + + // Change polynomial basis to GF(2^4) + { + alignas(16) static constexpr uint8_t basisL[16] = { + 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, + 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA}; + alignas(16) static constexpr uint8_t basisU[16] = { + 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, + 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD}; + const auto sL = And(state, mask); + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto gf4L = TableLookupBytes(LoadDup128(du, basisL), sL); + const auto gf4U = TableLookupBytes(LoadDup128(du, basisU), sU); + state = Xor(gf4L, gf4U); + } + + // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and + // cause TableLookupBytesOr0 to return 0. + alignas(16) static constexpr uint8_t kZetaInv[16] = { + 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3}; + alignas(16) static constexpr uint8_t kInv[16] = { + 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4}; + const auto tbl = LoadDup128(du, kInv); + const auto sL = And(state, mask); // L=low nibble, U=upper + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto sX = Xor(sU, sL); + const auto invL = TableLookupBytes(LoadDup128(du, kZetaInv), sL); + const auto invU = TableLookupBytes(tbl, sU); + const auto invX = TableLookupBytes(tbl, sX); + const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); + const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); + + // Linear skew (cannot bake 0x63 bias into the table because out* indices + // may have the infinity flag set). + alignas(16) static constexpr uint8_t kAffineL[16] = { + 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, + 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15}; + alignas(16) static constexpr uint8_t kAffineU[16] = { + 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, + 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E}; + const auto affL = TableLookupBytesOr0(LoadDup128(du, kAffineL), outL); + const auto affU = TableLookupBytesOr0(LoadDup128(du, kAffineU), outU); + return Xor(Xor(affL, affU), Set(du, uint8_t{0x63})); +} + +} // namespace detail + +#endif // HWY_TARGET != HWY_SCALAR + +// "Include guard": skip if native AES instructions are available. +#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR + +namespace detail { + +template <class V> // u8 +HWY_API V ShiftRows(const V state) { + const DFromV<V> du; + alignas(16) static constexpr uint8_t kShiftRow[16] = { + 0, 5, 10, 15, // transposed: state is column major + 4, 9, 14, 3, // + 8, 13, 2, 7, // + 12, 1, 6, 11}; + const auto shift_row = LoadDup128(du, kShiftRow); + return TableLookupBytes(state, shift_row); +} + +template <class V> // u8 +HWY_API V MixColumns(const V state) { + const DFromV<V> du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. + // 1 2 3 1 // d are on diagonal, no permutation needed. + // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. + // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). + alignas(16) static constexpr uint8_t k2301[16] = { + 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13}; + alignas(16) static constexpr uint8_t k1230[16] = { + 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12}; + const RebindToSigned<decltype(du)> di; // can only do signed comparisons + const auto msb = Lt(BitCast(di, state), Zero(di)); + const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, int8_t{0x1B}))); + const auto d = Xor(Add(state, state), overflow); // = state*2 in GF(2^8). + const auto s2301 = TableLookupBytes(state, LoadDup128(du, k2301)); + const auto d_s2301 = Xor(d, s2301); + const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} + const auto t1230_s3012 = TableLookupBytes(t_s2301, LoadDup128(du, k1230)); + return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms +} + +} // namespace detail + +template <class V> // u8 +HWY_API V AESRound(V state, const V round_key) { + // Intel docs swap the first two steps, but it does not matter because + // ShiftRows is a permutation and SubBytes is independent of lane index. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = detail::MixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template <class V> // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +// Constant-time implementation inspired by +// https://www.bearssl.org/constanttime.html, but about half the cost because we +// use 64x64 multiplies and 128-bit XORs. +template <class V> +HWY_API V CLMulLower(V a, V b) { + const DFromV<V> d; + static_assert(IsSame<TFromD<decltype(d)>, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); + auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); + auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); + auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); + m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); + m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); + m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); + m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +template <class V> +HWY_API V CLMulUpper(V a, V b) { + const DFromV<V> d; + static_assert(IsSame<TFromD<decltype(d)>, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); + auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); + auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); + auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); + m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); + m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); + m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); + m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +#endif // HWY_NATIVE_AES +#endif // HWY_TARGET != HWY_SCALAR + +// "Include guard": skip if native POPCNT-related instructions are available. +#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +#undef HWY_MIN_POW2_FOR_128 +#if HWY_TARGET == HWY_RVV +#define HWY_MIN_POW2_FOR_128 1 +#else +// All other targets except HWY_SCALAR (which is excluded by HWY_IF_GE128_D) +// guarantee 128 bits anyway. +#define HWY_MIN_POW2_FOR_128 0 +#endif + +// This algorithm requires vectors to be at least 16 bytes, which is the case +// for LMUL >= 2. If not, use the fallback below. +template <typename V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 1), + HWY_IF_GE128_D(D), HWY_IF_POW2_GE(D, HWY_MIN_POW2_FOR_128)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame<TFromD<D>, uint8_t>(), "V must be u8"); + const D d; + HWY_ALIGN constexpr uint8_t kLookup[16] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + }; + const auto lo = And(v, Set(d, uint8_t{0xF})); + const auto hi = ShiftRight<4>(v); + const auto lookup = LoadDup128(d, kLookup); + return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo)); +} + +// RVV has a specialization that avoids the Set(). +#if HWY_TARGET != HWY_RVV +// Slower fallback for capped vectors. +template <typename V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 1), + HWY_IF_LT128_D(D)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame<TFromD<D>, uint8_t>(), "V must be u8"); + const D d; + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + const V k33 = Set(d, uint8_t{0x33}); + v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55}))); + v = Add(And(ShiftRight<2>(v), k33), And(v, k33)); + return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F})); +} +#endif // HWY_TARGET != HWY_RVV + +template <typename V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 2)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame<TFromD<D>, uint16_t>(), "V must be u16"); + const D d; + const Repartition<uint8_t, decltype(d)> d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + return Add(ShiftRight<8>(vals), And(vals, Set(d, uint16_t{0xFF}))); +} + +template <typename V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 4)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame<TFromD<D>, uint32_t>(), "V must be u32"); + const D d; + Repartition<uint16_t, decltype(d)> d16; + auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); + return Add(ShiftRight<16>(vals), And(vals, Set(d, uint32_t{0xFF}))); +} + +#if HWY_HAVE_INTEGER64 +template <typename V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API V PopulationCount(V v) { + static_assert(IsSame<TFromD<D>, uint64_t>(), "V must be u64"); + const D d; + Repartition<uint32_t, decltype(d)> d32; + auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); + return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFFULL))); +} +#endif + +#endif // HWY_NATIVE_POPCNT + +template <class V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 8), + HWY_IF_LT128_D(D), HWY_IF_FLOAT_D(D)> +HWY_API V operator*(V x, V y) { + return Set(D(), GetLane(x) * GetLane(y)); +} + +template <class V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 8), + HWY_IF_LT128_D(D), HWY_IF_NOT_FLOAT_D(D)> +HWY_API V operator*(V x, V y) { + const DFromV<V> d; + using T = TFromD<decltype(d)>; + using TU = MakeUnsigned<T>; + const TU xu = static_cast<TU>(GetLane(x)); + const TU yu = static_cast<TU>(GetLane(y)); + return Set(d, static_cast<T>(xu * yu)); +} + +// "Include guard": skip if native 64-bit mul instructions are available. +#if (defined(HWY_NATIVE_I64MULLO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +template <class V, class D64 = DFromV<V>, typename T = LaneType<V>, + HWY_IF_LANE_SIZE(T, 8), HWY_IF_UNSIGNED(T), HWY_IF_GE128_D(D64)> +HWY_API V operator*(V x, V y) { + RepartitionToNarrow<D64> d32; + auto x32 = BitCast(d32, x); + auto y32 = BitCast(d32, y); + auto lolo = BitCast(d32, MulEven(x32, y32)); + auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); + auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); + auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); + return BitCast(D64{}, lolo + hi); +} +template <class V, class DI64 = DFromV<V>, typename T = LaneType<V>, + HWY_IF_LANE_SIZE(T, 8), HWY_IF_SIGNED(T), HWY_IF_GE128_D(DI64)> +HWY_API V operator*(V x, V y) { + RebindToUnsigned<DI64> du64; + return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); +} + +#endif // HWY_NATIVE_I64MULLO + +// "Include guard": skip if native 8-bit compress instructions are available. +#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template <class V, class D, typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(V v, const uint8_t* HWY_RESTRICT bits, D d, + T* unaligned) { + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + const Simd<T, HWY_MIN(MaxLanes(d), 8), 0> d8; + T* HWY_RESTRICT pos = unaligned; + + HWY_ALIGN constexpr T table[2048] = { + 0, 1, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 1, 0, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 2, 0, 1, 3, 4, 5, 6, 7, /**/ 0, 2, 1, 3, 4, 5, 6, 7, // + 1, 2, 0, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 3, 0, 1, 2, 4, 5, 6, 7, /**/ 0, 3, 1, 2, 4, 5, 6, 7, // + 1, 3, 0, 2, 4, 5, 6, 7, /**/ 0, 1, 3, 2, 4, 5, 6, 7, // + 2, 3, 0, 1, 4, 5, 6, 7, /**/ 0, 2, 3, 1, 4, 5, 6, 7, // + 1, 2, 3, 0, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 4, 0, 1, 2, 3, 5, 6, 7, /**/ 0, 4, 1, 2, 3, 5, 6, 7, // + 1, 4, 0, 2, 3, 5, 6, 7, /**/ 0, 1, 4, 2, 3, 5, 6, 7, // + 2, 4, 0, 1, 3, 5, 6, 7, /**/ 0, 2, 4, 1, 3, 5, 6, 7, // + 1, 2, 4, 0, 3, 5, 6, 7, /**/ 0, 1, 2, 4, 3, 5, 6, 7, // + 3, 4, 0, 1, 2, 5, 6, 7, /**/ 0, 3, 4, 1, 2, 5, 6, 7, // + 1, 3, 4, 0, 2, 5, 6, 7, /**/ 0, 1, 3, 4, 2, 5, 6, 7, // + 2, 3, 4, 0, 1, 5, 6, 7, /**/ 0, 2, 3, 4, 1, 5, 6, 7, // + 1, 2, 3, 4, 0, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 5, 0, 1, 2, 3, 4, 6, 7, /**/ 0, 5, 1, 2, 3, 4, 6, 7, // + 1, 5, 0, 2, 3, 4, 6, 7, /**/ 0, 1, 5, 2, 3, 4, 6, 7, // + 2, 5, 0, 1, 3, 4, 6, 7, /**/ 0, 2, 5, 1, 3, 4, 6, 7, // + 1, 2, 5, 0, 3, 4, 6, 7, /**/ 0, 1, 2, 5, 3, 4, 6, 7, // + 3, 5, 0, 1, 2, 4, 6, 7, /**/ 0, 3, 5, 1, 2, 4, 6, 7, // + 1, 3, 5, 0, 2, 4, 6, 7, /**/ 0, 1, 3, 5, 2, 4, 6, 7, // + 2, 3, 5, 0, 1, 4, 6, 7, /**/ 0, 2, 3, 5, 1, 4, 6, 7, // + 1, 2, 3, 5, 0, 4, 6, 7, /**/ 0, 1, 2, 3, 5, 4, 6, 7, // + 4, 5, 0, 1, 2, 3, 6, 7, /**/ 0, 4, 5, 1, 2, 3, 6, 7, // + 1, 4, 5, 0, 2, 3, 6, 7, /**/ 0, 1, 4, 5, 2, 3, 6, 7, // + 2, 4, 5, 0, 1, 3, 6, 7, /**/ 0, 2, 4, 5, 1, 3, 6, 7, // + 1, 2, 4, 5, 0, 3, 6, 7, /**/ 0, 1, 2, 4, 5, 3, 6, 7, // + 3, 4, 5, 0, 1, 2, 6, 7, /**/ 0, 3, 4, 5, 1, 2, 6, 7, // + 1, 3, 4, 5, 0, 2, 6, 7, /**/ 0, 1, 3, 4, 5, 2, 6, 7, // + 2, 3, 4, 5, 0, 1, 6, 7, /**/ 0, 2, 3, 4, 5, 1, 6, 7, // + 1, 2, 3, 4, 5, 0, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 6, 0, 1, 2, 3, 4, 5, 7, /**/ 0, 6, 1, 2, 3, 4, 5, 7, // + 1, 6, 0, 2, 3, 4, 5, 7, /**/ 0, 1, 6, 2, 3, 4, 5, 7, // + 2, 6, 0, 1, 3, 4, 5, 7, /**/ 0, 2, 6, 1, 3, 4, 5, 7, // + 1, 2, 6, 0, 3, 4, 5, 7, /**/ 0, 1, 2, 6, 3, 4, 5, 7, // + 3, 6, 0, 1, 2, 4, 5, 7, /**/ 0, 3, 6, 1, 2, 4, 5, 7, // + 1, 3, 6, 0, 2, 4, 5, 7, /**/ 0, 1, 3, 6, 2, 4, 5, 7, // + 2, 3, 6, 0, 1, 4, 5, 7, /**/ 0, 2, 3, 6, 1, 4, 5, 7, // + 1, 2, 3, 6, 0, 4, 5, 7, /**/ 0, 1, 2, 3, 6, 4, 5, 7, // + 4, 6, 0, 1, 2, 3, 5, 7, /**/ 0, 4, 6, 1, 2, 3, 5, 7, // + 1, 4, 6, 0, 2, 3, 5, 7, /**/ 0, 1, 4, 6, 2, 3, 5, 7, // + 2, 4, 6, 0, 1, 3, 5, 7, /**/ 0, 2, 4, 6, 1, 3, 5, 7, // + 1, 2, 4, 6, 0, 3, 5, 7, /**/ 0, 1, 2, 4, 6, 3, 5, 7, // + 3, 4, 6, 0, 1, 2, 5, 7, /**/ 0, 3, 4, 6, 1, 2, 5, 7, // + 1, 3, 4, 6, 0, 2, 5, 7, /**/ 0, 1, 3, 4, 6, 2, 5, 7, // + 2, 3, 4, 6, 0, 1, 5, 7, /**/ 0, 2, 3, 4, 6, 1, 5, 7, // + 1, 2, 3, 4, 6, 0, 5, 7, /**/ 0, 1, 2, 3, 4, 6, 5, 7, // + 5, 6, 0, 1, 2, 3, 4, 7, /**/ 0, 5, 6, 1, 2, 3, 4, 7, // + 1, 5, 6, 0, 2, 3, 4, 7, /**/ 0, 1, 5, 6, 2, 3, 4, 7, // + 2, 5, 6, 0, 1, 3, 4, 7, /**/ 0, 2, 5, 6, 1, 3, 4, 7, // + 1, 2, 5, 6, 0, 3, 4, 7, /**/ 0, 1, 2, 5, 6, 3, 4, 7, // + 3, 5, 6, 0, 1, 2, 4, 7, /**/ 0, 3, 5, 6, 1, 2, 4, 7, // + 1, 3, 5, 6, 0, 2, 4, 7, /**/ 0, 1, 3, 5, 6, 2, 4, 7, // + 2, 3, 5, 6, 0, 1, 4, 7, /**/ 0, 2, 3, 5, 6, 1, 4, 7, // + 1, 2, 3, 5, 6, 0, 4, 7, /**/ 0, 1, 2, 3, 5, 6, 4, 7, // + 4, 5, 6, 0, 1, 2, 3, 7, /**/ 0, 4, 5, 6, 1, 2, 3, 7, // + 1, 4, 5, 6, 0, 2, 3, 7, /**/ 0, 1, 4, 5, 6, 2, 3, 7, // + 2, 4, 5, 6, 0, 1, 3, 7, /**/ 0, 2, 4, 5, 6, 1, 3, 7, // + 1, 2, 4, 5, 6, 0, 3, 7, /**/ 0, 1, 2, 4, 5, 6, 3, 7, // + 3, 4, 5, 6, 0, 1, 2, 7, /**/ 0, 3, 4, 5, 6, 1, 2, 7, // + 1, 3, 4, 5, 6, 0, 2, 7, /**/ 0, 1, 3, 4, 5, 6, 2, 7, // + 2, 3, 4, 5, 6, 0, 1, 7, /**/ 0, 2, 3, 4, 5, 6, 1, 7, // + 1, 2, 3, 4, 5, 6, 0, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 7, 0, 1, 2, 3, 4, 5, 6, /**/ 0, 7, 1, 2, 3, 4, 5, 6, // + 1, 7, 0, 2, 3, 4, 5, 6, /**/ 0, 1, 7, 2, 3, 4, 5, 6, // + 2, 7, 0, 1, 3, 4, 5, 6, /**/ 0, 2, 7, 1, 3, 4, 5, 6, // + 1, 2, 7, 0, 3, 4, 5, 6, /**/ 0, 1, 2, 7, 3, 4, 5, 6, // + 3, 7, 0, 1, 2, 4, 5, 6, /**/ 0, 3, 7, 1, 2, 4, 5, 6, // + 1, 3, 7, 0, 2, 4, 5, 6, /**/ 0, 1, 3, 7, 2, 4, 5, 6, // + 2, 3, 7, 0, 1, 4, 5, 6, /**/ 0, 2, 3, 7, 1, 4, 5, 6, // + 1, 2, 3, 7, 0, 4, 5, 6, /**/ 0, 1, 2, 3, 7, 4, 5, 6, // + 4, 7, 0, 1, 2, 3, 5, 6, /**/ 0, 4, 7, 1, 2, 3, 5, 6, // + 1, 4, 7, 0, 2, 3, 5, 6, /**/ 0, 1, 4, 7, 2, 3, 5, 6, // + 2, 4, 7, 0, 1, 3, 5, 6, /**/ 0, 2, 4, 7, 1, 3, 5, 6, // + 1, 2, 4, 7, 0, 3, 5, 6, /**/ 0, 1, 2, 4, 7, 3, 5, 6, // + 3, 4, 7, 0, 1, 2, 5, 6, /**/ 0, 3, 4, 7, 1, 2, 5, 6, // + 1, 3, 4, 7, 0, 2, 5, 6, /**/ 0, 1, 3, 4, 7, 2, 5, 6, // + 2, 3, 4, 7, 0, 1, 5, 6, /**/ 0, 2, 3, 4, 7, 1, 5, 6, // + 1, 2, 3, 4, 7, 0, 5, 6, /**/ 0, 1, 2, 3, 4, 7, 5, 6, // + 5, 7, 0, 1, 2, 3, 4, 6, /**/ 0, 5, 7, 1, 2, 3, 4, 6, // + 1, 5, 7, 0, 2, 3, 4, 6, /**/ 0, 1, 5, 7, 2, 3, 4, 6, // + 2, 5, 7, 0, 1, 3, 4, 6, /**/ 0, 2, 5, 7, 1, 3, 4, 6, // + 1, 2, 5, 7, 0, 3, 4, 6, /**/ 0, 1, 2, 5, 7, 3, 4, 6, // + 3, 5, 7, 0, 1, 2, 4, 6, /**/ 0, 3, 5, 7, 1, 2, 4, 6, // + 1, 3, 5, 7, 0, 2, 4, 6, /**/ 0, 1, 3, 5, 7, 2, 4, 6, // + 2, 3, 5, 7, 0, 1, 4, 6, /**/ 0, 2, 3, 5, 7, 1, 4, 6, // + 1, 2, 3, 5, 7, 0, 4, 6, /**/ 0, 1, 2, 3, 5, 7, 4, 6, // + 4, 5, 7, 0, 1, 2, 3, 6, /**/ 0, 4, 5, 7, 1, 2, 3, 6, // + 1, 4, 5, 7, 0, 2, 3, 6, /**/ 0, 1, 4, 5, 7, 2, 3, 6, // + 2, 4, 5, 7, 0, 1, 3, 6, /**/ 0, 2, 4, 5, 7, 1, 3, 6, // + 1, 2, 4, 5, 7, 0, 3, 6, /**/ 0, 1, 2, 4, 5, 7, 3, 6, // + 3, 4, 5, 7, 0, 1, 2, 6, /**/ 0, 3, 4, 5, 7, 1, 2, 6, // + 1, 3, 4, 5, 7, 0, 2, 6, /**/ 0, 1, 3, 4, 5, 7, 2, 6, // + 2, 3, 4, 5, 7, 0, 1, 6, /**/ 0, 2, 3, 4, 5, 7, 1, 6, // + 1, 2, 3, 4, 5, 7, 0, 6, /**/ 0, 1, 2, 3, 4, 5, 7, 6, // + 6, 7, 0, 1, 2, 3, 4, 5, /**/ 0, 6, 7, 1, 2, 3, 4, 5, // + 1, 6, 7, 0, 2, 3, 4, 5, /**/ 0, 1, 6, 7, 2, 3, 4, 5, // + 2, 6, 7, 0, 1, 3, 4, 5, /**/ 0, 2, 6, 7, 1, 3, 4, 5, // + 1, 2, 6, 7, 0, 3, 4, 5, /**/ 0, 1, 2, 6, 7, 3, 4, 5, // + 3, 6, 7, 0, 1, 2, 4, 5, /**/ 0, 3, 6, 7, 1, 2, 4, 5, // + 1, 3, 6, 7, 0, 2, 4, 5, /**/ 0, 1, 3, 6, 7, 2, 4, 5, // + 2, 3, 6, 7, 0, 1, 4, 5, /**/ 0, 2, 3, 6, 7, 1, 4, 5, // + 1, 2, 3, 6, 7, 0, 4, 5, /**/ 0, 1, 2, 3, 6, 7, 4, 5, // + 4, 6, 7, 0, 1, 2, 3, 5, /**/ 0, 4, 6, 7, 1, 2, 3, 5, // + 1, 4, 6, 7, 0, 2, 3, 5, /**/ 0, 1, 4, 6, 7, 2, 3, 5, // + 2, 4, 6, 7, 0, 1, 3, 5, /**/ 0, 2, 4, 6, 7, 1, 3, 5, // + 1, 2, 4, 6, 7, 0, 3, 5, /**/ 0, 1, 2, 4, 6, 7, 3, 5, // + 3, 4, 6, 7, 0, 1, 2, 5, /**/ 0, 3, 4, 6, 7, 1, 2, 5, // + 1, 3, 4, 6, 7, 0, 2, 5, /**/ 0, 1, 3, 4, 6, 7, 2, 5, // + 2, 3, 4, 6, 7, 0, 1, 5, /**/ 0, 2, 3, 4, 6, 7, 1, 5, // + 1, 2, 3, 4, 6, 7, 0, 5, /**/ 0, 1, 2, 3, 4, 6, 7, 5, // + 5, 6, 7, 0, 1, 2, 3, 4, /**/ 0, 5, 6, 7, 1, 2, 3, 4, // + 1, 5, 6, 7, 0, 2, 3, 4, /**/ 0, 1, 5, 6, 7, 2, 3, 4, // + 2, 5, 6, 7, 0, 1, 3, 4, /**/ 0, 2, 5, 6, 7, 1, 3, 4, // + 1, 2, 5, 6, 7, 0, 3, 4, /**/ 0, 1, 2, 5, 6, 7, 3, 4, // + 3, 5, 6, 7, 0, 1, 2, 4, /**/ 0, 3, 5, 6, 7, 1, 2, 4, // + 1, 3, 5, 6, 7, 0, 2, 4, /**/ 0, 1, 3, 5, 6, 7, 2, 4, // + 2, 3, 5, 6, 7, 0, 1, 4, /**/ 0, 2, 3, 5, 6, 7, 1, 4, // + 1, 2, 3, 5, 6, 7, 0, 4, /**/ 0, 1, 2, 3, 5, 6, 7, 4, // + 4, 5, 6, 7, 0, 1, 2, 3, /**/ 0, 4, 5, 6, 7, 1, 2, 3, // + 1, 4, 5, 6, 7, 0, 2, 3, /**/ 0, 1, 4, 5, 6, 7, 2, 3, // + 2, 4, 5, 6, 7, 0, 1, 3, /**/ 0, 2, 4, 5, 6, 7, 1, 3, // + 1, 2, 4, 5, 6, 7, 0, 3, /**/ 0, 1, 2, 4, 5, 6, 7, 3, // + 3, 4, 5, 6, 7, 0, 1, 2, /**/ 0, 3, 4, 5, 6, 7, 1, 2, // + 1, 3, 4, 5, 6, 7, 0, 2, /**/ 0, 1, 3, 4, 5, 6, 7, 2, // + 2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 2, 3, 4, 5, 6, 7, 1, // + 1, 2, 3, 4, 5, 6, 7, 0, /**/ 0, 1, 2, 3, 4, 5, 6, 7}; + + for (size_t i = 0; i < Lanes(d); i += 8) { + // Each byte worth of bits is the index of one of 256 8-byte ranges, and its + // population count determines how far to advance the write position. + const size_t bits8 = bits[i / 8]; + const auto indices = Load(d8, table + bits8 * 8); + const auto compressed = TableLookupBytes(LoadU(d8, lanes + i), indices); + StoreU(compressed, d8, pos); + pos += PopCount(bits8); + } + return static_cast<size_t>(pos - unaligned); +} + +template <class V, class M, class D, typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API size_t CompressStore(V v, M mask, D d, T* HWY_RESTRICT unaligned) { + uint8_t bits[HWY_MAX(size_t{8}, MaxLanes(d) / 8)]; + (void)StoreMaskBits(d, mask, bits); + return CompressBitsStore(v, bits, d, unaligned); +} + +template <class V, class M, class D, typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API size_t CompressBlendedStore(V v, M mask, D d, + T* HWY_RESTRICT unaligned) { + HWY_ALIGN T buf[MaxLanes(d)]; + const size_t bytes = CompressStore(v, mask, d, buf); + BlendedStore(Load(d, buf), FirstN(d, bytes), d, unaligned); + return bytes; +} + +// For reasons unknown, HWY_IF_LANE_SIZE_V is a compile error in SVE. +template <class V, class M, typename T = TFromV<V>, HWY_IF_LANE_SIZE(T, 1)> +HWY_API V Compress(V v, const M mask) { + const DFromV<V> d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +template <class V, typename T = TFromV<V>, HWY_IF_LANE_SIZE(T, 1)> +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + const DFromV<V> d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressBitsStore(v, bits, d, lanes); + return Load(d, lanes); +} + +template <class V, class M, typename T = TFromV<V>, HWY_IF_LANE_SIZE(T, 1)> +HWY_API V CompressNot(V v, M mask) { + return Compress(v, Not(mask)); +} + +#endif // HWY_NATIVE_COMPRESS8 + +// ================================================== Operator wrapper + +// These targets currently cannot define operators and have already defined +// (only) the corresponding functions such as Add. +#if HWY_TARGET != HWY_RVV && HWY_TARGET != HWY_SVE && \ + HWY_TARGET != HWY_SVE2 && HWY_TARGET != HWY_SVE_256 && \ + HWY_TARGET != HWY_SVE2_128 + +template <class V> +HWY_API V Add(V a, V b) { + return a + b; +} +template <class V> +HWY_API V Sub(V a, V b) { + return a - b; +} + +template <class V> +HWY_API V Mul(V a, V b) { + return a * b; +} +template <class V> +HWY_API V Div(V a, V b) { + return a / b; +} + +template <class V> +V Shl(V a, V b) { + return a << b; +} +template <class V> +V Shr(V a, V b) { + return a >> b; +} + +template <class V> +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template <class V> +HWY_API auto Ne(V a, V b) -> decltype(a == b) { + return a != b; +} +template <class V> +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template <class V> +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template <class V> +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template <class V> +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +#endif // HWY_TARGET for operators + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/rvv-inl.h b/third_party/highway/hwy/ops/rvv-inl.h new file mode 100644 index 0000000000..502611282c --- /dev/null +++ b/third_party/highway/hwy/ops/rvv-inl.h @@ -0,0 +1,3451 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RISC-V V vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include <riscv_vector.h> +#include <stddef.h> +#include <stdint.h> + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <class V> +struct DFromV_t {}; // specialized in macros +template <class V> +using DFromV = typename DFromV_t<RemoveConst<V>>::type; + +template <class V> +using TFromV = TFromD<DFromV<V>>; + +// Enables the overload if Pow2 is in [min, max]. +#define HWY_RVV_IF_POW2_IN(D, min, max) \ + hwy::EnableIf<(min) <= Pow2(D()) && Pow2(D()) <= (max)>* = nullptr + +template <typename T, size_t N, int kPow2> +constexpr size_t MLenFromD(Simd<T, N, kPow2> /* tag */) { + // Returns divisor = type bits / LMUL. Folding *8 into the ScaleByPower + // argument enables fractional LMUL < 1. Limit to 64 because that is the + // largest value for which vbool##_t are defined. + return HWY_MIN(64, sizeof(T) * 8 * 8 / detail::ScaleByPower(8, kPow2)); +} + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// For all mask sizes MLEN: (1/Nth of a register, one bit per lane) +// The first two arguments are SEW and SHIFT such that SEW >> SHIFT = MLEN. +#define HWY_RVV_FOREACH_B(X_MACRO, NAME, OP) \ + X_MACRO(64, 0, 64, NAME, OP) \ + X_MACRO(32, 0, 32, NAME, OP) \ + X_MACRO(16, 0, 16, NAME, OP) \ + X_MACRO(8, 0, 8, NAME, OP) \ + X_MACRO(8, 1, 4, NAME, OP) \ + X_MACRO(8, 2, 2, NAME, OP) \ + X_MACRO(8, 3, 1, NAME, OP) + +// For given SEW, iterate over one of LMULS: _TRUNC, _EXT, _ALL. This allows +// reusing type lists such as HWY_RVV_FOREACH_U for _ALL (the usual case) or +// _EXT (for Combine). To achieve this, we HWY_CONCAT with the LMULS suffix. +// +// Precompute SEW/LMUL => MLEN to allow token-pasting the result. For the same +// reason, also pass the double-width and half SEW and LMUL (suffixed D and H, +// respectively). "__" means there is no corresponding LMUL (e.g. LMULD for m8). +// Args: BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP + +// LMULS = _TRUNC: truncatable (not the smallest LMUL) +#define HWY_RVV_FOREACH_08_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. +#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _LE2: <= 2 +#define HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) + +// LMULS = _EXT: not the largest LMUL +#define HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) + +// LMULS = _ALL (2^MinPow2() <= LMUL <= 8) +#define HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// 'Virtual' LMUL. This upholds the Highway guarantee that vectors are at least +// 128 bit and LowerHalf is defined whenever there are at least 2 lanes, even +// though RISC-V LMUL must be at least SEW/64 (notice that this rules out +// LMUL=1/2 for SEW=64). To bridge the gap, we add overloads for kPow2 equal to +// one less than should be supported, with all other parameters (vector type +// etc.) unchanged. For D with the lowest kPow2 ('virtual LMUL'), Lanes() +// returns half of what it usually would. +// +// Notice that we can only add overloads whenever there is a D argument: those +// are unique with respect to non-virtual-LMUL overloads because their kPow2 +// template argument differs. Otherwise, there is no actual vuint64mf2_t, and +// defining another overload with the same LMUL would be an error. Thus we have +// a separate _VIRT category for HWY_RVV_FOREACH*, and the common case is +// _ALL_VIRT (meaning the regular LMUL plus the VIRT overloads), used in most +// functions that take a D. + +#define HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, -1, /*MLEN=*/64, NAME, OP) + +// ALL + VIRT +#define HWY_RVV_FOREACH_08_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// LE2 + VIRT +#define HWY_RVV_FOREACH_08_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// EXT + VIRT +#define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// DEMOTE + VIRT +#define HWY_RVV_FOREACH_08_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// SEW for unsigned: +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, uint, u, NAME, OP) + +// SEW for signed: +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, int, i, NAME, OP) + +// SEW for float: +#if HWY_HAVE_FLOAT16 +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, float, f, NAME, OP) +#else +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#endif +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, float, f, NAME, OP) + +// Commonly used type/SEW groups: +#define HWY_RVV_FOREACH_UI08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) + +// For all combinations of SEW: +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) + +// Commonly used type categories: +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) + +// Assemble types for use in x-macros +#define HWY_RVV_T(BASE, SEW) BASE##SEW##_t +#define HWY_RVV_D(BASE, SEW, N, SHIFT) Simd<HWY_RVV_T(BASE, SEW), N, SHIFT> +#define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##LMUL##_t +#define HWY_RVV_M(MLEN) vbool##MLEN##_t + +} // namespace detail + +// Until we have full intrinsic support for fractional LMUL, mixed-precision +// code can use LMUL 1..8 (adequate unless they need many registers). +#define HWY_SPECIALIZE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <> \ + struct DFromV_t<HWY_RVV_V(BASE, SEW, LMUL)> { \ + using Lane = HWY_RVV_T(BASE, SEW); \ + using type = ScalableTag<Lane, SHIFT>; \ + }; + +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) +#undef HWY_SPECIALIZE + +// ------------------------------ Lanes + +// WARNING: we want to query VLMAX/sizeof(T), but this actually changes VL! +// vlenb is not exposed through intrinsics and vreadvl is not VLMAX. +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + size_t actual = v##OP##SEW##LMUL(); \ + /* Common case of full vectors: avoid any extra instructions. */ \ + /* actual includes LMUL, so do not shift again. */ \ + if (detail::IsFull(d)) return actual; \ + /* Check for virtual LMUL, e.g. "uint16mf8_t" (not provided by */ \ + /* intrinsics). In this case the actual LMUL is 1/4, so divide by */ \ + /* another factor of two. */ \ + if (detail::ScaleByPower(128 / SEW, SHIFT) == 1) actual >>= 1; \ + return HWY_MIN(actual, N); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL_VIRT) +#undef HWY_RVV_LANES + +template <size_t N, int kPow2> +HWY_API size_t Lanes(Simd<bfloat16_t, N, kPow2> /* tag*/) { + return Lanes(Simd<uint16_t, N, kPow2>()); +} + +// ------------------------------ Common x-macros + +// Last argument to most intrinsics. Use when the op has no d arg of its own, +// which means there is no user-specified cap. +#define HWY_RVV_AVL(SEW, SHIFT) \ + Lanes(ScalableTag<HWY_RVV_T(uint, SEW), SHIFT>()) + +// vector = f(vector), e.g. Not +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, scalar), e.g. detail::AddS +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask) +#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return vm##OP##_m_b##MLEN(m, ~0ull); \ + } + +// ================================================== INIT + +// ------------------------------ Set + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_T(BASE, SEW) arg) { \ + return v##OP##_##CHAR##SEW##LMUL(arg, Lanes(d)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL_VIRT) +#undef HWY_RVV_SET + +// Treat bfloat16_t as uint16_t (using the previously defined Set overloads); +// required for Zero and VFromD. +template <size_t N, int kPow2> +decltype(Set(Simd<uint16_t, N, kPow2>(), 0)) Set(Simd<bfloat16_t, N, kPow2> d, + bfloat16_t arg) { + return Set(RebindToUnsigned<decltype(d)>(), arg.bits); +} + +template <class D> +using VFromD = decltype(Set(D(), TFromD<D>())); + +// ------------------------------ Zero + +template <class D> +HWY_API VFromD<D> Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +// RVV vundefined is 'poisoned' such that even XORing a _variable_ initialized +// by it gives unpredictable results. It should only be used for maskoff, so +// keep it internal. For the Highway op, just use Zero (single instruction). +namespace detail { +#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) /* tag */) { \ + return v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ + } + +HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined, _ALL) +#undef HWY_RVV_UNDEFINED +} // namespace detail + +template <class D> +HWY_API VFromD<D> Undefined(D d) { + return Zero(d); +} + +// ------------------------------ BitCast + +namespace detail { + +// Halves LMUL. (Use LMUL arg for the source so we can use _TRUNC.) +#define HWY_RVV_TRUNC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +#undef HWY_RVV_TRUNC + +// Doubles LMUL to `d2` (the arg is only necessary for _VIRT). +#define HWY_RVV_EXT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMULD) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULD(v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) +#undef HWY_RVV_EXT + +// For virtual LMUL e.g. 'uint32mf4_t', the return type should be mf2, which is +// the same as the actual input type. +#define HWY_RVV_EXT_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT_VIRT, Ext, lmul_ext, _VIRT) +#undef HWY_RVV_EXT_VIRT + +// For BitCastToByte, the D arg is only to prevent duplicate definitions caused +// by _ALL_VIRT. + +// There is no reinterpret from u8 <-> u8, so just return. +#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <typename T, size_t N> \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd<T, N, SHIFT> /* d */, \ + vuint8##LMUL##_t v) { \ + return v; \ + } \ + template <size_t N> \ + HWY_API vuint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v; \ + } + +// For i8, need a single reinterpret (HWY_RVV_CAST_IF does two). +#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <typename T, size_t N> \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd<T, N, SHIFT> /* d */, \ + vint8##LMUL##_t v) { \ + return vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ + } \ + template <size_t N> \ + HWY_API vint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ + } + +// Separate u/i because clang only provides signed <-> unsigned reinterpret for +// the same SEW. +#define HWY_RVV_CAST_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <typename T, size_t N> \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd<T, N, SHIFT> /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v); \ + } \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <typename T, size_t N> \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd<T, N, SHIFT> /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ + } \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ + } + +// Additional versions for virtual LMUL using LMULH for byte vectors. +#define HWY_RVV_CAST_VIRT_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <typename T, size_t N> \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd<T, N, SHIFT> /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v)); \ + } \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v2); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_VIRT_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <typename T, size_t N> \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd<T, N, SHIFT> /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v))); \ + } \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + v##OP##_v_u8##LMUL##_u##SEW##LMUL(v2)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_VIRT_U, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) + +#undef HWY_RVV_CAST_U8 +#undef HWY_RVV_CAST_I8 +#undef HWY_RVV_CAST_U +#undef HWY_RVV_CAST_IF +#undef HWY_RVV_CAST_VIRT_U +#undef HWY_RVV_CAST_VIRT_IF + +template <size_t N, int kPow2> +HWY_INLINE VFromD<Simd<uint16_t, N, kPow2>> BitCastFromByte( + Simd<bfloat16_t, N, kPow2> /* d */, VFromD<Simd<uint8_t, N, kPow2>> v) { + return BitCastFromByte(Simd<uint16_t, N, kPow2>(), v); +} + +} // namespace detail + +template <class D, class FromV> +HWY_API VFromD<D> BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(d, v)); +} + +namespace detail { + +template <class V, class DU = RebindToUnsigned<DFromV<V>>> +HWY_INLINE VFromD<DU> BitCastToUnsigned(V v) { + return BitCast(DU(), v); +} + +} // namespace detail + +// ------------------------------ Iota + +namespace detail { + +#define HWY_RVV_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + return v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v, _ALL_VIRT) +#undef HWY_RVV_IOTA + +template <class D, class DU = RebindToUnsigned<D>> +HWY_INLINE VFromD<DU> Iota0(const D /*d*/) { + return BitCastToUnsigned(Iota0(DU())); +} + +} // namespace detail + +// ================================================== LOGICAL + +// ------------------------------ Not + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not, _ALL) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V Not(const V v) { + using DF = DFromV<V>; + using DU = RebindToUnsigned<DF>; + return BitCast(DF(), Not(BitCast(DU(), v))); +} + +// ------------------------------ And + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and, _ALL) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V And(const V a, const V b) { + using DF = DFromV<V>; + using DU = RebindToUnsigned<DF>; + return BitCast(DF(), And(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Or + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or, _ALL) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V Or(const V a, const V b) { + using DF = DFromV<V>; + using DU = RebindToUnsigned<DF>; + return BitCast(DF(), Or(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Xor + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor, _ALL) + +template <class V, HWY_IF_FLOAT_V(V)> +HWY_API V Xor(const V a, const V b) { + using DF = DFromV<V>; + using DU = RebindToUnsigned<DF>; + return BitCast(DF(), Xor(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ AndNot +template <class V> +HWY_API V AndNot(const V not_a, const V b) { + return And(Not(not_a), b); +} + +// ------------------------------ Xor3 +template <class V> +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 +template <class V> +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template <class V> +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ CopySign + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj, _ALL) + +template <class V> +HWY_API V CopySignToAbs(const V abs, const V sign) { + // RVV can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Add + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, ReverseSubS, rsub_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, ReverseSubS, frsub_vf, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd, _ALL) + +// ------------------------------ Sub +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) + +// ------------------------------ SaturatedAdd + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) + +// ------------------------------ SaturatedSub + +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) + +HWY_RVV_FOREACH_I08(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) + +// ------------------------------ AverageRound + +// TODO(janwas): check vxrm rounding mode +HWY_RVV_FOREACH_U08(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, AverageRound, aaddu, _ALL) + +// ------------------------------ ShiftLeft[Same] + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <int kBits> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(v, kBits, HWY_RVV_AVL(SEW, SHIFT)); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(v, static_cast<uint8_t>(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll, _ALL) + +// ------------------------------ ShiftRight[Same] + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) + +#undef HWY_RVV_SHIFT + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template <class VU8> +HWY_API VFromD<Repartition<uint64_t, DFromV<VU8>>> SumsOf8(const VU8 v) { + const DFromV<VU8> du8; + const RepartitionToWide<decltype(du8)> du16; + const RepartitionToWide<decltype(du16)> du32; + const RepartitionToWide<decltype(du32)> du64; + using VU16 = VFromD<decltype(du16)>; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = detail::AndS(BitCast(du16, v), 0xFF); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); +} + +// ------------------------------ RotateRight +template <int kBits, class V> +HWY_API V RotateRight(const V v) { + constexpr size_t kSizeInBits = sizeof(TFromV<V>) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v)); +} + +// ------------------------------ Shl +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, bits, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll, _ALL) + +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, detail::BitCastToUnsigned(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll, _ALL) + +// ------------------------------ Shr + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) + +#undef HWY_RVV_SHIFT_II +#undef HWY_RVV_SHIFT_VV + +// ------------------------------ Min + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin, _ALL) + +// ------------------------------ Max + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax, _ALL) + +// ------------------------------ Mul + +HWY_RVV_FOREACH_UI163264(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// ------------------------------ MulHigh + +// Only for internal use (Highway only promises MulHigh for 16-bit inputs). +// Used by MulEven; vwmul does not work for m8. +namespace detail { +HWY_RVV_FOREACH_I32(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) +HWY_RVV_FOREACH_U32(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_U64(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_U16(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) + +// ------------------------------ MulFixedPoint15 +HWY_RVV_FOREACH_I16(HWY_RVV_RETV_ARGVV, MulFixedPoint15, smul, _ALL) + +// ------------------------------ Div +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) + +// ------------------------------ ApproximateReciprocal +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7, _ALL) + +// ------------------------------ Sqrt +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt, _ALL) + +// ------------------------------ ApproximateReciprocalSqrt +HWY_RVV_FOREACH_F32(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7, _ALL) + +// ------------------------------ MulAdd +// Note: op is still named vv, not vvv. +#define HWY_RVV_FMA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ + HWY_RVV_V(BASE, SEW, LMUL) add) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(add, mul, x, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc, _ALL) + +// ------------------------------ NegMulAdd +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac, _ALL) + +// ------------------------------ MulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac, _ALL) + +// ------------------------------ NegMulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) + +#undef HWY_RVV_FMA + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in +// vboolXX_t is a power of two divisor for vector bits. SLEN 8 / LMUL 1 = 1/8th +// of all bits; SLEN 8 / LMUL 4 = half of all bits. + +// mask = f(vector, vector) +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN(a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(vector, scalar) +#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return v##OP##_##CHAR##SEW##LMUL##_b##MLEN(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// ------------------------------ Eq +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL) +} // namespace detail + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL) +} // namespace detail + +// ------------------------------ Lt +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, LtS, msltu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL) +} // namespace detail + +// ------------------------------ Le +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL) + +#undef HWY_RVV_RETM_ARGVV +#undef HWY_RVV_RETM_ARGVS + +// ------------------------------ Gt/Ge + +template <class V> +HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) { + return Le(b, a); +} + +template <class V> +HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { + return Lt(b, a); +} + +// ------------------------------ TestBit +template <class V> +HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { + return detail::NeS(And(a, bit), 0); +} + +// ------------------------------ Not +// NOLINTNEXTLINE +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, Not, not ) + +// ------------------------------ And + +// mask = f(mask_a, mask_b) (note arg2,arg1 order!) +#define HWY_RVV_RETM_ARGMM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) a, HWY_RVV_M(MLEN) b) { \ + return vm##OP##_mm_b##MLEN(b, a, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, And, and) + +// ------------------------------ AndNot +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, AndNot, andn) + +// ------------------------------ Or +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Or, or) + +// ------------------------------ Xor +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) + +// ------------------------------ ExclusiveNeither +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, ExclusiveNeither, xnor) + +#undef HWY_RVV_RETM_ARGMM + +// ------------------------------ IfThenElse +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ + HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_vvm_##CHAR##SEW##LMUL(no, yes, m, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge, _ALL) + +#undef HWY_RVV_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template <class M, class V> +HWY_API V IfThenElseZero(const M mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV<V>())); +} + +// ------------------------------ IfThenZeroElse + +#define HWY_RVV_IF_THEN_ZERO_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return v##OP##_##CHAR##SEW##LMUL(no, 0, m, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, merge_vxm, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) + +#undef HWY_RVV_IF_THEN_ZERO_ELSE + +// ------------------------------ MaskFromVec + +template <class V> +HWY_API auto MaskFromVec(const V v) -> decltype(Eq(v, v)) { + return detail::NeS(v, 0); +} + +template <class D> +using MFromD = decltype(MaskFromVec(Zero(D()))); + +template <class D, typename MFrom> +HWY_API MFromD<D> RebindMask(const D /*d*/, const MFrom mask) { + // No need to check lane size/LMUL are the same: if not, casting MFrom to + // MFromD<D> would fail. + return mask; +} + +// ------------------------------ VecFromMask + +namespace detail { +#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_M(MLEN) m) { \ + return v##OP##_##CHAR##SEW##LMUL##_m(m, v0, v0, 1, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, SubS, sub_vx, _ALL) +#undef HWY_RVV_VEC_FROM_MASK +} // namespace detail + +template <class D, HWY_IF_NOT_FLOAT_D(D)> +HWY_API VFromD<D> VecFromMask(const D d, MFromD<D> mask) { + return detail::SubS(Zero(d), mask); +} + +template <class D, HWY_IF_FLOAT_D(D)> +HWY_API VFromD<D> VecFromMask(const D d, MFromD<D> mask) { + return BitCast(d, VecFromMask(RebindToUnsigned<D>(), mask)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec) + +template <class V> +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ ZeroIfNegative +template <class V> +HWY_API V ZeroIfNegative(const V v) { + return IfThenZeroElse(detail::LtS(v, 0), v); +} + +// ------------------------------ BroadcastSignBit +template <class V> +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight<sizeof(TFromV<V>) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template <class V> +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float"); + const DFromV<V> d; + const RebindToSigned<decltype(d)> di; + + MFromD<decltype(d)> m = + MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ FindFirstTrue + +#define HWY_RVV_FIND_FIRST_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template <class D> \ + HWY_API intptr_t FindFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return vfirst_m_b##MLEN(m, Lanes(d)); \ + } \ + template <class D> \ + HWY_API size_t FindKnownFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return static_cast<size_t>(vfirst_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_FIND_FIRST_TRUE, , _) +#undef HWY_RVV_FIND_FIRST_TRUE + +// ------------------------------ AllFalse +template <class D> +HWY_API bool AllFalse(D d, MFromD<D> m) { + return FindFirstTrue(d, m) < 0; +} + +// ------------------------------ AllTrue + +#define HWY_RVV_ALL_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template <class D> \ + HWY_API bool AllTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return AllFalse(d, vmnot_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) +#undef HWY_RVV_ALL_TRUE + +// ------------------------------ CountTrue + +#define HWY_RVV_COUNT_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template <class D> \ + HWY_API size_t CountTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return vcpop_m_b##MLEN(m, Lanes(d)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) +#undef HWY_RVV_COUNT_TRUE + +// ================================================== MEMORY + +// ------------------------------ Load + +#define HWY_RVV_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL_VIRT) +#undef HWY_RVV_LOAD + +// There is no native BF16, treat as uint16_t. +template <size_t N, int kPow2> +HWY_API VFromD<Simd<uint16_t, N, kPow2>> Load( + Simd<bfloat16_t, N, kPow2> d, const bfloat16_t* HWY_RESTRICT p) { + return Load(RebindToUnsigned<decltype(d)>(), + reinterpret_cast<const uint16_t * HWY_RESTRICT>(p)); +} + +template <size_t N, int kPow2> +HWY_API void Store(VFromD<Simd<uint16_t, N, kPow2>> v, + Simd<bfloat16_t, N, kPow2> d, bfloat16_t* HWY_RESTRICT p) { + Store(v, RebindToUnsigned<decltype(d)>(), + reinterpret_cast<uint16_t * HWY_RESTRICT>(p)); +} + +// ------------------------------ LoadU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template <class D> +HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ MaskedLoad + +#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, Zero(d), p, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) +#undef HWY_RVV_MASKED_LOAD + +// ------------------------------ Store + +#define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) +#undef HWY_RVV_STORE + +// ------------------------------ BlendedStore + +#define HWY_RVV_BLENDED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL##_m(m, p, v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_BLENDED_STORE, BlendedStore, se, _ALL_VIRT) +#undef HWY_RVV_BLENDED_STORE + +namespace detail { + +#define HWY_RVV_STOREN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return v##OP##SEW##_v_##CHAR##SEW##LMUL(p, v, count); \ + } +HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) +#undef HWY_RVV_STOREN + +} // namespace detail + +// ------------------------------ StoreU + +// RVV only requires lane alignment, not natural alignment of the entire vector. +template <class V, class D> +HWY_API void StoreU(const V v, D d, TFromD<D>* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ Stream +template <class V, class D, typename T> +HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ ScatterOffset + +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL_VIRT) +#undef HWY_RVV_SCATTER + +// ------------------------------ ScatterIndex + +template <class D, HWY_IF_LANE_SIZE_D(D, 4)> +HWY_API void ScatterIndex(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT base, + const VFromD<RebindToSigned<D>> index) { + return ScatterOffset(v, d, base, ShiftLeft<2>(index)); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API void ScatterIndex(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT base, + const VFromD<RebindToSigned<D>> index) { + return ScatterOffset(v, d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ GatherOffset + +#define HWY_RVV_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + return v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + base, detail::BitCastToUnsigned(offset), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL_VIRT) +#undef HWY_RVV_GATHER + +// ------------------------------ GatherIndex + +template <class D, HWY_IF_LANE_SIZE_D(D, 4)> +HWY_API VFromD<D> GatherIndex(D d, const TFromD<D>* HWY_RESTRICT base, + const VFromD<RebindToSigned<D>> index) { + return GatherOffset(d, base, ShiftLeft<2>(index)); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API VFromD<D> GatherIndex(D d, const TFromD<D>* HWY_RESTRICT base, + const VFromD<RebindToSigned<D>> index) { + return GatherOffset(d, base, ShiftLeft<3>(index)); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_RVV_LOAD2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, unaligned, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD2, LoadInterleaved2, lseg2, _LE2_VIRT) +#undef HWY_RVV_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_RVV_LOAD3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, &v2, unaligned, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD3, LoadInterleaved3, lseg3, _LE2_VIRT) +#undef HWY_RVV_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_RVV_LOAD4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2, HWY_RVV_V(BASE, SEW, LMUL) & v3) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(&v0, &v1, &v2, &v3, aligned, \ + Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD4, LoadInterleaved4, lseg4, _LE2_VIRT) +#undef HWY_RVV_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_RVV_STORE2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(unaligned, v0, v1, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE2, StoreInterleaved2, sseg2, _LE2_VIRT) +#undef HWY_RVV_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(unaligned, v0, v1, v2, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE3, StoreInterleaved3, sseg3, _LE2_VIRT) +#undef HWY_RVV_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_RVV_STORE4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT aligned) { \ + v##OP##e##SEW##_v_##CHAR##SEW##LMUL(aligned, v0, v1, v2, v3, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE4, StoreInterleaved4, sseg4, _LE2_VIRT) +#undef HWY_RVV_STORE4 + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// SEW is for the input so we can use F16 (no-op if not supported). +#define HWY_RVV_PROMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWD##LMULD(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_PROMOTE, PromoteTo, vzext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I08(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_PROMOTE, PromoteTo, vsext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_F16(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT_VIRT) +HWY_RVV_FOREACH_F32(HWY_RVV_PROMOTE, PromoteTo, vfwcvt_f_f_v_, _EXT_VIRT) +#undef HWY_RVV_PROMOTE + +// The above X-macro cannot handle 4x promotion nor type switching. +// TODO(janwas): use BASE2 arg to allow the latter. +#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN, \ + SHIFT, ADD) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, BITS, LMUL) \ + PromoteTo(HWY_RVV_D(BASE, BITS, N, SHIFT + ADD) d, \ + HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ + return OP##CHAR##BITS##LMUL(v, Lanes(d)); \ + } + +#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -2, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1, 0, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2, 1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4, 2, 1) + +#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, mf2, mf8, -3, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4, -2, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2, -1, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1, 0, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2, 1, 2) + +HWY_RVV_PROMOTE_X4(vzext_vf4_, uint, u, 32, uint, 8) +HWY_RVV_PROMOTE_X4(vsext_vf4_, int, i, 32, int, 8) + +// i32 to f64 +HWY_RVV_PROMOTE_X2(vfwcvt_f_x_v_, float, f, 64, int, 32) + +#undef HWY_RVV_PROMOTE_X4 +#undef HWY_RVV_PROMOTE_X2 +#undef HWY_RVV_PROMOTE + +// Unsigned to signed: cast for unsigned promote. +template <size_t N, int kPow2> +HWY_API auto PromoteTo(Simd<int16_t, N, kPow2> d, + VFromD<Rebind<uint8_t, decltype(d)>> v) + -> VFromD<decltype(d)> { + return BitCast(d, PromoteTo(RebindToUnsigned<decltype(d)>(), v)); +} + +template <size_t N, int kPow2> +HWY_API auto PromoteTo(Simd<int32_t, N, kPow2> d, + VFromD<Rebind<uint8_t, decltype(d)>> v) + -> VFromD<decltype(d)> { + return BitCast(d, PromoteTo(RebindToUnsigned<decltype(d)>(), v)); +} + +template <size_t N, int kPow2> +HWY_API auto PromoteTo(Simd<int32_t, N, kPow2> d, + VFromD<Rebind<uint16_t, decltype(d)>> v) + -> VFromD<decltype(d)> { + return BitCast(d, PromoteTo(RebindToUnsigned<decltype(d)>(), v)); +} + +template <size_t N, int kPow2> +HWY_API auto PromoteTo(Simd<float32_t, N, kPow2> d, + VFromD<Rebind<bfloat16_t, decltype(d)>> v) + -> VFromD<decltype(d)> { + const RebindToSigned<decltype(d)> di32; + const Rebind<uint16_t, decltype(d)> du16; + return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ DemoteTo U + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 0, Lanes(d)); \ + } \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME##Shr16( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + +// Unsigned -> unsigned (also used for bf16) +namespace detail { +HWY_RVV_FOREACH_U16(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE, DemoteTo, vnclipu_wx_, _DEMOTE_VIRT) +} // namespace detail + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_I_TO_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(uint, SEWH, LMULH) NAME( \ + HWY_RVV_D(uint, SEWH, N, SHIFT - 1) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + /* First clamp negative numbers to zero to match x86 packus. */ \ + return detail::DemoteTo(d, detail::BitCastToUnsigned(detail::MaxS(v, 0))); \ + } +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_I_TO_U + +template <size_t N> +HWY_API vuint8mf8_t DemoteTo(Simd<uint8_t, N, -3> d, const vint32mf2_t v) { + return vnclipu_wx_u8mf8(DemoteTo(Simd<uint16_t, N, -2>(), v), 0, Lanes(d)); +} +template <size_t N> +HWY_API vuint8mf4_t DemoteTo(Simd<uint8_t, N, -2> d, const vint32m1_t v) { + return vnclipu_wx_u8mf4(DemoteTo(Simd<uint16_t, N, -1>(), v), 0, Lanes(d)); +} +template <size_t N> +HWY_API vuint8mf2_t DemoteTo(Simd<uint8_t, N, -1> d, const vint32m2_t v) { + return vnclipu_wx_u8mf2(DemoteTo(Simd<uint16_t, N, 0>(), v), 0, Lanes(d)); +} +template <size_t N> +HWY_API vuint8m1_t DemoteTo(Simd<uint8_t, N, 0> d, const vint32m4_t v) { + return vnclipu_wx_u8m1(DemoteTo(Simd<uint16_t, N, 1>(), v), 0, Lanes(d)); +} +template <size_t N> +HWY_API vuint8m2_t DemoteTo(Simd<uint8_t, N, 1> d, const vint32m8_t v) { + return vnclipu_wx_u8m2(DemoteTo(Simd<uint16_t, N, 2>(), v), 0, Lanes(d)); +} + +HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { + const size_t avl = Lanes(ScalableTag<uint8_t, -3>()); + return vnclipu_wx_u8mf8(vnclipu_wx_u16mf4(v, 0, avl), 0, avl); +} +HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) { + const size_t avl = Lanes(ScalableTag<uint8_t, -2>()); + return vnclipu_wx_u8mf4(vnclipu_wx_u16mf2(v, 0, avl), 0, avl); +} +HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) { + const size_t avl = Lanes(ScalableTag<uint8_t, -1>()); + return vnclipu_wx_u8mf2(vnclipu_wx_u16m1(v, 0, avl), 0, avl); +} +HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) { + const size_t avl = Lanes(ScalableTag<uint8_t, 0>()); + return vnclipu_wx_u8m1(vnclipu_wx_u16m2(v, 0, avl), 0, avl); +} +HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) { + const size_t avl = Lanes(ScalableTag<uint8_t, 1>()); + return vnclipu_wx_u8m2(vnclipu_wx_u16m4(v, 0, avl), 0, avl); +} + +// ------------------------------ Truncations + +template <size_t N> +HWY_API vuint8mf8_t TruncateTo(Simd<uint8_t, N, -3> d, + const VFromD<Simd<uint64_t, N, 0>> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFF, avl); + const vuint32mf2_t v2 = vnclipu_wx_u32mf2(v1, 0, avl); + const vuint16mf4_t v3 = vnclipu_wx_u16mf4(v2, 0, avl); + return vnclipu_wx_u8mf8(v3, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf4_t TruncateTo(Simd<uint8_t, N, -2> d, + const VFromD<Simd<uint64_t, N, 1>> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFF, avl); + const vuint32m1_t v2 = vnclipu_wx_u32m1(v1, 0, avl); + const vuint16mf2_t v3 = vnclipu_wx_u16mf2(v2, 0, avl); + return vnclipu_wx_u8mf4(v3, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf2_t TruncateTo(Simd<uint8_t, N, -1> d, + const VFromD<Simd<uint64_t, N, 2>> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFF, avl); + const vuint32m2_t v2 = vnclipu_wx_u32m2(v1, 0, avl); + const vuint16m1_t v3 = vnclipu_wx_u16m1(v2, 0, avl); + return vnclipu_wx_u8mf2(v3, 0, avl); +} + +template <size_t N> +HWY_API vuint8m1_t TruncateTo(Simd<uint8_t, N, 0> d, + const VFromD<Simd<uint64_t, N, 3>> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFF, avl); + const vuint32m4_t v2 = vnclipu_wx_u32m4(v1, 0, avl); + const vuint16m2_t v3 = vnclipu_wx_u16m2(v2, 0, avl); + return vnclipu_wx_u8m1(v3, 0, avl); +} + +template <size_t N> +HWY_API vuint16mf4_t TruncateTo(Simd<uint16_t, N, -2> d, + const VFromD<Simd<uint64_t, N, 0>> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = vnclipu_wx_u32mf2(v1, 0, avl); + return vnclipu_wx_u16mf4(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint16mf2_t TruncateTo(Simd<uint16_t, N, -1> d, + const VFromD<Simd<uint64_t, N, 1>> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFFFF, avl); + const vuint32m1_t v2 = vnclipu_wx_u32m1(v1, 0, avl); + return vnclipu_wx_u16mf2(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint16m1_t TruncateTo(Simd<uint16_t, N, 0> d, + const VFromD<Simd<uint64_t, N, 2>> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFFFF, avl); + const vuint32m2_t v2 = vnclipu_wx_u32m2(v1, 0, avl); + return vnclipu_wx_u16m1(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint16m2_t TruncateTo(Simd<uint16_t, N, 1> d, + const VFromD<Simd<uint64_t, N, 3>> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFFFF, avl); + const vuint32m4_t v2 = vnclipu_wx_u32m4(v1, 0, avl); + return vnclipu_wx_u16m2(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint32mf2_t TruncateTo(Simd<uint32_t, N, -1> d, + const VFromD<Simd<uint64_t, N, 0>> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32mf2(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint32m1_t TruncateTo(Simd<uint32_t, N, 0> d, + const VFromD<Simd<uint64_t, N, 1>> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m1(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint32m2_t TruncateTo(Simd<uint32_t, N, 1> d, + const VFromD<Simd<uint64_t, N, 2>> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m2(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint32m4_t TruncateTo(Simd<uint32_t, N, 2> d, + const VFromD<Simd<uint64_t, N, 3>> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = vand(v, 0xFFFFFFFFu, avl); + return vnclipu_wx_u32m4(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf8_t TruncateTo(Simd<uint8_t, N, -3> d, + const VFromD<Simd<uint32_t, N, -1>> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = vand(v, 0xFF, avl); + const vuint16mf4_t v2 = vnclipu_wx_u16mf4(v1, 0, avl); + return vnclipu_wx_u8mf8(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf4_t TruncateTo(Simd<uint8_t, N, -2> d, + const VFromD<Simd<uint32_t, N, 0>> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = vand(v, 0xFF, avl); + const vuint16mf2_t v2 = vnclipu_wx_u16mf2(v1, 0, avl); + return vnclipu_wx_u8mf4(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf2_t TruncateTo(Simd<uint8_t, N, -1> d, + const VFromD<Simd<uint32_t, N, 1>> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = vand(v, 0xFF, avl); + const vuint16m1_t v2 = vnclipu_wx_u16m1(v1, 0, avl); + return vnclipu_wx_u8mf2(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint8m1_t TruncateTo(Simd<uint8_t, N, 0> d, + const VFromD<Simd<uint32_t, N, 2>> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = vand(v, 0xFF, avl); + const vuint16m2_t v2 = vnclipu_wx_u16m2(v1, 0, avl); + return vnclipu_wx_u8m1(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint8m2_t TruncateTo(Simd<uint8_t, N, 1> d, + const VFromD<Simd<uint32_t, N, 3>> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = vand(v, 0xFF, avl); + const vuint16m4_t v2 = vnclipu_wx_u16m4(v1, 0, avl); + return vnclipu_wx_u8m2(v2, 0, avl); +} + +template <size_t N> +HWY_API vuint16mf4_t TruncateTo(Simd<uint16_t, N, -2> d, + const VFromD<Simd<uint32_t, N, -1>> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16mf4(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint16mf2_t TruncateTo(Simd<uint16_t, N, -1> d, + const VFromD<Simd<uint32_t, N, 0>> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16mf2(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint16m1_t TruncateTo(Simd<uint16_t, N, 0> d, + const VFromD<Simd<uint32_t, N, 1>> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m1(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint16m2_t TruncateTo(Simd<uint16_t, N, 1> d, + const VFromD<Simd<uint32_t, N, 2>> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m2(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint16m4_t TruncateTo(Simd<uint16_t, N, 2> d, + const VFromD<Simd<uint32_t, N, 3>> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = vand(v, 0xFFFF, avl); + return vnclipu_wx_u16m4(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf8_t TruncateTo(Simd<uint8_t, N, -3> d, + const VFromD<Simd<uint16_t, N, -2>> v) { + const size_t avl = Lanes(d); + const vuint16mf4_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf8(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf4_t TruncateTo(Simd<uint8_t, N, -2> d, + const VFromD<Simd<uint16_t, N, -1>> v) { + const size_t avl = Lanes(d); + const vuint16mf2_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf4(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8mf2_t TruncateTo(Simd<uint8_t, N, -1> d, + const VFromD<Simd<uint16_t, N, 0>> v) { + const size_t avl = Lanes(d); + const vuint16m1_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8mf2(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8m1_t TruncateTo(Simd<uint8_t, N, 0> d, + const VFromD<Simd<uint16_t, N, 1>> v) { + const size_t avl = Lanes(d); + const vuint16m2_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m1(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8m2_t TruncateTo(Simd<uint8_t, N, 1> d, + const VFromD<Simd<uint16_t, N, 2>> v) { + const size_t avl = Lanes(d); + const vuint16m4_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m2(v1, 0, avl); +} + +template <size_t N> +HWY_API vuint8m4_t TruncateTo(Simd<uint8_t, N, 2> d, + const VFromD<Simd<uint16_t, N, 3>> v) { + const size_t avl = Lanes(d); + const vuint16m8_t v1 = vand(v, 0xFF, avl); + return vnclipu_wx_u8m4(v1, 0, avl); +} + +// ------------------------------ DemoteTo I + +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE, DemoteTo, vnclip_wx_, _DEMOTE_VIRT) + +template <size_t N> +HWY_API vint8mf8_t DemoteTo(Simd<int8_t, N, -3> d, const vint32mf2_t v) { + return DemoteTo(d, DemoteTo(Simd<int16_t, N, -2>(), v)); +} +template <size_t N> +HWY_API vint8mf4_t DemoteTo(Simd<int8_t, N, -2> d, const vint32m1_t v) { + return DemoteTo(d, DemoteTo(Simd<int16_t, N, -1>(), v)); +} +template <size_t N> +HWY_API vint8mf2_t DemoteTo(Simd<int8_t, N, -1> d, const vint32m2_t v) { + return DemoteTo(d, DemoteTo(Simd<int16_t, N, 0>(), v)); +} +template <size_t N> +HWY_API vint8m1_t DemoteTo(Simd<int8_t, N, 0> d, const vint32m4_t v) { + return DemoteTo(d, DemoteTo(Simd<int16_t, N, 1>(), v)); +} +template <size_t N> +HWY_API vint8m2_t DemoteTo(Simd<int8_t, N, 1> d, const vint32m8_t v) { + return DemoteTo(d, DemoteTo(Simd<int16_t, N, 2>(), v)); +} + +#undef HWY_RVV_DEMOTE + +// ------------------------------ DemoteTo F + +// SEW is for the source so we can use _DEMOTE. +#define HWY_RVV_DEMOTE_F(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return OP##SEWH##LMULH(v, Lanes(d)); \ + } + +#if HWY_HAVE_FLOAT16 +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, + _DEMOTE_VIRT) +#endif +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, vfncvt_rod_f_f_w_f, + _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_F + +// TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. +template <size_t N> +HWY_API vint32mf2_t DemoteTo(Simd<int32_t, N, -2> d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template <size_t N> +HWY_API vint32mf2_t DemoteTo(Simd<int32_t, N, -1> d, const vfloat64m1_t v) { + return vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template <size_t N> +HWY_API vint32m1_t DemoteTo(Simd<int32_t, N, 0> d, const vfloat64m2_t v) { + return vfncvt_rtz_x_f_w_i32m1(v, Lanes(d)); +} +template <size_t N> +HWY_API vint32m2_t DemoteTo(Simd<int32_t, N, 1> d, const vfloat64m4_t v) { + return vfncvt_rtz_x_f_w_i32m2(v, Lanes(d)); +} +template <size_t N> +HWY_API vint32m4_t DemoteTo(Simd<int32_t, N, 2> d, const vfloat64m8_t v) { + return vfncvt_rtz_x_f_w_i32m4(v, Lanes(d)); +} + +template <size_t N, int kPow2> +HWY_API VFromD<Simd<uint16_t, N, kPow2>> DemoteTo( + Simd<bfloat16_t, N, kPow2> d, VFromD<Simd<float, N, kPow2 + 1>> v) { + const RebindToUnsigned<decltype(d)> du16; + const Rebind<uint32_t, decltype(d)> du32; + return detail::DemoteToShr16(du16, BitCast(du32, v)); +} + +// ------------------------------ ConvertTo F + +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + return vfcvt_f_x_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(uint, SEW, LMUL) v) {\ + return vfcvt_f_xu_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + /* Truncates (rounds toward zero). */ \ + template <size_t N> \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(int, SEW, N, SHIFT) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_rtz_x_f_v_i##SEW##LMUL(v, Lanes(d)); \ + } \ +// API only requires f32 but we provide f64 for internal use. +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) +#undef HWY_RVV_CONVERT + +// Uses default rounding mode. Must be separate because there is no D arg. +#define HWY_RVV_NEAREST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) NearestInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return vfcvt_x_f_v_i##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_NEAREST, _, _, _ALL) +#undef HWY_RVV_NEAREST + +// ================================================== COMBINE + +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template <typename T, size_t N, int kPow2> +size_t LanesPerBlock(Simd<T, N, kPow2> d) { + size_t lpb = 16 / sizeof(T); + if (IsFull(d)) return lpb; + // Also honor the user-specified (constexpr) N limit. + lpb = HWY_MIN(lpb, N); + // No fraction, we're done. + if (kPow2 >= 0) return lpb; + // Fractional LMUL: Lanes(d) may be smaller than lpb, so honor that. + return HWY_MIN(lpb, Lanes(d)); +} + +template <class D, class V> +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned<TFromD<D>>; + return AndS(iota0, static_cast<T>(~(LanesPerBlock(d) - 1))); +} + +template <size_t kLanes, class D> +HWY_INLINE MFromD<D> FirstNPerBlock(D /* tag */) { + const RebindToUnsigned<D> du; + const RebindToSigned<D> di; + using TU = TFromD<decltype(du)>; + const auto idx_mod = AndS(Iota0(du), static_cast<TU>(LanesPerBlock(du) - 1)); + return LtS(BitCast(di, idx_mod), static_cast<TFromD<decltype(di)>>(kLanes)); +} + +// vector = f(vector, vector, size_t) +#define HWY_RVV_SLIDE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dst, HWY_RVV_V(BASE, SEW, LMUL) src, \ + size_t lanes) { \ + return v##OP##_vx_##CHAR##SEW##LMUL(dst, src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideUp, slideup, _ALL) +HWY_RVV_FOREACH(HWY_RVV_SLIDE, SlideDown, slidedown, _ALL) + +#undef HWY_RVV_SLIDE + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template <class D, class V> +HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template <class D, class V> +HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { + return detail::SlideUp(lo, hi, Lanes(d) / 2); +} + +// ------------------------------ ConcatUpperUpper +template <class D, class V> +HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { + // Move upper half into lower + const auto lo_down = detail::SlideDown(lo, lo, Lanes(d) / 2); + return ConcatUpperLower(d, hi, lo_down); +} + +// ------------------------------ ConcatLowerUpper +template <class D, class V> +HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { + // Move half of both inputs to the other half + const auto hi_up = detail::SlideUp(hi, hi, Lanes(d) / 2); + const auto lo_down = detail::SlideDown(lo, lo, Lanes(d) / 2); + return ConcatUpperLower(d, hi_up, lo_down); +} + +// ------------------------------ Combine +template <class D2, class V> +HWY_API VFromD<D2> Combine(D2 d2, const V hi, const V lo) { + return detail::SlideUp(detail::Ext(d2, lo), detail::Ext(d2, hi), + Lanes(d2) / 2); +} + +// ------------------------------ ZeroExtendVector + +template <class D2, class V> +HWY_API VFromD<D2> ZeroExtendVector(D2 d2, const V lo) { + return Combine(d2, Xor(lo, lo), lo); +} + +// ------------------------------ Lower/UpperHalf + +namespace detail { + +// RVV may only support LMUL >= SEW/64; returns whether that holds for D. Note +// that SEW = sizeof(T)*8 and LMUL = 1 << Pow2(). +template <class D> +constexpr bool IsSupportedLMUL(D d) { + return (size_t{1} << (Pow2(d) + 3)) >= sizeof(TFromD<D>); +} + +} // namespace detail + +// If IsSupportedLMUL, just 'truncate' i.e. halve LMUL. +template <class DH, hwy::EnableIf<detail::IsSupportedLMUL(DH())>* = nullptr> +HWY_API VFromD<DH> LowerHalf(const DH /* tag */, const VFromD<Twice<DH>> v) { + return detail::Trunc(v); +} + +// Otherwise, there is no corresponding intrinsic type (e.g. vuint64mf2_t), and +// the hardware may set "vill" if we attempt such an LMUL. However, the V +// extension on application processors requires Zvl128b, i.e. VLEN >= 128, so it +// still makes sense to have half of an SEW=64 vector. We instead just return +// the vector, and rely on the kPow2 in DH to halve the return value of Lanes(). +template <class DH, class V, + hwy::EnableIf<!detail::IsSupportedLMUL(DH())>* = nullptr> +HWY_API V LowerHalf(const DH /* tag */, const V v) { + return v; +} + +// Same, but without D arg +template <class V> +HWY_API VFromD<Half<DFromV<V>>> LowerHalf(const V v) { + return LowerHalf(Half<DFromV<V>>(), v); +} + +template <class DH> +HWY_API VFromD<DH> UpperHalf(const DH d2, const VFromD<Twice<DH>> v) { + return LowerHalf(d2, detail::SlideDown(v, v, Lanes(d2))); +} + +// ================================================== SWIZZLE + +namespace detail { +// Special instruction for 1 lane is presumably faster? +#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Up, slide1up_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Up, fslide1up_vf, _ALL) +HWY_RVV_FOREACH_UI3264(HWY_RVV_SLIDE1, Slide1Down, slide1down_vx, _ALL) +HWY_RVV_FOREACH_F3264(HWY_RVV_SLIDE1, Slide1Down, fslide1down_vf, _ALL) +#undef HWY_RVV_SLIDE1 +} // namespace detail + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f, _ALL) +#undef HWY_RVV_GET_LANE + +// ------------------------------ ExtractLane +template <class V> +HWY_API TFromV<V> ExtractLane(const V v, size_t i) { + return GetLane(detail::SlideDown(v, v, i)); +} + +// ------------------------------ InsertLane + +template <class V, HWY_IF_NOT_LANE_SIZE_V(V, 1)> +HWY_API V InsertLane(const V v, size_t i, TFromV<V> t) { + const DFromV<V> d; + const RebindToUnsigned<decltype(d)> du; // Iota0 is unsigned only + using TU = TFromD<decltype(du)>; + const auto is_i = detail::EqS(detail::Iota0(du), static_cast<TU>(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +namespace detail { +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetOnlyFirst, sof) +} // namespace detail + +// For 8-bit lanes, Iota0 might overflow. +template <class V, HWY_IF_LANE_SIZE_V(V, 1)> +HWY_API V InsertLane(const V v, size_t i, TFromV<V> t) { + const DFromV<V> d; + const auto zero = Zero(d); + const auto one = Set(d, 1); + const auto ge_i = Eq(detail::SlideUp(zero, one, i), one); + const auto is_i = detail::SetOnlyFirst(ge_i); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ OddEven +template <class V> +HWY_API V OddEven(const V a, const V b) { + const RebindToUnsigned<DFromV<V>> du; // Iota0 is unsigned only + const auto is_even = detail::EqS(detail::AndS(detail::Iota0(du), 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ DupEven (OddEven) +template <class V> +HWY_API V DupEven(const V v) { + const V up = detail::Slide1Up(v); + return OddEven(up, v); +} + +// ------------------------------ DupOdd (OddEven) +template <class V> +HWY_API V DupOdd(const V v) { + const V down = detail::Slide1Down(v); + return OddEven(v, down); +} + +// ------------------------------ OddEvenBlocks +template <class V> +HWY_API V OddEvenBlocks(const V a, const V b) { + const RebindToUnsigned<DFromV<V>> du; // Iota0 is unsigned only + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV<V>)); + const auto idx_block = ShiftRight<kShift>(detail::Iota0(du)); + const auto is_even = detail::EqS(detail::AndS(idx_block, 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ SwapAdjacentBlocks + +template <class V> +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV<V> d; + const size_t lpb = detail::LanesPerBlock(d); + const V down = detail::SlideDown(v, v, lpb); + const V up = detail::SlideUp(v, v, lpb); + return OddEvenBlocks(up, down); +} + +// ------------------------------ TableLookupLanes + +template <class D, class VI> +HWY_API VFromD<RebindToUnsigned<D>> IndicesFromVec(D d, VI vec) { + static_assert(sizeof(TFromD<D>) == sizeof(TFromV<VI>), "Index != lane"); + const RebindToUnsigned<decltype(d)> du; // instead of <D>: avoids unused d. + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllTrue(du, detail::LtS(indices, Lanes(d)))); +#endif + return indices; +} + +template <class D, typename TI> +HWY_API VFromD<RebindToUnsigned<D>> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind<TI, D>(), idx)); +} + +// <32bit are not part of Highway API, but used in Broadcast. This limits VLMAX +// to 2048! We could instead use vrgatherei16. +#define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, idx, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) +#undef HWY_RVV_TABLE + +// ------------------------------ ConcatOdd (TableLookupLanes) +template <class D, class V> +HWY_API V ConcatOdd(D d, const V hi, const V lo) { + const RebindToUnsigned<decltype(d)> du; // Iota0 is unsigned only + const auto iota = detail::Iota0(du); + const auto idx = detail::AddS(Add(iota, iota), 1); + const auto lo_odd = TableLookupLanes(lo, idx); + const auto hi_odd = TableLookupLanes(hi, idx); + return detail::SlideUp(lo_odd, hi_odd, Lanes(d) / 2); +} + +// ------------------------------ ConcatEven (TableLookupLanes) +template <class D, class V> +HWY_API V ConcatEven(D d, const V hi, const V lo) { + const RebindToUnsigned<decltype(d)> du; // Iota0 is unsigned only + const auto iota = detail::Iota0(du); + const auto idx = Add(iota, iota); + const auto lo_even = TableLookupLanes(lo, idx); + const auto hi_even = TableLookupLanes(hi, idx); + return detail::SlideUp(lo_even, hi_even, Lanes(d) / 2); +} + +// ------------------------------ Reverse (TableLookupLanes) +template <class D> +HWY_API VFromD<D> Reverse(D /* tag */, VFromD<D> v) { + const RebindToUnsigned<D> du; + using TU = TFromD<decltype(du)>; + const size_t N = Lanes(du); + const auto idx = + detail::ReverseSubS(detail::Iota0(du), static_cast<TU>(N - 1)); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse2 (RotateRight, OddEven) + +// Shifting and adding requires fewer instructions than blending, but casting to +// u32 only works for LMUL in [1/2, 8]. +template <class D, HWY_IF_LANE_SIZE_D(D, 2), HWY_RVV_IF_POW2_IN(D, -1, 3)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { + const Repartition<uint32_t, D> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} +// For LMUL < 1/2, we can extend and then truncate. +template <class D, HWY_IF_LANE_SIZE_D(D, 2), HWY_RVV_IF_POW2_IN(D, -3, -2)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { + const Twice<decltype(d)> d2; + const Twice<decltype(d2)> d4; + const Repartition<uint32_t, decltype(d4)> du32; + const auto vx = detail::Ext(d4, detail::Ext(d2, v)); + const auto rx = BitCast(d4, RotateRight<16>(BitCast(du32, vx))); + return detail::Trunc(detail::Trunc(rx)); +} + +// Shifting and adding requires fewer instructions than blending, but casting to +// u64 does not work for LMUL < 1. +template <class D, HWY_IF_LANE_SIZE_D(D, 4), HWY_RVV_IF_POW2_IN(D, 0, 3)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { + const Repartition<uint64_t, decltype(d)> du64; + return BitCast(d, RotateRight<32>(BitCast(du64, v))); +} + +// For fractions, we can extend and then truncate. +template <class D, HWY_IF_LANE_SIZE_D(D, 4), HWY_RVV_IF_POW2_IN(D, -2, -1)> +HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) { + const Twice<decltype(d)> d2; + const Twice<decltype(d2)> d4; + const Repartition<uint64_t, decltype(d4)> du64; + const auto vx = detail::Ext(d4, detail::Ext(d2, v)); + const auto rx = BitCast(d4, RotateRight<32>(BitCast(du64, vx))); + return detail::Trunc(detail::Trunc(rx)); +} + +template <class D, class V = VFromD<D>, HWY_IF_LANE_SIZE_D(D, 8)> +HWY_API V Reverse2(D /* tag */, const V v) { + const V up = detail::Slide1Up(v); + const V down = detail::Slide1Down(v); + return OddEven(up, down); +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template <class D> +HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) { + const RebindToUnsigned<D> du; + const auto idx = detail::XorS(detail::Iota0(du), 3); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template <class D> +HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) { + const RebindToUnsigned<D> du; + const auto idx = detail::XorS(detail::Iota0(du), 7); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template <class D, class V = VFromD<D>> +HWY_API V ReverseBlocks(D d, V v) { + const Repartition<uint64_t, D> du64; + const size_t N = Lanes(du64); + const auto rev = + detail::ReverseSubS(detail::Iota0(du64), static_cast<uint64_t>(N - 1)); + // Swap lo/hi u64 within each block + const auto idx = detail::XorS(rev, 1); + return BitCast(d, TableLookupLanes(BitCast(du64, v), idx)); +} + +// ------------------------------ Compress + +// RVV supports all lane types natively. +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template <typename T> +struct CompressIsPartition { + enum { value = 0 }; +}; + +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ + return v##OP##_vm_##CHAR##SEW##LMUL(v, v, mask, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_COMPRESS, Compress, compress, _ALL) +#undef HWY_RVV_COMPRESS + +// ------------------------------ CompressNot +template <class V, class M> +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +template <class V, class M> +HWY_API V CompressBlocksNot(V v, const M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template <class V, class M, class D> +HWY_API size_t CompressStore(const V v, const M mask, const D d, + TFromD<D>* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template <class V, class M, class D> +HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d, + TFromD<D>* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + detail::StoreN(count, Compress(v, mask), d, unaligned); + return count; +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes +template <size_t kBytes, class D, class V = VFromD<D>> +HWY_API V CombineShiftRightBytes(const D d, const V hi, V lo) { + const Repartition<uint8_t, decltype(d)> d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); + const auto hi_up = detail::SlideUp(hi8, hi8, 16 - kBytes); + const auto lo_down = detail::SlideDown(lo8, lo8, kBytes); + const auto is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +} + +// ------------------------------ CombineShiftRightLanes +template <size_t kLanes, class D, class V = VFromD<D>> +HWY_API V CombineShiftRightLanes(const D d, const V hi, V lo) { + constexpr size_t kLanesUp = 16 / sizeof(TFromV<V>) - kLanes; + const auto hi_up = detail::SlideUp(hi, hi, kLanesUp); + const auto lo_down = detail::SlideDown(lo, lo, kLanes); + const auto is_lo = detail::FirstNPerBlock<kLanesUp>(d); + return IfThenElse(is_lo, lo_down, hi_up); +} + +// ------------------------------ Shuffle2301 (ShiftLeft) +template <class V> +HWY_API V Shuffle2301(const V v) { + const DFromV<V> d; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + const Repartition<uint64_t, decltype(d)> du64; + const auto v64 = BitCast(du64, v); + return BitCast(d, Or(ShiftRight<32>(v64), ShiftLeft<32>(v64))); +} + +// ------------------------------ Shuffle2103 +template <class V> +HWY_API V Shuffle2103(const V v) { + const DFromV<V> d; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<3>(d, v, v); +} + +// ------------------------------ Shuffle0321 +template <class V> +HWY_API V Shuffle0321(const V v) { + const DFromV<V> d; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle1032 +template <class V> +HWY_API V Shuffle1032(const V v) { + const DFromV<V> d; + static_assert(sizeof(TFromD<decltype(d)>) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<2>(d, v, v); +} + +// ------------------------------ Shuffle01 +template <class V> +HWY_API V Shuffle01(const V v) { + const DFromV<V> d; + static_assert(sizeof(TFromD<decltype(d)>) == 8, "Defined for 64-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle0123 +template <class V> +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ TableLookupBytes + +// Extends or truncates a vector to match the given d. +namespace detail { + +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<Simd<T, N, kPow2 - 3>> v) + -> VFromD<decltype(d)> { + const Simd<T, N, kPow2 - 1> dh; + const Simd<T, N, kPow2 - 2> dhh; + return Ext(d, Ext(dh, Ext(dhh, v))); +} +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<Simd<T, N, kPow2 - 2>> v) + -> VFromD<decltype(d)> { + const Simd<T, N, kPow2 - 1> dh; + return Ext(d, Ext(dh, v)); +} +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<Simd<T, N, kPow2 - 1>> v) + -> VFromD<decltype(d)> { + return Ext(d, v); +} + +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<decltype(d)> v) + -> VFromD<decltype(d)> { + return v; +} + +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<Simd<T, N, kPow2 + 1>> v) + -> VFromD<decltype(d)> { + return Trunc(v); +} +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<Simd<T, N, kPow2 + 2>> v) + -> VFromD<decltype(d)> { + return Trunc(Trunc(v)); +} +template <typename T, size_t N, int kPow2> +HWY_INLINE auto ChangeLMUL(Simd<T, N, kPow2> d, VFromD<Simd<T, N, kPow2 + 3>> v) + -> VFromD<decltype(d)> { + return Trunc(Trunc(Trunc(v))); +} + +} // namespace detail + +template <class VT, class VI> +HWY_API VI TableLookupBytes(const VT vt, const VI vi) { + const DFromV<VT> dt; // T=table, I=index. + const DFromV<VI> di; + const Repartition<uint8_t, decltype(dt)> dt8; + const Repartition<uint8_t, decltype(di)> di8; + // Required for producing half-vectors with table lookups from a full vector. + // If we instead run at the LMUL of the index vector, lookups into the table + // would be truncated. Thus we run at the larger of the two LMULs and truncate + // the result vector to the original index LMUL. + constexpr int kPow2T = Pow2(dt8); + constexpr int kPow2I = Pow2(di8); + const Simd<uint8_t, MaxLanes(di8), HWY_MAX(kPow2T, kPow2I)> dm8; // m=max + const auto vmt = detail::ChangeLMUL(dm8, BitCast(dt8, vt)); + const auto vmi = detail::ChangeLMUL(dm8, BitCast(di8, vi)); + auto offsets = detail::OffsetsOf128BitBlocks(dm8, detail::Iota0(dm8)); + // If the table is shorter, wrap around offsets so they do not reference + // undefined lanes in the newly extended vmt. + if (kPow2T < kPow2I) { + offsets = detail::AndS(offsets, static_cast<uint8_t>(Lanes(dt8) - 1)); + } + const auto out = TableLookupLanes(vmt, Add(vmi, offsets)); + return BitCast(di, detail::ChangeLMUL(di8, out)); +} + +template <class VT, class VI> +HWY_API VI TableLookupBytesOr0(const VT vt, const VI idx) { + const DFromV<VI> di; + const Repartition<int8_t, decltype(di)> di8; + const auto idx8 = BitCast(di8, idx); + const auto lookup = TableLookupBytes(vt, idx8); + return BitCast(di, IfThenZeroElse(detail::LtS(idx8, 0), lookup)); +} + +// ------------------------------ Broadcast +template <int kLane, class V> +HWY_API V Broadcast(const V v) { + const DFromV<V> d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(d)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return TableLookupLanes(v, idx); +} + +// ------------------------------ ShiftLeftLanes + +template <size_t kLanes, class D, class V = VFromD<D>> +HWY_API V ShiftLeftLanes(const D d, const V v) { + const RebindToSigned<decltype(d)> di; + using TI = TFromD<decltype(di)>; + const auto shifted = detail::SlideUp(v, v, kLanes); + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(di)), + static_cast<TI>(detail::LanesPerBlock(di) - 1)); + const auto clear = detail::LtS(idx_mod, static_cast<TI>(kLanes)); + return IfThenZeroElse(clear, shifted); +} + +template <size_t kLanes, class V> +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes<kLanes>(DFromV<V>(), v); +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, class D> +HWY_API VFromD<D> ShiftLeftBytes(D d, const VFromD<D> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftLanes<kBytes>(BitCast(d8, v))); +} + +template <int kBytes, class V> +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes<kBytes>(DFromV<V>(), v); +} + +// ------------------------------ ShiftRightLanes +template <size_t kLanes, typename T, size_t N, int kPow2, + class V = VFromD<Simd<T, N, kPow2>>> +HWY_API V ShiftRightLanes(const Simd<T, N, kPow2> d, V v) { + const RebindToSigned<decltype(d)> di; + using TI = TFromD<decltype(di)>; + // For partial vectors, clear upper lanes so we shift in zeros. + if (N <= 16 / sizeof(T)) { + v = IfThenElseZero(FirstN(d, N), v); + } + + const auto shifted = detail::SlideDown(v, v, kLanes); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + const size_t lpb = detail::LanesPerBlock(di); + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(di)), static_cast<TI>(lpb - 1)); + const auto keep = detail::LtS(idx_mod, static_cast<TI>(lpb - kLanes)); + return IfThenElseZero(keep, shifted); +} + +// ------------------------------ ShiftRightBytes +template <int kBytes, class D, class V = VFromD<D>> +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightLanes<kBytes>(d8, BitCast(d8, v))); +} + +// ------------------------------ InterleaveLower + +template <class D, class V> +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch"); + const RebindToUnsigned<decltype(d)> du; + using TU = TFromD<decltype(du)>; + const auto i = detail::Iota0(du); + const auto idx_mod = ShiftRight<1>( + detail::AndS(i, static_cast<TU>(detail::LanesPerBlock(du) - 1))); + const auto idx = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +template <class V> +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV<V>(), a, b); +} + +// ------------------------------ InterleaveUpper + +template <class D, class V> +HWY_API V InterleaveUpper(const D d, const V a, const V b) { + static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch"); + const RebindToUnsigned<decltype(d)> du; + using TU = TFromD<decltype(du)>; + const size_t lpb = detail::LanesPerBlock(du); + const auto i = detail::Iota0(du); + const auto idx_mod = ShiftRight<1>(detail::AndS(i, static_cast<TU>(lpb - 1))); + const auto idx_lower = Add(idx_mod, detail::OffsetsOf128BitBlocks(d, i)); + const auto idx = detail::AddS(idx_lower, static_cast<TU>(lpb / 2)); + const auto is_even = detail::EqS(detail::AndS(i, 1), 0u); + return IfThenElse(is_even, TableLookupLanes(a, idx), + TableLookupLanes(b, idx)); +} + +// ------------------------------ ZipLower + +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow<DW> dn; + static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} + +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} + +// ------------------------------ ZipUpper +template <class DW, class V> +HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow<DW> dn; + static_assert(IsSame<TFromD<decltype(dn)>, TFromV<V>>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== REDUCE + +// vector = f(vector, zero_m1) +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <class D> \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(D d, HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, m1) v0) { \ + return Set(d, GetLane(v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v0, v, v0, Lanes(d)))); \ + } + +// ------------------------------ SumOfLanes + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL) +} // namespace detail + +template <class D> +HWY_API VFromD<D> SumOfLanes(D d, const VFromD<D> v) { + const auto v0 = Zero(ScalableTag<TFromD<D>>()); // always m1 + return detail::RedSum(d, v, v0); +} + +// ------------------------------ MinOfLanes +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL) +} // namespace detail + +template <class D> +HWY_API VFromD<D> MinOfLanes(D d, const VFromD<D> v) { + using T = TFromD<D>; + const ScalableTag<T> d1; // always m1 + const auto neutral = Set(d1, HighestValue<T>()); + return detail::RedMin(d, v, neutral); +} + +// ------------------------------ MaxOfLanes +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL) +} // namespace detail + +template <class D> +HWY_API VFromD<D> MaxOfLanes(D d, const VFromD<D> v) { + using T = TFromD<D>; + const ScalableTag<T> d1; // always m1 + const auto neutral = Set(d1, LowestValue<T>()); + return detail::RedMax(d, v, neutral); +} + +#undef HWY_RVV_REDUCE + +// ================================================== Ops with dependencies + +// ------------------------------ PopulationCount (ShiftRight) + +// Handles LMUL >= 2 or capped vectors, which generic_ops-inl cannot. +template <typename V, class D = DFromV<V>, HWY_IF_LANE_SIZE_D(D, 1), + hwy::EnableIf<Pow2(D()) < 1 || MaxLanes(D()) < 16>* = nullptr> +HWY_API V PopulationCount(V v) { + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + v = Sub(v, detail::AndS(ShiftRight<1>(v), 0x55)); + v = Add(detail::AndS(ShiftRight<2>(v), 0x33), detail::AndS(v, 0x33)); + return detail::AndS(Add(v, ShiftRight<4>(v)), 0x0F); +} + +// ------------------------------ LoadDup128 + +template <class D> +HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* const HWY_RESTRICT p) { + const VFromD<D> loaded = Load(d, p); + // idx must be unsigned for TableLookupLanes. + using TU = MakeUnsigned<TFromD<D>>; + const TU mask = static_cast<TU>(detail::LanesPerBlock(d) - 1); + // Broadcast the first block. + const VFromD<RebindToUnsigned<D>> idx = detail::AndS(detail::Iota0(d), mask); + return TableLookupLanes(loaded, idx); +} + +// ------------------------------ LoadMaskBits + +// Support all combinations of T and SHIFT(LMUL) without explicit overloads for +// each. First overload for MLEN=1..64. +namespace detail { + +// Maps D to MLEN (wrapped in SizeTag), such that #mask_bits = VLEN/MLEN. MLEN +// increases with lane size and decreases for increasing LMUL. Cap at 64, the +// largest supported by HWY_RVV_FOREACH_B (and intrinsics), for virtual LMUL +// e.g. vuint16mf8_t: (8*2 << 3) == 128. +template <class D> +using MaskTag = hwy::SizeTag<HWY_MIN( + 64, detail::ScaleByPower(8 * sizeof(TFromD<D>), -Pow2(D())))>; + +#define HWY_RVV_LOAD_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_INLINE HWY_RVV_M(MLEN) \ + NAME(hwy::SizeTag<MLEN> /* tag */, const uint8_t* bits, size_t N) { \ + return OP##_v_b##MLEN(bits, N); \ + } +HWY_RVV_FOREACH_B(HWY_RVV_LOAD_MASK_BITS, LoadMaskBits, vlm) +#undef HWY_RVV_LOAD_MASK_BITS +} // namespace detail + +template <class D, class MT = detail::MaskTag<D>> +HWY_API auto LoadMaskBits(D d, const uint8_t* bits) + -> decltype(detail::LoadMaskBits(MT(), bits, Lanes(d))) { + return detail::LoadMaskBits(MT(), bits, Lanes(d)); +} + +// ------------------------------ StoreMaskBits +#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + template <class D> \ + HWY_API size_t NAME(D d, HWY_RVV_M(MLEN) m, uint8_t* bits) { \ + const size_t N = Lanes(d); \ + OP##_v_b##MLEN(bits, m, N); \ + /* Non-full byte, need to clear the undefined upper bits. */ \ + /* Use MaxLanes and sizeof(T) to move some checks to compile-time. */ \ + constexpr bool kLessThan8 = \ + detail::ScaleByPower(16 / sizeof(TFromD<D>), Pow2(d)) < 8; \ + if (MaxLanes(d) < 8 || (kLessThan8 && N < 8)) { \ + const int mask = (1 << N) - 1; \ + bits[0] = static_cast<uint8_t>(bits[0] & mask); \ + } \ + return (N + 7) / 8; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, StoreMaskBits, vsm) +#undef HWY_RVV_STORE_MASK_BITS + +// ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits) + +template <class V> +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV<V>(), bits)); +} + +template <class D> +HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD<D>* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ FirstN (Iota0, Lt, RebindMask, SlideUp) + +// Disallow for 8-bit because Iota is likely to overflow. +template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 1)> +HWY_API MFromD<D> FirstN(const D d, const size_t n) { + const RebindToSigned<D> di; + using TI = TFromD<decltype(di)>; + return RebindMask( + d, detail::LtS(BitCast(di, detail::Iota0(d)), static_cast<TI>(n))); +} + +template <class D, HWY_IF_LANE_SIZE_D(D, 1)> +HWY_API MFromD<D> FirstN(const D d, const size_t n) { + const auto zero = Zero(d); + const auto one = Set(d, 1); + return Eq(detail::SlideUp(one, zero, n), one); +} + +// ------------------------------ Neg (Sub) + +template <class V, HWY_IF_SIGNED_V(V)> +HWY_API V Neg(const V v) { + return detail::ReverseSubS(v, 0); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v##OP##_vv_##CHAR##SEW##LMUL(v, v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + +// ------------------------------ Abs (Max, Neg) + +template <class V, HWY_IF_SIGNED_V(V)> +HWY_API V Abs(const V v) { + return Max(v, Neg(v)); +} + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx, _ALL) + +#undef HWY_RVV_RETV_ARGV2 + +// ------------------------------ AbsDiff (Abs, Sub) +template <class V> +HWY_API V AbsDiff(const V a, const V b) { + return Abs(Sub(a, b)); +} + +// ------------------------------ Round (NearestInt, ConvertTo, CopySign) + +// IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not have +// a dedicated instruction for that. Rounding to integer and converting back to +// float is correct except when the input magnitude is large, in which case the +// input was already an integer (because mantissa >> exponent is zero). + +namespace detail { +enum RoundingModes { kNear, kTrunc, kDown, kUp }; + +template <class V> +HWY_INLINE auto UseInt(const V v) -> decltype(MaskFromVec(v)) { + return detail::LtS(Abs(v), MantissaEnd<TFromV<V>>()); +} + +} // namespace detail + +template <class V> +HWY_API V Round(const V v) { + const DFromV<V> df; + + const auto integer = NearestInt(v); // round using current mode + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Trunc (ConvertTo) +template <class V> +HWY_API V Trunc(const V v) { + const DFromV<V> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Ceil +template <class V> +HWY_API V Ceil(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kUp)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floor +template <class V> +HWY_API V Floor(const V v) { + asm volatile("fsrm %0" ::"r"(detail::kDown)); + const auto ret = Round(v); + asm volatile("fsrm %0" ::"r"(detail::kNear)); + return ret; +} + +// ------------------------------ Floating-point classification (Ne) + +// vfclass does not help because it would require 3 instructions (to AND and +// then compare the bits), whereas these are just 1-3 integer instructions. + +template <class V> +HWY_API MFromD<DFromV<V>> IsNaN(const V v) { + return Ne(v, v); +} + +template <class V, class D = DFromV<V>> +HWY_API MFromD<D> IsInf(const V v) { + const D d; + const RebindToSigned<decltype(d)> di; + using T = TFromD<D>; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqS(Add(vi, vi), hwy::MaxExponentTimes2<T>())); +} + +// Returns whether normal/subnormal/zero. +template <class V, class D = DFromV<V>> +HWY_API MFromD<D> IsFinite(const V v) { + const D d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + using T = TFromD<D>; + const VFromD<decltype(du)> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtS(exp, hwy::MaxExponentField<T>())); +} + +// ------------------------------ Iota (ConvertTo) + +template <class D, HWY_IF_UNSIGNED_D(D)> +HWY_API VFromD<D> Iota(const D d, TFromD<D> first) { + return detail::AddS(detail::Iota0(d), first); +} + +template <class D, HWY_IF_SIGNED_D(D)> +HWY_API VFromD<D> Iota(const D d, TFromD<D> first) { + const RebindToUnsigned<D> du; + return detail::AddS(BitCast(d, detail::Iota0(du)), first); +} + +template <class D, HWY_IF_FLOAT_D(D)> +HWY_API VFromD<D> Iota(const D d, TFromD<D> first) { + const RebindToUnsigned<D> du; + const RebindToSigned<D> di; + return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), first); +} + +// ------------------------------ MulEven/Odd (Mul, OddEven) + +template <class V, HWY_IF_LANE_SIZE_V(V, 4), class D = DFromV<V>, + class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo)); +} + +// There is no 64x64 vwmul. +template <class V, HWY_IF_LANE_SIZE_V(V, 8)> +HWY_INLINE V MulEven(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return OddEven(detail::Slide1Up(hi), lo); +} + +template <class V, HWY_IF_LANE_SIZE_V(V, 8)> +HWY_INLINE V MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = detail::MulHigh(a, b); + return OddEven(hi, detail::Slide1Down(lo)); +} + +// ------------------------------ ReorderDemote2To (OddEven, Combine) + +template <size_t N, int kPow2> +HWY_API VFromD<Simd<uint16_t, N, kPow2>> ReorderDemote2To( + Simd<bfloat16_t, N, kPow2> dbf16, + VFromD<RepartitionToWide<decltype(dbf16)>> a, + VFromD<RepartitionToWide<decltype(dbf16)>> b) { + const RebindToUnsigned<decltype(dbf16)> du16; + const RebindToUnsigned<DFromV<decltype(a)>> du32; + const VFromD<decltype(du32)> b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template <size_t N, int kPow2, hwy::EnableIf<(kPow2 < 3)>* = nullptr, + class D32 = RepartitionToWide<Simd<int16_t, N, kPow2>>> +HWY_API VFromD<Simd<int16_t, N, kPow2>> ReorderDemote2To( + Simd<int16_t, N, kPow2> d16, VFromD<D32> a, VFromD<D32> b) { + const Twice<D32> d32t; + const VFromD<decltype(d32t)> ab = Combine(d32t, a, b); + return DemoteTo(d16, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template <size_t N, class V32 = VFromD<RepartitionToWide<Simd<int16_t, N, 3>>>> +HWY_API VFromD<Simd<int16_t, N, 3>> ReorderDemote2To(Simd<int16_t, N, 3> d16, + V32 a, V32 b) { + const Half<decltype(d16)> d16h; + const VFromD<decltype(d16h)> a16 = DemoteTo(d16h, a); + const VFromD<decltype(d16h)> b16 = DemoteTo(d16h, b); + return Combine(d16, a16, b16); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +namespace detail { + +// Non-overloaded wrapper function so we can define DF32 in template args. +template < + size_t N, int kPow2, class DF32 = Simd<float, N, kPow2>, + class VF32 = VFromD<DF32>, + class DU16 = RepartitionToNarrow<RebindToUnsigned<Simd<float, N, kPow2>>>> +HWY_API VF32 ReorderWidenMulAccumulateBF16(Simd<float, N, kPow2> df32, + VFromD<DU16> a, VFromD<DU16> b, + const VF32 sum0, VF32& sum1) { + const RebindToUnsigned<DF32> du32; + using VU32 = VFromD<decltype(du32)>; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Using shift/and instead of Zip leads to the odd/even order that + // RearrangeToOddPlusEven prefers. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +#define HWY_RVV_WIDEN_MACC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template <size_t N> \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEWD, LMULD) sum, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return OP##CHAR##SEWD##LMULD(sum, a, b, Lanes(d)); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_WIDEN_MACC, WidenMulAcc, vwmacc_vv_, _EXT_VIRT) +#undef HWY_RVV_WIDEN_MACC + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template <size_t N, int kPow2, hwy::EnableIf<(kPow2 < 3)>* = nullptr, + class D32 = Simd<int32_t, N, kPow2>, class V32 = VFromD<D32>, + class D16 = RepartitionToNarrow<D32>> +HWY_API VFromD<D32> ReorderWidenMulAccumulateI16(Simd<int32_t, N, kPow2> d32, + VFromD<D16> a, VFromD<D16> b, + const V32 sum0, V32& sum1) { + const Twice<decltype(d32)> d32t; + using V32T = VFromD<decltype(d32t)>; + V32T sum = Combine(d32t, sum1, sum0); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template <size_t N, class D32 = Simd<int32_t, N, 3>, class V32 = VFromD<D32>, + class D16 = RepartitionToNarrow<D32>> +HWY_API VFromD<D32> ReorderWidenMulAccumulateI16(Simd<int32_t, N, 3> d32, + VFromD<D16> a, VFromD<D16> b, + const V32 sum0, V32& sum1) { + const Half<D16> d16h; + using V16H = VFromD<decltype(d16h)>; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +} // namespace detail + +template <size_t N, int kPow2, class VN, class VW> +HWY_API VW ReorderWidenMulAccumulate(Simd<float, N, kPow2> d32, VN a, VN b, + const VW sum0, VW& sum1) { + return detail::ReorderWidenMulAccumulateBF16(d32, a, b, sum0, sum1); +} + +template <size_t N, int kPow2, class VN, class VW> +HWY_API VW ReorderWidenMulAccumulate(Simd<int32_t, N, kPow2> d32, VN a, VN b, + const VW sum0, VW& sum1) { + return detail::ReorderWidenMulAccumulateI16(d32, a, b, sum0, sum1); +} + +// ------------------------------ RearrangeToOddPlusEven + +template <class VW, HWY_IF_SIGNED_V(VW)> // vint32_t* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // vwmacc doubles LMUL, so we require a pairwise sum here. This op is + // expected to be less frequent than ReorderWidenMulAccumulate, hence it's + // preferable to do the extra work here rather than do manual odd/even + // extraction there. + const DFromV<VW> di32; + const RebindToUnsigned<decltype(di32)> du32; + const Twice<decltype(di32)> di32x2; + const RepartitionToWide<decltype(di32x2)> di64x2; + const RebindToUnsigned<decltype(di64x2)> du64x2; + const auto combined = BitCast(di64x2, Combine(di32x2, sum1, sum0)); + // Isolate odd/even int32 in int64 lanes. + const auto even = ShiftRight<32>(ShiftLeft<32>(combined)); // sign extend + const auto odd = ShiftRight<32>(combined); + return BitCast(di32, TruncateTo(du32, BitCast(du64x2, Add(even, odd)))); +} + +// For max LMUL, we cannot Combine again and instead manually unroll. +HWY_API vint32m8_t RearrangeToOddPlusEven(vint32m8_t sum0, vint32m8_t sum1) { + const DFromV<vint32m8_t> d; + const Half<decltype(d)> dh; + const vint32m4_t lo = + RearrangeToOddPlusEven(LowerHalf(sum0), UpperHalf(dh, sum0)); + const vint32m4_t hi = + RearrangeToOddPlusEven(LowerHalf(sum1), UpperHalf(dh, sum1)); + return Combine(d, hi, lo); +} + +template <class VW, HWY_IF_FLOAT_V(VW)> // vfloat* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); // invariant already holds +} + +// ------------------------------ Lt128 +template <class D> +HWY_INLINE MFromD<D> Lt128(D d, const VFromD<D> a, const VFromD<D> b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const VFromD<D> eqHL = VecFromMask(d, Eq(a, b)); + const VFromD<D> ltHL = VecFromMask(d, Lt(a, b)); + // Shift leftward so L can influence H. + const VFromD<D> ltLx = detail::Slide1Up(ltHL); + const VFromD<D> vecHx = OrAnd(ltHL, eqHL, ltLx); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); +} + +// ------------------------------ Lt128Upper +template <class D> +HWY_INLINE MFromD<D> Lt128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const VFromD<D> ltHL = VecFromMask(d, Lt(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(ltHL, detail::Slide1Down(ltHL))); +} + +// ------------------------------ Eq128 +template <class D> +HWY_INLINE MFromD<D> Eq128(D d, const VFromD<D> a, const VFromD<D> b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const VFromD<D> eqHL = VecFromMask(d, Eq(a, b)); + const VFromD<D> eqLH = Reverse2(d, eqHL); + return MaskFromVec(And(eqHL, eqLH)); +} + +// ------------------------------ Eq128Upper +template <class D> +HWY_INLINE MFromD<D> Eq128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const VFromD<D> eqHL = VecFromMask(d, Eq(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(eqHL, detail::Slide1Down(eqHL))); +} + +// ------------------------------ Ne128 +template <class D> +HWY_INLINE MFromD<D> Ne128(D d, const VFromD<D> a, const VFromD<D> b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const VFromD<D> neHL = VecFromMask(d, Ne(a, b)); + const VFromD<D> neLH = Reverse2(d, neHL); + return MaskFromVec(Or(neHL, neLH)); +} + +// ------------------------------ Ne128Upper +template <class D> +HWY_INLINE MFromD<D> Ne128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const VFromD<D> neHL = VecFromMask(d, Ne(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(neHL, detail::Slide1Down(neHL))); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template <class D> +HWY_INLINE VFromD<D> Min128(D /* tag */, const VFromD<D> a, const VFromD<D> b) { + const VFromD<D> aXH = detail::Slide1Down(a); + const VFromD<D> bXH = detail::Slide1Down(b); + const VFromD<D> minHL = Min(a, b); + const MFromD<D> ltXH = Lt(aXH, bXH); + const MFromD<D> eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD<D> lo = IfThenElse(ltXH, a, b); + // The upper lane is just minHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(minHL, IfThenElse(eqXH, minHL, lo)); +} + +template <class D> +HWY_INLINE VFromD<D> Max128(D /* tag */, const VFromD<D> a, const VFromD<D> b) { + const VFromD<D> aXH = detail::Slide1Down(a); + const VFromD<D> bXH = detail::Slide1Down(b); + const VFromD<D> maxHL = Max(a, b); + const MFromD<D> ltXH = Lt(aXH, bXH); + const MFromD<D> eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD<D> lo = IfThenElse(ltXH, b, a); + // The upper lane is just maxHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(maxHL, IfThenElse(eqXH, maxHL, lo)); +} + +template <class D> +HWY_INLINE VFromD<D> Min128Upper(D d, VFromD<D> a, VFromD<D> b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Max128Upper(D d, VFromD<D> a, VFromD<D> b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +namespace detail { // for code folding +#undef HWY_RVV_AVL +#undef HWY_RVV_D +#undef HWY_RVV_FOREACH +#undef HWY_RVV_FOREACH_08_ALL +#undef HWY_RVV_FOREACH_08_ALL_VIRT +#undef HWY_RVV_FOREACH_08_DEMOTE +#undef HWY_RVV_FOREACH_08_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_08_EXT +#undef HWY_RVV_FOREACH_08_EXT_VIRT +#undef HWY_RVV_FOREACH_08_TRUNC +#undef HWY_RVV_FOREACH_08_VIRT +#undef HWY_RVV_FOREACH_16_ALL +#undef HWY_RVV_FOREACH_16_ALL_VIRT +#undef HWY_RVV_FOREACH_16_DEMOTE +#undef HWY_RVV_FOREACH_16_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_16_EXT +#undef HWY_RVV_FOREACH_16_EXT_VIRT +#undef HWY_RVV_FOREACH_16_TRUNC +#undef HWY_RVV_FOREACH_16_VIRT +#undef HWY_RVV_FOREACH_32_ALL +#undef HWY_RVV_FOREACH_32_ALL_VIRT +#undef HWY_RVV_FOREACH_32_DEMOTE +#undef HWY_RVV_FOREACH_32_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_32_EXT +#undef HWY_RVV_FOREACH_32_EXT_VIRT +#undef HWY_RVV_FOREACH_32_TRUNC +#undef HWY_RVV_FOREACH_32_VIRT +#undef HWY_RVV_FOREACH_64_ALL +#undef HWY_RVV_FOREACH_64_ALL_VIRT +#undef HWY_RVV_FOREACH_64_DEMOTE +#undef HWY_RVV_FOREACH_64_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_64_EXT +#undef HWY_RVV_FOREACH_64_EXT_VIRT +#undef HWY_RVV_FOREACH_64_TRUNC +#undef HWY_RVV_FOREACH_64_VIRT +#undef HWY_RVV_FOREACH_B +#undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F16 +#undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F3264 +#undef HWY_RVV_FOREACH_F64 +#undef HWY_RVV_FOREACH_I +#undef HWY_RVV_FOREACH_I08 +#undef HWY_RVV_FOREACH_I16 +#undef HWY_RVV_FOREACH_I163264 +#undef HWY_RVV_FOREACH_I32 +#undef HWY_RVV_FOREACH_I64 +#undef HWY_RVV_FOREACH_U +#undef HWY_RVV_FOREACH_U08 +#undef HWY_RVV_FOREACH_U16 +#undef HWY_RVV_FOREACH_U163264 +#undef HWY_RVV_FOREACH_U32 +#undef HWY_RVV_FOREACH_U64 +#undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI08 +#undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI163264 +#undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI3264 +#undef HWY_RVV_FOREACH_UI64 +#undef HWY_RVV_M +#undef HWY_RVV_RETM_ARGM +#undef HWY_RVV_RETV_ARGV +#undef HWY_RVV_RETV_ARGVS +#undef HWY_RVV_RETV_ARGVV +#undef HWY_RVV_T +#undef HWY_RVV_V +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/scalar-inl.h b/third_party/highway/hwy/ops/scalar-inl.h new file mode 100644 index 0000000000..c28f7b510f --- /dev/null +++ b/third_party/highway/hwy/ops/scalar-inl.h @@ -0,0 +1,1626 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include <stddef.h> +#include <stdint.h> + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template <typename T> +using Sisd = Simd<T, 1, 0>; + +// (Wrapper class required for overloading comparison operators.) +template <typename T> +struct Vec1 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 1; // only for DFromV + + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template <typename T> +class Mask1 { + using Raw = hwy::MakeUnsigned<T>; + + public: + static HWY_INLINE Mask1<T> FromBool(bool b) { + Mask1<T> mask; + mask.bits = b ? static_cast<Raw>(~Raw{0}) : 0; + return mask; + } + + Raw bits; +}; + +template <class V> +using DFromV = Simd<typename V::PrivateT, V::kPrivateN, 0>; + +template <class V> +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template <typename T, typename FromT> +HWY_API Vec1<T> BitCast(Sisd<T> /* tag */, Vec1<FromT> v) { + static_assert(sizeof(T) <= sizeof(FromT), "Promoting is undefined"); + T to; + CopyBytes<sizeof(FromT)>(&v.raw, &to); // not same size - ok to shrink + return Vec1<T>(to); +} + +// ------------------------------ Set + +template <typename T> +HWY_API Vec1<T> Zero(Sisd<T> /* tag */) { + return Vec1<T>(T(0)); +} + +template <typename T, typename T2> +HWY_API Vec1<T> Set(Sisd<T> /* tag */, const T2 t) { + return Vec1<T>(static_cast<T>(t)); +} + +template <typename T> +HWY_API Vec1<T> Undefined(Sisd<T> d) { + return Zero(d); +} + +template <typename T, typename T2> +HWY_API Vec1<T> Iota(const Sisd<T> /* tag */, const T2 first) { + return Vec1<T>(static_cast<T>(first)); +} + +template <class D> +using VFromD = decltype(Zero(D())); + +// ================================================== LOGICAL + +// ------------------------------ Not + +template <typename T> +HWY_API Vec1<T> Not(const Vec1<T> v) { + using TU = MakeUnsigned<T>; + const Sisd<TU> du; + return BitCast(Sisd<T>(), Vec1<TU>(static_cast<TU>(~BitCast(du, v).raw))); +} + +// ------------------------------ And + +template <typename T> +HWY_API Vec1<T> And(const Vec1<T> a, const Vec1<T> b) { + using TU = MakeUnsigned<T>; + const Sisd<TU> du; + return BitCast(Sisd<T>(), Vec1<TU>(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template <typename T> +HWY_API Vec1<T> operator&(const Vec1<T> a, const Vec1<T> b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template <typename T> +HWY_API Vec1<T> AndNot(const Vec1<T> a, const Vec1<T> b) { + using TU = MakeUnsigned<T>; + const Sisd<TU> du; + return BitCast(Sisd<T>(), Vec1<TU>(static_cast<TU>(~BitCast(du, a).raw & + BitCast(du, b).raw))); +} + +// ------------------------------ Or + +template <typename T> +HWY_API Vec1<T> Or(const Vec1<T> a, const Vec1<T> b) { + using TU = MakeUnsigned<T>; + const Sisd<TU> du; + return BitCast(Sisd<T>(), Vec1<TU>(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template <typename T> +HWY_API Vec1<T> operator|(const Vec1<T> a, const Vec1<T> b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template <typename T> +HWY_API Vec1<T> Xor(const Vec1<T> a, const Vec1<T> b) { + using TU = MakeUnsigned<T>; + const Sisd<TU> du; + return BitCast(Sisd<T>(), Vec1<TU>(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template <typename T> +HWY_API Vec1<T> operator^(const Vec1<T> a, const Vec1<T> b) { + return Xor(a, b); +} + +// ------------------------------ Xor3 + +template <typename T> +HWY_API Vec1<T> Xor3(Vec1<T> x1, Vec1<T> x2, Vec1<T> x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template <typename T> +HWY_API Vec1<T> Or3(Vec1<T> o1, Vec1<T> o2, Vec1<T> o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template <typename T> +HWY_API Vec1<T> OrAnd(const Vec1<T> o, const Vec1<T> a1, const Vec1<T> a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template <typename T> +HWY_API Vec1<T> IfVecThenElse(Vec1<T> mask, Vec1<T> yes, Vec1<T> no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ CopySign + +template <typename T> +HWY_API Vec1<T> CopySign(const Vec1<T> magn, const Vec1<T> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + const auto msb = SignBit(Sisd<T>()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template <typename T> +HWY_API Vec1<T> CopySignToAbs(const Vec1<T> abs, const Vec1<T> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Sisd<T>()), sign)); +} + +// ------------------------------ BroadcastSignBit + +template <typename T> +HWY_API Vec1<T> BroadcastSignBit(const Vec1<T> v) { + // This is used inside ShiftRight, so we cannot implement in terms of it. + return v.raw < 0 ? Vec1<T>(T(-1)) : Vec1<T>(0); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template <typename T> +HWY_API Vec1<T> PopulationCount(Vec1<T> v) { + return Vec1<T>(static_cast<T>(PopCount(v.raw))); +} + +// ------------------------------ Mask + +template <typename TFrom, typename TTo> +HWY_API Mask1<TTo> RebindMask(Sisd<TTo> /*tag*/, Mask1<TFrom> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1<TTo>{m.bits}; +} + +// v must be 0 or FF..FF. +template <typename T> +HWY_API Mask1<T> MaskFromVec(const Vec1<T> v) { + Mask1<T> mask; + CopySameSize(&v, &mask); + return mask; +} + +template <typename T> +Vec1<T> VecFromMask(const Mask1<T> mask) { + Vec1<T> v; + CopySameSize(&mask, &v); + return v; +} + +template <typename T> +Vec1<T> VecFromMask(Sisd<T> /* tag */, const Mask1<T> mask) { + Vec1<T> v; + CopySameSize(&mask, &v); + return v; +} + +template <typename T> +HWY_API Mask1<T> FirstN(Sisd<T> /*tag*/, size_t n) { + return Mask1<T>::FromBool(n != 0); +} + +// Returns mask ? yes : no. +template <typename T> +HWY_API Vec1<T> IfThenElse(const Mask1<T> mask, const Vec1<T> yes, + const Vec1<T> no) { + return mask.bits ? yes : no; +} + +template <typename T> +HWY_API Vec1<T> IfThenElseZero(const Mask1<T> mask, const Vec1<T> yes) { + return mask.bits ? yes : Vec1<T>(0); +} + +template <typename T> +HWY_API Vec1<T> IfThenZeroElse(const Mask1<T> mask, const Vec1<T> no) { + return mask.bits ? Vec1<T>(0) : no; +} + +template <typename T> +HWY_API Vec1<T> IfNegativeThenElse(Vec1<T> v, Vec1<T> yes, Vec1<T> no) { + return v.raw < 0 ? yes : no; +} + +template <typename T> +HWY_API Vec1<T> ZeroIfNegative(const Vec1<T> v) { + return v.raw < 0 ? Vec1<T>(0) : v; +} + +// ------------------------------ Mask logical + +template <typename T> +HWY_API Mask1<T> Not(const Mask1<T> m) { + return MaskFromVec(Not(VecFromMask(Sisd<T>(), m))); +} + +template <typename T> +HWY_API Mask1<T> And(const Mask1<T> a, Mask1<T> b) { + const Sisd<T> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask1<T> AndNot(const Mask1<T> a, Mask1<T> b) { + const Sisd<T> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask1<T> Or(const Mask1<T> a, Mask1<T> b) { + const Sisd<T> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask1<T> Xor(const Mask1<T> a, Mask1<T> b) { + const Sisd<T> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask1<T> ExclusiveNeither(const Mask1<T> a, Mask1<T> b) { + const Sisd<T> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template <int kBits, typename T> +HWY_API Vec1<T> ShiftLeft(const Vec1<T> v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1<T>( + static_cast<T>(static_cast<hwy::MakeUnsigned<T>>(v.raw) << kBits)); +} + +template <int kBits, typename T> +HWY_API Vec1<T> ShiftRight(const Vec1<T> v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1<T>(static_cast<T>(v.raw >> kBits)); +#else + if (IsSigned<T>()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned<T>; + const Sisd<TU> du; + const TU shifted = static_cast<TU>(BitCast(du, v).raw >> kBits); + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast<size_t>(static_cast<int>(sizeof(TU)) * 8 - 1 - kBits); + const TU upper = static_cast<TU>(sign << sign_shift); + return BitCast(Sisd<T>(), Vec1<TU>(shifted | upper)); + } else { // T is unsigned + return Vec1<T>(static_cast<T>(v.raw >> kBits)); + } +#endif +} + +// ------------------------------ RotateRight (ShiftRight) + +namespace detail { + +// For partial specialization: kBits == 0 results in an invalid shift count +template <int kBits> +struct RotateRight { + template <typename T> + HWY_INLINE Vec1<T> operator()(const Vec1<T> v) const { + return Or(ShiftRight<kBits>(v), ShiftLeft<sizeof(T) * 8 - kBits>(v)); + } +}; + +template <> +struct RotateRight<0> { + template <typename T> + HWY_INLINE Vec1<T> operator()(const Vec1<T> v) const { + return v; + } +}; + +} // namespace detail + +template <int kBits, typename T> +HWY_API Vec1<T> RotateRight(const Vec1<T> v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return detail::RotateRight<kBits>()(v); +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template <typename T> +HWY_API Vec1<T> ShiftLeftSame(const Vec1<T> v, int bits) { + return Vec1<T>( + static_cast<T>(static_cast<hwy::MakeUnsigned<T>>(v.raw) << bits)); +} + +template <typename T> +HWY_API Vec1<T> ShiftRightSame(const Vec1<T> v, int bits) { +#if __cplusplus >= 202002L + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + return Vec1<T>(static_cast<T>(v.raw >> bits)); +#else + if (IsSigned<T>()) { + // Emulate arithmetic shift using only logical (unsigned) shifts, because + // signed shifts are still implementation-defined. + using TU = hwy::MakeUnsigned<T>; + const Sisd<TU> du; + const TU shifted = static_cast<TU>(BitCast(du, v).raw >> bits); + const TU sign = BitCast(du, BroadcastSignBit(v)).raw; + const size_t sign_shift = + static_cast<size_t>(static_cast<int>(sizeof(TU)) * 8 - 1 - bits); + const TU upper = static_cast<TU>(sign << sign_shift); + return BitCast(Sisd<T>(), Vec1<TU>(shifted | upper)); + } else { // T is unsigned + return Vec1<T>(static_cast<T>(v.raw >> bits)); + } +#endif +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template <typename T> +HWY_API Vec1<T> operator<<(const Vec1<T> v, const Vec1<T> bits) { + return ShiftLeftSame(v, static_cast<int>(bits.raw)); +} + +template <typename T> +HWY_API Vec1<T> operator>>(const Vec1<T> v, const Vec1<T> bits) { + return ShiftRightSame(v, static_cast<int>(bits.raw)); +} + +// ================================================== ARITHMETIC + +template <typename T> +HWY_API Vec1<T> operator+(Vec1<T> a, Vec1<T> b) { + const uint64_t a64 = static_cast<uint64_t>(a.raw); + const uint64_t b64 = static_cast<uint64_t>(b.raw); + return Vec1<T>(static_cast<T>((a64 + b64) & static_cast<uint64_t>(~T(0)))); +} +HWY_API Vec1<float> operator+(const Vec1<float> a, const Vec1<float> b) { + return Vec1<float>(a.raw + b.raw); +} +HWY_API Vec1<double> operator+(const Vec1<double> a, const Vec1<double> b) { + return Vec1<double>(a.raw + b.raw); +} + +template <typename T> +HWY_API Vec1<T> operator-(Vec1<T> a, Vec1<T> b) { + const uint64_t a64 = static_cast<uint64_t>(a.raw); + const uint64_t b64 = static_cast<uint64_t>(b.raw); + return Vec1<T>(static_cast<T>((a64 - b64) & static_cast<uint64_t>(~T(0)))); +} +HWY_API Vec1<float> operator-(const Vec1<float> a, const Vec1<float> b) { + return Vec1<float>(a.raw - b.raw); +} +HWY_API Vec1<double> operator-(const Vec1<double> a, const Vec1<double> b) { + return Vec1<double>(a.raw - b.raw); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec1<uint64_t> SumsOf8(const Vec1<uint8_t> v) { + return Vec1<uint64_t>(v.raw); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec1<uint8_t> SaturatedAdd(const Vec1<uint8_t> a, + const Vec1<uint8_t> b) { + return Vec1<uint8_t>( + static_cast<uint8_t>(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_API Vec1<uint16_t> SaturatedAdd(const Vec1<uint16_t> a, + const Vec1<uint16_t> b) { + return Vec1<uint16_t>( + static_cast<uint16_t>(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 65535))); +} + +// Signed +HWY_API Vec1<int8_t> SaturatedAdd(const Vec1<int8_t> a, const Vec1<int8_t> b) { + return Vec1<int8_t>( + static_cast<int8_t>(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_API Vec1<int16_t> SaturatedAdd(const Vec1<int16_t> a, + const Vec1<int16_t> b) { + return Vec1<int16_t>( + static_cast<int16_t>(HWY_MIN(HWY_MAX(-32768, a.raw + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec1<uint8_t> SaturatedSub(const Vec1<uint8_t> a, + const Vec1<uint8_t> b) { + return Vec1<uint8_t>( + static_cast<uint8_t>(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_API Vec1<uint16_t> SaturatedSub(const Vec1<uint16_t> a, + const Vec1<uint16_t> b) { + return Vec1<uint16_t>( + static_cast<uint16_t>(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 65535))); +} + +// Signed +HWY_API Vec1<int8_t> SaturatedSub(const Vec1<int8_t> a, const Vec1<int8_t> b) { + return Vec1<int8_t>( + static_cast<int8_t>(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_API Vec1<int16_t> SaturatedSub(const Vec1<int16_t> a, + const Vec1<int16_t> b) { + return Vec1<int16_t>( + static_cast<int16_t>(HWY_MIN(HWY_MAX(-32768, a.raw - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +HWY_API Vec1<uint8_t> AverageRound(const Vec1<uint8_t> a, + const Vec1<uint8_t> b) { + return Vec1<uint8_t>(static_cast<uint8_t>((a.raw + b.raw + 1) / 2)); +} +HWY_API Vec1<uint16_t> AverageRound(const Vec1<uint16_t> a, + const Vec1<uint16_t> b) { + return Vec1<uint16_t>(static_cast<uint16_t>((a.raw + b.raw + 1) / 2)); +} + +// ------------------------------ Absolute value + +template <typename T> +HWY_API Vec1<T> Abs(const Vec1<T> a) { + const T i = a.raw; + if (i >= 0 || i == hwy::LimitsMin<T>()) return a; + return Vec1<T>(static_cast<T>(-i & T{-1})); +} +HWY_API Vec1<float> Abs(Vec1<float> a) { + int32_t i; + CopyBytes<sizeof(i)>(&a.raw, &i); + i &= 0x7FFFFFFF; + CopyBytes<sizeof(i)>(&i, &a.raw); + return a; +} +HWY_API Vec1<double> Abs(Vec1<double> a) { + int64_t i; + CopyBytes<sizeof(i)>(&a.raw, &i); + i &= 0x7FFFFFFFFFFFFFFFL; + CopyBytes<sizeof(i)>(&i, &a.raw); + return a; +} + +// ------------------------------ Min/Max + +// <cmath> may be unavailable, so implement our own. +namespace detail { + +static inline float Abs(float f) { + uint32_t i; + CopyBytes<4>(&f, &i); + i &= 0x7FFFFFFFu; + CopyBytes<4>(&i, &f); + return f; +} +static inline double Abs(double f) { + uint64_t i; + CopyBytes<8>(&f, &i); + i &= 0x7FFFFFFFFFFFFFFFull; + CopyBytes<8>(&i, &f); + return f; +} + +static inline bool SignBit(float f) { + uint32_t i; + CopyBytes<4>(&f, &i); + return (i >> 31) != 0; +} +static inline bool SignBit(double f) { + uint64_t i; + CopyBytes<8>(&f, &i); + return (i >> 63) != 0; +} + +} // namespace detail + +template <typename T, HWY_IF_NOT_FLOAT(T)> +HWY_API Vec1<T> Min(const Vec1<T> a, const Vec1<T> b) { + return Vec1<T>(HWY_MIN(a.raw, b.raw)); +} + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec1<T> Min(const Vec1<T> a, const Vec1<T> b) { + if (isnan(a.raw)) return b; + if (isnan(b.raw)) return a; + return Vec1<T>(HWY_MIN(a.raw, b.raw)); +} + +template <typename T, HWY_IF_NOT_FLOAT(T)> +HWY_API Vec1<T> Max(const Vec1<T> a, const Vec1<T> b) { + return Vec1<T>(HWY_MAX(a.raw, b.raw)); +} + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec1<T> Max(const Vec1<T> a, const Vec1<T> b) { + if (isnan(a.raw)) return b; + if (isnan(b.raw)) return a; + return Vec1<T>(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec1<T> Neg(const Vec1<T> v) { + return Xor(v, SignBit(Sisd<T>())); +} + +template <typename T, HWY_IF_NOT_FLOAT(T)> +HWY_API Vec1<T> Neg(const Vec1<T> v) { + return Zero(Sisd<T>()) - v; +} + +// ------------------------------ mul/div + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec1<T> operator*(const Vec1<T> a, const Vec1<T> b) { + return Vec1<T>(static_cast<T>(double{a.raw} * b.raw)); +} + +template <typename T, HWY_IF_SIGNED(T)> +HWY_API Vec1<T> operator*(const Vec1<T> a, const Vec1<T> b) { + return Vec1<T>(static_cast<T>(static_cast<uint64_t>(a.raw) * + static_cast<uint64_t>(b.raw))); +} + +template <typename T, HWY_IF_UNSIGNED(T)> +HWY_API Vec1<T> operator*(const Vec1<T> a, const Vec1<T> b) { + return Vec1<T>(static_cast<T>(static_cast<uint64_t>(a.raw) * + static_cast<uint64_t>(b.raw))); +} + +template <typename T> +HWY_API Vec1<T> operator/(const Vec1<T> a, const Vec1<T> b) { + return Vec1<T>(a.raw / b.raw); +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec1<int16_t> MulHigh(const Vec1<int16_t> a, const Vec1<int16_t> b) { + return Vec1<int16_t>(static_cast<int16_t>((a.raw * b.raw) >> 16)); +} +HWY_API Vec1<uint16_t> MulHigh(const Vec1<uint16_t> a, const Vec1<uint16_t> b) { + // Cast to uint32_t first to prevent overflow. Otherwise the result of + // uint16_t * uint16_t is in "int" which may overflow. In practice the result + // is the same but this way it is also defined. + return Vec1<uint16_t>(static_cast<uint16_t>( + (static_cast<uint32_t>(a.raw) * static_cast<uint32_t>(b.raw)) >> 16)); +} + +HWY_API Vec1<int16_t> MulFixedPoint15(Vec1<int16_t> a, Vec1<int16_t> b) { + return Vec1<int16_t>(static_cast<int16_t>((2 * a.raw * b.raw + 32768) >> 16)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +HWY_API Vec1<int64_t> MulEven(const Vec1<int32_t> a, const Vec1<int32_t> b) { + const int64_t a64 = a.raw; + return Vec1<int64_t>(a64 * b.raw); +} +HWY_API Vec1<uint64_t> MulEven(const Vec1<uint32_t> a, const Vec1<uint32_t> b) { + const uint64_t a64 = a.raw; + return Vec1<uint64_t>(a64 * b.raw); +} + +// Approximate reciprocal +HWY_API Vec1<float> ApproximateReciprocal(const Vec1<float> v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1<float>(0.0f); + return Vec1<float>(1.0f / v.raw); +} + +// Absolute value of difference. +HWY_API Vec1<float> AbsDiff(const Vec1<float> a, const Vec1<float> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template <typename T> +HWY_API Vec1<T> MulAdd(const Vec1<T> mul, const Vec1<T> x, const Vec1<T> add) { + return mul * x + add; +} + +template <typename T> +HWY_API Vec1<T> NegMulAdd(const Vec1<T> mul, const Vec1<T> x, + const Vec1<T> add) { + return add - mul * x; +} + +template <typename T> +HWY_API Vec1<T> MulSub(const Vec1<T> mul, const Vec1<T> x, const Vec1<T> sub) { + return mul * x - sub; +} + +template <typename T> +HWY_API Vec1<T> NegMulSub(const Vec1<T> mul, const Vec1<T> x, + const Vec1<T> sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_API Vec1<float> ApproximateReciprocalSqrt(const Vec1<float> v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopySameSize(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &f); + // One Newton-Raphson iteration + return Vec1<float>(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_API Vec1<float> Sqrt(const Vec1<float> v) { +#if HWY_COMPILER_GCC && defined(HWY_NO_LIBCXX) + return Vec1<float>(__builtin_sqrt(v.raw)); +#else + return Vec1<float>(sqrtf(v.raw)); +#endif +} +HWY_API Vec1<double> Sqrt(const Vec1<double> v) { +#if HWY_COMPILER_GCC && defined(HWY_NO_LIBCXX) + return Vec1<float>(__builtin_sqrt(v.raw)); +#else + return Vec1<double>(sqrt(v.raw)); +#endif +} + +// ------------------------------ Floating-point rounding + +template <typename T> +HWY_API Vec1<T> Round(const Vec1<T> v) { + using TI = MakeSigned<T>; + if (!(Abs(v).raw < MantissaEnd<T>())) { // Huge or NaN + return v; + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast<TI>(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1<T>(0), v); + // Round to even + if ((rounded & 1) && detail::Abs(static_cast<T>(rounded) - v.raw) == T(0.5)) { + return Vec1<T>(static_cast<T>(rounded - (v.raw < T(0) ? -1 : 1))); + } + return Vec1<T>(static_cast<T>(rounded)); +} + +// Round-to-nearest even. +HWY_API Vec1<int32_t> NearestInt(const Vec1<float> v) { + using T = float; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool is_sign = detail::SignBit(v.raw); + + if (!(abs < MantissaEnd<T>())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= static_cast<T>(LimitsMax<TI>()))) { + return Vec1<TI>(is_sign ? LimitsMin<TI>() : LimitsMax<TI>()); + } + return Vec1<int32_t>(static_cast<TI>(v.raw)); + } + const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); + const TI rounded = static_cast<TI>(v.raw + bias); + if (rounded == 0) return Vec1<int32_t>(0); + // Round to even + if ((rounded & 1) && detail::Abs(static_cast<T>(rounded) - v.raw) == T(0.5)) { + return Vec1<TI>(rounded - (is_sign ? -1 : 1)); + } + return Vec1<TI>(rounded); +} + +template <typename T> +HWY_API Vec1<T> Trunc(const Vec1<T> v) { + using TI = MakeSigned<T>; + if (!(Abs(v).raw <= MantissaEnd<T>())) { // Huge or NaN + return v; + } + const TI truncated = static_cast<TI>(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1<T>(0), v); + return Vec1<T>(static_cast<T>(truncated)); +} + +template <typename Float, typename Bits, int kMantissaBits, int kExponentBits, + class V> +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast<int>(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +template <typename Float, typename Bits, int kMantissaBits, int kExponentBits, + class V> +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast<int>(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_API Vec1<float> Ceil(const Vec1<float> v) { + return Ceiling<float, uint32_t, 23, 8>(v); +} +HWY_API Vec1<double> Ceil(const Vec1<double> v) { + return Ceiling<double, uint64_t, 52, 11>(v); +} + +// Toward -infinity, aka floor +HWY_API Vec1<float> Floor(const Vec1<float> v) { + return Floor<float, uint32_t, 23, 8>(v); +} +HWY_API Vec1<double> Floor(const Vec1<double> v) { + return Floor<double, uint64_t, 52, 11>(v); +} + +// ================================================== COMPARE + +template <typename T> +HWY_API Mask1<T> operator==(const Vec1<T> a, const Vec1<T> b) { + return Mask1<T>::FromBool(a.raw == b.raw); +} + +template <typename T> +HWY_API Mask1<T> operator!=(const Vec1<T> a, const Vec1<T> b) { + return Mask1<T>::FromBool(a.raw != b.raw); +} + +template <typename T> +HWY_API Mask1<T> TestBit(const Vec1<T> v, const Vec1<T> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template <typename T> +HWY_API Mask1<T> operator<(const Vec1<T> a, const Vec1<T> b) { + return Mask1<T>::FromBool(a.raw < b.raw); +} +template <typename T> +HWY_API Mask1<T> operator>(const Vec1<T> a, const Vec1<T> b) { + return Mask1<T>::FromBool(a.raw > b.raw); +} + +template <typename T> +HWY_API Mask1<T> operator<=(const Vec1<T> a, const Vec1<T> b) { + return Mask1<T>::FromBool(a.raw <= b.raw); +} +template <typename T> +HWY_API Mask1<T> operator>=(const Vec1<T> a, const Vec1<T> b) { + return Mask1<T>::FromBool(a.raw >= b.raw); +} + +// ------------------------------ Floating-point classification (==) + +template <typename T> +HWY_API Mask1<T> IsNaN(const Vec1<T> v) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + MakeUnsigned<T> bits; + CopySameSize(&v, &bits); + bits += bits; + bits >>= 1; // clear sign bit + // NaN if all exponent bits are set and the mantissa is not zero. + return Mask1<T>::FromBool(bits > ExponentMask<T>()); +} + +HWY_API Mask1<float> IsInf(const Vec1<float> v) { + const Sisd<float> d; + const RebindToUnsigned<decltype(d)> du; + const Vec1<uint32_t> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFF000000u)); +} +HWY_API Mask1<double> IsInf(const Vec1<double> v) { + const Sisd<double> d; + const RebindToUnsigned<decltype(d)> du; + const Vec1<uint64_t> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFFE0000000000000ull)); +} + +HWY_API Mask1<float> IsFinite(const Vec1<float> v) { + const Vec1<uint32_t> vu = BitCast(Sisd<uint32_t>(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1<float>::FromBool((vu.raw << 1) < 0xFF000000u); +} +HWY_API Mask1<double> IsFinite(const Vec1<double> v) { + const Vec1<uint64_t> vu = BitCast(Sisd<uint64_t>(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1<double>::FromBool((vu.raw << 1) < 0xFFE0000000000000ull); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template <typename T> +HWY_API Vec1<T> Load(Sisd<T> /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopySameSize(aligned, &t); + return Vec1<T>(t); +} + +template <typename T> +HWY_API Vec1<T> MaskedLoad(Mask1<T> m, Sisd<T> d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template <typename T> +HWY_API Vec1<T> LoadU(Sisd<T> d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template <typename T> +HWY_API Vec1<T> LoadDup128(Sisd<T> d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +// ------------------------------ Store + +template <typename T> +HWY_API void Store(const Vec1<T> v, Sisd<T> /* tag */, + T* HWY_RESTRICT aligned) { + CopySameSize(&v.raw, aligned); +} + +template <typename T> +HWY_API void StoreU(const Vec1<T> v, Sisd<T> d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +template <typename T> +HWY_API void BlendedStore(const Vec1<T> v, Mask1<T> m, Sisd<T> d, + T* HWY_RESTRICT p) { + if (!m.bits) return; + StoreU(v, d, p); +} + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template <typename T> +HWY_API void LoadInterleaved2(Sisd<T> d, const T* HWY_RESTRICT unaligned, + Vec1<T>& v0, Vec1<T>& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +template <typename T> +HWY_API void LoadInterleaved3(Sisd<T> d, const T* HWY_RESTRICT unaligned, + Vec1<T>& v0, Vec1<T>& v1, Vec1<T>& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +template <typename T> +HWY_API void LoadInterleaved4(Sisd<T> d, const T* HWY_RESTRICT unaligned, + Vec1<T>& v0, Vec1<T>& v1, Vec1<T>& v2, + Vec1<T>& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template <typename T> +HWY_API void StoreInterleaved2(const Vec1<T> v0, const Vec1<T> v1, Sisd<T> d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); +} + +template <typename T> +HWY_API void StoreInterleaved3(const Vec1<T> v0, const Vec1<T> v1, + const Vec1<T> v2, Sisd<T> d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +template <typename T> +HWY_API void StoreInterleaved4(const Vec1<T> v0, const Vec1<T> v1, + const Vec1<T> v2, const Vec1<T> v3, Sisd<T> d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template <typename T> +HWY_API void Stream(const Vec1<T> v, Sisd<T> d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +template <typename T, typename Offset> +HWY_API void ScatterOffset(Vec1<T> v, Sisd<T> d, T* base, + const Vec1<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + uint8_t* const base8 = reinterpret_cast<uint8_t*>(base) + offset.raw; + return Store(v, d, reinterpret_cast<T*>(base8)); +} + +template <typename T, typename Index> +HWY_API void ScatterIndex(Vec1<T> v, Sisd<T> d, T* HWY_RESTRICT base, + const Vec1<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +template <typename T, typename Offset> +HWY_API Vec1<T> GatherOffset(Sisd<T> d, const T* base, + const Vec1<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + const intptr_t addr = + reinterpret_cast<intptr_t>(base) + static_cast<intptr_t>(offset.raw); + return Load(d, reinterpret_cast<const T*>(addr)); +} + +template <typename T, typename Index> +HWY_API Vec1<T> GatherIndex(Sisd<T> d, const T* HWY_RESTRICT base, + const Vec1<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return Load(d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +template <typename FromT, typename ToT> +HWY_API Vec1<ToT> PromoteTo(Sisd<ToT> /* tag */, Vec1<FromT> from) { + static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1<ToT>(static_cast<ToT>(from.raw)); +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, +// so we overload for FromT=double and ToT={float,int32_t}. +HWY_API Vec1<float> DemoteTo(Sisd<float> /* tag */, Vec1<double> from) { + // Prevent ubsan errors when converting float to narrower integer/float + if (IsInf(from).bits || + Abs(from).raw > static_cast<double>(HighestValue<float>())) { + return Vec1<float>(detail::SignBit(from.raw) ? LowestValue<float>() + : HighestValue<float>()); + } + return Vec1<float>(static_cast<float>(from.raw)); +} +HWY_API Vec1<int32_t> DemoteTo(Sisd<int32_t> /* tag */, Vec1<double> from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + if (IsInf(from).bits || + Abs(from).raw > static_cast<double>(HighestValue<int32_t>())) { + return Vec1<int32_t>(detail::SignBit(from.raw) ? LowestValue<int32_t>() + : HighestValue<int32_t>()); + } + return Vec1<int32_t>(static_cast<int32_t>(from.raw)); +} + +template <typename FromT, typename ToT> +HWY_API Vec1<ToT> DemoteTo(Sisd<ToT> /* tag */, Vec1<FromT> from) { + static_assert(!IsFloat<FromT>(), "FromT=double are handled above"); + static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); + + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw = HWY_MIN(HWY_MAX(LimitsMin<ToT>(), from.raw), LimitsMax<ToT>()); + return Vec1<ToT>(static_cast<ToT>(from.raw)); +} + +HWY_API Vec1<float> PromoteTo(Sisd<float> /* tag */, const Vec1<float16_t> v) { + uint16_t bits16; + CopySameSize(&v.raw, &bits16); + const uint32_t sign = static_cast<uint32_t>(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast<float>(mantissa) * (1.0f / 1024)); + return Vec1<float>(sign ? -subnormal : subnormal); + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + float out; + CopySameSize(&bits32, &out); + return Vec1<float>(out); +} + +HWY_API Vec1<float> PromoteTo(Sisd<float> d, const Vec1<bfloat16_t> v) { + return Set(d, F32FromBF16(v.raw)); +} + +HWY_API Vec1<float16_t> DemoteTo(Sisd<float16_t> /* tag */, + const Vec1<float> v) { + uint32_t bits32; + CopySameSize(&v.raw, &bits32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = HWY_MIN(static_cast<int32_t>(biased_exp32) - 127, 15); + + // Tiny or zero => zero. + Vec1<float16_t> out; + if (exp < -24) { + const uint16_t zero = 0; + CopySameSize(&zero, &out.raw); + return out; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (exp < -14) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast<uint32_t>(-14 - exp); + HWY_DASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = static_cast<uint32_t>((1u << (10 - sub_exp)) + + (mantissa32 >> (13 + sub_exp))); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast<uint32_t>(exp + 15); + HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + HWY_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast<uint16_t>(bits16); // big-endian safe + CopySameSize(&narrowed, &out.raw); + return out; +} + +HWY_API Vec1<bfloat16_t> DemoteTo(Sisd<bfloat16_t> d, const Vec1<float> v) { + return Set(d, BF16FromF32(v.raw)); +} + +template <typename FromT, typename ToT, HWY_IF_FLOAT(FromT)> +HWY_API Vec1<ToT> ConvertTo(Sisd<ToT> /* tag */, Vec1<FromT> from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // float## -> int##: return closest representable value. We cannot exactly + // represent LimitsMax<ToT> in FromT, so use double. + const double f = static_cast<double>(from.raw); + if (IsInf(from).bits || + Abs(Vec1<double>(f)).raw > static_cast<double>(LimitsMax<ToT>())) { + return Vec1<ToT>(detail::SignBit(from.raw) ? LimitsMin<ToT>() + : LimitsMax<ToT>()); + } + return Vec1<ToT>(static_cast<ToT>(from.raw)); +} + +template <typename FromT, typename ToT, HWY_IF_NOT_FLOAT(FromT)> +HWY_API Vec1<ToT> ConvertTo(Sisd<ToT> /* tag */, Vec1<FromT> from) { + static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); + // int## -> float##: no check needed + return Vec1<ToT>(static_cast<ToT>(from.raw)); +} + +HWY_API Vec1<uint8_t> U8FromU32(const Vec1<uint32_t> v) { + return DemoteTo(Sisd<uint8_t>(), v); +} + +// ------------------------------ Truncations + +HWY_API Vec1<uint8_t> TruncateTo(Sisd<uint8_t> /* tag */, + const Vec1<uint64_t> v) { + return Vec1<uint8_t>{static_cast<uint8_t>(v.raw & 0xFF)}; +} + +HWY_API Vec1<uint16_t> TruncateTo(Sisd<uint16_t> /* tag */, + const Vec1<uint64_t> v) { + return Vec1<uint16_t>{static_cast<uint16_t>(v.raw & 0xFFFF)}; +} + +HWY_API Vec1<uint32_t> TruncateTo(Sisd<uint32_t> /* tag */, + const Vec1<uint64_t> v) { + return Vec1<uint32_t>{static_cast<uint32_t>(v.raw & 0xFFFFFFFFu)}; +} + +HWY_API Vec1<uint8_t> TruncateTo(Sisd<uint8_t> /* tag */, + const Vec1<uint32_t> v) { + return Vec1<uint8_t>{static_cast<uint8_t>(v.raw & 0xFF)}; +} + +HWY_API Vec1<uint16_t> TruncateTo(Sisd<uint16_t> /* tag */, + const Vec1<uint32_t> v) { + return Vec1<uint16_t>{static_cast<uint16_t>(v.raw & 0xFFFF)}; +} + +HWY_API Vec1<uint8_t> TruncateTo(Sisd<uint8_t> /* tag */, + const Vec1<uint16_t> v) { + return Vec1<uint8_t>{static_cast<uint8_t>(v.raw & 0xFF)}; +} + +// ================================================== COMBINE +// UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. + +template <typename T> +HWY_API Vec1<T> LowerHalf(Vec1<T> v) { + return v; +} + +template <typename T> +HWY_API Vec1<T> LowerHalf(Sisd<T> /* tag */, Vec1<T> v) { + return v; +} + +// ================================================== SWIZZLE + +template <typename T> +HWY_API T GetLane(const Vec1<T> v) { + return v.raw; +} + +template <typename T> +HWY_API T ExtractLane(const Vec1<T> v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return v.raw; +} + +template <typename T> +HWY_API Vec1<T> InsertLane(Vec1<T> v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + v.raw = t; + return v; +} + +template <typename T> +HWY_API Vec1<T> DupEven(Vec1<T> v) { + return v; +} +// DupOdd is unsupported. + +template <typename T> +HWY_API Vec1<T> OddEven(Vec1<T> /* odd */, Vec1<T> even) { + return even; +} + +template <typename T> +HWY_API Vec1<T> OddEvenBlocks(Vec1<T> /* odd */, Vec1<T> even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T> +HWY_API Vec1<T> SwapAdjacentBlocks(Vec1<T> v) { + return v; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template <typename T> +struct Indices1 { + MakeSigned<T> raw; +}; + +template <typename T, typename TI> +HWY_API Indices1<T> IndicesFromVec(Sisd<T>, Vec1<TI> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + HWY_DASSERT(vec.raw == 0); + return Indices1<T>{vec.raw}; +} + +template <typename T, typename TI> +HWY_API Indices1<T> SetTableIndices(Sisd<T> d, const TI* idx) { + return IndicesFromVec(d, LoadU(Sisd<TI>(), idx)); +} + +template <typename T> +HWY_API Vec1<T> TableLookupLanes(const Vec1<T> v, const Indices1<T> /* idx */) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template <typename T> +HWY_API Vec1<T> ReverseBlocks(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} + +// ------------------------------ Reverse + +template <typename T> +HWY_API Vec1<T> Reverse(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} + +// Must not be called: +template <typename T> +HWY_API Vec1<T> Reverse2(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} + +template <typename T> +HWY_API Vec1<T> Reverse4(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} + +template <typename T> +HWY_API Vec1<T> Reverse8(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} + +// ================================================== BLOCKWISE +// Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. + +// ------------------------------ Broadcast/splat any lane + +template <int kLane, typename T> +HWY_API Vec1<T> Broadcast(const Vec1<T> v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template <typename T, typename TI> +HWY_API Vec1<TI> TableLookupBytes(const Vec1<T> in, const Vec1<TI> indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes<sizeof(T)>(&in, &in_bytes); // copy to bytes + CopyBytes<sizeof(T)>(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes<sizeof(TI)>(&out_bytes, &out); + return Vec1<TI>{out}; +} + +template <typename T, typename TI> +HWY_API Vec1<TI> TableLookupBytesOr0(const Vec1<T> in, const Vec1<TI> indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes<sizeof(T)>(&in, &in_bytes); // copy to bytes + CopyBytes<sizeof(T)>(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes<sizeof(TI)>(&out_bytes, &out); + return Vec1<TI>{out}; +} + +// ------------------------------ ZipLower + +HWY_API Vec1<uint16_t> ZipLower(const Vec1<uint8_t> a, const Vec1<uint8_t> b) { + return Vec1<uint16_t>(static_cast<uint16_t>((uint32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1<uint32_t> ZipLower(const Vec1<uint16_t> a, + const Vec1<uint16_t> b) { + return Vec1<uint32_t>((uint32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1<uint64_t> ZipLower(const Vec1<uint32_t> a, + const Vec1<uint32_t> b) { + return Vec1<uint64_t>((uint64_t{b.raw} << 32) + a.raw); +} +HWY_API Vec1<int16_t> ZipLower(const Vec1<int8_t> a, const Vec1<int8_t> b) { + return Vec1<int16_t>(static_cast<int16_t>((int32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1<int32_t> ZipLower(const Vec1<int16_t> a, const Vec1<int16_t> b) { + return Vec1<int32_t>((int32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1<int64_t> ZipLower(const Vec1<int32_t> a, const Vec1<int32_t> b) { + return Vec1<int64_t>((int64_t{b.raw} << 32) + a.raw); +} + +template <typename T, typename TW = MakeWide<T>, class VW = Vec1<TW>> +HWY_API VW ZipLower(Sisd<TW> /* tag */, Vec1<T> a, Vec1<T> b) { + return VW(static_cast<TW>((TW{b.raw} << (sizeof(T) * 8)) + a.raw)); +} + +// ================================================== MASK + +template <typename T> +HWY_API bool AllFalse(Sisd<T> /* tag */, const Mask1<T> mask) { + return mask.bits == 0; +} + +template <typename T> +HWY_API bool AllTrue(Sisd<T> /* tag */, const Mask1<T> mask) { + return mask.bits != 0; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T> +HWY_API Mask1<T> LoadMaskBits(Sisd<T> /* tag */, + const uint8_t* HWY_RESTRICT bits) { + return Mask1<T>::FromBool((bits[0] & 1) != 0); +} + +// `p` points to at least 8 writable bytes. +template <typename T> +HWY_API size_t StoreMaskBits(Sisd<T> d, const Mask1<T> mask, uint8_t* bits) { + *bits = AllTrue(d, mask); + return 1; +} + +template <typename T> +HWY_API size_t CountTrue(Sisd<T> /* tag */, const Mask1<T> mask) { + return mask.bits == 0 ? 0 : 1; +} + +template <typename T> +HWY_API intptr_t FindFirstTrue(Sisd<T> /* tag */, const Mask1<T> mask) { + return mask.bits == 0 ? -1 : 0; +} + +template <typename T> +HWY_API size_t FindKnownFirstTrue(Sisd<T> /* tag */, const Mask1<T> /* m */) { + return 0; // There is only one lane and we know it is true. +} + +// ------------------------------ Compress, CompressBits + +template <typename T> +struct CompressIsPartition { + enum { value = 1 }; +}; + +template <typename T> +HWY_API Vec1<T> Compress(Vec1<T> v, const Mask1<T> /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +template <typename T> +HWY_API Vec1<T> CompressNot(Vec1<T> v, const Mask1<T> /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +// ------------------------------ CompressStore +template <typename T> +HWY_API size_t CompressStore(Vec1<T> v, const Mask1<T> mask, Sisd<T> d, + T* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template <typename T> +HWY_API size_t CompressBlendedStore(Vec1<T> v, const Mask1<T> mask, Sisd<T> d, + T* HWY_RESTRICT unaligned) { + if (!mask.bits) return 0; + StoreU(v, d, unaligned); + return 1; +} + +// ------------------------------ CompressBits +template <typename T> +HWY_API Vec1<T> CompressBits(Vec1<T> v, const uint8_t* HWY_RESTRICT /*bits*/) { + return v; +} + +// ------------------------------ CompressBitsStore +template <typename T> +HWY_API size_t CompressBitsStore(Vec1<T> v, const uint8_t* HWY_RESTRICT bits, + Sisd<T> d, T* HWY_RESTRICT unaligned) { + const Mask1<T> mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +HWY_API Vec1<float> ReorderWidenMulAccumulate(Sisd<float> /* tag */, + Vec1<bfloat16_t> a, + Vec1<bfloat16_t> b, + const Vec1<float> sum0, + Vec1<float>& /* sum1 */) { + return MulAdd(Vec1<float>(F32FromBF16(a.raw)), + Vec1<float>(F32FromBF16(b.raw)), sum0); +} + +HWY_API Vec1<int32_t> ReorderWidenMulAccumulate(Sisd<int32_t> /* tag */, + Vec1<int16_t> a, + Vec1<int16_t> b, + const Vec1<int32_t> sum0, + Vec1<int32_t>& /* sum1 */) { + return Vec1<int32_t>(a.raw * b.raw + sum0.raw); +} + +// ------------------------------ RearrangeToOddPlusEven +template <typename TW> +HWY_API Vec1<TW> RearrangeToOddPlusEven(const Vec1<TW> sum0, + Vec1<TW> /* sum1 */) { + return sum0; // invariant already holds +} + +// ================================================== REDUCTIONS + +// Sum of all lanes, i.e. the only one. +template <typename T> +HWY_API Vec1<T> SumOfLanes(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} +template <typename T> +HWY_API Vec1<T> MinOfLanes(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} +template <typename T> +HWY_API Vec1<T> MaxOfLanes(Sisd<T> /* tag */, const Vec1<T> v) { + return v; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/set_macros-inl.h b/third_party/highway/hwy/ops/set_macros-inl.h new file mode 100644 index 0000000000..051dbb3348 --- /dev/null +++ b/third_party/highway/hwy/ops/set_macros-inl.h @@ -0,0 +1,444 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Sets macros based on HWY_TARGET. + +// This include guard is toggled by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. +#if defined(HWY_SET_MACROS_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif + +#endif // HWY_SET_MACROS_PER_TARGET + +#include "hwy/detect_targets.h" + +#undef HWY_NAMESPACE +#undef HWY_ALIGN +#undef HWY_MAX_BYTES +#undef HWY_LANES + +#undef HWY_HAVE_SCALABLE +#undef HWY_HAVE_INTEGER64 +#undef HWY_HAVE_FLOAT16 +#undef HWY_HAVE_FLOAT64 +#undef HWY_MEM_OPS_MIGHT_FAULT +#undef HWY_NATIVE_FMA +#undef HWY_CAP_GE256 +#undef HWY_CAP_GE512 + +#undef HWY_TARGET_STR + +#if defined(HWY_DISABLE_PCLMUL_AES) +#define HWY_TARGET_STR_PCLMUL_AES "" +#else +#define HWY_TARGET_STR_PCLMUL_AES ",pclmul,aes" +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) +#define HWY_TARGET_STR_BMI2_FMA "" +#else +#define HWY_TARGET_STR_BMI2_FMA ",bmi,bmi2,fma" +#endif + +#if defined(HWY_DISABLE_F16C) +#define HWY_TARGET_STR_F16C "" +#else +#define HWY_TARGET_STR_F16C ",f16c" +#endif + +#define HWY_TARGET_STR_SSSE3 "sse2,ssse3" + +#define HWY_TARGET_STR_SSE4 \ + HWY_TARGET_STR_SSSE3 ",sse4.1,sse4.2" HWY_TARGET_STR_PCLMUL_AES +// Include previous targets, which are the half-vectors of the next target. +#define HWY_TARGET_STR_AVX2 \ + HWY_TARGET_STR_SSE4 ",avx,avx2" HWY_TARGET_STR_BMI2_FMA HWY_TARGET_STR_F16C +#define HWY_TARGET_STR_AVX3 \ + HWY_TARGET_STR_AVX2 ",avx512f,avx512vl,avx512dq,avx512bw" + +// Before include guard so we redefine HWY_TARGET_STR on each include, +// governed by the current HWY_TARGET. + +//----------------------------------------------------------------------------- +// SSSE3 +#if HWY_TARGET == HWY_SSSE3 + +#define HWY_NAMESPACE N_SSSE3 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSSE3 + +//----------------------------------------------------------------------------- +// SSE4 +#elif HWY_TARGET == HWY_SSE4 + +#define HWY_NAMESPACE N_SSE4 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE4 + +//----------------------------------------------------------------------------- +// AVX2 +#elif HWY_TARGET == HWY_AVX2 + +#define HWY_NAMESPACE N_AVX2 +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#ifdef HWY_DISABLE_BMI2_FMA +#define HWY_NATIVE_FMA 0 +#else +#define HWY_NATIVE_FMA 1 +#endif + +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_TARGET_STR HWY_TARGET_STR_AVX2 + +//----------------------------------------------------------------------------- +// AVX3[_DL] +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL + +#define HWY_ALIGN alignas(64) +#define HWY_MAX_BYTES 64 +#define HWY_LANES(T) (64 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 1 + +#if HWY_TARGET == HWY_AVX3 + +#define HWY_NAMESPACE N_AVX3 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3 + +#elif HWY_TARGET == HWY_AVX3_DL + +#define HWY_NAMESPACE N_AVX3_DL +#define HWY_TARGET_STR \ + HWY_TARGET_STR_AVX3 \ + ",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avxvnni,avx512bitalg," \ + "avx512vpopcntdq" + +#else +#error "Logic error" +#endif // HWY_TARGET == HWY_AVX3_DL + +//----------------------------------------------------------------------------- +// PPC8 +#elif HWY_TARGET == HWY_PPC8 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_PPC8 + +#define HWY_TARGET_STR "altivec,vsx" + +//----------------------------------------------------------------------------- +// NEON +#elif HWY_TARGET == HWY_NEON + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 + +#if HWY_ARCH_ARM_A64 +#define HWY_HAVE_FLOAT64 1 +#else +#define HWY_HAVE_FLOAT64 0 +#endif + +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#if defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#define HWY_NATIVE_FMA 1 +#else +#define HWY_NATIVE_FMA 0 +#endif + +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_NEON + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_ARCH_ARM_V7 +#define HWY_TARGET_STR "+neon-vfpv4" +#else +#define HWY_TARGET_STR "+crypto" +#endif // HWY_ARCH_ARM_V7 +#else +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// SVE[2] +#elif HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE || \ + HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + +// SVE only requires lane alignment, not natural alignment of the entire vector. +#define HWY_ALIGN alignas(8) + +// Value ensures MaxLanes() is the tightest possible upper bound to reduce +// overallocation. +#define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if HWY_TARGET == HWY_SVE2 +#define HWY_NAMESPACE N_SVE2 +#define HWY_MAX_BYTES 256 +#elif HWY_TARGET == HWY_SVE_256 +#define HWY_NAMESPACE N_SVE_256 +#define HWY_MAX_BYTES 32 +#elif HWY_TARGET == HWY_SVE2_128 +#define HWY_NAMESPACE N_SVE2_128 +#define HWY_MAX_BYTES 16 +#else +#define HWY_NAMESPACE N_SVE +#define HWY_MAX_BYTES 256 +#endif + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +#define HWY_TARGET_STR "+sve2-aes" +#else +#define HWY_TARGET_STR "+sve" +#endif +#else +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// WASM +#elif HWY_TARGET == HWY_WASM + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// WASM_EMU256 +#elif HWY_TARGET == HWY_WASM_EMU256 + +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_LANES(T) (32 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 0 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 1 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_WASM_EMU256 + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// RVV +#elif HWY_TARGET == HWY_RVV + +// RVV only requires lane alignment, not natural alignment of the entire vector, +// and the compiler already aligns builtin types, so nothing to do here. +#define HWY_ALIGN + +// The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8). +#define HWY_MAX_BYTES 65536 + +// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual +// LMUL. This is the tightest possible upper bound. +#define HWY_LANES(T) (8192 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#if defined(__riscv_zvfh) +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#define HWY_NAMESPACE N_RVV + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. +// (rv64gcv is not a valid target) + +//----------------------------------------------------------------------------- +// EMU128 +#elif HWY_TARGET == HWY_EMU128 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_LANES(T) (16 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_EMU128 + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// SCALAR +#elif HWY_TARGET == HWY_SCALAR + +#define HWY_ALIGN +#define HWY_MAX_BYTES 8 +#define HWY_LANES(T) 1 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 0 +#define HWY_CAP_GE256 0 +#define HWY_CAP_GE512 0 + +#define HWY_NAMESPACE N_SCALAR + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +// Override this to 1 in asan/msan builds, which will still fault. +#if HWY_IS_ASAN || HWY_IS_MSAN +#undef HWY_MEM_OPS_MIGHT_FAULT +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#endif + +// Clang <9 requires this be invoked at file scope, before any namespace. +#undef HWY_BEFORE_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_BEFORE_NAMESPACE() \ + HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_BEFORE_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +// Clang <9 requires any namespaces be closed before this macro. +#undef HWY_AFTER_NAMESPACE +#if defined(HWY_TARGET_STR) +#define HWY_AFTER_NAMESPACE() \ + HWY_POP_ATTRIBUTES \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_AFTER_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +#undef HWY_ATTR +#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) +#define HWY_ATTR __attribute__((target(HWY_TARGET_STR))) +#else +#define HWY_ATTR +#endif diff --git a/third_party/highway/hwy/ops/shared-inl.h b/third_party/highway/hwy/ops/shared-inl.h new file mode 100644 index 0000000000..02246bfa4f --- /dev/null +++ b/third_party/highway/hwy/ops/shared-inl.h @@ -0,0 +1,332 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target definitions shared by ops/*.h and user code. + +// We are covered by the highway.h include guard, but generic_ops-inl.h +// includes this again #if HWY_IDE. +#if defined(HIGHWAY_HWY_OPS_SHARED_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#undef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_OPS_SHARED_TOGGLE +#endif + +#ifndef HWY_NO_LIBCXX +#include <math.h> +#endif + +#include "hwy/base.h" + +// Separate header because foreach_target.h re-enables its include guard. +#include "hwy/ops/set_macros-inl.h" + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Highway operations are implemented as overloaded functions selected using an +// internal-only tag type D := Simd<T, N, kPow2>. T is the lane type. kPow2 is a +// shift count applied to scalable vectors. Instead of referring to Simd<> +// directly, users create D via aliases ScalableTag<T[, kPow2]>() (defaults to a +// full vector, or fractions/groups if the argument is negative/positive), +// CappedTag<T, kLimit> or FixedTag<T, kNumLanes>. The actual number of lanes is +// Lanes(D()), a power of two. For scalable vectors, N is either HWY_LANES or a +// cap. For constexpr-size vectors, N is the actual number of lanes. This +// ensures Half<Full512<T>> is the same type as Full256<T>, as required by x86. +template <typename Lane, size_t N, int kPow2> +struct Simd { + constexpr Simd() = default; + using T = Lane; + static_assert((N & (N - 1)) == 0 && N != 0, "N must be a power of two"); + + // Only for use by MaxLanes, required by MSVC. Cannot be enum because GCC + // warns when using enums and non-enums in the same expression. Cannot be + // static constexpr function (another MSVC limitation). + static constexpr size_t kPrivateN = N; + static constexpr int kPrivatePow2 = kPow2; + + template <typename NewT> + static constexpr size_t NewN() { + // Round up to correctly handle scalars with N=1. + return (N * sizeof(T) + sizeof(NewT) - 1) / sizeof(NewT); + } + +#if HWY_HAVE_SCALABLE + template <typename NewT> + static constexpr int Pow2Ratio() { + return (sizeof(NewT) > sizeof(T)) + ? static_cast<int>(CeilLog2(sizeof(NewT) / sizeof(T))) + : -static_cast<int>(CeilLog2(sizeof(T) / sizeof(NewT))); + } +#endif + + // Widening/narrowing ops change the number of lanes and/or their type. + // To initialize such vectors, we need the corresponding tag types: + +// PromoteTo/DemoteTo() with another lane type, but same number of lanes. +#if HWY_HAVE_SCALABLE + template <typename NewT> + using Rebind = Simd<NewT, N, kPow2 + Pow2Ratio<NewT>()>; +#else + template <typename NewT> + using Rebind = Simd<NewT, N, kPow2>; +#endif + + // Change lane type while keeping the same vector size, e.g. for MulEven. + template <typename NewT> + using Repartition = Simd<NewT, NewN<NewT>(), kPow2>; + +// Half the lanes while keeping the same lane type, e.g. for LowerHalf. +// Round up to correctly handle scalars with N=1. +#if HWY_HAVE_SCALABLE + // Reducing the cap (N) is required for SVE - if N is the limiter for f32xN, + // then we expect Half<Rebind<u16>> to have N/2 lanes (rounded up). + using Half = Simd<T, (N + 1) / 2, kPow2 - 1>; +#else + using Half = Simd<T, (N + 1) / 2, kPow2>; +#endif + +// Twice the lanes while keeping the same lane type, e.g. for Combine. +#if HWY_HAVE_SCALABLE + using Twice = Simd<T, 2 * N, kPow2 + 1>; +#else + using Twice = Simd<T, 2 * N, kPow2>; +#endif +}; + +namespace detail { + +template <typename T, size_t N, int kPow2> +constexpr bool IsFull(Simd<T, N, kPow2> /* d */) { + return N == HWY_LANES(T) && kPow2 == 0; +} + +// Returns the number of lanes (possibly zero) after applying a shift: +// - 0: no change; +// - [1,3]: a group of 2,4,8 [fractional] vectors; +// - [-3,-1]: a fraction of a vector from 1/8 to 1/2. +constexpr size_t ScaleByPower(size_t N, int pow2) { +#if HWY_TARGET == HWY_RVV + return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); +#else + return pow2 >= 0 ? N : (N >> (-pow2)); +#endif +} + +// Struct wrappers enable validation of arguments via static_assert. +template <typename T, int kPow2> +struct ScalableTagChecker { + static_assert(-3 <= kPow2 && kPow2 <= 3, "Fraction must be 1/8 to 8"); +#if HWY_TARGET == HWY_RVV + // Only RVV supports register groups. + using type = Simd<T, HWY_LANES(T), kPow2>; +#elif HWY_HAVE_SCALABLE + // For SVE[2], only allow full or fractions. + using type = Simd<T, HWY_LANES(T), HWY_MIN(kPow2, 0)>; +#elif HWY_TARGET == HWY_SCALAR + using type = Simd<T, /*N=*/1, 0>; +#else + // Only allow full or fractions. + using type = Simd<T, ScaleByPower(HWY_LANES(T), HWY_MIN(kPow2, 0)), 0>; +#endif +}; + +template <typename T, size_t kLimit> +struct CappedTagChecker { + static_assert(kLimit != 0, "Does not make sense to have zero lanes"); + // Safely handle non-power-of-two inputs by rounding down, which is allowed by + // CappedTag. Otherwise, Simd<T, 3, 0> would static_assert. + static constexpr size_t kLimitPow2 = size_t{1} << hwy::FloorLog2(kLimit); + using type = Simd<T, HWY_MIN(kLimitPow2, HWY_LANES(T)), 0>; +}; + +template <typename T, size_t kNumLanes> +struct FixedTagChecker { + static_assert(kNumLanes != 0, "Does not make sense to have zero lanes"); + static_assert(kNumLanes <= HWY_LANES(T), "Too many lanes"); + using type = Simd<T, kNumLanes, 0>; +}; + +} // namespace detail + +// Alias for a tag describing a full vector (kPow2 == 0: the most common usage, +// e.g. 1D loops where the application does not care about the vector size) or a +// fraction/multiple of one. Multiples are the same as full vectors for all +// targets except RVV. Fractions (kPow2 < 0) are useful as the argument/return +// value of type promotion and demotion. +template <typename T, int kPow2 = 0> +using ScalableTag = typename detail::ScalableTagChecker<T, kPow2>::type; + +// Alias for a tag describing a vector with *up to* kLimit active lanes, even on +// targets with scalable vectors and HWY_SCALAR. The runtime lane count +// `Lanes(tag)` may be less than kLimit, and is 1 on HWY_SCALAR. This alias is +// typically used for 1D loops with a relatively low application-defined upper +// bound, e.g. for 8x8 DCTs. However, it is better if data structures are +// designed to be vector-length-agnostic (e.g. a hybrid SoA where there are +// chunks of `M >= MaxLanes(d)` DC components followed by M AC1, .., and M AC63; +// this would enable vector-length-agnostic loops using ScalableTag). +template <typename T, size_t kLimit> +using CappedTag = typename detail::CappedTagChecker<T, kLimit>::type; + +// Alias for a tag describing a vector with *exactly* kNumLanes active lanes, +// even on targets with scalable vectors. Requires `kNumLanes` to be a power of +// two not exceeding `HWY_LANES(T)`. +// +// NOTE: if the application does not need to support HWY_SCALAR (+), use this +// instead of CappedTag to emphasize that there will be exactly kNumLanes lanes. +// This is useful for data structures that rely on exactly 128-bit SIMD, but +// these are discouraged because they cannot benefit from wider vectors. +// Instead, applications would ideally define a larger problem size and loop +// over it with the (unknown size) vectors from ScalableTag. +// +// + e.g. if the baseline is known to support SIMD, or the application requires +// ops such as TableLookupBytes not supported by HWY_SCALAR. +template <typename T, size_t kNumLanes> +using FixedTag = typename detail::FixedTagChecker<T, kNumLanes>::type; + +template <class D> +using TFromD = typename D::T; + +// Tag for the same number of lanes as D, but with the LaneType T. +template <class T, class D> +using Rebind = typename D::template Rebind<T>; + +template <class D> +using RebindToSigned = Rebind<MakeSigned<TFromD<D>>, D>; +template <class D> +using RebindToUnsigned = Rebind<MakeUnsigned<TFromD<D>>, D>; +template <class D> +using RebindToFloat = Rebind<MakeFloat<TFromD<D>>, D>; + +// Tag for the same total size as D, but with the LaneType T. +template <class T, class D> +using Repartition = typename D::template Repartition<T>; + +template <class D> +using RepartitionToWide = Repartition<MakeWide<TFromD<D>>, D>; +template <class D> +using RepartitionToNarrow = Repartition<MakeNarrow<TFromD<D>>, D>; + +// Tag for the same lane type as D, but half the lanes. +template <class D> +using Half = typename D::Half; + +// Tag for the same lane type as D, but twice the lanes. +template <class D> +using Twice = typename D::Twice; + +template <typename T> +using Full16 = Simd<T, 2 / sizeof(T), 0>; + +template <typename T> +using Full32 = Simd<T, 4 / sizeof(T), 0>; + +template <typename T> +using Full64 = Simd<T, 8 / sizeof(T), 0>; + +template <typename T> +using Full128 = Simd<T, 16 / sizeof(T), 0>; + +// Same as base.h macros but with a Simd<T, N, kPow2> argument instead of T. +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(TFromD<D>) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(TFromD<D>) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(TFromD<D>) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(TFromD<D>) +#define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_LANE_SIZE(TFromD<D>, bytes) +#define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_LANE_SIZE(TFromD<D>, bytes) +#define HWY_IF_LANE_SIZE_ONE_OF_D(D, bit_array) \ + HWY_IF_LANE_SIZE_ONE_OF(TFromD<D>, bit_array) + +// MSVC workaround: use PrivateN directly instead of MaxLanes. +#define HWY_IF_LT128_D(D) \ + hwy::EnableIf<D::kPrivateN * sizeof(TFromD<D>) < 16>* = nullptr +#define HWY_IF_GE128_D(D) \ + hwy::EnableIf<D::kPrivateN * sizeof(TFromD<D>) >= 16>* = nullptr + +// Same, but with a vector argument. ops/*-inl.h define their own TFromV. +#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(TFromV<V>) +#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(TFromV<V>) +#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(TFromV<V>) +#define HWY_IF_LANE_SIZE_V(V, bytes) HWY_IF_LANE_SIZE(TFromV<V>, bytes) +#define HWY_IF_NOT_LANE_SIZE_V(V, bytes) HWY_IF_NOT_LANE_SIZE(TFromV<V>, bytes) +#define HWY_IF_LANE_SIZE_ONE_OF_V(V, bit_array) \ + HWY_IF_LANE_SIZE_ONE_OF(TFromV<V>, bit_array) + +template <class D> +HWY_INLINE HWY_MAYBE_UNUSED constexpr int Pow2(D /* d */) { + return D::kPrivatePow2; +} + +// MSVC requires the explicit <D>. +#define HWY_IF_POW2_GE(D, MIN) hwy::EnableIf<Pow2<D>(D()) >= (MIN)>* = nullptr + +#if HWY_HAVE_SCALABLE + +// Upper bound on the number of lanes. Intended for template arguments and +// reducing code size (e.g. for SSE4, we know at compile-time that vectors will +// not exceed 16 bytes). WARNING: this may be a loose bound, use Lanes() as the +// actual size for allocating storage. WARNING: MSVC might not be able to deduce +// arguments if this is used in EnableIf. See HWY_IF_LT128_D above. +template <class D> +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return detail::ScaleByPower(HWY_MIN(D::kPrivateN, HWY_LANES(TFromD<D>)), + D::kPrivatePow2); +} + +#else +// Workaround for MSVC 2017: T,N,kPow2 argument deduction fails, so returning N +// is not an option, nor does a member function work. +template <class D> +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return D::kPrivateN; +} + +// (Potentially) non-constant actual size of the vector at runtime, subject to +// the limit imposed by the Simd. Useful for advancing loop counters. +// Targets with scalable vectors define this themselves. +template <typename T, size_t N, int kPow2> +HWY_INLINE HWY_MAYBE_UNUSED size_t Lanes(Simd<T, N, kPow2>) { + return N; +} + +#endif // !HWY_HAVE_SCALABLE + +// NOTE: GCC generates incorrect code for vector arguments to non-inlined +// functions in two situations: +// - on Windows and GCC 10.3, passing by value crashes due to unaligned loads: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412. +// - on ARM64 and GCC 9.3.0 or 11.2.1, passing by value causes many (but not +// all) tests to fail. +// +// We therefore pass by const& only on GCC and (Windows or ARM64). This alias +// must be used for all vector/mask parameters of functions marked HWY_NOINLINE, +// and possibly also other functions that are not inlined. +#if HWY_COMPILER_GCC_ACTUAL && (HWY_OS_WIN || HWY_ARCH_ARM_A64) +template <class V> +using VecArg = const V&; +#else +template <class V> +using VecArg = V; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_OPS_SHARED_TOGGLE diff --git a/third_party/highway/hwy/ops/wasm_128-inl.h b/third_party/highway/hwy/ops/wasm_128-inl.h new file mode 100644 index 0000000000..095fd4f1f0 --- /dev/null +++ b/third_party/highway/hwy/ops/wasm_128-inl.h @@ -0,0 +1,4591 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit WASM vectors and operations. +// External include guard in highway.h - see comment there. + +#include <stddef.h> +#include <stdint.h> +#include <wasm_simd128.h> + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +#ifdef HWY_WASM_OLD_NAMES +#define wasm_i8x16_shuffle wasm_v8x16_shuffle +#define wasm_i16x8_shuffle wasm_v16x8_shuffle +#define wasm_i32x4_shuffle wasm_v32x4_shuffle +#define wasm_i64x2_shuffle wasm_v64x2_shuffle +#define wasm_u16x8_extend_low_u8x16 wasm_i16x8_widen_low_u8x16 +#define wasm_u32x4_extend_low_u16x8 wasm_i32x4_widen_low_u16x8 +#define wasm_i32x4_extend_low_i16x8 wasm_i32x4_widen_low_i16x8 +#define wasm_i16x8_extend_low_i8x16 wasm_i16x8_widen_low_i8x16 +#define wasm_u32x4_extend_high_u16x8 wasm_i32x4_widen_high_u16x8 +#define wasm_i32x4_extend_high_i16x8 wasm_i32x4_widen_high_i16x8 +#define wasm_i32x4_trunc_sat_f32x4 wasm_i32x4_trunc_saturate_f32x4 +#define wasm_u8x16_add_sat wasm_u8x16_add_saturate +#define wasm_u8x16_sub_sat wasm_u8x16_sub_saturate +#define wasm_u16x8_add_sat wasm_u16x8_add_saturate +#define wasm_u16x8_sub_sat wasm_u16x8_sub_saturate +#define wasm_i8x16_add_sat wasm_i8x16_add_saturate +#define wasm_i8x16_sub_sat wasm_i8x16_sub_saturate +#define wasm_i16x8_add_sat wasm_i16x8_add_saturate +#define wasm_i16x8_sub_sat wasm_i16x8_sub_saturate +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET == HWY_WASM_EMU256 +template <typename T> +using Full256 = Simd<T, 32 / sizeof(T), 0>; +#endif + +namespace detail { + +template <typename T> +struct Raw128 { + using type = __v128_u; +}; +template <> +struct Raw128<float> { + using type = __f32x4; +}; + +} // namespace detail + +template <typename T, size_t N = 16 / sizeof(T)> +class Vec128 { + using Raw = typename detail::Raw128<T>::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template <typename T> +using Vec64 = Vec128<T, 8 / sizeof(T)>; + +template <typename T> +using Vec32 = Vec128<T, 4 / sizeof(T)>; + +template <typename T> +using Vec16 = Vec128<T, 2 / sizeof(T)>; + +// FF..FF or 0. +template <typename T, size_t N = 16 / sizeof(T)> +struct Mask128 { + typename detail::Raw128<T>::type raw; +}; + +template <class V> +using DFromV = Simd<typename V::PrivateT, V::kPrivateN, 0>; + +template <class V> +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __v128_u BitCastToInteger(__v128_u v) { return v; } +HWY_INLINE __v128_u BitCastToInteger(__f32x4 v) { + return static_cast<__v128_u>(v); +} +HWY_INLINE __v128_u BitCastToInteger(__f64x2 v) { + return static_cast<__v128_u>(v); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<uint8_t, N * sizeof(T)> BitCastToByte(Vec128<T, N> v) { + return Vec128<uint8_t, N * sizeof(T)>{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template <typename T> +struct BitCastFromInteger128 { + HWY_INLINE __v128_u operator()(__v128_u v) { return v; } +}; +template <> +struct BitCastFromInteger128<float> { + HWY_INLINE __f32x4 operator()(__v128_u v) { return static_cast<__f32x4>(v); } +}; + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> BitCastFromByte(Simd<T, N, 0> /* tag */, + Vec128<uint8_t, N * sizeof(T)> v) { + return Vec128<T, N>{BitCastFromInteger128<T>()(v.raw)}; +} + +} // namespace detail + +template <typename T, size_t N, typename FromT> +HWY_API Vec128<T, N> BitCast(Simd<T, N, 0> d, + Vec128<FromT, N * sizeof(T) / sizeof(FromT)> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Zero(Simd<T, N, 0> /* tag */) { + return Vec128<T, N>{wasm_i32x4_splat(0)}; +} +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> Zero(Simd<float, N, 0> /* tag */) { + return Vec128<float, N>{wasm_f32x4_splat(0.0f)}; +} + +template <class D> +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template <size_t N, HWY_IF_LE128(uint8_t, N)> +HWY_API Vec128<uint8_t, N> Set(Simd<uint8_t, N, 0> /* tag */, const uint8_t t) { + return Vec128<uint8_t, N>{wasm_i8x16_splat(static_cast<int8_t>(t))}; +} +template <size_t N, HWY_IF_LE128(uint16_t, N)> +HWY_API Vec128<uint16_t, N> Set(Simd<uint16_t, N, 0> /* tag */, + const uint16_t t) { + return Vec128<uint16_t, N>{wasm_i16x8_splat(static_cast<int16_t>(t))}; +} +template <size_t N, HWY_IF_LE128(uint32_t, N)> +HWY_API Vec128<uint32_t, N> Set(Simd<uint32_t, N, 0> /* tag */, + const uint32_t t) { + return Vec128<uint32_t, N>{wasm_i32x4_splat(static_cast<int32_t>(t))}; +} +template <size_t N, HWY_IF_LE128(uint64_t, N)> +HWY_API Vec128<uint64_t, N> Set(Simd<uint64_t, N, 0> /* tag */, + const uint64_t t) { + return Vec128<uint64_t, N>{wasm_i64x2_splat(static_cast<int64_t>(t))}; +} + +template <size_t N, HWY_IF_LE128(int8_t, N)> +HWY_API Vec128<int8_t, N> Set(Simd<int8_t, N, 0> /* tag */, const int8_t t) { + return Vec128<int8_t, N>{wasm_i8x16_splat(t)}; +} +template <size_t N, HWY_IF_LE128(int16_t, N)> +HWY_API Vec128<int16_t, N> Set(Simd<int16_t, N, 0> /* tag */, const int16_t t) { + return Vec128<int16_t, N>{wasm_i16x8_splat(t)}; +} +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> Set(Simd<int32_t, N, 0> /* tag */, const int32_t t) { + return Vec128<int32_t, N>{wasm_i32x4_splat(t)}; +} +template <size_t N, HWY_IF_LE128(int64_t, N)> +HWY_API Vec128<int64_t, N> Set(Simd<int64_t, N, 0> /* tag */, const int64_t t) { + return Vec128<int64_t, N>{wasm_i64x2_splat(t)}; +} + +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> Set(Simd<float, N, 0> /* tag */, const float t) { + return Vec128<float, N>{wasm_f32x4_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Undefined(Simd<T, N, 0> d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// Returns a vector with lane i=[0, N) set to "first" + i. +template <typename T, size_t N, typename T2, HWY_IF_LE128(T, N)> +Vec128<T, N> Iota(const Simd<T, N, 0> d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag<T>(), static_cast<T>(first), i); + } + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> operator+(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_i8x16_add(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> operator+(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_i16x8_add(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> operator+(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_i32x4_add(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> operator+(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> operator+(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_add(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> operator+(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_add(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> operator+(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_add(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> operator+(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> operator+(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{wasm_f32x4_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> operator-(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_i8x16_sub(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> operator-(Vec128<uint16_t, N> a, + Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_i16x8_sub(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> operator-(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_i32x4_sub(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> operator-(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> operator-(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_sub(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> operator-(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_sub(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> operator-(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_sub(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> operator-(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> operator-(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{wasm_f32x4_sub(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> SaturatedAdd(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_u8x16_add_sat(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> SaturatedAdd(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_u16x8_add_sat(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> SaturatedAdd(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_add_sat(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> SaturatedAdd(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_add_sat(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> SaturatedSub(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_u8x16_sub_sat(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> SaturatedSub(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_u16x8_sub_sat(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> SaturatedSub(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_sub_sat(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> SaturatedSub(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_sub_sat(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> AverageRound(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_u8x16_avgr(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> AverageRound(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template <size_t N> +HWY_API Vec128<int8_t, N> Abs(const Vec128<int8_t, N> v) { + return Vec128<int8_t, N>{wasm_i8x16_abs(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> Abs(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{wasm_i16x8_abs(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Abs(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_abs(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> Abs(const Vec128<int64_t, N> v) { + return Vec128<int64_t, N>{wasm_i64x2_abs(v.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> Abs(const Vec128<float, N> v) { + return Vec128<float, N>{wasm_f32x4_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template <int kBits, size_t N> +HWY_API Vec128<uint16_t, N> ShiftLeft(const Vec128<uint16_t, N> v) { + return Vec128<uint16_t, N>{wasm_i16x8_shl(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint16_t, N> ShiftRight(const Vec128<uint16_t, N> v) { + return Vec128<uint16_t, N>{wasm_u16x8_shr(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint32_t, N> ShiftLeft(const Vec128<uint32_t, N> v) { + return Vec128<uint32_t, N>{wasm_i32x4_shl(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint64_t, N> ShiftLeft(const Vec128<uint64_t, N> v) { + return Vec128<uint64_t, N>{wasm_i64x2_shl(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint32_t, N> ShiftRight(const Vec128<uint32_t, N> v) { + return Vec128<uint32_t, N>{wasm_u32x4_shr(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint64_t, N> ShiftRight(const Vec128<uint64_t, N> v) { + return Vec128<uint64_t, N>{wasm_u64x2_shr(v.raw, kBits)}; +} + +// Signed +template <int kBits, size_t N> +HWY_API Vec128<int16_t, N> ShiftLeft(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{wasm_i16x8_shl(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int16_t, N> ShiftRight(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{wasm_i16x8_shr(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int32_t, N> ShiftLeft(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_shl(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int64_t, N> ShiftLeft(const Vec128<int64_t, N> v) { + return Vec128<int64_t, N>{wasm_i64x2_shl(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int32_t, N> ShiftRight(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_shr(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int64_t, N> ShiftRight(const Vec128<int64_t, N> v) { + return Vec128<int64_t, N>{wasm_i64x2_shr(v.raw, kBits)}; +} + +// 8-bit +template <int kBits, typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> ShiftLeft(const Vec128<T, N> v) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<T, N> shifted{ShiftLeft<kBits>(Vec128<MakeWide<T>>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); +} + +template <int kBits, size_t N> +HWY_API Vec128<uint8_t, N> ShiftRight(const Vec128<uint8_t, N> v) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<uint8_t, N> shifted{ + ShiftRight<kBits>(Vec128<uint16_t>{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template <int kBits, size_t N> +HWY_API Vec128<int8_t, N> ShiftRight(const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template <int kBits, typename T, size_t N> +HWY_API Vec128<T, N> RotateRight(const Vec128<T, N> v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +// After https://reviews.llvm.org/D108415 shift argument became unsigned. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unsigned +template <size_t N> +HWY_API Vec128<uint16_t, N> ShiftLeftSame(const Vec128<uint16_t, N> v, + const int bits) { + return Vec128<uint16_t, N>{wasm_i16x8_shl(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> ShiftRightSame(const Vec128<uint16_t, N> v, + const int bits) { + return Vec128<uint16_t, N>{wasm_u16x8_shr(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> ShiftLeftSame(const Vec128<uint32_t, N> v, + const int bits) { + return Vec128<uint32_t, N>{wasm_i32x4_shl(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> ShiftRightSame(const Vec128<uint32_t, N> v, + const int bits) { + return Vec128<uint32_t, N>{wasm_u32x4_shr(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> ShiftLeftSame(const Vec128<uint64_t, N> v, + const int bits) { + return Vec128<uint64_t, N>{wasm_i64x2_shl(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> ShiftRightSame(const Vec128<uint64_t, N> v, + const int bits) { + return Vec128<uint64_t, N>{wasm_u64x2_shr(v.raw, bits)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int16_t, N> ShiftLeftSame(const Vec128<int16_t, N> v, + const int bits) { + return Vec128<int16_t, N>{wasm_i16x8_shl(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> ShiftRightSame(const Vec128<int16_t, N> v, + const int bits) { + return Vec128<int16_t, N>{wasm_i16x8_shr(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> ShiftLeftSame(const Vec128<int32_t, N> v, + const int bits) { + return Vec128<int32_t, N>{wasm_i32x4_shl(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> ShiftRightSame(const Vec128<int32_t, N> v, + const int bits) { + return Vec128<int32_t, N>{wasm_i32x4_shr(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> ShiftLeftSame(const Vec128<int64_t, N> v, + const int bits) { + return Vec128<int64_t, N>{wasm_i64x2_shl(v.raw, bits)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> ShiftRightSame(const Vec128<int64_t, N> v, + const int bits) { + return Vec128<int64_t, N>{wasm_i64x2_shr(v.raw, bits)}; +} + +// 8-bit +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> ShiftLeftSame(const Vec128<T, N> v, const int bits) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<T, N> shifted{ + ShiftLeftSame(Vec128<MakeWide<T>>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> ShiftRightSame(Vec128<uint8_t, N> v, + const int bits) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<uint8_t, N> shifted{ + ShiftRightSame(Vec128<uint16_t>{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template <size_t N> +HWY_API Vec128<int8_t, N> ShiftRightSame(Vec128<int8_t, N> v, const int bits) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ignore Wsign-conversion +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> Min(Vec128<uint8_t, N> a, Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_u8x16_min(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> Min(Vec128<uint16_t, N> a, Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_u16x8_min(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> Min(Vec128<uint32_t, N> a, Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_u32x4_min(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> Min(Vec128<uint64_t, N> a, Vec128<uint64_t, N> b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast<uint64_t>(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast<uint64_t>(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast<uint64_t>(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast<uint64_t>(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t min[2] = {HWY_MIN(a0, b0), HWY_MIN(a1, b1)}; + return Vec128<uint64_t, N>{wasm_v128_load(min)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> Min(Vec128<int8_t, N> a, Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_min(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> Min(Vec128<int16_t, N> a, Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_min(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Min(Vec128<int32_t, N> a, Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_min(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> Min(Vec128<int64_t, N> a, Vec128<int64_t, N> b) { + alignas(16) int64_t min[4]; + min[0] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128<int64_t, N>{wasm_v128_load(min)}; +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> Min(Vec128<float, N> a, Vec128<float, N> b) { + // Equivalent to a < b ? a : b (taking into account our swapped arg order, + // so that Min(NaN, x) is x to match x86). + return Vec128<float, N>{wasm_f32x4_pmin(b.raw, a.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> Max(Vec128<uint8_t, N> a, Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_u8x16_max(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> Max(Vec128<uint16_t, N> a, Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_u16x8_max(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> Max(Vec128<uint32_t, N> a, Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_u32x4_max(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> Max(Vec128<uint64_t, N> a, Vec128<uint64_t, N> b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast<uint64_t>(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast<uint64_t>(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast<uint64_t>(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast<uint64_t>(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t max[2] = {HWY_MAX(a0, b0), HWY_MAX(a1, b1)}; + return Vec128<uint64_t, N>{wasm_v128_load(max)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> Max(Vec128<int8_t, N> a, Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_max(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> Max(Vec128<int16_t, N> a, Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_max(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Max(Vec128<int32_t, N> a, Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_max(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> Max(Vec128<int64_t, N> a, Vec128<int64_t, N> b) { + alignas(16) int64_t max[2]; + max[0] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128<int64_t, N>{wasm_v128_load(max)}; +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> Max(Vec128<float, N> a, Vec128<float, N> b) { + // Equivalent to b < a ? a : b (taking into account our swapped arg order, + // so that Max(NaN, x) is x to match x86). + return Vec128<float, N>{wasm_f32x4_pmax(b.raw, a.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +template <size_t N> +HWY_API Vec128<uint16_t, N> operator*(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{wasm_i16x8_mul(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> operator*(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int16_t, N> operator*(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_mul(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> operator*(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template <size_t N> +HWY_API Vec128<uint16_t, N> MulHigh(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + const auto l = wasm_u32x4_extmul_low_u16x8(a.raw, b.raw); + const auto h = wasm_u32x4_extmul_high_u16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128<uint16_t, N>{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> MulHigh(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + const auto l = wasm_i32x4_extmul_low_i16x8(a.raw, b.raw); + const auto h = wasm_i32x4_extmul_high_i16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128<int16_t, N>{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +template <size_t N> +HWY_API Vec128<int16_t, N> MulFixedPoint15(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{wasm_i16x8_q15mulr_sat(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +template <size_t N> +HWY_API Vec128<int64_t, (N + 1) / 2> MulEven(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128<int64_t, (N + 1) / 2>{wasm_i64x2_mul(ae, be)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, (N + 1) / 2> MulEven(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128<uint64_t, (N + 1) / 2>{wasm_i64x2_mul(ae, be)}; +} + +// ------------------------------ Negate + +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> Neg(const Vec128<T, N> v) { + return Xor(v, SignBit(DFromV<decltype(v)>())); +} + +template <size_t N> +HWY_API Vec128<int8_t, N> Neg(const Vec128<int8_t, N> v) { + return Vec128<int8_t, N>{wasm_i8x16_neg(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> Neg(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{wasm_i16x8_neg(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Neg(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_neg(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> Neg(const Vec128<int64_t, N> v) { + return Vec128<int64_t, N>{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +template <size_t N> +HWY_API Vec128<float, N> operator*(Vec128<float, N> a, Vec128<float, N> b) { + return Vec128<float, N>{wasm_f32x4_mul(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> operator/(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{wasm_f32x4_div(a.raw, b.raw)}; +} + +// Approximate reciprocal +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocal(const Vec128<float, N> v) { + const Vec128<float, N> one = Vec128<float, N>{wasm_f32x4_splat(1.0f)}; + return one / v; +} + +// Absolute value of difference. +template <size_t N> +HWY_API Vec128<float, N> AbsDiff(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template <size_t N> +HWY_API Vec128<float, N> MulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { + return mul * x + add; +} + +// Returns add - mul * x +template <size_t N> +HWY_API Vec128<float, N> NegMulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { + return add - mul * x; +} + +// Returns mul * x - sub +template <size_t N> +HWY_API Vec128<float, N> MulSub(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> sub) { + return mul * x - sub; +} + +// Returns -mul * x - sub +template <size_t N> +HWY_API Vec128<float, N> NegMulSub(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template <size_t N> +HWY_API Vec128<float, N> Sqrt(const Vec128<float, N> v) { + return Vec128<float, N>{wasm_f32x4_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocalSqrt(const Vec128<float, N> v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec128<float, N> one = Vec128<float, N>{wasm_f32x4_splat(1.0f)}; + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template <size_t N> +HWY_API Vec128<float, N> Round(const Vec128<float, N> v) { + return Vec128<float, N>{wasm_f32x4_nearest(v.raw)}; +} + +// Toward zero, aka truncate +template <size_t N> +HWY_API Vec128<float, N> Trunc(const Vec128<float, N> v) { + return Vec128<float, N>{wasm_f32x4_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template <size_t N> +HWY_API Vec128<float, N> Ceil(const Vec128<float, N> v) { + return Vec128<float, N>{wasm_f32x4_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template <size_t N> +HWY_API Vec128<float, N> Floor(const Vec128<float, N> v) { + return Vec128<float, N>{wasm_f32x4_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification +template <typename T, size_t N> +HWY_API Mask128<T, N> IsNaN(const Vec128<T, N> v) { + return v != v; +} + +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Mask128<T, N> IsInf(const Vec128<T, N> v) { + const Simd<T, N, 0> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2<T>()))); +} + +// Returns whether normal/subnormal/zero. +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Mask128<T, N> IsFinite(const Vec128<T, N> v) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + const VFromD<decltype(du)> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField<T>()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template <typename TFrom, typename TTo, size_t N> +HWY_API Mask128<TTo, N> RebindMask(Simd<TTo, N, 0> /*tag*/, + Mask128<TFrom, N> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128<TTo, N>{m.raw}; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> TestBit(Vec128<T, N> v, Vec128<T, N> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template <size_t N> +HWY_API Mask128<uint8_t, N> operator==(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Mask128<uint8_t, N>{wasm_i8x16_eq(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint16_t, N> operator==(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Mask128<uint16_t, N>{wasm_i16x8_eq(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint32_t, N> operator==(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Mask128<uint32_t, N>{wasm_i32x4_eq(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint64_t, N> operator==(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Mask128<uint64_t, N>{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Mask128<int8_t, N> operator==(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Mask128<int8_t, N>{wasm_i8x16_eq(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int16_t, N> operator==(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Mask128<int16_t, N>{wasm_i16x8_eq(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int32_t, N> operator==(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Mask128<int32_t, N>{wasm_i32x4_eq(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int64_t, N> operator==(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Mask128<int64_t, N>{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Float +template <size_t N> +HWY_API Mask128<float, N> operator==(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{wasm_f32x4_eq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// Unsigned +template <size_t N> +HWY_API Mask128<uint8_t, N> operator!=(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Mask128<uint8_t, N>{wasm_i8x16_ne(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint16_t, N> operator!=(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Mask128<uint16_t, N>{wasm_i16x8_ne(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint32_t, N> operator!=(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Mask128<uint32_t, N>{wasm_i32x4_ne(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint64_t, N> operator!=(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Mask128<uint64_t, N>{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Mask128<int8_t, N> operator!=(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Mask128<int8_t, N>{wasm_i8x16_ne(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int16_t, N> operator!=(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Mask128<int16_t, N>{wasm_i16x8_ne(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int32_t, N> operator!=(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Mask128<int32_t, N>{wasm_i32x4_ne(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int64_t, N> operator!=(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Mask128<int64_t, N>{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Float +template <size_t N> +HWY_API Mask128<float, N> operator!=(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{wasm_f32x4_ne(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +template <size_t N> +HWY_API Mask128<int8_t, N> operator>(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Mask128<int8_t, N>{wasm_i8x16_gt(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int16_t, N> operator>(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Mask128<int16_t, N>{wasm_i16x8_gt(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int32_t, N> operator>(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Mask128<int32_t, N>{wasm_i32x4_gt(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int64_t, N> operator>(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Mask128<int64_t, N>{wasm_i64x2_gt(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Mask128<uint8_t, N> operator>(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Mask128<uint8_t, N>{wasm_u8x16_gt(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint16_t, N> operator>(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Mask128<uint16_t, N>{wasm_u16x8_gt(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint32_t, N> operator>(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Mask128<uint32_t, N>{wasm_u32x4_gt(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint64_t, N> operator>(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + const DFromV<decltype(a)> d; + const Repartition<uint32_t, decltype(d)> d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper halves are not equal, this is the answer. + const auto m_gt = a32 > b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt.raw, m_gt.raw, 0, 0, 2, 2); + const auto lo_gt = And(m_eq, MaskFromVec(VFromD<decltype(d32)>{lo_in_hi})); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128<uint64_t, N>{wasm_i32x4_shuffle(gt.raw, gt.raw, 1, 1, 3, 3)}; +} + +template <size_t N> +HWY_API Mask128<float, N> operator>(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{wasm_f32x4_gt(a.raw, b.raw)}; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator<(const Vec128<T, N> a, const Vec128<T, N> b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float <= >= +template <size_t N> +HWY_API Mask128<float, N> operator<=(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{wasm_f32x4_le(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<float, N> operator>=(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{wasm_f32x4_ge(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Mask128<T, N> FirstN(const Simd<T, N, 0> d, size_t num) { + const RebindToSigned<decltype(d)> di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast<MakeSigned<T>>(num))); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template <typename T, size_t N> +HWY_API Vec128<T, N> Not(Vec128<T, N> v) { + return Vec128<T, N>{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template <typename T, size_t N> +HWY_API Vec128<T, N> And(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template <typename T, size_t N> +HWY_API Vec128<T, N> AndNot(Vec128<T, N> not_mask, Vec128<T, N> mask) { + return Vec128<T, N>{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template <typename T, size_t N> +HWY_API Vec128<T, N> Or(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Xor3 + +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) { + return Xor(x1, Xor(x2, x3)); +} + +// ------------------------------ Or3 + +template <typename T, size_t N> +HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template <typename T, size_t N> +HWY_API Vec128<T, N> OrAnd(Vec128<T, N> o, Vec128<T, N> a1, Vec128<T, N> a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator&(const Vec128<T, N> a, const Vec128<T, N> b) { + return And(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator|(const Vec128<T, N> a, const Vec128<T, N> b) { + return Or(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator^(const Vec128<T, N> a, const Vec128<T, N> b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySign(const Vec128<T, N> magn, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + const auto msb = SignBit(DFromV<decltype(magn)>()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySignToAbs(const Vec128<T, N> abs, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(DFromV<decltype(abs)>()), sign)); +} + +// ------------------------------ BroadcastSignBit (compare) + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> BroadcastSignBit(const Vec128<T, N> v) { + return ShiftRight<sizeof(T) * 8 - 1>(v); +} +template <size_t N> +HWY_API Vec128<int8_t, N> BroadcastSignBit(const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> d; + return VecFromMask(d, v < Zero(d)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template <typename T, size_t N> +HWY_API Mask128<T, N> MaskFromVec(const Vec128<T, N> v) { + return Mask128<T, N>{v.raw}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> VecFromMask(Simd<T, N, 0> /* tag */, Mask128<T, N> v) { + return Vec128<T, N>{v.raw}; +} + +// mask ? yes : no +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElse(Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Vec128<T, N>{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElseZero(Mask128<T, N> mask, Vec128<T, N> yes) { + return yes & VecFromMask(DFromV<decltype(yes)>(), mask); +} + +// mask ? 0 : no +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenZeroElse(Mask128<T, N> mask, Vec128<T, N> no) { + return AndNot(VecFromMask(DFromV<decltype(no)>(), mask), no); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes, + Vec128<T, N> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + const DFromV<decltype(v)> d; + const RebindToSigned<decltype(d)> di; + + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template <typename T, size_t N, HWY_IF_FLOAT(T)> +HWY_API Vec128<T, N> ZeroIfNegative(Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const auto zero = Zero(d); + return IfThenElse(Mask128<T, N>{(v > zero).raw}, v, zero); +} + +// ------------------------------ Mask logical + +template <typename T, size_t N> +HWY_API Mask128<T, N> Not(const Mask128<T, N> m) { + return MaskFromVec(Not(VecFromMask(Simd<T, N, 0>(), m))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> And(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> AndNot(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Or(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Xor(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, const Vec128<T, N> bits) { + const DFromV<decltype(v)> d; + Mask128<T, N> mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned<decltype(d)>(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, const Vec128<T, N> bits) { + const DFromV<decltype(v)> d; + Mask128<T, N> mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned<decltype(d)>(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, const Vec128<T, N> bits) { + const DFromV<decltype(v)> d; + alignas(16) T lanes[2]; + alignas(16) T bits_lanes[2]; + Store(v, d, lanes); + Store(bits, d, bits_lanes); + lanes[0] <<= bits_lanes[0]; + lanes[1] <<= bits_lanes[1]; + return Load(d, lanes); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> operator>>(Vec128<T, N> v, const Vec128<T, N> bits) { + const DFromV<decltype(v)> d; + Mask128<T, N> mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned<decltype(d)>(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> operator>>(Vec128<T, N> v, const Vec128<T, N> bits) { + const DFromV<decltype(v)> d; + Mask128<T, N> mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned<decltype(d)>(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template <typename T> +HWY_API Vec128<T> Load(Full128<T> /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128<T>{wasm_v128_load(aligned)}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// Partial load. +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> Load(Simd<T, N, 0> /* tag */, const T* HWY_RESTRICT p) { + Vec128<T, N> v; + CopyBytes<sizeof(T) * N>(p, &v); + return v; +} + +// LoadU == Load. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> LoadU(Simd<T, N, 0> d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> LoadDup128(Simd<T, N, 0> d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store + +template <typename T> +HWY_API void Store(Vec128<T> v, Full128<T> /* tag */, T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// Partial store. +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API void Store(Vec128<T, N> v, Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + CopyBytes<sizeof(T) * N>(&v, p); +} + +HWY_API void Store(const Vec128<float, 1> v, Simd<float, 1, 0> /* tag */, + float* HWY_RESTRICT p) { + *p = wasm_f32x4_extract_lane(v.raw, 0); +} + +// StoreU == Store. +template <typename T, size_t N> +HWY_API void StoreU(Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template <typename T, size_t N> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, Simd<T, N, 0> d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template <typename T, size_t N> +HWY_API void Stream(Vec128<T, N> v, Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter (Store) + +template <typename T, size_t N, typename Offset, HWY_IF_LE128(T, N)> +HWY_API void ScatterOffset(Vec128<T, N> v, Simd<T, N, 0> d, + T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind<Offset, decltype(d)>(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast<uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template <typename T, size_t N, typename Index, HWY_IF_LE128(T, N)> +HWY_API void ScatterIndex(Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind<Index, decltype(d)>(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template <typename T, size_t N, typename Offset> +HWY_API Vec128<T, N> GatherOffset(const Simd<T, N, 0> d, + const T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind<Offset, decltype(d)>(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast<const uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template <typename T, size_t N, typename Index> +HWY_API Vec128<T, N> GatherIndex(const Simd<T, N, 0> d, + const T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind<Index, decltype(d)>(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane + +namespace detail { + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + return static_cast<T>(wasm_i8x16_extract_lane(v.raw, kLane)); +} +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + return static_cast<T>(wasm_i16x8_extract_lane(v.raw, kLane)); +} +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + return static_cast<T>(wasm_i32x4_extract_lane(v.raw, kLane)); +} +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + return static_cast<T>(wasm_i64x2_extract_lane(v.raw, kLane)); +} + +template <size_t kLane, size_t N> +HWY_INLINE float ExtractLane(const Vec128<float, N> v) { + return wasm_f32x4_extract_lane(v.raw, kLane); +} + +} // namespace detail + +// One overload per vector length just in case *_extract_lane raise compile +// errors if their argument is out of bounds (even if that would never be +// reached at runtime). +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 1> v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 2> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 4> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 8> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 16> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +// ------------------------------ GetLane +template <typename T, size_t N> +HWY_API T GetLane(const Vec128<T, N> v) { + return detail::ExtractLane<0>(v); +} + +// ------------------------------ InsertLane + +namespace detail { + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128<T, N>{ + wasm_i8x16_replace_lane(v.raw, kLane, static_cast<int8_t>(t))}; +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128<T, N>{ + wasm_i16x8_replace_lane(v.raw, kLane, static_cast<int16_t>(t))}; +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128<T, N>{ + wasm_i32x4_replace_lane(v.raw, kLane, static_cast<int32_t>(t))}; +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128<T, N>{ + wasm_i64x2_replace_lane(v.raw, kLane, static_cast<int64_t>(t))}; +} + +template <size_t kLane, size_t N> +HWY_INLINE Vec128<float, N> InsertLane(const Vec128<float, N> v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128<float, N>{wasm_f32x4_replace_lane(v.raw, kLane, t)}; +} + +template <size_t kLane, size_t N> +HWY_INLINE Vec128<double, N> InsertLane(const Vec128<double, N> v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + return Vec128<double, N>{wasm_f64x2_replace_lane(v.raw, kLane, t)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls wasm_f64x2_replace_lane. + +template <typename T> +HWY_API Vec128<T, 1> InsertLane(const Vec128<T, 1> v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV<decltype(v)>(), t); +} + +template <typename T> +HWY_API Vec128<T, 2> InsertLane(const Vec128<T, 2> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 4> InsertLane(const Vec128<T, 4> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 8> InsertLane(const Vec128<T, 8> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 16> InsertLane(const Vec128<T, 16> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Simd<T, N / 2, 0> /* tag */, + Vec128<T, N> v) { + return Vec128<T, N / 2>{v.raw}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Vec128<T, N> v) { + return LowerHalf(Simd<T, N / 2, 0>(), v); +} + +// ------------------------------ ShiftLeftBytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return Vec128<T, N>{wasm_i8x16_shuffle( + v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return Vec128<T, N>{wasm_i8x16_shuffle( + v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, + 6)}; + + case 10: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, + 5)}; + + case 11: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, + 4)}; + + case 12: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, + 2, 3)}; + + case 13: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1, 2)}; + + case 14: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0, 1)}; + + case 15: + return Vec128<T, N>{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 0)}; + } + return Vec128<T, N>{zero}; +} + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(Vec128<T, N> v) { + return ShiftLeftBytes<kBytes>(Simd<T, N, 0>(), v); +} + +// ------------------------------ ShiftLeftLanes + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(const Vec128<T, N> v) { + return ShiftLeftLanes<kLanes>(DFromV<decltype(v)>(), v); +} + +// ------------------------------ ShiftRightBytes +namespace detail { + +// Helper function allows zeroing invalid lanes in caller. +template <int kBytes, typename T, size_t N> +HWY_API __i8x16 ShrBytes(const Vec128<T, N> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + + switch (kBytes) { + case 0: + return v.raw; + + case 1: + return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16); + + case 2: + return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16); + + case 3: + return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16); + + case 4: + return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16); + + case 5: + return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16); + + case 6: + return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16); + + case 7: + return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16); + + case 8: + return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 9: + return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 10: + return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 11: + return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 12: + return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 13: + return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 14: + return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 15: + return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + case 16: + return zero; + } +} + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128<T> vfull{v.raw}; + v = Vec128<T, N>{IfThenElseZero(FirstN(Full128<T>(), N), vfull).raw}; + } + return Vec128<T, N>{detail::ShrBytes<kBytes>(v)}; +} + +// ------------------------------ ShiftRightLanes +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template <typename T> +HWY_API Vec64<T> UpperHalf(Full64<T> /* tag */, const Vec128<T> v) { + return Vec64<T>{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} +HWY_API Vec64<float> UpperHalf(Full64<float> /* tag */, const Vec128<float> v) { + return Vec64<float>{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, (N + 1) / 2> UpperHalf(Half<Simd<T, N, 0>> /* tag */, + Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes<N * sizeof(T) / 2>(du, vu)); + return Vec128<T, (N + 1) / 2>{upper.raw}; +} + +// ------------------------------ CombineShiftRightBytes + +template <int kBytes, typename T, class V = Vec128<T>> +HWY_API V CombineShiftRightBytes(Full128<T> /* tag */, V hi, V lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16)}; + + case 2: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21)}; + + case 7: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22)}; + + case 8: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23)}; + + case 9: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24)}; + + case 10: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25)}; + + case 11: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26)}; + + case 12: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27)}; + + case 13: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28)}; + + case 14: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29)}; + + case 15: + return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30)}; + } + return hi; +} + +template <int kBytes, typename T, size_t N, HWY_IF_LE64(T, N), + class V = Vec128<T, N>> +HWY_API V CombineShiftRightBytes(Simd<T, N, 0> d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition<uint8_t, decltype(d)> d8; + const Full128<uint8_t> d_full8; + using V8 = VFromD<decltype(d_full8)>; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128<T>(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template <int kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Broadcast(const Vec128<T, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<T, N>{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; +} + +template <int kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Broadcast(const Vec128<T, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<T, N>{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +template <int kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Broadcast(const Vec128<T, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<T, N>{wasm_i64x2_shuffle(v.raw, v.raw, kLane, kLane)}; +} + +// ------------------------------ TableLookupBytes + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template <typename T, size_t N, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytes(const Vec128<T, N> bytes, + const Vec128<TI, NI> from) { +// Not yet available in all engines, see +// https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md +// V8 implementation of this had a bug, fixed on 2021-04-03: +// https://chromium-review.googlesource.com/c/v8/v8/+/2822951 +#if 0 + return Vec128<TI, NI>{wasm_i8x16_swizzle(bytes.raw, from.raw)}; +#else + alignas(16) uint8_t control[16]; + alignas(16) uint8_t input[16]; + alignas(16) uint8_t output[16]; + wasm_v128_store(control, from.raw); + wasm_v128_store(input, bytes.raw); + for (size_t i = 0; i < 16; ++i) { + output[i] = control[i] < 16 ? input[control[i]] : 0; + } + return Vec128<TI, NI>{wasm_v128_load(output)}; +#endif +} + +template <typename T, size_t N, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytesOr0(const Vec128<T, N> bytes, + const Vec128<TI, NI> from) { + const Simd<TI, NI, 0> d; + // Mask size must match vector type, so cast everything to this type. + Repartition<int8_t, decltype(d)> di8; + Repartition<int8_t, Simd<T, N, 0>> d_bytes8; + const auto msb = BitCast(di8, from) < Zero(di8); + const auto lookup = + TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128<int32_t> have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template <typename T, size_t N> +HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. +namespace detail { + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i8x16_shuffle(a.raw, b.raw, 1, 0, 3 + 16, 2 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i16x8_shuffle(a.raw, b.raw, 1, 0, 3 + 8, 2 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i32x4_shuffle(a.raw, b.raw, 1, 0, 3 + 4, 2 + 4)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> Shuffle1230(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i8x16_shuffle(a.raw, b.raw, 0, 3, 2 + 16, 1 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Shuffle1230(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i16x8_shuffle(a.raw, b.raw, 0, 3, 2 + 8, 1 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Shuffle1230(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i32x4_shuffle(a.raw, b.raw, 0, 3, 2 + 4, 1 + 4)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> Shuffle3012(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i8x16_shuffle(a.raw, b.raw, 2, 1, 0 + 16, 3 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Shuffle3012(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i16x8_shuffle(a.raw, b.raw, 2, 1, 0 + 8, 3 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Shuffle3012(const Vec128<T, N> a, const Vec128<T, N> b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{wasm_i32x4_shuffle(a.raw, b.raw, 2, 1, 0 + 4, 3 + 4)}; +} + +} // namespace detail + +// Swap 64-bit halves +template <typename T> +HWY_API Vec128<T> Shuffle01(const Vec128<T> v) { + static_assert(sizeof(T) == 8, "Only for 64-bit lanes"); + return Vec128<T>{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +template <typename T> +HWY_API Vec128<T> Shuffle1032(const Vec128<T> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128<T>{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +template <typename T> +HWY_API Vec128<T> Shuffle0321(const Vec128<T> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128<T>{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} + +// Rotate left 32 bits +template <typename T> +HWY_API Vec128<T> Shuffle2103(const Vec128<T> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128<T>{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +template <typename T> +HWY_API Vec128<T> Shuffle0123(const Vec128<T> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128<T>{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template <typename T, size_t N = 16 / sizeof(T)> +struct Indices128 { + __v128_u raw; +}; + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N)> +HWY_API Indices128<T, N> IndicesFromVec(Simd<T, N, 0> d, Vec128<TI, N> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind<TI, decltype(d)> di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast<TI>(N))))); +#endif + + const Repartition<uint8_t, decltype(d)> d8; + using V8 = VFromD<decltype(d8)>; + const Repartition<uint16_t, decltype(d)> d16; + + // Broadcast each lane index to all bytes of T and shift to bytes + static_assert(sizeof(T) == 4 || sizeof(T) == 8, ""); + if (sizeof(T) == 4) { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + return Indices128<T, N>{Add(byte_indices, Load(d8, kByteOffsets)).raw}; + } else { + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + const V8 lane_indices = + TableLookupBytes(BitCast(d8, vec), Load(d8, kBroadcastLaneBytes)); + const V8 byte_indices = + BitCast(d8, ShiftLeft<3>(BitCast(d16, lane_indices))); + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + return Indices128<T, N>{Add(byte_indices, Load(d8, kByteOffsets)).raw}; + } +} + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N)> +HWY_API Indices128<T, N> SetTableIndices(Simd<T, N, 0> d, const TI* idx) { + const Rebind<TI, decltype(d)> di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> TableLookupLanes(Vec128<T, N> v, Indices128<T, N> idx) { + using TI = MakeSigned<T>; + const DFromV<decltype(v)> d; + const Rebind<TI, decltype(d)> di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec128<TI, N>{idx.raw})); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template <typename T> +HWY_API Vec128<T, 1> Reverse(Simd<T, 1, 0> /* tag */, const Vec128<T, 1> v) { + return v; +} + +// Two lanes: shuffle +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, 2> Reverse(Simd<T, 2, 0> /* tag */, const Vec128<T, 2> v) { + return Vec128<T, 2>{Shuffle2301(Vec128<T>{v.raw}).raw}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> Reverse(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> Reverse(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle0123(v); +} + +// 16-bit +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RepartitionToWide<RebindToUnsigned<decltype(d)>> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------ Reverse2 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RepartitionToWide<RebindToUnsigned<decltype(d)>> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle2301(v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> d, const Vec128<T, N> v) { + return BitCast(d, Vec128<uint16_t, N>{wasm_i16x8_shuffle(v.raw, v.raw, 3, 2, + 1, 0, 7, 6, 5, 4)}); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle0123(v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> /* tag */, const Vec128<T, N>) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0> d, const Vec128<T, N> v) { + return Reverse(d, v); +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0>, const Vec128<T, N>) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +template <size_t N> +HWY_API Vec128<uint8_t, N> InterleaveLower(Vec128<uint8_t, N> a, + Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> InterleaveLower(Vec128<uint16_t, N> a, + Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> InterleaveLower(Vec128<uint32_t, N> a, + Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> InterleaveLower(Vec128<uint64_t, N> a, + Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template <size_t N> +HWY_API Vec128<int8_t, N> InterleaveLower(Vec128<int8_t, N> a, + Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> InterleaveLower(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> InterleaveLower(Vec128<int32_t, N> a, + Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> InterleaveLower(Vec128<int64_t, N> a, + Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template <size_t N> +HWY_API Vec128<float, N> InterleaveLower(Vec128<float, N> a, + Vec128<float, N> b) { + return Vec128<float, N>{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +template <size_t N> +HWY_API Vec128<double, N> InterleaveLower(Vec128<double, N> a, + Vec128<double, N> b) { + return Vec128<double, N>{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +// Additional overload for the optional tag. +template <class V> +HWY_API V InterleaveLower(DFromV<V> /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +template <size_t N> +HWY_API Vec128<uint8_t, N> InterleaveUpper(Vec128<uint8_t, N> a, + Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> InterleaveUpper(Vec128<uint16_t, N> a, + Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> InterleaveUpper(Vec128<uint32_t, N> a, + Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> InterleaveUpper(Vec128<uint64_t, N> a, + Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template <size_t N> +HWY_API Vec128<int8_t, N> InterleaveUpper(Vec128<int8_t, N> a, + Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> InterleaveUpper(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> InterleaveUpper(Vec128<int32_t, N> a, + Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> InterleaveUpper(Vec128<int64_t, N> a, + Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template <size_t N> +HWY_API Vec128<float, N> InterleaveUpper(Vec128<float, N> a, + Vec128<float, N> b) { + return Vec128<float, N>{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +template <size_t N> +HWY_API Vec128<double, N> InterleaveUpper(Vec128<double, N> a, + Vec128<double, N> b) { + return Vec128<double, N>{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +} // namespace detail + +// Full +template <typename T, class V = Vec128<T>> +HWY_API V InterleaveUpper(Full128<T> /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N), class V = Vec128<T, N>> +HWY_API V InterleaveUpper(Simd<T, N, 0> d, V a, V b) { + const Half<decltype(d)> d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Combine(Simd<T, N, 0> d, Vec128<T, N / 2> hi_half, + Vec128<T, N / 2> lo_half) { + const Half<decltype(d)> d2; + const RebindToUnsigned<decltype(d2)> du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128<UnsignedFromSize<N * sizeof(T) / 2>, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> ZeroExtendVector(Simd<T, N, 0> d, Vec128<T, N / 2> lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128<T, N>{lo.raw}); +} + +// ------------------------------ ConcatLowerLower + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template <typename T> +HWY_API Vec128<T> ConcatLowerLower(Full128<T> /* tag */, const Vec128<T> hi, + const Vec128<T> lo) { + return Vec128<T>{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatLowerLower(Simd<T, N, 0> d, const Vec128<T, N> hi, + const Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatUpperUpper + +template <typename T> +HWY_API Vec128<T> ConcatUpperUpper(Full128<T> /* tag */, const Vec128<T> hi, + const Vec128<T> lo) { + return Vec128<T>{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatUpperUpper(Simd<T, N, 0> d, const Vec128<T, N> hi, + const Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +// ------------------------------ ConcatLowerUpper + +template <typename T> +HWY_API Vec128<T> ConcatLowerUpper(Full128<T> d, const Vec128<T> hi, + const Vec128<T> lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatLowerUpper(Simd<T, N, 0> d, const Vec128<T, N> hi, + const Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +// ------------------------------ ConcatUpperLower +template <typename T, size_t N> +HWY_API Vec128<T, N> ConcatUpperLower(Simd<T, N, 0> d, const Vec128<T, N> hi, + const Vec128<T, N> lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T> ConcatOdd(Full128<T> /* tag */, Vec128<T> hi, Vec128<T> lo) { + return Vec128<T>{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} + +// 8-bit x8 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 8> ConcatOdd(Simd<T, 8, 0> /* tag */, Vec128<T, 8> hi, + Vec128<T, 8> lo) { + // Don't care about upper half. + return Vec128<T, 8>{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 17, 19, 21, + 23, 1, 3, 5, 7, 17, 19, 21, 23)}; +} + +// 8-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> ConcatOdd(Simd<T, 4, 0> /* tag */, Vec128<T, 4> hi, + Vec128<T, 4> lo) { + // Don't care about upper 3/4. + return Vec128<T, 4>{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 17, 19, 1, 3, 17, + 19, 1, 3, 17, 19, 1, 3, 17, 19)}; +} + +// 16-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T> ConcatOdd(Full128<T> /* tag */, Vec128<T> hi, Vec128<T> lo) { + return Vec128<T>{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +// 16-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, 4> ConcatOdd(Simd<T, 4, 0> /* tag */, Vec128<T, 4> hi, + Vec128<T, 4> lo) { + // Don't care about upper half. + return Vec128<T, 4>{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 9, 11, 1, 3, 9, 11)}; +} + +// 32-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> ConcatOdd(Full128<T> /* tag */, Vec128<T> hi, Vec128<T> lo) { + return Vec128<T>{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; +} + +// Any T x2 +template <typename T> +HWY_API Vec128<T, 2> ConcatOdd(Simd<T, 2, 0> d, Vec128<T, 2> hi, + Vec128<T, 2> lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T> ConcatEven(Full128<T> /* tag */, Vec128<T> hi, Vec128<T> lo) { + return Vec128<T>{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30)}; +} + +// 8-bit x8 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 8> ConcatEven(Simd<T, 8, 0> /* tag */, Vec128<T, 8> hi, + Vec128<T, 8> lo) { + // Don't care about upper half. + return Vec128<T, 8>{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 16, 18, 20, + 22, 0, 2, 4, 6, 16, 18, 20, 22)}; +} + +// 8-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> ConcatEven(Simd<T, 4, 0> /* tag */, Vec128<T, 4> hi, + Vec128<T, 4> lo) { + // Don't care about upper 3/4. + return Vec128<T, 4>{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 16, 18, 0, 2, 16, + 18, 0, 2, 16, 18, 0, 2, 16, 18)}; +} + +// 16-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T> ConcatEven(Full128<T> /* tag */, Vec128<T> hi, Vec128<T> lo) { + return Vec128<T>{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14)}; +} + +// 16-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, 4> ConcatEven(Simd<T, 4, 0> /* tag */, Vec128<T, 4> hi, + Vec128<T, 4> lo) { + // Don't care about upper half. + return Vec128<T, 4>{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 8, 10, 0, 2, 8, 10)}; +} + +// 32-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> ConcatEven(Full128<T> /* tag */, Vec128<T> hi, Vec128<T> lo) { + return Vec128<T>{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; +} + +// Any T x2 +template <typename T> +HWY_API Vec128<T, 2> ConcatEven(Simd<T, 2, 0> d, Vec128<T, 2> hi, + Vec128<T, 2> lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> DupEven(Vec128<T, N> v) { + return Vec128<T, N>{wasm_i32x4_shuffle(v.raw, v.raw, 0, 0, 2, 2)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> DupEven(const Vec128<T, N> v) { + return InterleaveLower(DFromV<decltype(v)>(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> DupOdd(Vec128<T, N> v) { + return Vec128<T, N>{wasm_i32x4_shuffle(v.raw, v.raw, 1, 1, 3, 3)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> DupOdd(const Vec128<T, N> v) { + return InterleaveUpper(DFromV<decltype(v)>(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> OddEven(hwy::SizeTag<1> /* tag */, const Vec128<T, N> a, + const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const Repartition<uint8_t, decltype(d)> d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> OddEven(hwy::SizeTag<2> /* tag */, const Vec128<T, N> a, + const Vec128<T, N> b) { + return Vec128<T, N>{ + wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> OddEven(hwy::SizeTag<4> /* tag */, const Vec128<T, N> a, + const Vec128<T, N> b) { + return Vec128<T, N>{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> OddEven(hwy::SizeTag<8> /* tag */, const Vec128<T, N> a, + const Vec128<T, N> b) { + return Vec128<T, N>{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) { + return detail::OddEven(hwy::SizeTag<sizeof(T)>(), a, b); +} +template <size_t N> +HWY_API Vec128<float, N> OddEven(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ------------------------------ OddEvenBlocks +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEvenBlocks(Vec128<T, N> /* odd */, Vec128<T, N> even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T, size_t N> +HWY_API Vec128<T, N> SwapAdjacentBlocks(Vec128<T, N> v) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template <typename T> +HWY_API Vec128<T> ReverseBlocks(Full128<T> /* tag */, const Vec128<T> v) { + return v; +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template <size_t N, HWY_IF_LE128(uint16_t, N)> +HWY_API Vec128<uint16_t, N> PromoteTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + return Vec128<uint16_t, N>{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template <size_t N, HWY_IF_LE128(uint32_t, N)> +HWY_API Vec128<uint32_t, N> PromoteTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + return Vec128<uint32_t, N>{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +template <size_t N, HWY_IF_LE128(int16_t, N)> +HWY_API Vec128<int16_t, N> PromoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + return Vec128<int16_t, N>{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { + return Vec128<int32_t, N>{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} +template <size_t N, HWY_IF_LE128(uint32_t, N)> +HWY_API Vec128<uint32_t, N> PromoteTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + return Vec128<uint32_t, N>{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template <size_t N, HWY_IF_LE128(uint64_t, N)> +HWY_API Vec128<uint64_t, N> PromoteTo(Simd<uint64_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + return Vec128<uint64_t, N>{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + return Vec128<int32_t, N>{wasm_u32x4_extend_low_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +template <size_t N, HWY_IF_LE128(int16_t, N)> +HWY_API Vec128<int16_t, N> PromoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<int8_t, N> v) { + return Vec128<int16_t, N>{wasm_i16x8_extend_low_i8x16(v.raw)}; +} +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<int8_t, N> v) { + return Vec128<int32_t, N>{ + wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; +} +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_extend_low_i16x8(v.raw)}; +} +template <size_t N, HWY_IF_LE128(int64_t, N)> +HWY_API Vec128<int64_t, N> PromoteTo(Simd<int64_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<int64_t, N>{wasm_i64x2_extend_low_i32x4(v.raw)}; +} + +template <size_t N, HWY_IF_LE128(double, N)> +HWY_API Vec128<double, N> PromoteTo(Simd<double, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<double, N>{wasm_f64x2_convert_low_i32x4(v.raw)}; +} + +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> df32, + const Vec128<float16_t, N> v) { + const RebindToSigned<decltype(df32)> di32; + const RebindToUnsigned<decltype(df32)> du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128<uint16_t, N>{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> df32, + const Vec128<bfloat16_t, N> v) { + const Rebind<uint16_t, decltype(df32)> du16; + const RebindToSigned<decltype(df32)> di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template <size_t N> +HWY_API Vec128<uint16_t, N> DemoteTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<uint16_t, N>{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template <size_t N> +HWY_API Vec128<int16_t, N> DemoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<int16_t, N>{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128<uint8_t, N>{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<uint8_t, N>{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template <size_t N> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128<int8_t, N>{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template <size_t N> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<int8_t, N>{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template <size_t N> +HWY_API Vec128<int32_t, N> DemoteTo(Simd<int32_t, N, 0> /* di */, + const Vec128<double, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template <size_t N> +HWY_API Vec128<float16_t, N> DemoteTo(Simd<float16_t, N, 0> df16, + const Vec128<float, N> v) { + const RebindToUnsigned<decltype(df16)> du16; + const Rebind<uint32_t, decltype(du16)> du; + const RebindToSigned<decltype(du)> di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return Vec128<float16_t, N>{DemoteTo(du16, bits16).raw}; +} + +template <size_t N> +HWY_API Vec128<bfloat16_t, N> DemoteTo(Simd<bfloat16_t, N, 0> dbf16, + const Vec128<float, N> v) { + const Rebind<int32_t, decltype(dbf16)> di32; + const Rebind<uint32_t, decltype(dbf16)> du32; // for logical shift right + const Rebind<uint16_t, decltype(dbf16)> du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template <size_t N> +HWY_API Vec128<bfloat16_t, 2 * N> ReorderDemote2To( + Simd<bfloat16_t, 2 * N, 0> dbf16, Vec128<float, N> a, Vec128<float, N> b) { + const RebindToUnsigned<decltype(dbf16)> du16; + const Repartition<uint32_t, decltype(dbf16)> du32; + const Vec128<uint32_t, N> b_in_even = ShiftRight<16>(BitCast(du32, b)); + const auto u16 = OddEven(BitCast(du16, a), BitCast(du16, b_in_even)); + return BitCast(dbf16, u16); +} + +// Specializations for partial vectors because i16x8_narrow_i32x4 sets lanes +// above 2*N. +HWY_API Vec128<int16_t, 2> ReorderDemote2To(Simd<int16_t, 2, 0> dn, + Vec128<int32_t, 1> a, + Vec128<int32_t, 1> b) { + const Half<decltype(dn)> dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128<int16_t, 2> an{DemoteTo(dnh, a).raw}; + const Vec128<int16_t, 2> bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128<int16_t, 4> ReorderDemote2To(Simd<int16_t, 4, 0> dn, + Vec128<int32_t, 2> a, + Vec128<int32_t, 2> b) { + const Half<decltype(dn)> dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128<int16_t, 4> an{DemoteTo(dnh, a).raw}; + const Vec128<int16_t, 4> bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128<int16_t> ReorderDemote2To(Full128<int16_t> /*d16*/, + Vec128<int32_t> a, Vec128<int32_t> b) { + return Vec128<int16_t>{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; +} + +// For already range-limited input [0, 255]. +template <size_t N> +HWY_API Vec128<uint8_t, N> U8FromU32(const Vec128<uint32_t, N> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128<uint8_t, N>{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Truncations + +template <typename From, typename To, HWY_IF_UNSIGNED(From), + HWY_IF_UNSIGNED(To), + hwy::EnableIf<(sizeof(To) < sizeof(From))>* = nullptr> +HWY_API Vec128<To, 1> TruncateTo(Simd<To, 1, 0> /* tag */, + const Vec128<From, 1> v) { + const Repartition<To, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + return Vec128<To, 1>{v1.raw}; +} + +HWY_API Vec16<uint8_t> TruncateTo(Full16<uint8_t> /* tag */, + const Vec128<uint64_t> v) { + const Full128<uint8_t> d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + const auto v4 = ConcatEven(d, v2, v2); + return LowerHalf(LowerHalf(LowerHalf(ConcatEven(d, v4, v4)))); +} + +HWY_API Vec32<uint16_t> TruncateTo(Full32<uint16_t> /* tag */, + const Vec128<uint64_t> v) { + const Full128<uint16_t> d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + return LowerHalf(LowerHalf(ConcatEven(d, v2, v2))); +} + +HWY_API Vec64<uint32_t> TruncateTo(Full64<uint32_t> /* tag */, + const Vec128<uint64_t> v) { + const Full128<uint32_t> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + const Full128<uint8_t> d; + const auto v1 = Vec128<uint8_t>{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + const auto v3 = ConcatEven(d, v2, v2); + return Vec128<uint8_t, N>{v3.raw}; +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint16_t, N> TruncateTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + const Full128<uint16_t> d; + const auto v1 = Vec128<uint16_t>{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return Vec128<uint16_t, N>{v2.raw}; +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + const Full128<uint8_t> d; + const auto v1 = Vec128<uint8_t>{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return Vec128<uint8_t, N>{v2.raw}; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template <size_t N> +HWY_API Vec128<float, N> ConvertTo(Simd<float, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<float, N>{wasm_f32x4_convert_i32x4(v.raw)}; +} +template <size_t N> +HWY_API Vec128<float, N> ConvertTo(Simd<float, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + return Vec128<float, N>{wasm_f32x4_convert_u32x4(v.raw)}; +} +// Truncates (rounds toward zero). +template <size_t N> +HWY_API Vec128<int32_t, N> ConvertTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<float, N> v) { + return Vec128<int32_t, N>{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +} + +template <size_t N> +HWY_API Vec128<int32_t, N> NearestInt(const Vec128<float, N> v) { + return ConvertTo(Simd<int32_t, N, 0>(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template <size_t N> +HWY_API Vec128<uint64_t, N / 8> SumsOf8(const Vec128<uint8_t, N> v) { + const DFromV<decltype(v)> du8; + const RepartitionToWide<decltype(du8)> du16; + const RepartitionToWide<decltype(du16)> du32; + const RepartitionToWide<decltype(du32)> du64; + using VU16 = VFromD<decltype(du16)>; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = And(BitCast(du16, v), Set(du16, 0xFF)); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t bits) { + const RebindToUnsigned<decltype(d)> du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128<T, N> vbits{wasm_i32x4_splat(static_cast<int32_t>(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask( + d, TestBit(Set(du, static_cast<uint16_t>(bits)), Load(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask( + d, TestBit(Set(du, static_cast<uint32_t>(bits)), Load(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(N + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Full +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128<T> mask) { + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return (hi + lo); +} + +// 64-bit +template <typename T> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128<T, 8> mask) { + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return (static_cast<uint64_t>(wasm_i64x2_extract_lane(mask.raw, 0)) * + kMagic) >> + 56; +} + +// 32-bit or less: need masking +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128<T, N> mask) { + uint64_t bytes = static_cast<uint64_t>(wasm_i64x2_extract_lane(mask.raw, 0)); + // Clear potentially undefined bytes. + bytes &= (1ULL << (N * 8)) - 1; + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return (bytes * kMagic) >> 56; +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128<T, N> mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const __i16x8 zero = wasm_i16x8_splat(0); + const Mask128<uint8_t, N> mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return BitsFromMask(hwy::SizeTag<1>(), mask8); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128<T, N> mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1] | lanes[2] | lanes[3]; +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128<T, N> mask) { + const __i64x2 mask_i = static_cast<__i64x2>(mask.raw); + const __i64x2 slice = wasm_i64x2_make(1, 2); + const __i64x2 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, sliced_mask); + return lanes[0] | lanes[1]; +} + +// Returns the lowest N bits for the BitsFromMask result. +template <typename T, size_t N> +constexpr uint64_t OnlyActive(uint64_t bits) { + return ((N * sizeof(T)) == 16) ? bits : bits & ((1ull << N) - 1); +} + +// Returns 0xFF for bytes with index >= N, otherwise 0. +template <size_t N> +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(const Mask128<T, N> mask) { + return OnlyActive<T, N>(BitsFromMask(hwy::SizeTag<sizeof(T)>(), mask)); +} + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> tag, const Mask128<T> m) { + return PopCount(BitsFromMask(tag, m)); +} + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> tag, const Mask128<T> m) { + return PopCount(BitsFromMask(tag, m)); +} + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128<T> m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} + +template <typename T> +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, const Mask128<T> m) { + alignas(16) int64_t lanes[2]; + wasm_v128_store(lanes, m.raw); + return static_cast<size_t>(-(lanes[0] + lanes[1])); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template <typename T, size_t N> +HWY_API size_t StoreMaskBits(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask, uint8_t* bits) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(&mask_bits, bits); + return kNumBytes; +} + +template <typename T, size_t N> +HWY_API size_t CountTrue(const Simd<T, N, 0> /* tag */, const Mask128<T> m) { + return detail::CountTrue(hwy::SizeTag<sizeof(T)>(), m); +} + +// Partial vector +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API size_t CountTrue(const Simd<T, N, 0> d, const Mask128<T, N> m) { + // Ensure all undefined bytes are 0. + const Mask128<T, N> mask{detail::BytesAbove<N * sizeof(T)>()}; + return CountTrue(d, Mask128<T>{AndNot(mask, m).raw}); +} + +// Full vector +template <typename T> +HWY_API bool AllFalse(const Full128<T> d, const Mask128<T> m) { +#if 0 + // Casting followed by wasm_i8x16_any_true results in wasm error: + // i32.eqz[0] expected type i32, found i8x16.popcnt of type s128 + const auto v8 = BitCast(Full128<int8_t>(), VecFromMask(d, m)); + return !wasm_i8x16_any_true(v8.raw); +#else + (void)d; + return (wasm_i64x2_extract_lane(m.raw, 0) | + wasm_i64x2_extract_lane(m.raw, 1)) == 0; +#endif +} + +// Full vector +namespace detail { +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128<T> m) { + return wasm_i8x16_all_true(m.raw); +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128<T> m) { + return wasm_i16x8_all_true(m.raw); +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128<T> m) { + return wasm_i32x4_all_true(m.raw); +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask128<T> m) { + return wasm_i64x2_all_true(m.raw); +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API bool AllTrue(const Simd<T, N, 0> /* tag */, const Mask128<T> m) { + return detail::AllTrue(hwy::SizeTag<sizeof(T)>(), m); +} + +// Partial vectors + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API bool AllFalse(Simd<T, N, 0> /* tag */, const Mask128<T, N> m) { + // Ensure all undefined bytes are 0. + const Mask128<T, N> mask{detail::BytesAbove<N * sizeof(T)>()}; + return AllFalse(Full128<T>(), Mask128<T>{AndNot(mask, m).raw}); +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API bool AllTrue(const Simd<T, N, 0> /* d */, const Mask128<T, N> m) { + // Ensure all undefined bytes are FF. + const Mask128<T, N> mask{detail::BytesAbove<N * sizeof(T)>()}; + return AllTrue(Full128<T>(), Mask128<T>{Or(mask, m).raw}); +} + +template <typename T, size_t N> +HWY_API size_t FindKnownFirstTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(bits); +} + +template <typename T, size_t N> +HWY_API intptr_t FindFirstTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint64_t bits = detail::BitsFromMask(mask); + return bits ? static_cast<intptr_t>(Num0BitsBelowLS1Bit_Nonzero64(bits)) : -1; +} + +// ------------------------------ Compress + +namespace detail { + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd<T, N, 0> d; + const Rebind<uint8_t, decltype(d)> d8; + const Simd<uint16_t, N, 0> du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128<uint8_t, 2 * N> byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128<uint16_t, N> pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd<T, N, 0> d; + const Rebind<uint8_t, decltype(d)> d8; + const Simd<uint16_t, N, 0> du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128<uint8_t, 2 * N> byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128<uint16_t, N> pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<T, N> IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<T, N> IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec128<T, N> IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec128<T, N> IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd<T, N, 0> d; + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Compress(Vec128<T, N> v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromBits<T, N>(mask_bits); + const DFromV<decltype(v)> d; + const RebindToSigned<decltype(d)> di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> CompressNot(Vec128<T, N> v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromNotBits<T, N>(mask_bits); + const DFromV<decltype(v)> d; + const RebindToSigned<decltype(d)> di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +template <typename T> +struct CompressIsPartition { +#if HWY_TARGET == HWY_WASM_EMU256 + enum { value = 0 }; +#else + enum { value = (sizeof(T) != 1) }; +#endif +}; + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> Compress(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +// Two lanes: conditional swap +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> Compress(Vec128<T> v, Mask128<T> mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128<T> d; + const Vec128<T> m = VecFromMask(d, mask); + const Vec128<T> maskL = DupEven(m); + const Vec128<T> maskH = DupOdd(m); + const Vec128<T> swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API Vec128<T, N> Compress(Vec128<T, N> v, Mask128<T, N> mask) { + return detail::Compress(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> CompressNot(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +// Two lanes: conditional swap +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128<T> d; + const Vec128<T> m = VecFromMask(d, mask); + const Vec128<T> maskL = DupEven(m); + const Vec128<T> maskH = DupOdd(m); + const Vec128<T> swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API Vec128<T, N> CompressNot(Vec128<T, N> v, Mask128<T, N> mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNot(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128<uint64_t> CompressBlocksNot(Vec128<uint64_t> v, + Mask128<uint64_t> /* m */) { + return v; +} + +// ------------------------------ CompressBits +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> CompressBits(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressStore(Vec128<T, N> v, const Mask128<T, N> mask, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; // so we can support fp16/bf16 + using TU = TFromD<decltype(du)>; + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Vec128<TU, N> compressed = detail::Compress(BitCast(du, v), mask_bits); + const Mask128<T, N> store_mask = RebindMask(d, FirstN(du, count)); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ MulEven/Odd (Load) + +HWY_INLINE Vec128<uint64_t> MulEven(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + alignas(16) uint64_t mul[2]; + mul[0] = + Mul128(static_cast<uint64_t>(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast<uint64_t>(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full128<uint64_t>(), mul); +} + +HWY_INLINE Vec128<uint64_t> MulOdd(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + alignas(16) uint64_t mul[2]; + mul[0] = + Mul128(static_cast<uint64_t>(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast<uint64_t>(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full128<uint64_t>(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template <size_t N> +HWY_API Vec128<float, N> ReorderWidenMulAccumulate(Simd<float, N, 0> df32, + Vec128<bfloat16_t, 2 * N> a, + Vec128<bfloat16_t, 2 * N> b, + const Vec128<float, N> sum0, + Vec128<float, N>& sum1) { + const Rebind<uint32_t, decltype(df32)> du32; + using VU32 = VFromD<decltype(du32)>; + const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 + // Using shift/and instead of Zip leads to the odd/even order that + // RearrangeToOddPlusEven prefers. + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is +// safe. +template <size_t N> +HWY_API Vec128<int32_t, N> ReorderWidenMulAccumulate( + Simd<int32_t, N, 0> /*d32*/, Vec128<int16_t, 2 * N> a, + Vec128<int16_t, 2 * N> b, const Vec128<int32_t, N> sum0, + Vec128<int32_t, N>& /*sum1*/) { + return sum0 + Vec128<int32_t, N>{wasm_i32x4_dot_i16x8(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +template <size_t N> +HWY_API Vec128<int32_t, N> RearrangeToOddPlusEven( + const Vec128<int32_t, N> sum0, const Vec128<int32_t, N> /*sum1*/) { + return sum0; // invariant already holds +} + +template <size_t N> +HWY_API Vec128<float, N> RearrangeToOddPlusEven(const Vec128<float, N> sum0, + const Vec128<float, N> sum1) { + return Add(sum0, sum1); +} + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template <typename T> +HWY_INLINE Vec128<T, 1> SumOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} +template <typename T> +HWY_INLINE Vec128<T, 1> MinOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} +template <typename T> +HWY_INLINE Vec128<T, 1> MaxOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template <typename T> +HWY_INLINE Vec128<T, 2> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return v10 + Vec128<T, 2>{Shuffle2301(Vec128<T>{v10.raw}).raw}; +} +template <typename T> +HWY_INLINE Vec128<T, 2> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return Min(v10, Vec128<T, 2>{Shuffle2301(Vec128<T>{v10.raw}).raw}); +} +template <typename T> +HWY_INLINE Vec128<T, 2> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return Max(v10, Vec128<T, 2>{Shuffle2301(Vec128<T>{v10.raw}).raw}); +} + +// N=4 (full) +template <typename T> +HWY_INLINE Vec128<T> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = v3210 + v1032; + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template <typename T> +HWY_INLINE Vec128<T> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = Min(v3210, v1032); + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template <typename T> +HWY_INLINE Vec128<T> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = Max(v3210, v1032); + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template <typename T> +HWY_INLINE Vec128<T> SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<T> v10) { + const Vec128<T> v01 = Shuffle01(v10); + return v10 + v01; +} +template <typename T> +HWY_INLINE Vec128<T> MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<T> v10) { + const Vec128<T> v01 = Shuffle01(v10); + return Min(v10, v01); +} +template <typename T> +HWY_INLINE Vec128<T> MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<T> v10) { + const Vec128<T> v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template <typename T, size_t N> +HWY_API Vec128<T, N> SumOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::SumOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MinOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::MinOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MaxOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::MaxOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} + +// ------------------------------ Lt128 + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Lt128(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + static_assert(!IsSigned<T>() && sizeof(T) == 8, "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const Mask128<T, N> eqHL = Eq(a, b); + const Vec128<T, N> ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const Vec128<T, N> ltLx = DupEven(ltHL); + const Vec128<T, N> outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Lt128Upper(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + const Vec128<T, N> ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Eq128(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + static_assert(!IsSigned<T>() && sizeof(T) == 8, "T must be u64"); + const Vec128<T, N> eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Eq128Upper(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + const Vec128<T, N> eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Ne128(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + static_assert(!IsSigned<T>() && sizeof(T) == 8, "T must be u64"); + const Vec128<T, N> neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_INLINE Mask128<T, N> Ne128Upper(Simd<T, N, 0> d, Vec128<T, N> a, + Vec128<T, N> b) { + const Vec128<T, N> neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template <class D> +HWY_INLINE VFromD<D> Min128(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Max128(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Min128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template <class D> +HWY_INLINE VFromD<D> Max128Upper(D d, const VFromD<D> a, const VFromD<D> b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/wasm_256-inl.h b/third_party/highway/hwy/ops/wasm_256-inl.h new file mode 100644 index 0000000000..aa62f05e00 --- /dev/null +++ b/third_party/highway/hwy/ops/wasm_256-inl.h @@ -0,0 +1,2003 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit WASM vectors and operations. Experimental. +// External include guard in highway.h - see comment there. + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/wasm_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename T> +class Vec256 { + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Vec128<T> v0; + Vec128<T> v1; +}; + +template <typename T> +struct Mask256 { + Mask128<T> m0; + Mask128<T> m1; +}; + +// ------------------------------ BitCast + +template <typename T, typename FromT> +HWY_API Vec256<T> BitCast(Full256<T> d, Vec256<FromT> v) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = BitCast(dh, v.v0); + ret.v1 = BitCast(dh, v.v1); + return ret; +} + +// ------------------------------ Zero + +template <typename T> +HWY_API Vec256<T> Zero(Full256<T> d) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = ret.v1 = Zero(dh); + return ret; +} + +template <class D> +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template <typename T, typename T2> +HWY_API Vec256<T> Set(Full256<T> d, const T2 t) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = ret.v1 = Set(dh, static_cast<T>(t)); + return ret; +} + +template <typename T> +HWY_API Vec256<T> Undefined(Full256<T> d) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = ret.v1 = Undefined(dh); + return ret; +} + +template <typename T, typename T2> +Vec256<T> Iota(const Full256<T> d, const T2 first) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Iota(dh, first); + // NB: for floating types the gap between parts might be a bit uneven. + ret.v1 = Iota(dh, AddWithWraparound(hwy::IsFloatTag<T>(), + static_cast<T>(first), Lanes(dh))); + return ret; +} + +// ================================================== ARITHMETIC + +template <typename T> +HWY_API Vec256<T> operator+(Vec256<T> a, const Vec256<T> b) { + a.v0 += b.v0; + a.v1 += b.v1; + return a; +} + +template <typename T> +HWY_API Vec256<T> operator-(Vec256<T> a, const Vec256<T> b) { + a.v0 -= b.v0; + a.v1 -= b.v1; + return a; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256<uint64_t> SumsOf8(const Vec256<uint8_t> v) { + Vec256<uint64_t> ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +template <typename T> +HWY_API Vec256<T> SaturatedAdd(Vec256<T> a, const Vec256<T> b) { + a.v0 = SaturatedAdd(a.v0, b.v0); + a.v1 = SaturatedAdd(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> SaturatedSub(Vec256<T> a, const Vec256<T> b) { + a.v0 = SaturatedSub(a.v0, b.v0); + a.v1 = SaturatedSub(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> AverageRound(Vec256<T> a, const Vec256<T> b) { + a.v0 = AverageRound(a.v0, b.v0); + a.v1 = AverageRound(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> Abs(Vec256<T> v) { + v.v0 = Abs(v.v0); + v.v1 = Abs(v.v1); + return v; +} + +// ------------------------------ Shift lanes by constant #bits + +template <int kBits, typename T> +HWY_API Vec256<T> ShiftLeft(Vec256<T> v) { + v.v0 = ShiftLeft<kBits>(v.v0); + v.v1 = ShiftLeft<kBits>(v.v1); + return v; +} + +template <int kBits, typename T> +HWY_API Vec256<T> ShiftRight(Vec256<T> v) { + v.v0 = ShiftRight<kBits>(v.v0); + v.v1 = ShiftRight<kBits>(v.v1); + return v; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template <int kBits, typename T> +HWY_API Vec256<T> RotateRight(const Vec256<T> v) { + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<kSizeInBits - kBits>(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +template <typename T> +HWY_API Vec256<T> ShiftLeftSame(Vec256<T> v, const int bits) { + v.v0 = ShiftLeftSame(v.v0, bits); + v.v1 = ShiftLeftSame(v.v1, bits); + return v; +} + +template <typename T> +HWY_API Vec256<T> ShiftRightSame(Vec256<T> v, const int bits) { + v.v0 = ShiftRightSame(v.v0, bits); + v.v1 = ShiftRightSame(v.v1, bits); + return v; +} + +// ------------------------------ Min, Max +template <typename T> +HWY_API Vec256<T> Min(Vec256<T> a, const Vec256<T> b) { + a.v0 = Min(a.v0, b.v0); + a.v1 = Min(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> Max(Vec256<T> a, const Vec256<T> b) { + a.v0 = Max(a.v0, b.v0); + a.v1 = Max(a.v1, b.v1); + return a; +} +// ------------------------------ Integer multiplication + +template <typename T> +HWY_API Vec256<T> operator*(Vec256<T> a, const Vec256<T> b) { + a.v0 *= b.v0; + a.v1 *= b.v1; + return a; +} + +template <typename T> +HWY_API Vec256<T> MulHigh(Vec256<T> a, const Vec256<T> b) { + a.v0 = MulHigh(a.v0, b.v0); + a.v1 = MulHigh(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> MulFixedPoint15(Vec256<T> a, const Vec256<T> b) { + a.v0 = MulFixedPoint15(a.v0, b.v0); + a.v1 = MulFixedPoint15(a.v1, b.v1); + return a; +} + +// Cannot use MakeWide because that returns uint128_t for uint64_t, but we want +// uint64_t. +HWY_API Vec256<uint64_t> MulEven(Vec256<uint32_t> a, const Vec256<uint32_t> b) { + Vec256<uint64_t> ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +HWY_API Vec256<int64_t> MulEven(Vec256<int32_t> a, const Vec256<int32_t> b) { + Vec256<int64_t> ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} + +HWY_API Vec256<uint64_t> MulEven(Vec256<uint64_t> a, const Vec256<uint64_t> b) { + Vec256<uint64_t> ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +HWY_API Vec256<uint64_t> MulOdd(Vec256<uint64_t> a, const Vec256<uint64_t> b) { + Vec256<uint64_t> ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} + +// ------------------------------ Negate +template <typename T> +HWY_API Vec256<T> Neg(Vec256<T> v) { + v.v0 = Neg(v.v0); + v.v1 = Neg(v.v1); + return v; +} + +// ------------------------------ Floating-point division +template <typename T> +HWY_API Vec256<T> operator/(Vec256<T> a, const Vec256<T> b) { + a.v0 /= b.v0; + a.v1 /= b.v1; + return a; +} + +// Approximate reciprocal +HWY_API Vec256<float> ApproximateReciprocal(const Vec256<float> v) { + const Vec256<float> one = Set(Full256<float>(), 1.0f); + return one / v; +} + +// Absolute value of difference. +HWY_API Vec256<float> AbsDiff(const Vec256<float> a, const Vec256<float> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256<float> MulAdd(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> add) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfma? + return mul * x + add; +} + +// Returns add - mul * x +HWY_API Vec256<float> NegMulAdd(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> add) { + // TODO(eustas): replace, when implemented in WASM. + return add - mul * x; +} + +// Returns mul * x - sub +HWY_API Vec256<float> MulSub(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> sub) { + // TODO(eustas): replace, when implemented in WASM. + // TODO(eustas): is it wasm_f32x4_qfms? + return mul * x - sub; +} + +// Returns -mul * x - sub +HWY_API Vec256<float> NegMulSub(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> sub) { + // TODO(eustas): replace, when implemented in WASM. + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template <typename T> +HWY_API Vec256<T> Sqrt(Vec256<T> v) { + v.v0 = Sqrt(v.v0); + v.v1 = Sqrt(v.v1); + return v; +} + +// Approximate reciprocal square root +HWY_API Vec256<float> ApproximateReciprocalSqrt(const Vec256<float> v) { + // TODO(eustas): find cheaper a way to calculate this. + const Vec256<float> one = Set(Full256<float>(), 1.0f); + return one / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +HWY_API Vec256<float> Round(Vec256<float> v) { + v.v0 = Round(v.v0); + v.v1 = Round(v.v1); + return v; +} + +// Toward zero, aka truncate +HWY_API Vec256<float> Trunc(Vec256<float> v) { + v.v0 = Trunc(v.v0); + v.v1 = Trunc(v.v1); + return v; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256<float> Ceil(Vec256<float> v) { + v.v0 = Ceil(v.v0); + v.v1 = Ceil(v.v1); + return v; +} + +// Toward -infinity, aka floor +HWY_API Vec256<float> Floor(Vec256<float> v) { + v.v0 = Floor(v.v0); + v.v1 = Floor(v.v1); + return v; +} + +// ------------------------------ Floating-point classification + +template <typename T> +HWY_API Mask256<T> IsNaN(const Vec256<T> v) { + return v != v; +} + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Mask256<T> IsInf(const Vec256<T> v) { + const Full256<T> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2<T>()))); +} + +// Returns whether normal/subnormal/zero. +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Mask256<T> IsFinite(const Vec256<T> v) { + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + const VFromD<decltype(du)> vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField<T>()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template <typename TFrom, typename TTo> +HWY_API Mask256<TTo> RebindMask(Full256<TTo> /*tag*/, Mask256<TFrom> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256<TTo>{Mask128<TTo>{m.m0.raw}, Mask128<TTo>{m.m1.raw}}; +} + +template <typename T> +HWY_API Mask256<T> TestBit(Vec256<T> v, Vec256<T> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template <typename T> +HWY_API Mask256<T> operator==(Vec256<T> a, const Vec256<T> b) { + Mask256<T> m; + m.m0 = operator==(a.v0, b.v0); + m.m1 = operator==(a.v1, b.v1); + return m; +} + +template <typename T> +HWY_API Mask256<T> operator!=(Vec256<T> a, const Vec256<T> b) { + Mask256<T> m; + m.m0 = operator!=(a.v0, b.v0); + m.m1 = operator!=(a.v1, b.v1); + return m; +} + +template <typename T> +HWY_API Mask256<T> operator<(Vec256<T> a, const Vec256<T> b) { + Mask256<T> m; + m.m0 = operator<(a.v0, b.v0); + m.m1 = operator<(a.v1, b.v1); + return m; +} + +template <typename T> +HWY_API Mask256<T> operator>(Vec256<T> a, const Vec256<T> b) { + Mask256<T> m; + m.m0 = operator>(a.v0, b.v0); + m.m1 = operator>(a.v1, b.v1); + return m; +} + +template <typename T> +HWY_API Mask256<T> operator<=(Vec256<T> a, const Vec256<T> b) { + Mask256<T> m; + m.m0 = operator<=(a.v0, b.v0); + m.m1 = operator<=(a.v1, b.v1); + return m; +} + +template <typename T> +HWY_API Mask256<T> operator>=(Vec256<T> a, const Vec256<T> b) { + Mask256<T> m; + m.m0 = operator>=(a.v0, b.v0); + m.m1 = operator>=(a.v1, b.v1); + return m; +} + +// ------------------------------ FirstN (Iota, Lt) + +template <typename T> +HWY_API Mask256<T> FirstN(const Full256<T> d, size_t num) { + const RebindToSigned<decltype(d)> di; // Signed comparisons may be cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast<MakeSigned<T>>(num))); +} + +// ================================================== LOGICAL + +template <typename T> +HWY_API Vec256<T> Not(Vec256<T> v) { + v.v0 = Not(v.v0); + v.v1 = Not(v.v1); + return v; +} + +template <typename T> +HWY_API Vec256<T> And(Vec256<T> a, Vec256<T> b) { + a.v0 = And(a.v0, b.v0); + a.v1 = And(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> AndNot(Vec256<T> not_mask, Vec256<T> mask) { + not_mask.v0 = AndNot(not_mask.v0, mask.v0); + not_mask.v1 = AndNot(not_mask.v1, mask.v1); + return not_mask; +} + +template <typename T> +HWY_API Vec256<T> Or(Vec256<T> a, Vec256<T> b) { + a.v0 = Or(a.v0, b.v0); + a.v1 = Or(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> Xor(Vec256<T> a, Vec256<T> b) { + a.v0 = Xor(a.v0, b.v0); + a.v1 = Xor(a.v1, b.v1); + return a; +} + +template <typename T> +HWY_API Vec256<T> Xor3(Vec256<T> x1, Vec256<T> x2, Vec256<T> x3) { + return Xor(x1, Xor(x2, x3)); +} + +template <typename T> +HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) { + return Or(o1, Or(o2, o3)); +} + +template <typename T> +HWY_API Vec256<T> OrAnd(Vec256<T> o, Vec256<T> a1, Vec256<T> a2) { + return Or(o, And(a1, a2)); +} + +template <typename T> +HWY_API Vec256<T> IfVecThenElse(Vec256<T> mask, Vec256<T> yes, Vec256<T> no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template <typename T> +HWY_API Vec256<T> operator&(const Vec256<T> a, const Vec256<T> b) { + return And(a, b); +} + +template <typename T> +HWY_API Vec256<T> operator|(const Vec256<T> a, const Vec256<T> b) { + return Or(a, b); +} + +template <typename T> +HWY_API Vec256<T> operator^(const Vec256<T> a, const Vec256<T> b) { + return Xor(a, b); +} + +// ------------------------------ CopySign + +template <typename T> +HWY_API Vec256<T> CopySign(const Vec256<T> magn, const Vec256<T> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + const auto msb = SignBit(Full256<T>()); + return Or(AndNot(msb, magn), And(msb, sign)); +} + +template <typename T> +HWY_API Vec256<T> CopySignToAbs(const Vec256<T> abs, const Vec256<T> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + return Or(abs, And(SignBit(Full256<T>()), sign)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template <typename T> +HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) { + Mask256<T> m; + m.m0 = MaskFromVec(v.v0); + m.m1 = MaskFromVec(v.v1); + return m; +} + +template <typename T> +HWY_API Vec256<T> VecFromMask(Full256<T> d, Mask256<T> m) { + const Half<decltype(d)> dh; + Vec256<T> v; + v.v0 = VecFromMask(dh, m.m0); + v.v1 = VecFromMask(dh, m.m1); + return v; +} + +// mask ? yes : no +template <typename T> +HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) { + yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); + yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); + return yes; +} + +// mask ? yes : 0 +template <typename T> +HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) { + return yes & VecFromMask(Full256<T>(), mask); +} + +// mask ? 0 : no +template <typename T> +HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) { + return AndNot(VecFromMask(Full256<T>(), mask), no); +} + +template <typename T> +HWY_API Vec256<T> IfNegativeThenElse(Vec256<T> v, Vec256<T> yes, Vec256<T> no) { + v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); + v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); + return v; +} + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec256<T> ZeroIfNegative(Vec256<T> v) { + return IfThenZeroElse(v < Zero(Full256<T>()), v); +} + +// ------------------------------ Mask logical + +template <typename T> +HWY_API Mask256<T> Not(const Mask256<T> m) { + return MaskFromVec(Not(VecFromMask(Full256<T>(), m))); +} + +template <typename T> +HWY_API Mask256<T> And(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> AndNot(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> Or(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> Xor(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> ExclusiveNeither(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) +template <typename T> +HWY_API Vec256<T> operator<<(Vec256<T> v, const Vec256<T> bits) { + v.v0 = operator<<(v.v0, bits.v0); + v.v1 = operator<<(v.v1, bits.v1); + return v; +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) +template <typename T> +HWY_API Vec256<T> operator>>(Vec256<T> v, const Vec256<T> bits) { + v.v0 = operator>>(v.v0, bits.v0); + v.v1 = operator>>(v.v1, bits.v1); + return v; +} + +// ------------------------------ BroadcastSignBit (compare, VecFromMask) + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec256<T> BroadcastSignBit(const Vec256<T> v) { + return ShiftRight<sizeof(T) * 8 - 1>(v); +} +HWY_API Vec256<int8_t> BroadcastSignBit(const Vec256<int8_t> v) { + const Full256<int8_t> d; + return VecFromMask(d, v < Zero(d)); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template <typename T> +HWY_API Vec256<T> Load(Full256<T> d, const T* HWY_RESTRICT aligned) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Load(dh, aligned); + ret.v1 = Load(dh, aligned + Lanes(dh)); + return ret; +} + +template <typename T> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> d, + const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +// LoadU == Load. +template <typename T> +HWY_API Vec256<T> LoadU(Full256<T> d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +template <typename T> +HWY_API Vec256<T> LoadDup128(Full256<T> d, const T* HWY_RESTRICT p) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = ret.v1 = Load(dh, p); + return ret; +} + +// ------------------------------ Store + +template <typename T> +HWY_API void Store(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT aligned) { + const Half<decltype(d)> dh; + Store(v.v0, dh, aligned); + Store(v.v1, dh, aligned + Lanes(dh)); +} + +// StoreU == Store. +template <typename T> +HWY_API void StoreU(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template <typename T> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> d, + T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Stream +template <typename T> +HWY_API void Stream(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT aligned) { + // Same as aligned stores. + Store(v, d, aligned); +} + +// ------------------------------ Scatter (Store) + +template <typename T, typename Offset> +HWY_API void ScatterOffset(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT base, + const Vec256<Offset> offset) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256<Offset>(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast<uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template <typename T, typename Index> +HWY_API void ScatterIndex(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT base, + const Vec256<Index> index) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Full256<Index>(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +// ------------------------------ Gather (Load/Store) + +template <typename T, typename Offset> +HWY_API Vec256<T> GatherOffset(const Full256<T> d, const T* HWY_RESTRICT base, + const Vec256<Offset> offset) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256<Offset>(), offset_lanes); + + alignas(32) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast<const uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template <typename T, typename Index> +HWY_API Vec256<T> GatherIndex(const Full256<T> d, const T* HWY_RESTRICT base, + const Vec256<Index> index) { + constexpr size_t N = 32 / sizeof(T); + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(32) Index index_lanes[N]; + Store(index, Full256<Index>(), index_lanes); + + alignas(32) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane +template <typename T> +HWY_API T ExtractLane(const Vec256<T> v, size_t i) { + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, Full256<T>(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane +template <typename T> +HWY_API Vec256<T> InsertLane(const Vec256<T> v, size_t i, T t) { + Full256<T> d; + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template <typename T> +HWY_API Vec128<T> LowerHalf(Full128<T> /* tag */, Vec256<T> v) { + return v.v0; +} + +template <typename T> +HWY_API Vec128<T> LowerHalf(Vec256<T> v) { + return v.v0; +} + +// ------------------------------ GetLane (LowerHalf) +template <typename T> +HWY_API T GetLane(const Vec256<T> v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, typename T> +HWY_API Vec256<T> ShiftLeftBytes(Full256<T> d, Vec256<T> v) { + const Half<decltype(d)> dh; + v.v0 = ShiftLeftBytes<kBytes>(dh, v.v0); + v.v1 = ShiftLeftBytes<kBytes>(dh, v.v1); + return v; +} + +template <int kBytes, typename T> +HWY_API Vec256<T> ShiftLeftBytes(Vec256<T> v) { + return ShiftLeftBytes<kBytes>(Full256<T>(), v); +} + +// ------------------------------ ShiftLeftLanes + +template <int kLanes, typename T> +HWY_API Vec256<T> ShiftLeftLanes(Full256<T> d, const Vec256<T> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T> +HWY_API Vec256<T> ShiftLeftLanes(const Vec256<T> v) { + return ShiftLeftLanes<kLanes>(Full256<T>(), v); +} + +// ------------------------------ ShiftRightBytes +template <int kBytes, typename T> +HWY_API Vec256<T> ShiftRightBytes(Full256<T> d, Vec256<T> v) { + const Half<decltype(d)> dh; + v.v0 = ShiftRightBytes<kBytes>(dh, v.v0); + v.v1 = ShiftRightBytes<kBytes>(dh, v.v1); + return v; +} + +// ------------------------------ ShiftRightLanes +template <int kLanes, typename T> +HWY_API Vec256<T> ShiftRightLanes(Full256<T> d, const Vec256<T> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template <typename T> +HWY_API Vec128<T> UpperHalf(Full128<T> /* tag */, const Vec256<T> v) { + return v.v1; +} + +// ------------------------------ CombineShiftRightBytes + +template <int kBytes, typename T, class V = Vec256<T>> +HWY_API V CombineShiftRightBytes(Full256<T> d, V hi, V lo) { + const Half<decltype(d)> dh; + hi.v0 = CombineShiftRightBytes<kBytes>(dh, hi.v0, lo.v0); + hi.v1 = CombineShiftRightBytes<kBytes>(dh, hi.v1, lo.v1); + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +template <int kLane, typename T> +HWY_API Vec256<T> Broadcast(const Vec256<T> v) { + Vec256<T> ret; + ret.v0 = Broadcast<kLane>(v.v0); + ret.v1 = Broadcast<kLane>(v.v1); + return ret; +} + +// ------------------------------ TableLookupBytes + +// Both full +template <typename T, typename TI> +HWY_API Vec256<TI> TableLookupBytes(const Vec256<T> bytes, Vec256<TI> from) { + from.v0 = TableLookupBytes(bytes.v0, from.v0); + from.v1 = TableLookupBytes(bytes.v1, from.v1); + return from; +} + +// Partial index vector +template <typename T, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytes(const Vec256<T> bytes, + const Vec128<TI, NI> from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256<TI>(), Vec128<TI>{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128<TI, NI>{LowerHalf(Full128<TI>(), tbl_full).raw}; +} + +// Partial table vector +template <typename T, size_t N, typename TI> +HWY_API Vec256<TI> TableLookupBytes(const Vec128<T, N> bytes, + const Vec256<TI> from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256<T>(), Vec128<T>{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by wasm_128. + +template <class V, class VI> +HWY_API VI TableLookupBytesOr0(const V bytes, VI from) { + // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Hard-coded shuffles + +template <typename T> +HWY_API Vec256<T> Shuffle01(Vec256<T> v) { + v.v0 = Shuffle01(v.v0); + v.v1 = Shuffle01(v.v1); + return v; +} + +template <typename T> +HWY_API Vec256<T> Shuffle2301(Vec256<T> v) { + v.v0 = Shuffle2301(v.v0); + v.v1 = Shuffle2301(v.v1); + return v; +} + +template <typename T> +HWY_API Vec256<T> Shuffle1032(Vec256<T> v) { + v.v0 = Shuffle1032(v.v0); + v.v1 = Shuffle1032(v.v1); + return v; +} + +template <typename T> +HWY_API Vec256<T> Shuffle0321(Vec256<T> v) { + v.v0 = Shuffle0321(v.v0); + v.v1 = Shuffle0321(v.v1); + return v; +} + +template <typename T> +HWY_API Vec256<T> Shuffle2103(Vec256<T> v) { + v.v0 = Shuffle2103(v.v0); + v.v1 = Shuffle2103(v.v1); + return v; +} + +template <typename T> +HWY_API Vec256<T> Shuffle0123(Vec256<T> v) { + v.v0 = Shuffle0123(v.v0); + v.v1 = Shuffle0123(v.v1); + return v; +} + +// Used by generic_ops-inl.h +namespace detail { + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle2301(Vec256<T> a, const Vec256<T> b) { + a.v0 = Shuffle2301(a.v0, b.v0); + a.v1 = Shuffle2301(a.v1, b.v1); + return a; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle1230(Vec256<T> a, const Vec256<T> b) { + a.v0 = Shuffle1230(a.v0, b.v0); + a.v1 = Shuffle1230(a.v1, b.v1); + return a; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle3012(Vec256<T> a, const Vec256<T> b) { + a.v0 = Shuffle3012(a.v0, b.v0); + a.v1 = Shuffle3012(a.v1, b.v1); + return a; +} + +} // namespace detail + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template <typename T> +struct Indices256 { + __v128_u i0; + __v128_u i1; +}; + +template <typename T, typename TI> +HWY_API Indices256<T> IndicesFromVec(Full256<T> /* tag */, Vec256<TI> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + Indices256<T> ret; + ret.i0 = vec.v0.raw; + ret.i1 = vec.v1.raw; + return ret; +} + +template <typename T, typename TI> +HWY_API Indices256<T> SetTableIndices(Full256<T> d, const TI* idx) { + const Rebind<TI, decltype(d)> di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template <typename T> +HWY_API Vec256<T> TableLookupLanes(const Vec256<T> v, Indices256<T> idx) { + using TU = MakeUnsigned<T>; + const Full128<T> dh; + const Full128<TU> duh; + constexpr size_t kLanesPerHalf = 16 / sizeof(TU); + + const Vec128<TU> vi0{idx.i0}; + const Vec128<TU> vi1{idx.i1}; + const Vec128<TU> mask = Set(duh, static_cast<TU>(kLanesPerHalf - 1)); + const Vec128<TU> vmod0 = vi0 & mask; + const Vec128<TU> vmod1 = vi1 & mask; + // If ANDing did not change the index, it is for the lower half. + const Mask128<T> is_lo0 = RebindMask(dh, vi0 == vmod0); + const Mask128<T> is_lo1 = RebindMask(dh, vi1 == vmod1); + const Indices128<T> mod0 = IndicesFromVec(dh, vmod0); + const Indices128<T> mod1 = IndicesFromVec(dh, vmod1); + + Vec256<T> ret; + ret.v0 = IfThenElse(is_lo0, TableLookupLanes(v.v0, mod0), + TableLookupLanes(v.v1, mod0)); + ret.v1 = IfThenElse(is_lo1, TableLookupLanes(v.v0, mod1), + TableLookupLanes(v.v1, mod1)); + return ret; +} + +template <typename T> +HWY_API Vec256<T> TableLookupLanesOr0(Vec256<T> v, Indices256<T> idx) { + // The out of bounds behavior will already zero lanes. + return TableLookupLanesOr0(v, idx); +} + +// ------------------------------ Reverse +template <typename T> +HWY_API Vec256<T> Reverse(Full256<T> d, const Vec256<T> v) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order + ret.v0 = Reverse(dh, v.v1); + return ret; +} + +// ------------------------------ Reverse2 +template <typename T> +HWY_API Vec256<T> Reverse2(Full256<T> d, Vec256<T> v) { + const Half<decltype(d)> dh; + v.v0 = Reverse2(dh, v.v0); + v.v1 = Reverse2(dh, v.v1); + return v; +} + +// ------------------------------ Reverse4 + +// Each block has only 2 lanes, so swap blocks and their lanes. +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse4(Full256<T> d, const Vec256<T> v) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Reverse2(dh, v.v1); // swapped + ret.v1 = Reverse2(dh, v.v0); + return ret; +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse4(Full256<T> d, Vec256<T> v) { + const Half<decltype(d)> dh; + v.v0 = Reverse4(dh, v.v0); + v.v1 = Reverse4(dh, v.v1); + return v; +} + +// ------------------------------ Reverse8 + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse8(Full256<T> /* tag */, Vec256<T> /* v */) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// Each block has only 4 lanes, so swap blocks and their lanes. +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Reverse8(Full256<T> d, const Vec256<T> v) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Reverse4(dh, v.v1); // swapped + ret.v1 = Reverse4(dh, v.v0); + return ret; +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x6)> // 1 or 2 bytes +HWY_API Vec256<T> Reverse8(Full256<T> d, Vec256<T> v) { + const Half<decltype(d)> dh; + v.v0 = Reverse8(dh, v.v0); + v.v1 = Reverse8(dh, v.v1); + return v; +} + +// ------------------------------ InterleaveLower + +template <typename T> +HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) { + a.v0 = InterleaveLower(a.v0, b.v0); + a.v1 = InterleaveLower(a.v1, b.v1); + return a; +} + +// wasm_128 already defines a template with D, V, V args. + +// ------------------------------ InterleaveUpper (UpperHalf) + +template <typename T, class V = Vec256<T>> +HWY_API V InterleaveUpper(Full256<T> d, V a, V b) { + const Half<decltype(d)> dh; + a.v0 = InterleaveUpper(dh, a.v0, b.v0); + a.v1 = InterleaveUpper(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <typename T, class DW = RepartitionToWide<Full256<T>>> +HWY_API VFromD<DW> ZipLower(Vec256<T> a, Vec256<T> b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template <typename T, class D = Full256<T>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipLower(DW dw, Vec256<T> a, Vec256<T> b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template <typename T, class D = Full256<T>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipUpper(DW dw, Vec256<T> a, Vec256<T> b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) +template <typename T> +HWY_API Vec256<T> Combine(Full256<T> /* d */, Vec128<T> hi, Vec128<T> lo) { + Vec256<T> ret; + ret.v1 = hi; + ret.v0 = lo; + return ret; +} + +// ------------------------------ ZeroExtendVector (Combine) +template <typename T> +HWY_API Vec256<T> ZeroExtendVector(Full256<T> d, Vec128<T> lo) { + const Half<decltype(d)> dh; + return Combine(d, Zero(dh), lo); +} + +// ------------------------------ ConcatLowerLower +template <typename T> +HWY_API Vec256<T> ConcatLowerLower(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + Vec256<T> ret; + ret.v1 = hi.v0; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatUpperUpper +template <typename T> +HWY_API Vec256<T> ConcatUpperUpper(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + Vec256<T> ret; + ret.v1 = hi.v1; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatLowerUpper +template <typename T> +HWY_API Vec256<T> ConcatLowerUpper(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + Vec256<T> ret; + ret.v1 = hi.v0; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatUpperLower +template <typename T> +HWY_API Vec256<T> ConcatUpperLower(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + Vec256<T> ret; + ret.v1 = hi.v1; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatOdd +template <typename T> +HWY_API Vec256<T> ConcatOdd(Full256<T> d, const Vec256<T> hi, + const Vec256<T> lo) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); + ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ ConcatEven +template <typename T> +HWY_API Vec256<T> ConcatEven(Full256<T> d, const Vec256<T> hi, + const Vec256<T> lo) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = ConcatEven(dh, lo.v1, lo.v0); + ret.v1 = ConcatEven(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ DupEven +template <typename T> +HWY_API Vec256<T> DupEven(Vec256<T> v) { + v.v0 = DupEven(v.v0); + v.v1 = DupEven(v.v1); + return v; +} + +// ------------------------------ DupOdd +template <typename T> +HWY_API Vec256<T> DupOdd(Vec256<T> v) { + v.v0 = DupOdd(v.v0); + v.v1 = DupOdd(v.v1); + return v; +} + +// ------------------------------ OddEven +template <typename T> +HWY_API Vec256<T> OddEven(Vec256<T> a, const Vec256<T> b) { + a.v0 = OddEven(a.v0, b.v0); + a.v1 = OddEven(a.v1, b.v1); + return a; +} + +// ------------------------------ OddEvenBlocks +template <typename T> +HWY_API Vec256<T> OddEvenBlocks(Vec256<T> odd, Vec256<T> even) { + odd.v0 = even.v0; + return odd; +} + +// ------------------------------ SwapAdjacentBlocks +template <typename T> +HWY_API Vec256<T> SwapAdjacentBlocks(Vec256<T> v) { + Vec256<T> ret; + ret.v0 = v.v1; // swapped order + ret.v1 = v.v0; + return ret; +} + +// ------------------------------ ReverseBlocks +template <typename T> +HWY_API Vec256<T> ReverseBlocks(Full256<T> /* tag */, const Vec256<T> v) { + return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +namespace detail { + +// Unsigned: zero-extend. +HWY_API Vec128<uint16_t> PromoteUpperTo(Full128<uint16_t> /* tag */, + const Vec128<uint8_t> v) { + return Vec128<uint16_t>{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +HWY_API Vec128<uint32_t> PromoteUpperTo(Full128<uint32_t> /* tag */, + const Vec128<uint8_t> v) { + return Vec128<uint32_t>{ + wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; +} +HWY_API Vec128<int16_t> PromoteUpperTo(Full128<int16_t> /* tag */, + const Vec128<uint8_t> v) { + return Vec128<int16_t>{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +HWY_API Vec128<int32_t> PromoteUpperTo(Full128<int32_t> /* tag */, + const Vec128<uint8_t> v) { + return Vec128<int32_t>{ + wasm_u32x4_extend_high_u16x8(wasm_u16x8_extend_high_u8x16(v.raw))}; +} +HWY_API Vec128<uint32_t> PromoteUpperTo(Full128<uint32_t> /* tag */, + const Vec128<uint16_t> v) { + return Vec128<uint32_t>{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +HWY_API Vec128<uint64_t> PromoteUpperTo(Full128<uint64_t> /* tag */, + const Vec128<uint32_t> v) { + return Vec128<uint64_t>{wasm_u64x2_extend_high_u32x4(v.raw)}; +} +HWY_API Vec128<int32_t> PromoteUpperTo(Full128<int32_t> /* tag */, + const Vec128<uint16_t> v) { + return Vec128<int32_t>{wasm_u32x4_extend_high_u16x8(v.raw)}; +} + +// Signed: replicate sign bit. +HWY_API Vec128<int16_t> PromoteUpperTo(Full128<int16_t> /* tag */, + const Vec128<int8_t> v) { + return Vec128<int16_t>{wasm_i16x8_extend_high_i8x16(v.raw)}; +} +HWY_API Vec128<int32_t> PromoteUpperTo(Full128<int32_t> /* tag */, + const Vec128<int8_t> v) { + return Vec128<int32_t>{ + wasm_i32x4_extend_high_i16x8(wasm_i16x8_extend_high_i8x16(v.raw))}; +} +HWY_API Vec128<int32_t> PromoteUpperTo(Full128<int32_t> /* tag */, + const Vec128<int16_t> v) { + return Vec128<int32_t>{wasm_i32x4_extend_high_i16x8(v.raw)}; +} +HWY_API Vec128<int64_t> PromoteUpperTo(Full128<int64_t> /* tag */, + const Vec128<int32_t> v) { + return Vec128<int64_t>{wasm_i64x2_extend_high_i32x4(v.raw)}; +} + +HWY_API Vec128<double> PromoteUpperTo(Full128<double> dd, + const Vec128<int32_t> v) { + // There is no wasm_f64x2_convert_high_i32x4. + const Full64<int32_t> di32h; + return PromoteTo(dd, UpperHalf(di32h, v)); +} + +HWY_API Vec128<float> PromoteUpperTo(Full128<float> df32, + const Vec128<float16_t> v) { + const RebindToSigned<decltype(df32)> di32; + const RebindToUnsigned<decltype(df32)> du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteUpperTo(du32, Vec128<uint16_t>{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +} + +HWY_API Vec128<float> PromoteUpperTo(Full128<float> df32, + const Vec128<bfloat16_t> v) { + const Full128<uint16_t> du16; + const RebindToSigned<decltype(df32)> di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +} // namespace detail + +template <typename T, typename TN> +HWY_API Vec256<T> PromoteTo(Full256<T> d, const Vec128<TN> v) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = PromoteTo(dh, LowerHalf(v)); + ret.v1 = detail::PromoteUpperTo(dh, v); + return ret; +} + +// This is the only 4x promotion from 8 to 32-bit. +template <typename TW, typename TN> +HWY_API Vec256<TW> PromoteTo(Full256<TW> d, const Vec64<TN> v) { + const Half<decltype(d)> dh; + const Rebind<MakeWide<TN>, decltype(d)> d2; // 16-bit lanes + const auto v16 = PromoteTo(d2, v); + Vec256<TW> ret; + ret.v0 = PromoteTo(dh, LowerHalf(v16)); + ret.v1 = detail::PromoteUpperTo(dh, v16); + return ret; +} + +// ------------------------------ DemoteTo + +HWY_API Vec128<uint16_t> DemoteTo(Full128<uint16_t> /* tag */, + const Vec256<int32_t> v) { + return Vec128<uint16_t>{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec128<int16_t> DemoteTo(Full128<int16_t> /* tag */, + const Vec256<int32_t> v) { + return Vec128<int16_t>{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec64<uint8_t> DemoteTo(Full64<uint8_t> /* tag */, + const Vec256<int32_t> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64<uint8_t>{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128<uint8_t> DemoteTo(Full128<uint8_t> /* tag */, + const Vec256<int16_t> v) { + return Vec128<uint8_t>{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec64<int8_t> DemoteTo(Full64<int8_t> /* tag */, + const Vec256<int32_t> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64<int8_t>{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +HWY_API Vec128<int8_t> DemoteTo(Full128<int8_t> /* tag */, + const Vec256<int16_t> v) { + return Vec128<int8_t>{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +HWY_API Vec128<int32_t> DemoteTo(Full128<int32_t> di, const Vec256<double> v) { + const Vec64<int32_t> lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64<int32_t> hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +HWY_API Vec128<float16_t> DemoteTo(Full128<float16_t> d16, + const Vec256<float> v) { + const Half<decltype(d16)> d16h; + const Vec64<float16_t> lo = DemoteTo(d16h, v.v0); + const Vec64<float16_t> hi = DemoteTo(d16h, v.v1); + return Combine(d16, hi, lo); +} + +HWY_API Vec128<bfloat16_t> DemoteTo(Full128<bfloat16_t> dbf16, + const Vec256<float> v) { + const Half<decltype(dbf16)> dbf16h; + const Vec64<bfloat16_t> lo = DemoteTo(dbf16h, v.v0); + const Vec64<bfloat16_t> hi = DemoteTo(dbf16h, v.v1); + return Combine(dbf16, hi, lo); +} + +// For already range-limited input [0, 255]. +HWY_API Vec64<uint8_t> U8FromU32(const Vec256<uint32_t> v) { + const Full64<uint8_t> du8; + const Full256<int32_t> di32; // no unsigned DemoteTo + return DemoteTo(du8, BitCast(di32, v)); +} + +// ------------------------------ Truncations + +HWY_API Vec32<uint8_t> TruncateTo(Full32<uint8_t> /* tag */, + const Vec256<uint64_t> v) { + return Vec32<uint8_t>{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, + 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, + 24)}; +} + +HWY_API Vec64<uint16_t> TruncateTo(Full64<uint16_t> /* tag */, + const Vec256<uint64_t> v) { + return Vec64<uint16_t>{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, + 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, + 25)}; +} + +HWY_API Vec128<uint32_t> TruncateTo(Full128<uint32_t> /* tag */, + const Vec256<uint64_t> v) { + return Vec128<uint32_t>{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, + 9, 10, 11, 16, 17, 18, 19, 24, 25, + 26, 27)}; +} + +HWY_API Vec64<uint8_t> TruncateTo(Full64<uint8_t> /* tag */, + const Vec256<uint32_t> v) { + return Vec64<uint8_t>{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, + 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, + 28)}; +} + +HWY_API Vec128<uint16_t> TruncateTo(Full128<uint16_t> /* tag */, + const Vec256<uint32_t> v) { + return Vec128<uint16_t>{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, + 9, 12, 13, 16, 17, 20, 21, 24, 25, + 28, 29)}; +} + +HWY_API Vec128<uint8_t> TruncateTo(Full128<uint8_t> /* tag */, + const Vec256<uint16_t> v) { + return Vec128<uint8_t>{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, + 10, 12, 14, 16, 18, 20, 22, 24, 26, + 28, 30)}; +} + +// ------------------------------ ReorderDemote2To +HWY_API Vec256<bfloat16_t> ReorderDemote2To(Full256<bfloat16_t> dbf16, + Vec256<float> a, Vec256<float> b) { + const RebindToUnsigned<decltype(dbf16)> du16; + return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); +} + +HWY_API Vec256<int16_t> ReorderDemote2To(Full256<int16_t> d16, + Vec256<int32_t> a, Vec256<int32_t> b) { + const Half<decltype(d16)> d16h; + Vec256<int16_t> demoted; + demoted.v0 = DemoteTo(d16h, a); + demoted.v1 = DemoteTo(d16h, b); + return demoted; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template <typename TTo, typename TFrom> +HWY_API Vec256<TTo> ConvertTo(Full256<TTo> d, const Vec256<TFrom> v) { + const Half<decltype(d)> dh; + Vec256<TTo> ret; + ret.v0 = ConvertTo(dh, v.v0); + ret.v1 = ConvertTo(dh, v.v1); + return ret; +} + +HWY_API Vec256<int32_t> NearestInt(const Vec256<float> v) { + return ConvertTo(Full256<int32_t>(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x110)> // 4 or 8 bytes +HWY_API Mask256<T> LoadMaskBits(Full256<T> d, + const uint8_t* HWY_RESTRICT bits) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = LoadMaskBits(dh, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + const uint8_t bits_upper[8] = {static_cast<uint8_t>(bits[0] >> kBitsPerHalf)}; + ret.m1 = LoadMaskBits(dh, bits_upper); + return ret; +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x6)> // 1 or 2 bytes +HWY_API Mask256<T> LoadMaskBits(Full256<T> d, + const uint8_t* HWY_RESTRICT bits) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = LoadMaskBits(dh, bits); + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); + return ret; +} + +// ------------------------------ Mask + +// `p` points to at least 8 writable bytes. +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x110)> // 4 or 8 bytes +HWY_API size_t StoreMaskBits(const Full256<T> d, const Mask256<T> mask, + uint8_t* bits) { + const Half<decltype(d)> dh; + StoreMaskBits(dh, mask.m0, bits); + const uint8_t lo = bits[0]; + StoreMaskBits(dh, mask.m1, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + bits[0] = static_cast<uint8_t>(lo | (bits[0] << kBitsPerHalf)); + return (kBitsPerHalf * 2 + 7) / 8; +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x6)> // 1 or 2 bytes +HWY_API size_t StoreMaskBits(const Full256<T> d, const Mask256<T> mask, + uint8_t* bits) { + const Half<decltype(d)> dh; + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + StoreMaskBits(dh, mask.m0, bits); + StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); + return kBytesPerHalf * 2; +} + +template <typename T> +HWY_API size_t CountTrue(const Full256<T> d, const Mask256<T> m) { + const Half<decltype(d)> dh; + return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); +} + +template <typename T> +HWY_API bool AllFalse(const Full256<T> d, const Mask256<T> m) { + const Half<decltype(d)> dh; + return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); +} + +template <typename T> +HWY_API bool AllTrue(const Full256<T> d, const Mask256<T> m) { + const Half<decltype(d)> dh; + return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); +} + +template <typename T> +HWY_API size_t FindKnownFirstTrue(const Full256<T> d, const Mask256<T> mask) { + const Half<decltype(d)> dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? static_cast<size_t>(lo) + : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); +} + +template <typename T> +HWY_API intptr_t FindFirstTrue(const Full256<T> d, const Mask256<T> mask) { + const Half<decltype(d)> dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); + const intptr_t hi = FindFirstTrue(dh, mask.m1); + if (lo < 0 && hi < 0) return lo; + constexpr int kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? lo : hi + kLanesPerHalf; +} + +// ------------------------------ CompressStore +template <typename T> +HWY_API size_t CompressStore(const Vec256<T> v, const Mask256<T> mask, + Full256<T> d, T* HWY_RESTRICT unaligned) { + const Half<decltype(d)> dh; + const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); + const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBlendedStore +template <typename T> +HWY_API size_t CompressBlendedStore(const Vec256<T> v, const Mask256<T> m, + Full256<T> d, T* HWY_RESTRICT unaligned) { + const Half<decltype(d)> dh; + const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); + const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBitsStore + +template <typename T> +HWY_API size_t CompressBitsStore(const Vec256<T> v, + const uint8_t* HWY_RESTRICT bits, Full256<T> d, + T* HWY_RESTRICT unaligned) { + const Mask256<T> m = LoadMaskBits(d, bits); + return CompressStore(v, m, d, unaligned); +} + +// ------------------------------ Compress + +template <typename T> +HWY_API Vec256<T> Compress(const Vec256<T> v, const Mask256<T> mask) { + const Full256<T> d; + alignas(32) T lanes[32 / sizeof(T)] = {}; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +// ------------------------------ CompressNot +template <typename T> +HWY_API Vec256<T> CompressNot(Vec256<T> v, const Mask256<T> mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256<uint64_t> CompressBlocksNot(Vec256<uint64_t> v, + Mask256<uint64_t> mask) { + const Full128<uint64_t> dh; + // Because the non-selected (mask=1) blocks are undefined, we can return the + // input unless mask = 01, in which case we must bring down the upper block. + return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; +} + +// ------------------------------ CompressBits + +template <typename T> +HWY_API Vec256<T> CompressBits(Vec256<T> v, const uint8_t* HWY_RESTRICT bits) { + const Mask256<T> m = LoadMaskBits(Full256<T>(), bits); + return Compress(v, m); +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template <typename T> +HWY_API void LoadTransposedBlocks3(Full256<T> d, + const T* HWY_RESTRICT unaligned, + Vec256<T>& A, Vec256<T>& B, Vec256<T>& C) { + constexpr size_t N = 32 / sizeof(T); + const Vec256<T> v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const Vec256<T> v32 = LoadU(d, unaligned + 1 * N); + const Vec256<T> v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template <typename T> +HWY_API void LoadTransposedBlocks4(Full256<T> d, + const T* HWY_RESTRICT unaligned, + Vec256<T>& A, Vec256<T>& B, Vec256<T>& C, + Vec256<T>& D) { + constexpr size_t N = 32 / sizeof(T); + const Vec256<T> v10 = LoadU(d, unaligned + 0 * N); + const Vec256<T> v32 = LoadU(d, unaligned + 1 * N); + const Vec256<T> v54 = LoadU(d, unaligned + 2 * N); + const Vec256<T> v76 = LoadU(d, unaligned + 3 * N); + + A = ConcatLowerLower(d, v54, v10); + B = ConcatUpperUpper(d, v54, v10); + C = ConcatLowerLower(d, v76, v32); + D = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template <typename T> +HWY_API void StoreTransposedBlocks2(const Vec256<T> i, const Vec256<T> j, + const Full256<T> d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template <typename T> +HWY_API void StoreTransposedBlocks3(const Vec256<T> i, const Vec256<T> j, + const Vec256<T> k, Full256<T> d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template <typename T> +HWY_API void StoreTransposedBlocks4(const Vec256<T> i, const Vec256<T> j, + const Vec256<T> k, const Vec256<T> l, + Full256<T> d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ ReorderWidenMulAccumulate +template <typename TN, typename TW> +HWY_API Vec256<TW> ReorderWidenMulAccumulate(Full256<TW> d, Vec256<TN> a, + Vec256<TN> b, Vec256<TW> sum0, + Vec256<TW>& sum1) { + const Half<decltype(d)> dh; + sum0.v0 = ReorderWidenMulAccumulate(dh, a.v0, b.v0, sum0.v0, sum1.v0); + sum0.v1 = ReorderWidenMulAccumulate(dh, a.v1, b.v1, sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ RearrangeToOddPlusEven +template <typename TW> +HWY_API Vec256<TW> RearrangeToOddPlusEven(Vec256<TW> sum0, Vec256<TW> sum1) { + sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0); + sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ Reductions + +template <typename T> +HWY_API Vec256<T> SumOfLanes(Full256<T> d, const Vec256<T> v) { + const Half<decltype(d)> dh; + const Vec128<T> lo = SumOfLanes(dh, Add(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +template <typename T> +HWY_API Vec256<T> MinOfLanes(Full256<T> d, const Vec256<T> v) { + const Half<decltype(d)> dh; + const Vec128<T> lo = MinOfLanes(dh, Min(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +template <typename T> +HWY_API Vec256<T> MaxOfLanes(Full256<T> d, const Vec256<T> v) { + const Half<decltype(d)> dh; + const Vec128<T> lo = MaxOfLanes(dh, Max(v.v0, v.v1)); + return Combine(d, lo, lo); +} + +// ------------------------------ Lt128 + +template <typename T> +HWY_INLINE Mask256<T> Lt128(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = Lt128(dh, a.v0, b.v0); + ret.m1 = Lt128(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Mask256<T> Lt128Upper(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = Lt128Upper(dh, a.v0, b.v0); + ret.m1 = Lt128Upper(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Mask256<T> Eq128(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = Eq128(dh, a.v0, b.v0); + ret.m1 = Eq128(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Mask256<T> Eq128Upper(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = Eq128Upper(dh, a.v0, b.v0); + ret.m1 = Eq128Upper(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Mask256<T> Ne128(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = Ne128(dh, a.v0, b.v0); + ret.m1 = Ne128(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Mask256<T> Ne128Upper(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Mask256<T> ret; + ret.m0 = Ne128Upper(dh, a.v0, b.v0); + ret.m1 = Ne128Upper(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Vec256<T> Min128(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Min128(dh, a.v0, b.v0); + ret.v1 = Min128(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Vec256<T> Max128(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Max128(dh, a.v0, b.v0); + ret.v1 = Max128(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Vec256<T> Min128Upper(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Min128Upper(dh, a.v0, b.v0); + ret.v1 = Min128Upper(dh, a.v1, b.v1); + return ret; +} + +template <typename T> +HWY_INLINE Vec256<T> Max128Upper(Full256<T> d, Vec256<T> a, Vec256<T> b) { + const Half<decltype(d)> dh; + Vec256<T> ret; + ret.v0 = Max128Upper(dh, a.v0, b.v0); + ret.v1 = Max128Upper(dh, a.v1, b.v1); + return ret; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/highway/hwy/ops/x86_128-inl.h new file mode 100644 index 0000000000..ba8d581984 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_128-inl.h @@ -0,0 +1,7432 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_GCC_ACTUAL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's emmintrin.h - see +// https://github.com/google/highway/issues/710 and pull/902 +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include <emmintrin.h> +#include <stdio.h> +#if HWY_TARGET == HWY_SSSE3 +#include <tmmintrin.h> // SSSE3 +#else +#include <smmintrin.h> // SSE4 +#include <wmmintrin.h> // CLMUL +#endif +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#include "hwy/ops/shared-inl.h" + +#if HWY_IS_MSAN +#include <sanitizer/msan_interface.h> +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template <typename T> +struct Raw128 { + using type = __m128i; +}; +template <> +struct Raw128<float> { + using type = __m128; +}; +template <> +struct Raw128<double> { + using type = __m128d; +}; + +} // namespace detail + +template <typename T, size_t N = 16 / sizeof(T)> +class Vec128 { + using Raw = typename detail::Raw128<T>::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template <typename T> +using Vec64 = Vec128<T, 8 / sizeof(T)>; + +template <typename T> +using Vec32 = Vec128<T, 4 / sizeof(T)>; + +#if HWY_TARGET <= HWY_AVX3 + +namespace detail { + +// Template arg: sizeof(lane type) +template <size_t size> +struct RawMask128 {}; +template <> +struct RawMask128<1> { + using type = __mmask16; +}; +template <> +struct RawMask128<2> { + using type = __mmask8; +}; +template <> +struct RawMask128<4> { + using type = __mmask8; +}; +template <> +struct RawMask128<8> { + using type = __mmask8; +}; + +} // namespace detail + +template <typename T, size_t N = 16 / sizeof(T)> +struct Mask128 { + using Raw = typename detail::RawMask128<sizeof(T)>::type; + + static Mask128<T, N> FromBits(uint64_t mask_bits) { + return Mask128<T, N>{static_cast<Raw>(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 or below + +// FF..FF or 0. +template <typename T, size_t N = 16 / sizeof(T)> +struct Mask128 { + typename detail::Raw128<T>::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +template <class V> +using DFromV = Simd<typename V::PrivateT, V::kPrivateN, 0>; + +template <class V> +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +template <typename T, size_t N> +HWY_INLINE Vec128<uint8_t, N * sizeof(T)> BitCastToByte(Vec128<T, N> v) { + return Vec128<uint8_t, N * sizeof(T)>{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template <typename T> +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +template <> +struct BitCastFromInteger128<float> { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128<double> { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> BitCastFromByte(Simd<T, N, 0> /* tag */, + Vec128<uint8_t, N * sizeof(T)> v) { + return Vec128<T, N>{BitCastFromInteger128<T>()(v.raw)}; +} + +} // namespace detail + +template <typename T, size_t N, typename FromT> +HWY_API Vec128<T, N> BitCast(Simd<T, N, 0> d, + Vec128<FromT, N * sizeof(T) / sizeof(FromT)> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Zero(Simd<T, N, 0> /* tag */) { + return Vec128<T, N>{_mm_setzero_si128()}; +} +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> Zero(Simd<float, N, 0> /* tag */) { + return Vec128<float, N>{_mm_setzero_ps()}; +} +template <size_t N, HWY_IF_LE128(double, N)> +HWY_API Vec128<double, N> Zero(Simd<double, N, 0> /* tag */) { + return Vec128<double, N>{_mm_setzero_pd()}; +} + +template <class D> +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template <size_t N, HWY_IF_LE128(uint8_t, N)> +HWY_API Vec128<uint8_t, N> Set(Simd<uint8_t, N, 0> /* tag */, const uint8_t t) { + return Vec128<uint8_t, N>{_mm_set1_epi8(static_cast<char>(t))}; // NOLINT +} +template <size_t N, HWY_IF_LE128(uint16_t, N)> +HWY_API Vec128<uint16_t, N> Set(Simd<uint16_t, N, 0> /* tag */, + const uint16_t t) { + return Vec128<uint16_t, N>{_mm_set1_epi16(static_cast<short>(t))}; // NOLINT +} +template <size_t N, HWY_IF_LE128(uint32_t, N)> +HWY_API Vec128<uint32_t, N> Set(Simd<uint32_t, N, 0> /* tag */, + const uint32_t t) { + return Vec128<uint32_t, N>{_mm_set1_epi32(static_cast<int>(t))}; +} +template <size_t N, HWY_IF_LE128(uint64_t, N)> +HWY_API Vec128<uint64_t, N> Set(Simd<uint64_t, N, 0> /* tag */, + const uint64_t t) { + return Vec128<uint64_t, N>{ + _mm_set1_epi64x(static_cast<long long>(t))}; // NOLINT +} +template <size_t N, HWY_IF_LE128(int8_t, N)> +HWY_API Vec128<int8_t, N> Set(Simd<int8_t, N, 0> /* tag */, const int8_t t) { + return Vec128<int8_t, N>{_mm_set1_epi8(static_cast<char>(t))}; // NOLINT +} +template <size_t N, HWY_IF_LE128(int16_t, N)> +HWY_API Vec128<int16_t, N> Set(Simd<int16_t, N, 0> /* tag */, const int16_t t) { + return Vec128<int16_t, N>{_mm_set1_epi16(static_cast<short>(t))}; // NOLINT +} +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> Set(Simd<int32_t, N, 0> /* tag */, const int32_t t) { + return Vec128<int32_t, N>{_mm_set1_epi32(t)}; +} +template <size_t N, HWY_IF_LE128(int64_t, N)> +HWY_API Vec128<int64_t, N> Set(Simd<int64_t, N, 0> /* tag */, const int64_t t) { + return Vec128<int64_t, N>{ + _mm_set1_epi64x(static_cast<long long>(t))}; // NOLINT +} +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> Set(Simd<float, N, 0> /* tag */, const float t) { + return Vec128<float, N>{_mm_set1_ps(t)}; +} +template <size_t N, HWY_IF_LE128(double, N)> +HWY_API Vec128<double, N> Set(Simd<double, N, 0> /* tag */, const double t) { + return Vec128<double, N>{_mm_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Undefined(Simd<T, N, 0> /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec128<T, N>{_mm_undefined_si128()}; +} +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> Undefined(Simd<float, N, 0> /* tag */) { + return Vec128<float, N>{_mm_undefined_ps()}; +} +template <size_t N, HWY_IF_LE128(double, N)> +HWY_API Vec128<double, N> Undefined(Simd<double, N, 0> /* tag */) { + return Vec128<double, N>{_mm_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API T GetLane(const Vec128<T, N> v) { + return static_cast<T>(_mm_cvtsi128_si32(v.raw) & 0xFF); +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API T GetLane(const Vec128<T, N> v) { + return static_cast<T>(_mm_cvtsi128_si32(v.raw) & 0xFFFF); +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API T GetLane(const Vec128<T, N> v) { + return static_cast<T>(_mm_cvtsi128_si32(v.raw)); +} +template <size_t N> +HWY_API float GetLane(const Vec128<float, N> v) { + return _mm_cvtss_f32(v.raw); +} +template <size_t N> +HWY_API uint64_t GetLane(const Vec128<uint64_t, N> v) { +#if HWY_ARCH_X86_32 + alignas(16) uint64_t lanes[2]; + Store(v, Simd<uint64_t, N, 0>(), lanes); + return lanes[0]; +#else + return static_cast<uint64_t>(_mm_cvtsi128_si64(v.raw)); +#endif +} +template <size_t N> +HWY_API int64_t GetLane(const Vec128<int64_t, N> v) { +#if HWY_ARCH_X86_32 + alignas(16) int64_t lanes[2]; + Store(v, Simd<int64_t, N, 0>(), lanes); + return lanes[0]; +#else + return _mm_cvtsi128_si64(v.raw); +#endif +} +template <size_t N> +HWY_API double GetLane(const Vec128<double, N> v) { + return _mm_cvtsd_f64(v.raw); +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template <typename T, size_t N> +HWY_API Vec128<T, N> And(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>{_mm_and_si128(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<float, N> And(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_and_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> And(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template <typename T, size_t N> +HWY_API Vec128<T, N> AndNot(Vec128<T, N> not_mask, Vec128<T, N> mask) { + return Vec128<T, N>{_mm_andnot_si128(not_mask.raw, mask.raw)}; +} +template <size_t N> +HWY_API Vec128<float, N> AndNot(const Vec128<float, N> not_mask, + const Vec128<float, N> mask) { + return Vec128<float, N>{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> AndNot(const Vec128<double, N> not_mask, + const Vec128<double, N> mask) { + return Vec128<double, N>{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template <typename T, size_t N> +HWY_API Vec128<T, N> Or(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>{_mm_or_si128(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> Or(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_or_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> Or(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor(Vec128<T, N> a, Vec128<T, N> b) { + return Vec128<T, N>{_mm_xor_si128(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> Xor(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_xor_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> Xor(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template <typename T, size_t N> +HWY_API Vec128<T, N> Not(const Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; +#if HWY_TARGET <= HWY_AVX3 + const __m128i vu = BitCast(du, v).raw; + return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template <typename T, size_t N> +HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV<decltype(x1)> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template <typename T, size_t N> +HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV<decltype(o1)> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template <typename T, size_t N> +HWY_API Vec128<T, N> OrAnd(Vec128<T, N> o, Vec128<T, N> a1, Vec128<T, N> a2) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV<decltype(o)> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template <typename T, size_t N> +HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { +#if HWY_TARGET <= HWY_AVX3 + const DFromV<decltype(no)> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + return BitCast( + d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator&(const Vec128<T, N> a, const Vec128<T, N> b) { + return And(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator|(const Vec128<T, N> a, const Vec128<T, N> b) { + return Or(a, b); +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator^(const Vec128<T, N> a, const Vec128<T, N> b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128<T, N> v) { + return Vec128<T, N>{_mm_popcnt_epi8(v.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128<T, N> v) { + return Vec128<T, N>{_mm_popcnt_epi16(v.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128<T, N> v) { + return Vec128<T, N>{_mm_popcnt_epi32(v.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128<T, N> v) { + return Vec128<T, N>{_mm_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> PopulationCount(Vec128<T, N> v) { + return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Neg(hwy::FloatTag /*tag*/, const Vec128<T, N> v) { + return Xor(v, SignBit(DFromV<decltype(v)>())); +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Neg(hwy::NonFloatTag /*tag*/, const Vec128<T, N> v) { + return Zero(DFromV<decltype(v)>()) - v; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> Neg(const Vec128<T, N> v) { + return detail::Neg(hwy::IsFloatTag<T>(), v); +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template <size_t N> +HWY_API Vec128<int8_t, N> Abs(const Vec128<int8_t, N> v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (reaches breakpoint) + const auto zero = Zero(DFromV<decltype(v)>()); + return Vec128<int8_t, N>{_mm_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec128<int8_t, N>{_mm_abs_epi8(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int16_t, N> Abs(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{_mm_abs_epi16(v.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Abs(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{_mm_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. +template <size_t N> +HWY_API Vec128<float, N> Abs(const Vec128<float, N> v) { + const Vec128<int32_t, N> mask{_mm_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(DFromV<decltype(v)>(), mask); +} +template <size_t N> +HWY_API Vec128<double, N> Abs(const Vec128<double, N> v) { + const Vec128<int64_t, N> mask{_mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(DFromV<decltype(v)>(), mask); +} + +// ------------------------------ CopySign + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySign(const Vec128<T, N> magn, + const Vec128<T, N> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + + const DFromV<decltype(magn)> d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned<decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m128i out = _mm_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, VFromD<decltype(du)>{out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> CopySignToAbs(const Vec128<T, N> abs, + const Vec128<T, N> sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(DFromV<decltype(abs)>()), sign)); +#endif +} + +// ================================================== MASK + +namespace detail { + +template <typename T> +HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) { + // Workaround for MSAN not marking compressstore as initialized (b/233326619) +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElse(Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); +} + +template <size_t N> +HWY_API Vec128<float, N> IfThenElse(Mask128<float, N> mask, + Vec128<float, N> yes, Vec128<float, N> no) { + return Vec128<float, N>{_mm_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} + +template <size_t N> +HWY_API Vec128<double, N> IfThenElse(Mask128<double, N> mask, + Vec128<double, N> yes, + Vec128<double, N> no) { + return Vec128<double, N>{_mm_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes) { + return Vec128<T, N>{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes) { + return Vec128<T, N>{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes) { + return Vec128<T, N>{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128<T, N> mask, Vec128<T, N> yes) { + return Vec128<T, N>{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElseZero(Mask128<T, N> mask, Vec128<T, N> yes) { + return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); +} + +template <size_t N> +HWY_API Vec128<float, N> IfThenElseZero(Mask128<float, N> mask, + Vec128<float, N> yes) { + return Vec128<float, N>{_mm_maskz_mov_ps(mask.raw, yes.raw)}; +} + +template <size_t N> +HWY_API Vec128<double, N> IfThenElseZero(Mask128<double, N> mask, + Vec128<double, N> yes) { + return Vec128<double, N>{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128<T, N> mask, Vec128<T, N> no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128<T, N>{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128<T, N> mask, Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128<T, N> mask, Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128<T, N> mask, Vec128<T, N> no) { + return Vec128<T, N>{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenZeroElse(Mask128<T, N> mask, Vec128<T, N> no) { + return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); +} + +template <size_t N> +HWY_API Vec128<float, N> IfThenZeroElse(Mask128<float, N> mask, + Vec128<float, N> no) { + return Vec128<float, N>{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} + +template <size_t N> +HWY_API Vec128<double, N> IfThenZeroElse(Mask128<double, N> mask, + Vec128<double, N> no) { + return Vec128<double, N>{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// ------------------------------ Mask logical + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> And(hwy::SizeTag<1> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> And(hwy::SizeTag<2> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> And(hwy::SizeTag<4> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> And(hwy::SizeTag<8> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> AndNot(hwy::SizeTag<1> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> AndNot(hwy::SizeTag<2> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> AndNot(hwy::SizeTag<4> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> AndNot(hwy::SizeTag<8> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Or(hwy::SizeTag<1> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Or(hwy::SizeTag<2> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Or(hwy::SizeTag<4> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Or(hwy::SizeTag<8> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Xor(hwy::SizeTag<1> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Xor(hwy::SizeTag<2> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Xor(hwy::SizeTag<4> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Xor(hwy::SizeTag<8> /*tag*/, const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128<T, N> a, + const Mask128<T, N> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128<T, N>{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; +#else + return Mask128<T, N>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; +#endif +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Mask128<T, N> And(const Mask128<T, N> a, Mask128<T, N> b) { + return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> AndNot(const Mask128<T, N> a, Mask128<T, N> b) { + return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Or(const Mask128<T, N> a, Mask128<T, N> b) { + return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Xor(const Mask128<T, N> a, Mask128<T, N> b) { + return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Not(const Mask128<T, N> m) { + // Flip only the valid bits. + // TODO(janwas): use _knot intrinsics if N >= 8. + return Xor(m, Mask128<T, N>::FromBits((1ull << N) - 1)); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) { + return detail::ExclusiveNeither(hwy::SizeTag<sizeof(T)>(), a, b); +} + +#else // AVX2 or below + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template <typename T, size_t N> +HWY_API Mask128<T, N> MaskFromVec(const Vec128<T, N> v) { + return Mask128<T, N>{v.raw}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> VecFromMask(const Mask128<T, N> v) { + return Vec128<T, N>{v.raw}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> VecFromMask(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> v) { + return Vec128<T, N>{v.raw}; +} + +#if HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElse(Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + const auto vmask = VecFromMask(DFromV<decltype(no)>(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); +} + +#else // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : no +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElse(Mask128<T, N> mask, Vec128<T, N> yes, + Vec128<T, N> no) { + return Vec128<T, N>{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template <size_t N> +HWY_API Vec128<float, N> IfThenElse(const Mask128<float, N> mask, + const Vec128<float, N> yes, + const Vec128<float, N> no) { + return Vec128<float, N>{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> IfThenElse(const Mask128<double, N> mask, + const Vec128<double, N> yes, + const Vec128<double, N> no) { + return Vec128<double, N>{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +// mask ? yes : 0 +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenElseZero(Mask128<T, N> mask, Vec128<T, N> yes) { + return yes & VecFromMask(DFromV<decltype(yes)>(), mask); +} + +// mask ? 0 : no +template <typename T, size_t N> +HWY_API Vec128<T, N> IfThenZeroElse(Mask128<T, N> mask, Vec128<T, N> no) { + return AndNot(VecFromMask(DFromV<decltype(no)>(), mask), no); +} + +// ------------------------------ Mask logical + +template <typename T, size_t N> +HWY_API Mask128<T, N> Not(const Mask128<T, N> m) { + return MaskFromVec(Not(VecFromMask(Simd<T, N, 0>(), m))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> And(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> AndNot(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Or(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> Xor(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) { + const Simd<T, N, 0> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ ShiftLeft + +template <int kBits, size_t N> +HWY_API Vec128<uint16_t, N> ShiftLeft(const Vec128<uint16_t, N> v) { + return Vec128<uint16_t, N>{_mm_slli_epi16(v.raw, kBits)}; +} + +template <int kBits, size_t N> +HWY_API Vec128<uint32_t, N> ShiftLeft(const Vec128<uint32_t, N> v) { + return Vec128<uint32_t, N>{_mm_slli_epi32(v.raw, kBits)}; +} + +template <int kBits, size_t N> +HWY_API Vec128<uint64_t, N> ShiftLeft(const Vec128<uint64_t, N> v) { + return Vec128<uint64_t, N>{_mm_slli_epi64(v.raw, kBits)}; +} + +template <int kBits, size_t N> +HWY_API Vec128<int16_t, N> ShiftLeft(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{_mm_slli_epi16(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int32_t, N> ShiftLeft(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{_mm_slli_epi32(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int64_t, N> ShiftLeft(const Vec128<int64_t, N> v) { + return Vec128<int64_t, N>{_mm_slli_epi64(v.raw, kBits)}; +} + +template <int kBits, typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> ShiftLeft(const Vec128<T, N> v) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<T, N> shifted{ShiftLeft<kBits>(Vec128<MakeWide<T>>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template <int kBits, size_t N> +HWY_API Vec128<uint16_t, N> ShiftRight(const Vec128<uint16_t, N> v) { + return Vec128<uint16_t, N>{_mm_srli_epi16(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint32_t, N> ShiftRight(const Vec128<uint32_t, N> v) { + return Vec128<uint32_t, N>{_mm_srli_epi32(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<uint64_t, N> ShiftRight(const Vec128<uint64_t, N> v) { + return Vec128<uint64_t, N>{_mm_srli_epi64(v.raw, kBits)}; +} + +template <int kBits, size_t N> +HWY_API Vec128<uint8_t, N> ShiftRight(const Vec128<uint8_t, N> v) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<uint8_t, N> shifted{ + ShiftRight<kBits>(Vec128<uint16_t>{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template <int kBits, size_t N> +HWY_API Vec128<int16_t, N> ShiftRight(const Vec128<int16_t, N> v) { + return Vec128<int16_t, N>{_mm_srai_epi16(v.raw, kBits)}; +} +template <int kBits, size_t N> +HWY_API Vec128<int32_t, N> ShiftRight(const Vec128<int32_t, N> v) { + return Vec128<int32_t, N>{_mm_srai_epi32(v.raw, kBits)}; +} + +template <int kBits, size_t N> +HWY_API Vec128<int8_t, N> ShiftRight(const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template <typename T, size_t N, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytes(const Vec128<T, N> bytes, + const Vec128<TI, NI> from) { + return Vec128<TI, NI>{_mm_shuffle_epi8(bytes.raw, from.raw)}; +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80. +template <class V, class VI> +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128<int32_t> have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template <typename T, size_t N> +HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<T, N>{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +template <size_t N> +HWY_API Vec128<float, N> Shuffle2301(const Vec128<float, N> v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128<float, N>{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> Shuffle2301(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const Twice<DFromV<decltype(a)>> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {1, 0, 7, 6}; + return Vec128<T, 4>{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, 4> Shuffle2301(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const Twice<DFromV<decltype(a)>> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0302, 0x0100, 0x0f0e, 0x0d0c}; + return Vec128<T, 4>{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, 4> Shuffle2301(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const DFromV<decltype(a)> d; + const RebindToFloat<decltype(d)> df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128<float, 4>{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> Shuffle1230(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const Twice<DFromV<decltype(a)>> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0, 3, 6, 5}; + return Vec128<T, 4>{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, 4> Shuffle1230(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const Twice<DFromV<decltype(a)>> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0100, 0x0706, 0x0d0c, 0x0b0a}; + return Vec128<T, 4>{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, 4> Shuffle1230(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const DFromV<decltype(a)> d; + const RebindToFloat<decltype(d)> df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128<float, 4>{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, 4> Shuffle3012(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const Twice<DFromV<decltype(a)>> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {2, 1, 4, 7}; + return Vec128<T, 4>{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, 4> Shuffle3012(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const Twice<DFromV<decltype(a)>> d2; + const auto ba = Combine(d2, b, a); + alignas(16) const T kShuffle[8] = {0x0504, 0x0302, 0x0908, 0x0f0e}; + return Vec128<T, 4>{TableLookupBytes(ba, Load(d2, kShuffle)).raw}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, 4> Shuffle3012(const Vec128<T, 4> a, const Vec128<T, 4> b) { + const DFromV<decltype(a)> d; + const RebindToFloat<decltype(d)> df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128<float, 4>{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec128<uint32_t> Shuffle1032(const Vec128<uint32_t> v) { + return Vec128<uint32_t>{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128<int32_t> Shuffle1032(const Vec128<int32_t> v) { + return Vec128<int32_t>{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128<float> Shuffle1032(const Vec128<float> v) { + return Vec128<float>{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128<uint64_t> Shuffle01(const Vec128<uint64_t> v) { + return Vec128<uint64_t>{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128<int64_t> Shuffle01(const Vec128<int64_t> v) { + return Vec128<int64_t>{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128<double> Shuffle01(const Vec128<double> v) { + return Vec128<double>{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128<uint32_t> Shuffle0321(const Vec128<uint32_t> v) { + return Vec128<uint32_t>{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128<int32_t> Shuffle0321(const Vec128<int32_t> v) { + return Vec128<int32_t>{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128<float> Shuffle0321(const Vec128<float> v) { + return Vec128<float>{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128<uint32_t> Shuffle2103(const Vec128<uint32_t> v) { + return Vec128<uint32_t>{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128<int32_t> Shuffle2103(const Vec128<int32_t> v) { + return Vec128<int32_t>{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128<float> Shuffle2103(const Vec128<float> v) { + return Vec128<float>{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128<uint32_t> Shuffle0123(const Vec128<uint32_t> v) { + return Vec128<uint32_t>{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128<int32_t> Shuffle0123(const Vec128<int32_t> v) { + return Vec128<int32_t>{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128<float> Shuffle0123(const Vec128<float> v) { + return Vec128<float>{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template <typename TFrom, size_t NFrom, typename TTo, size_t NTo> +HWY_API Mask128<TTo, NTo> RebindMask(Simd<TTo, NTo, 0> /*tag*/, + Mask128<TFrom, NFrom> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask128<TTo, NTo>{m.raw}; +} + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> TestBit(hwy::SizeTag<1> /*tag*/, const Vec128<T, N> v, + const Vec128<T, N> bit) { + return Mask128<T, N>{_mm_test_epi8_mask(v.raw, bit.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> TestBit(hwy::SizeTag<2> /*tag*/, const Vec128<T, N> v, + const Vec128<T, N> bit) { + return Mask128<T, N>{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> TestBit(hwy::SizeTag<4> /*tag*/, const Vec128<T, N> v, + const Vec128<T, N> bit) { + return Mask128<T, N>{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> TestBit(hwy::SizeTag<8> /*tag*/, const Vec128<T, N> v, + const Vec128<T, N> bit) { + return Mask128<T, N>{_mm_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Mask128<T, N> TestBit(const Vec128<T, N> v, const Vec128<T, N> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); +} + +// ------------------------------ Equality + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask128<T, N> operator==(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask128<T, N> operator==(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask128<T, N> operator==(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask128<T, N> operator==(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Mask128<float, N> operator==(Vec128<float, N> a, Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template <size_t N> +HWY_API Mask128<double, N> operator==(Vec128<double, N> a, + Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask128<T, N> operator!=(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask128<T, N> operator!=(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask128<T, N> operator!=(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask128<T, N> operator!=(const Vec128<T, N> a, const Vec128<T, N> b) { + return Mask128<T, N>{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Mask128<float, N> operator!=(Vec128<float, N> a, Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template <size_t N> +HWY_API Mask128<double, N> operator!=(Vec128<double, N> a, + Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template <size_t N> +HWY_API Mask128<int8_t, N> operator>(Vec128<int8_t, N> a, Vec128<int8_t, N> b) { + return Mask128<int8_t, N>{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int16_t, N> operator>(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Mask128<int16_t, N>{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int32_t, N> operator>(Vec128<int32_t, N> a, + Vec128<int32_t, N> b) { + return Mask128<int32_t, N>{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int64_t, N> operator>(Vec128<int64_t, N> a, + Vec128<int64_t, N> b) { + return Mask128<int64_t, N>{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Mask128<uint8_t, N> operator>(Vec128<uint8_t, N> a, + Vec128<uint8_t, N> b) { + return Mask128<uint8_t, N>{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint16_t, N> operator>(Vec128<uint16_t, N> a, + Vec128<uint16_t, N> b) { + return Mask128<uint16_t, N>{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint32_t, N> operator>(Vec128<uint32_t, N> a, + Vec128<uint32_t, N> b) { + return Mask128<uint32_t, N>{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint64_t, N> operator>(Vec128<uint64_t, N> a, + Vec128<uint64_t, N> b) { + return Mask128<uint64_t, N>{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Mask128<float, N> operator>(Vec128<float, N> a, Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template <size_t N> +HWY_API Mask128<double, N> operator>(Vec128<double, N> a, Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +template <size_t N> +HWY_API Mask128<float, N> operator>=(Vec128<float, N> a, Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template <size_t N> +HWY_API Mask128<double, N> operator>=(Vec128<double, N> a, + Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128<T, N> v) { + return Mask128<T, N>{_mm_movepi8_mask(v.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128<T, N> v) { + return Mask128<T, N>{_mm_movepi16_mask(v.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128<T, N> v) { + return Mask128<T, N>{_mm_movepi32_mask(v.raw)}; +} +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128<T, N> v) { + return Mask128<T, N>{_mm_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Mask128<T, N> MaskFromVec(const Vec128<T, N> v) { + return detail::MaskFromVec(hwy::SizeTag<sizeof(T)>(), v); +} +// There do not seem to be native floating-point versions of these instructions. +template <size_t N> +HWY_API Mask128<float, N> MaskFromVec(const Vec128<float, N> v) { + const RebindToSigned<DFromV<decltype(v)>> di; + return Mask128<float, N>{MaskFromVec(BitCast(di, v)).raw}; +} +template <size_t N> +HWY_API Mask128<double, N> MaskFromVec(const Vec128<double, N> v) { + const RebindToSigned<DFromV<decltype(v)>> di; + return Mask128<double, N>{MaskFromVec(BitCast(di, v)).raw}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> VecFromMask(const Mask128<T, N> v) { + return Vec128<T, N>{_mm_movm_epi8(v.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> VecFromMask(const Mask128<T, N> v) { + return Vec128<T, N>{_mm_movm_epi16(v.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> VecFromMask(const Mask128<T, N> v) { + return Vec128<T, N>{_mm_movm_epi32(v.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> VecFromMask(const Mask128<T, N> v) { + return Vec128<T, N>{_mm_movm_epi64(v.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> VecFromMask(const Mask128<float, N> v) { + return Vec128<float, N>{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} + +template <size_t N> +HWY_API Vec128<double, N> VecFromMask(const Mask128<double, N> v) { + return Vec128<double, N>{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N> VecFromMask(Simd<T, N, 0> /* tag */, + const Mask128<T, N> v) { + return VecFromMask(v); +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template <typename TFrom, typename TTo, size_t N> +HWY_API Mask128<TTo, N> RebindMask(Simd<TTo, N, 0> /*tag*/, + Mask128<TFrom, N> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + const Simd<TFrom, N, 0> d; + return MaskFromVec(BitCast(Simd<TTo, N, 0>(), VecFromMask(d, m))); +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> TestBit(Vec128<T, N> v, Vec128<T, N> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template <size_t N> +HWY_API Mask128<uint8_t, N> operator==(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Mask128<uint8_t, N>{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint16_t, N> operator==(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Mask128<uint16_t, N>{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint32_t, N> operator==(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Mask128<uint32_t, N>{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<uint64_t, N> operator==(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + const Simd<uint32_t, N * 2, 0> d32; + const Simd<uint64_t, N, 0> d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128<uint64_t, N>{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template <size_t N> +HWY_API Mask128<int8_t, N> operator==(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Mask128<int8_t, N>{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int16_t, N> operator==(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Mask128<int16_t, N>{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int32_t, N> operator==(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Mask128<int32_t, N>{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<int64_t, N> operator==(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV<decltype(a)> d; + RebindToUnsigned<decltype(d)> du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template <size_t N> +HWY_API Mask128<float, N> operator==(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<double, N> operator==(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template <size_t N> +HWY_API Mask128<uint8_t, N> operator!=(Vec128<uint8_t, N> a, + Vec128<uint8_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<uint16_t, N> operator!=(Vec128<uint16_t, N> a, + Vec128<uint16_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<uint32_t, N> operator!=(Vec128<uint32_t, N> a, + Vec128<uint32_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<uint64_t, N> operator!=(Vec128<uint64_t, N> a, + Vec128<uint64_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<int8_t, N> operator!=(Vec128<int8_t, N> a, + Vec128<int8_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<int16_t, N> operator!=(Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<int32_t, N> operator!=(Vec128<int32_t, N> a, + Vec128<int32_t, N> b) { + return Not(a == b); +} +template <size_t N> +HWY_API Mask128<int64_t, N> operator!=(Vec128<int64_t, N> a, + Vec128<int64_t, N> b) { + return Not(a == b); +} + +template <size_t N> +HWY_API Mask128<float, N> operator!=(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<double, N> operator!=(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template <size_t N> +HWY_INLINE Mask128<int8_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<int8_t, N> a, + Vec128<int8_t, N> b) { + return Mask128<int8_t, N>{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_INLINE Mask128<int16_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<int16_t, N> a, + Vec128<int16_t, N> b) { + return Mask128<int16_t, N>{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_INLINE Mask128<int32_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<int32_t, N> a, + Vec128<int32_t, N> b) { + return Mask128<int32_t, N>{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template <size_t N> +HWY_INLINE Mask128<int64_t, N> Gt(hwy::SignedTag /*tag*/, + const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const Simd<int64_t, N, 0> d; + const RepartitionToNarrow<decltype(d)> d32; + const Vec128<int64_t, N> m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128<int64_t, N> m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128<int64_t, N>{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128<int64_t, N>{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> Gt(hwy::UnsignedTag /*tag*/, Vec128<T, N> a, + Vec128<T, N> b) { + const DFromV<decltype(a)> du; + const RebindToSigned<decltype(du)> di; + const Vec128<T, N> msb = Set(du, (LimitsMax<T>() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template <size_t N> +HWY_INLINE Mask128<float, N> Gt(hwy::FloatTag /*tag*/, Vec128<float, N> a, + Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_INLINE Mask128<double, N> Gt(hwy::FloatTag /*tag*/, Vec128<double, N> a, + Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> operator>(Vec128<T, N> a, Vec128<T, N> b) { + return detail::Gt(hwy::TypeTag<T>(), a, b); +} + +// ------------------------------ Weak inequality + +template <size_t N> +HWY_API Mask128<float, N> operator>=(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Mask128<float, N>{_mm_cmpge_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Mask128<double, N> operator>=(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Mask128<double, N>{_mm_cmpge_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator<(Vec128<T, N> a, Vec128<T, N> b) { + return b > a; +} + +template <typename T, size_t N> +HWY_API Mask128<T, N> operator<=(Vec128<T, N> a, Vec128<T, N> b) { + return b >= a; +} + +// ------------------------------ FirstN (Iota, Lt) + +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Mask128<T, N> FirstN(const Simd<T, N, 0> d, size_t num) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of num! + const uint64_t bits = (num > 255) ? all : _bzhi_u64(all, num); + return Mask128<T, N>::FromBits(bits); +#else + const RebindToSigned<decltype(d)> di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast<MakeSigned<T>>(num))); +#endif +} + +template <class D> +using MFromD = decltype(FirstN(D(), 0)); + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". This is a false alarm because msan does not +// raise errors. We work around this by using CopyBytes instead of intrinsics, +// but only for the analyzer to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template <typename T> +HWY_API Vec128<T> Load(Full128<T> /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128<T>{_mm_load_si128(reinterpret_cast<const __m128i*>(aligned))}; +} +HWY_API Vec128<float> Load(Full128<float> /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec128<float>{_mm_load_ps(aligned)}; +} +HWY_API Vec128<double> Load(Full128<double> /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec128<double>{_mm_load_pd(aligned)}; +} + +template <typename T> +HWY_API Vec128<T> LoadU(Full128<T> /* tag */, const T* HWY_RESTRICT p) { + return Vec128<T>{_mm_loadu_si128(reinterpret_cast<const __m128i*>(p))}; +} +HWY_API Vec128<float> LoadU(Full128<float> /* tag */, + const float* HWY_RESTRICT p) { + return Vec128<float>{_mm_loadu_ps(p)}; +} +HWY_API Vec128<double> LoadU(Full128<double> /* tag */, + const double* HWY_RESTRICT p) { + return Vec128<double>{_mm_loadu_pd(p)}; +} + +template <typename T> +HWY_API Vec64<T> Load(Full64<T> /* tag */, const T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size + return Vec64<T>{v}; +#else + return Vec64<T>{_mm_loadl_epi64(reinterpret_cast<const __m128i*>(p))}; +#endif +} + +HWY_API Vec128<float, 2> Load(Full64<float> /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec128<float, 2>{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec128<float, 2>{_mm_loadl_pi(hi, reinterpret_cast<const __m64*>(p))}; +#endif +} + +HWY_API Vec64<double> Load(Full64<double> /* tag */, + const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64<double>{v}; +#else + return Vec64<double>{_mm_load_sd(p)}; +#endif +} + +HWY_API Vec128<float, 1> Load(Full32<float> /* tag */, + const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec128<float, 1>{v}; +#else + return Vec128<float, 1>{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except <float, 1> +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API Vec128<T, N> Load(Simd<T, N, 0> /* tag */, const T* HWY_RESTRICT p) { + constexpr size_t kSize = sizeof(T) * N; +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<kSize>(p, &v); // not same size + return Vec128<T, N>{v}; +#else + int32_t bits = 0; + CopyBytes<kSize>(p, &bits); // not same size + return Vec128<T, N>{_mm_cvtsi32_si128(bits)}; +#endif +} + +// For < 128 bit, LoadU == Load. +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> LoadU(Simd<T, N, 0> d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> LoadDup128(Simd<T, N, 0> d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// Returns a vector with lane i=[0, N) set to "first" + i. +template <typename T, size_t N, typename T2, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Iota(const Simd<T, N, 0> d, const T2 first) { + HWY_ALIGN T lanes[16 / sizeof(T)]; + for (size_t i = 0; i < 16 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag<T>(), static_cast<T>(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT p) { + return Vec128<T, N>{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT p) { + return Vec128<T, N>{_mm_maskz_loadu_epi16(m.raw, p)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT p) { + return Vec128<T, N>{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT p) { + return Vec128<T, N>{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template <size_t N> +HWY_API Vec128<float, N> MaskedLoad(Mask128<float, N> m, + Simd<float, N, 0> /* tag */, + const float* HWY_RESTRICT p) { + return Vec128<float, N>{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template <size_t N> +HWY_API Vec128<double, N> MaskedLoad(Mask128<double, N> m, + Simd<double, N, 0> /* tag */, + const double* HWY_RESTRICT p) { + return Vec128<double, N>{_mm_maskz_loadu_pd(m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast<const int*>(p); // NOLINT + return Vec128<T, N>{_mm_maskload_epi32(p_p, m.raw)}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> /* tag */, + const T* HWY_RESTRICT p) { + auto p_p = reinterpret_cast<const long long*>(p); // NOLINT + return Vec128<T, N>{_mm_maskload_epi64(p_p, m.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> MaskedLoad(Mask128<float, N> m, Simd<float, N, 0> d, + const float* HWY_RESTRICT p) { + const Vec128<int32_t, N> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + return Vec128<float, N>{_mm_maskload_ps(p, mi.raw)}; +} + +template <size_t N> +HWY_API Vec128<double, N> MaskedLoad(Mask128<double, N> m, Simd<double, N, 0> d, + const double* HWY_RESTRICT p) { + const Vec128<int64_t, N> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + return Vec128<double, N>{_mm_maskload_pd(p, mi.raw)}; +} + +// There is no maskload_epi8/16, so blend instead. +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 6)> // 1 or 2 bytes +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#else // <= SSE4 + +// Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). +template <typename T, size_t N> +HWY_API Vec128<T, N> MaskedLoad(Mask128<T, N> m, Simd<T, N, 0> d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, Load(d, p)); +} + +#endif + +// ------------------------------ Store + +template <typename T> +HWY_API void Store(Vec128<T> v, Full128<T> /* tag */, T* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +HWY_API void Store(const Vec128<float> v, Full128<float> /* tag */, + float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec128<double> v, Full128<double> /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template <typename T> +HWY_API void StoreU(Vec128<T> v, Full128<T> /* tag */, T* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +HWY_API void StoreU(const Vec128<float> v, Full128<float> /* tag */, + float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec128<double> v, Full128<double> /* tag */, + double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template <typename T> +HWY_API void Store(Vec64<T> v, Full64<T> /* tag */, T* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec128<float, 2> v, Full64<float> /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +HWY_API void Store(const Vec64<double> v, Full64<double> /* tag */, + double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except <float, 1> +template <typename T, size_t N, HWY_IF_LE32(T, N)> +HWY_API void Store(Vec128<T, N> v, Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + CopyBytes<sizeof(T) * N>(&v, p); // not same size +} +HWY_API void Store(const Vec128<float, 1> v, Full32<float> /* tag */, + float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API void StoreU(const Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ------------------------------ BlendedStore + +namespace detail { + +// There is no maskload_epi8/16 with which we could safely implement +// BlendedStore. Manual blending is also unsafe because loading a full vector +// that crosses the array end causes asan faults. Resort to scalar code; the +// caller should instead use memcpy, assuming m is FirstN(d, n). +template <typename T, size_t N> +HWY_API void ScalarMaskedStore(Vec128<T, N> v, Mask128<T, N> m, Simd<T, N, 0> d, + T* HWY_RESTRICT p) { + const RebindToSigned<decltype(d)> di; // for testing mask if T=bfloat16_t. + using TI = TFromD<decltype(di)>; + alignas(16) TI buf[N]; + alignas(16) TI mask[N]; + Store(BitCast(di, v), di, buf); + Store(BitCast(di, VecFromMask(d, m)), di, mask); + for (size_t i = 0; i < N; ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi8(p, m.raw, v.raw); +} +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + _mm_mask_storeu_epi16(p, m.raw, v.raw); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast<int*>(p); // NOLINT + _mm_mask_storeu_epi32(pi, m.raw, v.raw); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + auto pi = reinterpret_cast<long long*>(p); // NOLINT + _mm_mask_storeu_epi64(pi, m.raw, v.raw); +} + +template <size_t N> +HWY_API void BlendedStore(Vec128<float, N> v, Mask128<float, N> m, + Simd<float, N, 0>, float* HWY_RESTRICT p) { + _mm_mask_storeu_ps(p, m.raw, v.raw); +} + +template <size_t N> +HWY_API void BlendedStore(Vec128<double, N> v, Mask128<double, N> m, + Simd<double, N, 0>, double* HWY_RESTRICT p) { + _mm_mask_storeu_pd(p, m.raw, v.raw); +} + +#elif HWY_TARGET == HWY_AVX2 + +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 6)> // 1 or 2 bytes +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, Simd<T, N, 0> d, + T* HWY_RESTRICT p) { + detail::ScalarMaskedStore(v, m, d, p); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128<T> df; + const Mask128<T> mf{m.raw}; + m = Mask128<T, N>{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast<int*>(p); // NOLINT + _mm_maskstore_epi32(pi, m.raw, v.raw); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT p) { + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128<T> df; + const Mask128<T> mf{m.raw}; + m = Mask128<T, N>{And(mf, FirstN(df, N)).raw}; + } + + auto pi = reinterpret_cast<long long*>(p); // NOLINT + _mm_maskstore_epi64(pi, m.raw, v.raw); +} + +template <size_t N> +HWY_API void BlendedStore(Vec128<float, N> v, Mask128<float, N> m, + Simd<float, N, 0> d, float* HWY_RESTRICT p) { + using T = float; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 4) { + const Full128<T> df; + const Mask128<T> mf{m.raw}; + m = Mask128<T, N>{And(mf, FirstN(df, N)).raw}; + } + + const Vec128<MakeSigned<T>, N> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + _mm_maskstore_ps(p, mi.raw, v.raw); +} + +template <size_t N> +HWY_API void BlendedStore(Vec128<double, N> v, Mask128<double, N> m, + Simd<double, N, 0> d, double* HWY_RESTRICT p) { + using T = double; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (N < 2) { + const Full128<T> df; + const Mask128<T> mf{m.raw}; + m = Mask128<T, N>{And(mf, FirstN(df, N)).raw}; + } + + const Vec128<MakeSigned<T>, N> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + _mm_maskstore_pd(p, mi.raw, v.raw); +} + +#else // <= SSE4 + +template <typename T, size_t N> +HWY_API void BlendedStore(Vec128<T, N> v, Mask128<T, N> m, Simd<T, N, 0> d, + T* HWY_RESTRICT p) { + // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). + detail::ScalarMaskedStore(v, m, d, p); +} + +#endif // SSE4 + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> operator+(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_add_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> operator+(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_add_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> operator+(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{_mm_add_epi32(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> operator+(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> operator+(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{_mm_add_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> operator+(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_add_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> operator+(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{_mm_add_epi32(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> operator+(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> operator+(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_add_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> operator+(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> operator-(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_sub_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> operator-(Vec128<uint16_t, N> a, + Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_sub_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> operator-(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{_mm_sub_epi32(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> operator-(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> operator-(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{_mm_sub_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> operator-(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_sub_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> operator-(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{_mm_sub_epi32(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> operator-(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> operator-(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_sub_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> operator-(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +template <size_t N> +HWY_API Vec128<uint64_t, N / 8> SumsOf8(const Vec128<uint8_t, N> v) { + return Vec128<uint64_t, N / 8>{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> SaturatedAdd(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_adds_epu8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> SaturatedAdd(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> SaturatedAdd(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{_mm_adds_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> SaturatedAdd(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> SaturatedSub(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_subs_epu8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> SaturatedSub(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> SaturatedSub(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{_mm_subs_epi8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> SaturatedSub(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> AverageRound(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_avg_epu8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> AverageRound(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +template <size_t N> +HWY_API Vec128<uint16_t, N> operator*(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_mullo_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> operator*(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_mullo_epi16(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +template <size_t N> +HWY_API Vec128<uint16_t, N> MulHigh(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int16_t, N> MulHigh(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Vec128<int16_t, N> MulFixedPoint15(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template <size_t N> +HWY_API Vec128<uint64_t, (N + 1) / 2> MulEven(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint64_t, (N + 1) / 2>{_mm_mul_epu32(a.raw, b.raw)}; +} + +#if HWY_TARGET == HWY_SSSE3 + +template <size_t N, HWY_IF_LE64(int32_t, N)> // N=1 or 2 +HWY_API Vec128<int64_t, (N + 1) / 2> MulEven(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Set(Simd<int64_t, (N + 1) / 2, 0>(), + static_cast<int64_t>(GetLane(a)) * GetLane(b)); +} +HWY_API Vec128<int64_t> MulEven(const Vec128<int32_t> a, + const Vec128<int32_t> b) { + alignas(16) int32_t a_lanes[4]; + alignas(16) int32_t b_lanes[4]; + const Full128<int32_t> di32; + Store(a, di32, a_lanes); + Store(b, di32, b_lanes); + alignas(16) int64_t mul[2]; + mul[0] = static_cast<int64_t>(a_lanes[0]) * b_lanes[0]; + mul[1] = static_cast<int64_t>(a_lanes[2]) * b_lanes[2]; + return Load(Full128<int64_t>(), mul); +} + +#else // HWY_TARGET == HWY_SSSE3 + +template <size_t N> +HWY_API Vec128<int64_t, (N + 1) / 2> MulEven(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int64_t, (N + 1) / 2>{_mm_mul_epi32(a.raw, b.raw)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 + +template <size_t N> +HWY_API Vec128<uint32_t, N> operator*(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + // Not as inefficient as it looks: _mm_mullo_epi32 has 10 cycle latency. + // 64-bit right shift would also work but also needs port 5, so no benefit. + // Notation: x=don't care, z=0. + const __m128i a_x3x1 = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x2x0 = MulEven(a, b); + const __m128i b_x3x1 = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x3x1 = + MulEven(Vec128<uint32_t, N>{a_x3x1}, Vec128<uint32_t, N>{b_x3x1}); + // We could _mm_slli_epi64 by 32 to get 3z1z and OR with z2z0, but generating + // the latter requires one more instruction or a constant. + const __m128i mul_20 = + _mm_shuffle_epi32(mullo_x2x0.raw, _MM_SHUFFLE(2, 0, 2, 0)); + const __m128i mul_31 = + _mm_shuffle_epi32(mullo_x3x1.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128<uint32_t, N>{_mm_unpacklo_epi32(mul_20, mul_31)}; +#else + return Vec128<uint32_t, N>{_mm_mullo_epi32(a.raw, b.raw)}; +#endif +} + +template <size_t N> +HWY_API Vec128<int32_t, N> operator*(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + // Same as unsigned; avoid duplicating the SSSE3 code. + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +// ------------------------------ RotateRight (ShiftRight, Or) + +template <int kBits, size_t N> +HWY_API Vec128<uint32_t, N> RotateRight(const Vec128<uint32_t, N> v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128<uint32_t, N>{_mm_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(31, 32 - kBits)>(v)); +#endif +} + +template <int kBits, size_t N> +HWY_API Vec128<uint64_t, N> RotateRight(const Vec128<uint64_t, N> v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128<uint64_t, N>{_mm_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(63, 64 - kBits)>(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template <size_t N> +HWY_API Vec128<int8_t, N> BroadcastSignBit(const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> d; + return VecFromMask(v < Zero(d)); +} + +template <size_t N> +HWY_API Vec128<int16_t, N> BroadcastSignBit(const Vec128<int16_t, N> v) { + return ShiftRight<15>(v); +} + +template <size_t N> +HWY_API Vec128<int32_t, N> BroadcastSignBit(const Vec128<int32_t, N> v) { + return ShiftRight<31>(v); +} + +template <size_t N> +HWY_API Vec128<int64_t, N> BroadcastSignBit(const Vec128<int64_t, N> v) { + const DFromV<decltype(v)> d; +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Vec128<int64_t, N>{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 + return VecFromMask(v < Zero(d)); +#else + // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift + // avoids generating a zero. + const RepartitionToNarrow<decltype(d)> d32; + const auto sign = ShiftRight<31>(BitCast(d32, v)); + return Vec128<int64_t, N>{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +template <size_t N> +HWY_API Vec128<int64_t, N> Abs(const Vec128<int64_t, N> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int64_t, N>{_mm_abs_epi64(v.raw)}; +#else + const auto zero = Zero(DFromV<decltype(v)>()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +template <int kBits, size_t N> +HWY_API Vec128<int64_t, N> ShiftRight(const Vec128<int64_t, N> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int64_t, N>{_mm_srai_epi64(v.raw, kBits)}; +#else + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const auto right = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ ZeroIfNegative (BroadcastSignBit) +template <typename T, size_t N> +HWY_API Vec128<T, N> ZeroIfNegative(Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only works for float"); + const DFromV<decltype(v)> d; +#if HWY_TARGET == HWY_SSSE3 + const RebindToSigned<decltype(d)> di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#else + const auto mask = MaskFromVec(v); // MSB is sufficient for BLENDVPS +#endif + return IfThenElse(mask, Zero(d), v); +} + +// ------------------------------ IfNegativeThenElse +template <size_t N> +HWY_API Vec128<int8_t, N> IfNegativeThenElse(const Vec128<int8_t, N> v, + const Vec128<int8_t, N> yes, + const Vec128<int8_t, N> no) { + // int8: IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes, + Vec128<T, N> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + const DFromV<decltype(v)> d; + const RebindToSigned<decltype(d)> di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes, + Vec128<T, N> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + const DFromV<decltype(v)> d; + const RebindToFloat<decltype(d)> df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + return BitCast(d, IfThenElse(MaskFromVec(BitCast(df, v)), BitCast(df, yes), + BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +template <size_t N> +HWY_API Vec128<uint16_t, N> ShiftLeftSame(const Vec128<uint16_t, N> v, + const int bits) { + return Vec128<uint16_t, N>{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> ShiftLeftSame(const Vec128<uint32_t, N> v, + const int bits) { + return Vec128<uint32_t, N>{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> ShiftLeftSame(const Vec128<uint64_t, N> v, + const int bits) { + return Vec128<uint64_t, N>{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <size_t N> +HWY_API Vec128<int16_t, N> ShiftLeftSame(const Vec128<int16_t, N> v, + const int bits) { + return Vec128<int16_t, N>{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <size_t N> +HWY_API Vec128<int32_t, N> ShiftLeftSame(const Vec128<int32_t, N> v, + const int bits) { + return Vec128<int32_t, N>{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <size_t N> +HWY_API Vec128<int64_t, N> ShiftLeftSame(const Vec128<int64_t, N> v, + const int bits) { + return Vec128<int64_t, N>{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> ShiftLeftSame(const Vec128<T, N> v, const int bits) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<T, N> shifted{ + ShiftLeftSame(Vec128<MakeWide<T>>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template <size_t N> +HWY_API Vec128<uint16_t, N> ShiftRightSame(const Vec128<uint16_t, N> v, + const int bits) { + return Vec128<uint16_t, N>{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template <size_t N> +HWY_API Vec128<uint32_t, N> ShiftRightSame(const Vec128<uint32_t, N> v, + const int bits) { + return Vec128<uint32_t, N>{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template <size_t N> +HWY_API Vec128<uint64_t, N> ShiftRightSame(const Vec128<uint64_t, N> v, + const int bits) { + return Vec128<uint64_t, N>{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> ShiftRightSame(Vec128<uint8_t, N> v, + const int bits) { + const DFromV<decltype(v)> d8; + // Use raw instead of BitCast to support N=1. + const Vec128<uint8_t, N> shifted{ + ShiftRightSame(Vec128<uint16_t>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits)); +} + +template <size_t N> +HWY_API Vec128<int16_t, N> ShiftRightSame(const Vec128<int16_t, N> v, + const int bits) { + return Vec128<int16_t, N>{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <size_t N> +HWY_API Vec128<int32_t, N> ShiftRightSame(const Vec128<int32_t, N> v, + const int bits) { + return Vec128<int32_t, N>{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template <size_t N> +HWY_API Vec128<int64_t, N> ShiftRightSame(const Vec128<int64_t, N> v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int64_t, N>{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template <size_t N> +HWY_API Vec128<int8_t, N> ShiftRightSame(Vec128<int8_t, N> v, const int bits) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Floating-point mul / div + +template <size_t N> +HWY_API Vec128<float, N> operator*(Vec128<float, N> a, Vec128<float, N> b) { + return Vec128<float, N>{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128<float, 1> operator*(const Vec128<float, 1> a, + const Vec128<float, 1> b) { + return Vec128<float, 1>{_mm_mul_ss(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> operator*(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec64<double> operator*(const Vec64<double> a, const Vec64<double> b) { + return Vec64<double>{_mm_mul_sd(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> operator/(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128<float, 1> operator/(const Vec128<float, 1> a, + const Vec128<float, 1> b) { + return Vec128<float, 1>{_mm_div_ss(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> operator/(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec64<double> operator/(const Vec64<double> a, const Vec64<double> b) { + return Vec64<double>{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocal(const Vec128<float, N> v) { + return Vec128<float, N>{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128<float, 1> ApproximateReciprocal(const Vec128<float, 1> v) { + return Vec128<float, 1>{_mm_rcp_ss(v.raw)}; +} + +// Absolute value of difference. +template <size_t N> +HWY_API Vec128<float, N> AbsDiff(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template <size_t N> +HWY_API Vec128<float, N> MulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128<float, N>{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<double, N> MulAdd(const Vec128<double, N> mul, + const Vec128<double, N> x, + const Vec128<double, N> add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x + add; +#else + return Vec128<double, N>{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template <size_t N> +HWY_API Vec128<float, N> NegMulAdd(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128<float, N>{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<double, N> NegMulAdd(const Vec128<double, N> mul, + const Vec128<double, N> x, + const Vec128<double, N> add) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return add - mul * x; +#else + return Vec128<double, N>{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template <size_t N> +HWY_API Vec128<float, N> MulSub(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128<float, N>{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<double, N> MulSub(const Vec128<double, N> mul, + const Vec128<double, N> x, + const Vec128<double, N> sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return mul * x - sub; +#else + return Vec128<double, N>{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template <size_t N> +HWY_API Vec128<float, N> NegMulSub(const Vec128<float, N> mul, + const Vec128<float, N> x, + const Vec128<float, N> sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128<float, N>{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<double, N> NegMulSub(const Vec128<double, N> mul, + const Vec128<double, N> x, + const Vec128<double, N> sub) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return Neg(mul) * x - sub; +#else + return Vec128<double, N>{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template <size_t N> +HWY_API Vec128<float, N> Sqrt(const Vec128<float, N> v) { + return Vec128<float, N>{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128<float, 1> Sqrt(const Vec128<float, 1> v) { + return Vec128<float, 1>{_mm_sqrt_ss(v.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> Sqrt(const Vec128<double, N> v) { + return Vec128<double, N>{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec64<double> Sqrt(const Vec64<double> v) { + return Vec64<double>{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +template <size_t N> +HWY_API Vec128<float, N> ApproximateReciprocalSqrt(const Vec128<float, N> v) { + return Vec128<float, N>{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128<float, 1> ApproximateReciprocalSqrt(const Vec128<float, 1> v) { + return Vec128<float, 1>{_mm_rsqrt_ss(v.raw)}; +} + +// ------------------------------ Min (Gt, IfThenElse) + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE HWY_MAYBE_UNUSED Vec128<T, N> MinU(const Vec128<T, N> a, + const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; + const auto msb = Set(du, static_cast<T>(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +} + +} // namespace detail + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> Min(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_min_epu8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> Min(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128<uint16_t, N>{_mm_min_epu16(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint32_t, N> Min(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128<uint32_t, N>{_mm_min_epu32(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint64_t, N> Min(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<uint64_t, N>{_mm_min_epu64(a.raw, b.raw)}; +#else + return detail::MinU(a, b); +#endif +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> Min(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128<int8_t, N>{_mm_min_epi8(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int16_t, N> Min(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_min_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Min(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128<int32_t, N>{_mm_min_epi32(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int64_t, N> Min(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int64_t, N>{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> Min(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_min_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> Min(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +namespace detail { +template <typename T, size_t N> +HWY_INLINE HWY_MAYBE_UNUSED Vec128<T, N> MaxU(const Vec128<T, N> a, + const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; + const auto msb = Set(du, static_cast<T>(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +} + +} // namespace detail + +// Unsigned +template <size_t N> +HWY_API Vec128<uint8_t, N> Max(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_max_epu8(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<uint16_t, N> Max(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128<uint16_t, N>{_mm_max_epu16(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint32_t, N> Max(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128<uint32_t, N>{_mm_max_epu32(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint64_t, N> Max(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<uint64_t, N>{_mm_max_epu64(a.raw, b.raw)}; +#else + return detail::MaxU(a, b); +#endif +} + +// Signed +template <size_t N> +HWY_API Vec128<int8_t, N> Max(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128<int8_t, N>{_mm_max_epi8(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int16_t, N> Max(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_max_epi16(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<int32_t, N> Max(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { +#if HWY_TARGET == HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128<int32_t, N>{_mm_max_epi32(a.raw, b.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int64_t, N> Max(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int64_t, N>{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +template <size_t N> +HWY_API Vec128<float, N> Max(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_max_ps(a.raw, b.raw)}; +} +template <size_t N> +HWY_API Vec128<double, N> Max(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_max_pd(a.raw, b.raw)}; +} + +// ================================================== MEMORY (2) + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template <typename T, size_t N> +HWY_API void Stream(Vec128<T, N> v, Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT aligned) { + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +template <size_t N> +HWY_API void Stream(const Vec128<float, N> v, Simd<float, N, 0> /* tag */, + float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template <size_t N> +HWY_API void Stream(const Vec128<double, N> v, Simd<double, N, 0> /* tag */, + double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(runtime/int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template <typename T, size_t N> +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec128<T, N> v, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT base, + const Vec128<int32_t, N> offset) { + if (N == 4) { + _mm_i32scatter_epi32(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, offset.raw, v.raw, 1); + } +} +template <typename T, size_t N> +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec128<T, N> v, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT base, + const Vec128<int32_t, N> index) { + if (N == 4) { + _mm_i32scatter_epi32(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, 4); + } +} + +template <typename T, size_t N> +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec128<T, N> v, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT base, + const Vec128<int64_t, N> offset) { + if (N == 2) { + _mm_i64scatter_epi64(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, offset.raw, v.raw, 1); + } +} +template <typename T, size_t N> +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec128<T, N> v, + Simd<T, N, 0> /* tag */, T* HWY_RESTRICT base, + const Vec128<int64_t, N> index) { + if (N == 2) { + _mm_i64scatter_epi64(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, 8); + } +} + +} // namespace detail + +template <typename T, size_t N, typename Offset> +HWY_API void ScatterOffset(Vec128<T, N> v, Simd<T, N, 0> d, + T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag<sizeof(T)>(), v, d, base, offset); +} +template <typename T, size_t N, typename Index> +HWY_API void ScatterIndex(Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag<sizeof(T)>(), v, d, base, index); +} + +template <size_t N> +HWY_API void ScatterOffset(Vec128<float, N> v, Simd<float, N, 0> /* tag */, + float* HWY_RESTRICT base, + const Vec128<int32_t, N> offset) { + if (N == 4) { + _mm_i32scatter_ps(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, offset.raw, v.raw, 1); + } +} +template <size_t N> +HWY_API void ScatterIndex(Vec128<float, N> v, Simd<float, N, 0> /* tag */, + float* HWY_RESTRICT base, + const Vec128<int32_t, N> index) { + if (N == 4) { + _mm_i32scatter_ps(base, index.raw, v.raw, 4); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, 4); + } +} + +template <size_t N> +HWY_API void ScatterOffset(Vec128<double, N> v, Simd<double, N, 0> /* tag */, + double* HWY_RESTRICT base, + const Vec128<int64_t, N> offset) { + if (N == 2) { + _mm_i64scatter_pd(base, offset.raw, v.raw, 1); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, offset.raw, v.raw, 1); + } +} +template <size_t N> +HWY_API void ScatterIndex(Vec128<double, N> v, Simd<double, N, 0> /* tag */, + double* HWY_RESTRICT base, + const Vec128<int64_t, N> index) { + if (N == 2) { + _mm_i64scatter_pd(base, index.raw, v.raw, 8); + } else { + const __mmask8 mask = (1u << N) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, 8); + } +} +#else // HWY_TARGET <= HWY_AVX3 + +template <typename T, size_t N, typename Offset, HWY_IF_LE128(T, N)> +HWY_API void ScatterOffset(Vec128<T, N> v, Simd<T, N, 0> d, + T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind<Offset, decltype(d)>(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast<uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template <typename T, size_t N, typename Index, HWY_IF_LE128(T, N)> +HWY_API void ScatterIndex(Vec128<T, N> v, Simd<T, N, 0> d, T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) T lanes[N]; + Store(v, d, lanes); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind<Index, decltype(d)>(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +template <typename T, size_t N, typename Offset> +HWY_API Vec128<T, N> GatherOffset(const Simd<T, N, 0> d, + const T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + alignas(16) Offset offset_lanes[N]; + Store(offset, Rebind<Offset, decltype(d)>(), offset_lanes); + + alignas(16) T lanes[N]; + const uint8_t* base_bytes = reinterpret_cast<const uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template <typename T, size_t N, typename Index> +HWY_API Vec128<T, N> GatherIndex(const Simd<T, N, 0> d, + const T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + alignas(16) Index index_lanes[N]; + Store(index, Rebind<Index, decltype(d)>(), index_lanes); + + alignas(16) T lanes[N]; + for (size_t i = 0; i < N; ++i) { + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +#else + +namespace detail { + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> GatherOffset(hwy::SizeTag<4> /* tag */, + Simd<T, N, 0> /* d */, + const T* HWY_RESTRICT base, + const Vec128<int32_t, N> offset) { + return Vec128<T, N>{_mm_i32gather_epi32( + reinterpret_cast<const int32_t*>(base), offset.raw, 1)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> GatherIndex(hwy::SizeTag<4> /* tag */, + Simd<T, N, 0> /* d */, + const T* HWY_RESTRICT base, + const Vec128<int32_t, N> index) { + return Vec128<T, N>{_mm_i32gather_epi32( + reinterpret_cast<const int32_t*>(base), index.raw, 4)}; +} + +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> GatherOffset(hwy::SizeTag<8> /* tag */, + Simd<T, N, 0> /* d */, + const T* HWY_RESTRICT base, + const Vec128<int64_t, N> offset) { + return Vec128<T, N>{_mm_i64gather_epi64( + reinterpret_cast<const GatherIndex64*>(base), offset.raw, 1)}; +} +template <typename T, size_t N> +HWY_INLINE Vec128<T, N> GatherIndex(hwy::SizeTag<8> /* tag */, + Simd<T, N, 0> /* d */, + const T* HWY_RESTRICT base, + const Vec128<int64_t, N> index) { + return Vec128<T, N>{_mm_i64gather_epi64( + reinterpret_cast<const GatherIndex64*>(base), index.raw, 8)}; +} + +} // namespace detail + +template <typename T, size_t N, typename Offset> +HWY_API Vec128<T, N> GatherOffset(Simd<T, N, 0> d, const T* HWY_RESTRICT base, + const Vec128<Offset, N> offset) { + return detail::GatherOffset(hwy::SizeTag<sizeof(T)>(), d, base, offset); +} +template <typename T, size_t N, typename Index> +HWY_API Vec128<T, N> GatherIndex(Simd<T, N, 0> d, const T* HWY_RESTRICT base, + const Vec128<Index, N> index) { + return detail::GatherIndex(hwy::SizeTag<sizeof(T)>(), d, base, index); +} + +template <size_t N> +HWY_API Vec128<float, N> GatherOffset(Simd<float, N, 0> /* tag */, + const float* HWY_RESTRICT base, + const Vec128<int32_t, N> offset) { + return Vec128<float, N>{_mm_i32gather_ps(base, offset.raw, 1)}; +} +template <size_t N> +HWY_API Vec128<float, N> GatherIndex(Simd<float, N, 0> /* tag */, + const float* HWY_RESTRICT base, + const Vec128<int32_t, N> index) { + return Vec128<float, N>{_mm_i32gather_ps(base, index.raw, 4)}; +} + +template <size_t N> +HWY_API Vec128<double, N> GatherOffset(Simd<double, N, 0> /* tag */, + const double* HWY_RESTRICT base, + const Vec128<int64_t, N> offset) { + return Vec128<double, N>{_mm_i64gather_pd(base, offset.raw, 1)}; +} +template <size_t N> +HWY_API Vec128<double, N> GatherIndex(Simd<double, N, 0> /* tag */, + const double* HWY_RESTRICT base, + const Vec128<int64_t, N> index) { + return Vec128<double, N>{_mm_i64gather_pd(base, index.raw, 8)}; +} + +#endif // HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +// Returns upper/lower half of a vector. +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Simd<T, N / 2, 0> /* tag */, + Vec128<T, N> v) { + return Vec128<T, N / 2>{v.raw}; +} + +template <typename T, size_t N> +HWY_API Vec128<T, N / 2> LowerHalf(Vec128<T, N> v) { + return LowerHalf(Simd<T, N / 2, 0>(), v); +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec128<T, N>{_mm_slli_si128(v.raw, kBytes)}; +} + +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftBytes(const Vec128<T, N> v) { + return ShiftLeftBytes<kBytes>(DFromV<decltype(v)>(), v); +} + +// ------------------------------ ShiftLeftLanes + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftLeftLanes(const Vec128<T, N> v) { + return ShiftLeftLanes<kLanes>(DFromV<decltype(v)>(), v); +} + +// ------------------------------ ShiftRightBytes +template <int kBytes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightBytes(Simd<T, N, 0> /* tag */, Vec128<T, N> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // For partial vectors, clear upper lanes so we shift in zeros. + if (N != 16 / sizeof(T)) { + const Vec128<T> vfull{v.raw}; + v = Vec128<T, N>{IfThenElseZero(FirstN(Full128<T>(), N), vfull).raw}; + } + return Vec128<T, N>{_mm_srli_si128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template <int kLanes, typename T, size_t N> +HWY_API Vec128<T, N> ShiftRightLanes(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template <typename T> +HWY_API Vec64<T> UpperHalf(Half<Full128<T>> /* tag */, Vec128<T> v) { + return Vec64<T>{_mm_unpackhi_epi64(v.raw, v.raw)}; +} +HWY_API Vec128<float, 2> UpperHalf(Full64<float> /* tag */, Vec128<float> v) { + return Vec128<float, 2>{_mm_movehl_ps(v.raw, v.raw)}; +} +HWY_API Vec64<double> UpperHalf(Full64<double> /* tag */, Vec128<double> v) { + return Vec64<double>{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, (N + 1) / 2> UpperHalf(Half<Simd<T, N, 0>> /* tag */, + Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto vu = BitCast(du, v); + const auto upper = BitCast(d, ShiftRightBytes<N * sizeof(T) / 2>(du, vu)); + return Vec128<T, (N + 1) / 2>{upper.raw}; +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const int pair = _mm_extract_epi16(v.raw, kLane / 2); + constexpr int kShift = kLane & 1 ? 8 : 0; + return static_cast<T>((pair >> kShift) & 0xFF); +#else + return static_cast<T>(_mm_extract_epi8(v.raw, kLane) & 0xFF); +#endif +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast<T>(_mm_extract_epi16(v.raw, kLane) & 0xFFFF); +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[kLane]; +#else + return static_cast<T>(_mm_extract_epi32(v.raw, kLane)); +#endif +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE T ExtractLane(const Vec128<T, N> v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + alignas(16) T lanes[2]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[kLane]; +#else + return static_cast<T>(_mm_extract_epi64(v.raw, kLane)); +#endif +} + +template <size_t kLane, size_t N> +HWY_INLINE float ExtractLane(const Vec128<float, N> v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) float lanes[4]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[kLane]; +#else + // Bug in the intrinsic, returns int but should be float. + const int32_t bits = _mm_extract_ps(v.raw, kLane); + float ret; + CopySameSize(&bits, &ret); + return ret; +#endif +} + +// There is no extract_pd; two overloads because there is no UpperHalf for N=1. +template <size_t kLane> +HWY_INLINE double ExtractLane(const Vec128<double, 1> v) { + static_assert(kLane == 0, "Lane index out of bounds"); + return GetLane(v); +} + +template <size_t kLane> +HWY_INLINE double ExtractLane(const Vec128<double> v) { + static_assert(kLane < 2, "Lane index out of bounds"); + const Half<DFromV<decltype(v)>> dh; + return kLane == 0 ? GetLane(v) : GetLane(UpperHalf(dh, v)); +} + +} // namespace detail + +// Requires one overload per vector length because ExtractLane<3> may be a +// compile error if it calls _mm_extract_epi64. +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 1> v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 2> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 4> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 8> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +template <typename T> +HWY_API T ExtractLane(const Vec128<T, 16> v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV<decltype(v)>(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV<decltype(v)> d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128<T, N>{_mm_insert_epi8(v.raw, t, kLane)}; +#endif +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128<T, N>{_mm_insert_epi16(v.raw, t, kLane)}; +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + alignas(16) T lanes[4]; + const DFromV<decltype(v)> d; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned<T> ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128<T, N>{_mm_insert_epi32(v.raw, ti, kLane)}; +#endif +} + +template <size_t kLane, typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 || HWY_ARCH_X86_32 + const DFromV<decltype(v)> d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + MakeSigned<T> ti; + CopySameSize(&t, &ti); // don't just cast because T might be float. + return Vec128<T, N>{_mm_insert_epi64(v.raw, ti, kLane)}; +#endif +} + +template <size_t kLane, size_t N> +HWY_INLINE Vec128<float, N> InsertLane(const Vec128<float, N> v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET == HWY_SSSE3 + const DFromV<decltype(v)> d; + alignas(16) float lanes[4]; + Store(v, d, lanes); + lanes[kLane] = t; + return Load(d, lanes); +#else + return Vec128<float, N>{_mm_insert_ps(v.raw, _mm_set_ss(t), kLane << 4)}; +#endif +} + +// There is no insert_pd; two overloads because there is no UpperHalf for N=1. +template <size_t kLane> +HWY_INLINE Vec128<double, 1> InsertLane(const Vec128<double, 1> v, double t) { + static_assert(kLane == 0, "Lane index out of bounds"); + return Set(DFromV<decltype(v)>(), t); +} + +template <size_t kLane> +HWY_INLINE Vec128<double> InsertLane(const Vec128<double> v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV<decltype(v)> d; + const Vec128<double> vt = Set(d, t); + if (kLane == 0) { + return Vec128<double>{_mm_shuffle_pd(vt.raw, v.raw, 2)}; + } + return Vec128<double>{_mm_shuffle_pd(v.raw, vt.raw, 0)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls _mm_insert_epi64. + +template <typename T> +HWY_API Vec128<T, 1> InsertLane(const Vec128<T, 1> v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV<decltype(v)>(), t); +} + +template <typename T> +HWY_API Vec128<T, 2> InsertLane(const Vec128<T, 2> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 4> InsertLane(const Vec128<T, 4> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 8> InsertLane(const Vec128<T, 8> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template <typename T> +HWY_API Vec128<T, 16> InsertLane(const Vec128<T, 16> v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV<decltype(v)> d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ CombineShiftRightBytes + +template <int kBytes, typename T, class V = Vec128<T>> +HWY_API V CombineShiftRightBytes(Full128<T> d, V hi, V lo) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Vec128<uint8_t>{_mm_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +template <int kBytes, typename T, size_t N, HWY_IF_LE64(T, N), + class V = Vec128<T, N>> +HWY_API V CombineShiftRightBytes(Simd<T, N, 0> d, V hi, V lo) { + constexpr size_t kSize = N * sizeof(T); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition<uint8_t, decltype(d)> d8; + const Full128<uint8_t> d_full8; + using V8 = VFromD<decltype(d_full8)>; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(d_full8, hi8, lo8); + return V{BitCast(Full128<T>(), r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template <int kLane, size_t N> +HWY_API Vec128<uint16_t, N> Broadcast(const Vec128<uint16_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128<uint16_t, N>{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128<uint16_t, N>{_mm_unpackhi_epi64(hi, hi)}; + } +} +template <int kLane, size_t N> +HWY_API Vec128<uint32_t, N> Broadcast(const Vec128<uint32_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<uint32_t, N>{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template <int kLane, size_t N> +HWY_API Vec128<uint64_t, N> Broadcast(const Vec128<uint64_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<uint64_t, N>{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template <int kLane, size_t N> +HWY_API Vec128<int16_t, N> Broadcast(const Vec128<int16_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec128<int16_t, N>{_mm_unpacklo_epi64(lo, lo)}; + } else { + const __m128i hi = _mm_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec128<int16_t, N>{_mm_unpackhi_epi64(hi, hi)}; + } +} +template <int kLane, size_t N> +HWY_API Vec128<int32_t, N> Broadcast(const Vec128<int32_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<int32_t, N>{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template <int kLane, size_t N> +HWY_API Vec128<int64_t, N> Broadcast(const Vec128<int64_t, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<int64_t, N>{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template <int kLane, size_t N> +HWY_API Vec128<float, N> Broadcast(const Vec128<float, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<float, N>{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template <int kLane, size_t N> +HWY_API Vec128<double, N> Broadcast(const Vec128<double, N> v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128<double, N>{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template <typename T, size_t N = 16 / sizeof(T)> +struct Indices128 { + __m128i raw; +}; + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N), + HWY_IF_LANE_SIZE(T, 4)> +HWY_API Indices128<T, N> IndicesFromVec(Simd<T, N, 0> d, Vec128<TI, N> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind<TI, decltype(d)> di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, N)))); +#endif + +#if HWY_TARGET <= HWY_AVX2 + (void)d; + return Indices128<T, N>{vec.raw}; +#else + const Repartition<uint8_t, decltype(d)> d8; + using V8 = VFromD<decltype(d8)>; + alignas(16) constexpr uint8_t kByteOffsets[16] = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition<uint16_t, decltype(d)> d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + + return Indices128<T, N>{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif +} + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N), + HWY_IF_LANE_SIZE(T, 8)> +HWY_API Indices128<T, N> IndicesFromVec(Simd<T, N, 0> d, Vec128<TI, N> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind<TI, decltype(d)> di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast<TI>(N))))); +#else + (void)d; +#endif + + // No change - even without AVX3, we can shuffle+blend. + return Indices128<T, N>{vec.raw}; +} + +template <typename T, size_t N, typename TI, HWY_IF_LE128(T, N)> +HWY_API Indices128<T, N> SetTableIndices(Simd<T, N, 0> d, const TI* idx) { + const Rebind<TI, decltype(d)> di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> TableLookupLanes(Vec128<T, N> v, Indices128<T, N> idx) { +#if HWY_TARGET <= HWY_AVX2 + const DFromV<decltype(v)> d; + const RebindToFloat<decltype(d)> df; + const Vec128<float, N> perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)}; + return BitCast(d, perm); +#else + return TableLookupBytes(v, Vec128<T, N>{idx.raw}); +#endif +} + +template <size_t N, HWY_IF_GE64(float, N)> +HWY_API Vec128<float, N> TableLookupLanes(Vec128<float, N> v, + Indices128<float, N> idx) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128<float, N>{_mm_permutevar_ps(v.raw, idx.raw)}; +#else + const DFromV<decltype(v)> df; + const RebindToSigned<decltype(df)> di; + return BitCast(df, + TableLookupBytes(BitCast(di, v), Vec128<int32_t, N>{idx.raw})); +#endif +} + +// Single lane: no change +template <typename T> +HWY_API Vec128<T, 1> TableLookupLanes(Vec128<T, 1> v, + Indices128<T, 1> /* idx */) { + return v; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> TableLookupLanes(Vec128<T> v, Indices128<T> idx) { + const Full128<T> d; + Vec128<int64_t> vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const Full128<double> df; + return BitCast( + d, Vec128<double>{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128<int64_t> di; + const Vec128<int64_t> same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128<T> mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +HWY_API Vec128<double> TableLookupLanes(Vec128<double> v, + Indices128<double> idx) { + Vec128<int64_t> vidx{idx.raw}; +#if HWY_TARGET <= HWY_AVX2 + vidx += vidx; // bit1 is the decider (unusual) + return Vec128<double>{_mm_permutevar_pd(v.raw, vidx.raw)}; +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const Full128<double> d; + const Full128<int64_t> di; + const Vec128<int64_t> same = (vidx ^ Iota(di, 0)) - Set(di, 1); + const Mask128<double> mask_same = RebindMask(d, MaskFromVec(same)); + return IfThenElse(mask_same, v, Shuffle01(v)); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template <typename T> +HWY_API Vec128<T> ReverseBlocks(Full128<T> /* tag */, const Vec128<T> v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template <typename T> +HWY_API Vec128<T, 1> Reverse(Simd<T, 1, 0> /* tag */, const Vec128<T, 1> v) { + return v; +} + +// Two lanes: shuffle +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, 2> Reverse(Full64<T> /* tag */, const Vec128<T, 2> v) { + return Vec128<T, 2>{Shuffle2301(Vec128<T>{v.raw}).raw}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> Reverse(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle01(v); +} + +// Four lanes: shuffle +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> Reverse(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle0123(v); +} + +// 16-bit +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse(Simd<T, N, 0> d, const Vec128<T, N> v) { +#if HWY_TARGET <= HWY_AVX3 + if (N == 1) return v; + if (N == 2) { + const Repartition<uint32_t, decltype(d)> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); + } + const RebindToSigned<decltype(d)> di; + alignas(16) constexpr int16_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + const Vec128<int16_t, N> idx = Load(di, kReverse + (N == 8 ? 0 : 4)); + return BitCast(d, Vec128<int16_t, N>{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide<RebindToUnsigned<decltype(d)>> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +#endif +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template <typename T> +HWY_API Vec128<T, 1> Reverse2(Simd<T, 1, 0> /* tag */, const Vec128<T, 1> v) { + return v; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> d, const Vec128<T, N> v) { + alignas(16) const T kShuffle[16] = {1, 0, 3, 2, 5, 4, 7, 6, + 9, 8, 11, 10, 13, 12, 15, 14}; + return TableLookupBytes(v, Load(d, kShuffle)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> d, const Vec128<T, N> v) { + const Repartition<uint32_t, decltype(d)> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle2301(v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Reverse2(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> d, const Vec128<T, N> v) { + const RebindToSigned<decltype(d)> di; + // 4x 16-bit: a single shufflelo suffices. + if (N == 4) { + return BitCast(d, Vec128<int16_t, N>{_mm_shufflelo_epi16( + BitCast(di, v).raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET <= HWY_AVX3 + alignas(16) constexpr int16_t kReverse4[8] = {3, 2, 1, 0, 7, 6, 5, 4}; + const Vec128<int16_t, N> idx = Load(di, kReverse4); + return BitCast(d, Vec128<int16_t, N>{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide<decltype(di)> dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +// 4x 32-bit: use Shuffle0123 +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> Reverse4(Full128<T> /* tag */, const Vec128<T> v) { + return Shuffle0123(v); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> Reverse4(Simd<T, N, 0> /* tag */, Vec128<T, N> /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0> d, const Vec128<T, N> v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned<decltype(d)> di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec128<int16_t, N> idx = Load(di, kReverse8); + return BitCast(d, Vec128<int16_t, N>{ + _mm_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide<decltype(d)> dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API Vec128<T, N> Reverse8(Simd<T, N, 0> /* tag */, Vec128<T, N> /* v */) { + HWY_ASSERT(0); // don't have 8 lanes unless 16-bit +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template <size_t N, HWY_IF_LE128(uint8_t, N)> +HWY_API Vec128<uint8_t, N> InterleaveLower(const Vec128<uint8_t, N> a, + const Vec128<uint8_t, N> b) { + return Vec128<uint8_t, N>{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(uint16_t, N)> +HWY_API Vec128<uint16_t, N> InterleaveLower(const Vec128<uint16_t, N> a, + const Vec128<uint16_t, N> b) { + return Vec128<uint16_t, N>{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(uint32_t, N)> +HWY_API Vec128<uint32_t, N> InterleaveLower(const Vec128<uint32_t, N> a, + const Vec128<uint32_t, N> b) { + return Vec128<uint32_t, N>{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(uint64_t, N)> +HWY_API Vec128<uint64_t, N> InterleaveLower(const Vec128<uint64_t, N> a, + const Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template <size_t N, HWY_IF_LE128(int8_t, N)> +HWY_API Vec128<int8_t, N> InterleaveLower(const Vec128<int8_t, N> a, + const Vec128<int8_t, N> b) { + return Vec128<int8_t, N>{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(int16_t, N)> +HWY_API Vec128<int16_t, N> InterleaveLower(const Vec128<int16_t, N> a, + const Vec128<int16_t, N> b) { + return Vec128<int16_t, N>{_mm_unpacklo_epi16(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(int32_t, N)> +HWY_API Vec128<int32_t, N> InterleaveLower(const Vec128<int32_t, N> a, + const Vec128<int32_t, N> b) { + return Vec128<int32_t, N>{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(int64_t, N)> +HWY_API Vec128<int64_t, N> InterleaveLower(const Vec128<int64_t, N> a, + const Vec128<int64_t, N> b) { + return Vec128<int64_t, N>{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API Vec128<float, N> InterleaveLower(const Vec128<float, N> a, + const Vec128<float, N> b) { + return Vec128<float, N>{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template <size_t N, HWY_IF_LE128(double, N)> +HWY_API Vec128<double, N> InterleaveLower(const Vec128<double, N> a, + const Vec128<double, N> b) { + return Vec128<double, N>{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Additional overload for the optional tag (also for 256/512). +template <class V> +HWY_API V InterleaveLower(DFromV<V> /* tag */, V a, V b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec128<uint8_t> InterleaveUpper(const Vec128<uint8_t> a, + const Vec128<uint8_t> b) { + return Vec128<uint8_t>{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128<uint16_t> InterleaveUpper(const Vec128<uint16_t> a, + const Vec128<uint16_t> b) { + return Vec128<uint16_t>{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128<uint32_t> InterleaveUpper(const Vec128<uint32_t> a, + const Vec128<uint32_t> b) { + return Vec128<uint32_t>{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128<uint64_t> InterleaveUpper(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + return Vec128<uint64_t>{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128<int8_t> InterleaveUpper(const Vec128<int8_t> a, + const Vec128<int8_t> b) { + return Vec128<int8_t>{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec128<int16_t> InterleaveUpper(const Vec128<int16_t> a, + const Vec128<int16_t> b) { + return Vec128<int16_t>{_mm_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec128<int32_t> InterleaveUpper(const Vec128<int32_t> a, + const Vec128<int32_t> b) { + return Vec128<int32_t>{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec128<int64_t> InterleaveUpper(const Vec128<int64_t> a, + const Vec128<int64_t> b) { + return Vec128<int64_t>{_mm_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec128<float> InterleaveUpper(const Vec128<float> a, + const Vec128<float> b) { + return Vec128<float>{_mm_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec128<double> InterleaveUpper(const Vec128<double> a, + const Vec128<double> b) { + return Vec128<double>{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +// Full +template <typename T, class V = Vec128<T>> +HWY_API V InterleaveUpper(Full128<T> /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template <typename T, size_t N, HWY_IF_LE64(T, N), class V = Vec128<T, N>> +HWY_API V InterleaveUpper(Simd<T, N, 0> d, V a, V b) { + const Half<decltype(d)> d2; + return InterleaveLower(d, V{UpperHalf(d2, a).raw}, V{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <class V, class DW = RepartitionToWide<DFromV<V>>> +HWY_API VFromD<DW> ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>> +HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Vec128<T, N> Combine(Simd<T, N, 0> d, Vec128<T, N / 2> hi_half, + Vec128<T, N / 2> lo_half) { + const Half<decltype(d)> d2; + const RebindToUnsigned<decltype(d2)> du2; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128<UnsignedFromSize<N * sizeof(T) / 2>, 2>; + const VU lo{BitCast(du2, lo_half).raw}; + const VU hi{BitCast(du2, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T> +HWY_INLINE Vec128<T> ZeroExtendVector(hwy::NonFloatTag /*tag*/, + Full128<T> /* d */, Vec64<T> lo) { + return Vec128<T>{_mm_move_epi64(lo.raw)}; +} + +template <typename T> +HWY_INLINE Vec128<T> ZeroExtendVector(hwy::FloatTag /*tag*/, Full128<T> d, + Vec64<T> lo) { + const RebindToUnsigned<decltype(d)> du; + return BitCast(d, ZeroExtendVector(du, BitCast(Half<decltype(du)>(), lo))); +} + +} // namespace detail + +template <typename T> +HWY_API Vec128<T> ZeroExtendVector(Full128<T> d, Vec64<T> lo) { + return detail::ZeroExtendVector(hwy::IsFloatTag<T>(), d, lo); +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ZeroExtendVector(Simd<T, N, 0> d, Vec128<T, N / 2> lo) { + return IfThenElseZero(FirstN(d, N / 2), Vec128<T, N>{lo.raw}); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template <typename T> +HWY_API Vec128<T> ConcatLowerLower(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const Repartition<uint64_t, decltype(d)> d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template <typename T> +HWY_API Vec128<T> ConcatUpperUpper(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const Repartition<uint64_t, decltype(d)> d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template <typename T> +HWY_API Vec128<T> ConcatLowerUpper(Full128<T> d, const Vec128<T> hi, + const Vec128<T> lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template <typename T> +HWY_API Vec128<T> ConcatUpperLower(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const Repartition<double, decltype(d)> dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128<double>{_mm_shuffle_pd(BitCast(dd, lo).raw, BitCast(dd, hi).raw, + _MM_SHUFFLE2(1, 0))}); +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _pd can do 3/cycle. + return BitCast(d, Vec128<double>{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128<float> ConcatUpperLower(Full128<float> d, Vec128<float> hi, + Vec128<float> lo) { +#if HWY_TARGET == HWY_SSSE3 + (void)d; + return Vec128<float>{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 2, 1, 0))}; +#else + // _mm_shuffle_ps has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + const RepartitionToWide<decltype(d)> dd; + return BitCast(d, Vec128<double>{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +HWY_API Vec128<double> ConcatUpperLower(Full128<double> /* tag */, + Vec128<double> hi, Vec128<double> lo) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128<double>{_mm_shuffle_pd(lo.raw, hi.raw, _MM_SHUFFLE2(1, 0))}; +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return Vec128<double>{_mm_blend_pd(hi.raw, lo.raw, 1)}; +#endif +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatLowerLower(Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatUpperUpper(Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatLowerUpper(Simd<T, N, 0> d, const Vec128<T, N> hi, + const Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template <typename T, size_t N, HWY_IF_LE64(T, N)> +HWY_API Vec128<T, N> ConcatUpperLower(Simd<T, N, 0> d, Vec128<T, N> hi, + Vec128<T, N> lo) { + const Half<decltype(d)> d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T> ConcatOdd(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const Repartition<uint16_t, decltype(d)> dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec128<uint16_t> uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128<uint16_t> uL = ShiftRight<8>(BitCast(dw, lo)); + return Vec128<T>{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec64<T> ConcatOdd(Simd<T, 8, 0> d, Vec64<T> hi, Vec64<T> lo) { + const Repartition<uint32_t, decltype(d)> du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[8] = {1, 3, 5, 7}; + const Vec64<T> shuf = BitCast(d, Load(Full64<uint8_t>(), kCompactOddU8)); + const Vec64<T> L = TableLookupBytes(lo, shuf); + const Vec64<T> H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec32<T> ConcatOdd(Simd<T, 4, 0> d, Vec32<T> hi, Vec32<T> lo) { + const Repartition<uint16_t, decltype(d)> du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[4] = {1, 3}; + const Vec32<T> shuf = BitCast(d, Load(Full32<uint8_t>(), kCompactOddU8)); + const Vec32<T> L = TableLookupBytes(lo, shuf); + const Vec32<T> H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T> ConcatOdd(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const Repartition<int32_t, decltype(d)> dw; + const Vec128<int32_t> uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128<int32_t> uL = ShiftRight<16>(BitCast(dw, lo)); + return Vec128<T>{_mm_packs_epi32(uL.raw, uH.raw)}; +} + +// 16-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec64<T> ConcatOdd(Simd<T, 4, 0> d, Vec64<T> hi, Vec64<T> lo) { + const Repartition<uint32_t, decltype(d)> du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU16[8] = {2, 3, 6, 7}; + const Vec64<T> shuf = BitCast(d, Load(Full64<uint8_t>(), kCompactOddU16)); + const Vec64<T> L = TableLookupBytes(lo, shuf); + const Vec64<T> H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> ConcatOdd(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const RebindToFloat<decltype(d)> df; + return BitCast( + d, Vec128<float>{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(3, 1, 3, 1))}); +} +template <size_t N> +HWY_API Vec128<float> ConcatOdd(Full128<float> /* tag */, Vec128<float> hi, + Vec128<float> lo) { + return Vec128<float>{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; +} + +// Any type x2 +template <typename T> +HWY_API Vec128<T, 2> ConcatOdd(Simd<T, 2, 0> d, Vec128<T, 2> hi, + Vec128<T, 2> lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec128<T> ConcatEven(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const Repartition<uint16_t, decltype(d)> dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec128<uint16_t> mask = Set(dw, 0x00FF); + const Vec128<uint16_t> uH = And(BitCast(dw, hi), mask); + const Vec128<uint16_t> uL = And(BitCast(dw, lo), mask); + return Vec128<T>{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec64<T> ConcatEven(Simd<T, 8, 0> d, Vec64<T> hi, Vec64<T> lo) { + const Repartition<uint32_t, decltype(d)> du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[8] = {0, 2, 4, 6}; + const Vec64<T> shuf = BitCast(d, Load(Full64<uint8_t>(), kCompactEvenU8)); + const Vec64<T> L = TableLookupBytes(lo, shuf); + const Vec64<T> H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 8-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec32<T> ConcatEven(Simd<T, 4, 0> d, Vec32<T> hi, Vec32<T> lo) { + const Repartition<uint16_t, decltype(d)> du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[4] = {0, 2}; + const Vec32<T> shuf = BitCast(d, Load(Full32<uint8_t>(), kCompactEvenU8)); + const Vec32<T> L = TableLookupBytes(lo, shuf); + const Vec32<T> H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +} + +// 16-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec128<T> ConcatEven(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { +#if HWY_TARGET <= HWY_SSE4 + // Isolate lower 16 bits per u32 so we can pack. + const Repartition<uint32_t, decltype(d)> dw; + const Vec128<uint32_t> mask = Set(dw, 0x0000FFFF); + const Vec128<uint32_t> uH = And(BitCast(dw, hi), mask); + const Vec128<uint32_t> uL = And(BitCast(dw, lo), mask); + return Vec128<T>{_mm_packus_epi32(uL.raw, uH.raw)}; +#else + // packs_epi32 saturates 0x8000 to 0x7FFF. Instead ConcatEven within the two + // inputs, then concatenate them. + alignas(16) const T kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + const Vec128<T> shuf = BitCast(d, Load(d, kCompactEvenU16)); + const Vec128<T> L = TableLookupBytes(lo, shuf); + const Vec128<T> H = TableLookupBytes(hi, shuf); + return ConcatLowerLower(d, H, L); +#endif +} + +// 16-bit x4 +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec64<T> ConcatEven(Simd<T, 4, 0> d, Vec64<T> hi, Vec64<T> lo) { + const Repartition<uint32_t, decltype(d)> du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU16[8] = {0, 1, 4, 5}; + const Vec64<T> shuf = BitCast(d, Load(Full64<uint8_t>(), kCompactEvenU16)); + const Vec64<T> L = TableLookupBytes(lo, shuf); + const Vec64<T> H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +} + +// 32-bit full +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T> ConcatEven(Full128<T> d, Vec128<T> hi, Vec128<T> lo) { + const RebindToFloat<decltype(d)> df; + return BitCast( + d, Vec128<float>{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +HWY_API Vec128<float> ConcatEven(Full128<float> /* tag */, Vec128<float> hi, + Vec128<float> lo) { + return Vec128<float>{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; +} + +// Any T x2 +template <typename T> +HWY_API Vec128<T, 2> ConcatEven(Simd<T, 2, 0> d, Vec128<T, 2> hi, + Vec128<T, 2> lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> DupEven(Vec128<T, N> v) { + return Vec128<T, N>{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +template <size_t N> +HWY_API Vec128<float, N> DupEven(Vec128<float, N> v) { + return Vec128<float, N>{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> DupEven(const Vec128<T, N> v) { + return InterleaveLower(DFromV<decltype(v)>(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec128<T, N> DupOdd(Vec128<T, N> v) { + return Vec128<T, N>{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template <size_t N> +HWY_API Vec128<float, N> DupOdd(Vec128<float, N> v) { + return Vec128<float, N>{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T, N> DupOdd(const Vec128<T, N> v) { + return InterleaveUpper(DFromV<decltype(v)>(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) { + const DFromV<decltype(a)> d; + const Repartition<uint8_t, decltype(d)> d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) { +#if HWY_TARGET == HWY_SSSE3 + const DFromV<decltype(a)> d; + const Repartition<uint8_t, decltype(d)> d8; + alignas(16) constexpr uint8_t mask[16] = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +#else + return Vec128<T, N>{_mm_blend_epi16(a.raw, b.raw, 0x55)}; +#endif +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i odd = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128i even = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128<T, N>{_mm_unpacklo_epi32(even, odd)}; +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _ps can do 3/cycle. + const DFromV<decltype(a)> d; + const RebindToFloat<decltype(d)> df; + return BitCast(d, Vec128<float, N>{_mm_blend_ps(BitCast(df, a).raw, + BitCast(df, b).raw, 5)}); +#endif +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV<decltype(a)> d; + const RebindToFloat<decltype(d)> dd; +#if HWY_TARGET == HWY_SSSE3 + return BitCast( + d, Vec128<double, N>{_mm_shuffle_pd( + BitCast(dd, b).raw, BitCast(dd, a).raw, _MM_SHUFFLE2(1, 0))}); +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return BitCast(d, Vec128<double, N>{_mm_blend_pd(BitCast(dd, a).raw, + BitCast(dd, b).raw, 1)}); +#endif +} + +template <size_t N> +HWY_API Vec128<float, N> OddEven(Vec128<float, N> a, Vec128<float, N> b) { +#if HWY_TARGET == HWY_SSSE3 + // SHUFPS must fill the lower half of the output from one input, so we + // need another shuffle. Unpack avoids another immediate byte. + const __m128 odd = _mm_shuffle_ps(a.raw, a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128 even = _mm_shuffle_ps(b.raw, b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128<float, N>{_mm_unpacklo_ps(even, odd)}; +#else + return Vec128<float, N>{_mm_blend_ps(a.raw, b.raw, 5)}; +#endif +} + +// ------------------------------ OddEvenBlocks +template <typename T, size_t N> +HWY_API Vec128<T, N> OddEvenBlocks(Vec128<T, N> /* odd */, Vec128<T, N> even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T, size_t N> +HWY_API Vec128<T, N> SwapAdjacentBlocks(Vec128<T, N> v) { + return v; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +namespace detail { +#if HWY_TARGET > HWY_AVX3 // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<MakeUnsigned<T>, N> Pow2(const Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const RepartitionToWide<decltype(d)> dw; + const Rebind<float, decltype(dw)> df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // See comment below. + const Vec128<int32_t, N> bits0{_mm_cvtps_epi32(BitCast(df, f0).raw)}; + const Vec128<int32_t, N> bits1{_mm_cvtps_epi32(BitCast(df, f1).raw)}; + return Vec128<MakeUnsigned<T>, N>{_mm_packus_epi32(bits0.raw, bits1.raw)}; +} + +// Same, for 32-bit shifts. +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec128<MakeUnsigned<T>, N> Pow2(const Vec128<T, N> v) { + const DFromV<decltype(v)> d; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. cvt instead of cvtt should be equivalent, but avoids test + // failure under GCC 10.2.1. + return Vec128<MakeUnsigned<T>, N>{_mm_cvtps_epi32(_mm_castsi128_ps(f.raw))}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +template <size_t N> +HWY_API Vec128<uint16_t, N> Shl(hwy::UnsignedTag /*tag*/, Vec128<uint16_t, N> v, + Vec128<uint16_t, N> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<uint16_t, N>{_mm_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} +HWY_API Vec128<uint16_t, 1> Shl(hwy::UnsignedTag /*tag*/, Vec128<uint16_t, 1> v, + Vec128<uint16_t, 1> bits) { + return Vec128<uint16_t, 1>{_mm_sll_epi16(v.raw, bits.raw)}; +} + +template <size_t N> +HWY_API Vec128<uint32_t, N> Shl(hwy::UnsignedTag /*tag*/, Vec128<uint32_t, N> v, + Vec128<uint32_t, N> bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + return v * Pow2(bits); +#else + return Vec128<uint32_t, N>{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} +HWY_API Vec128<uint32_t, 1> Shl(hwy::UnsignedTag /*tag*/, Vec128<uint32_t, 1> v, + const Vec128<uint32_t, 1> bits) { + return Vec128<uint32_t, 1>{_mm_sll_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec128<uint64_t> Shl(hwy::UnsignedTag /*tag*/, Vec128<uint64_t> v, + Vec128<uint64_t> bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128<uint64_t> out0{_mm_sll_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128<uint64_t> out1{_mm_sll_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128<uint64_t>(), out1, out0); +#else + return Vec128<uint64_t>{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64<uint64_t> Shl(hwy::UnsignedTag /*tag*/, Vec64<uint64_t> v, + Vec64<uint64_t> bits) { + return Vec64<uint64_t>{_mm_sll_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template <typename T, size_t N> +HWY_API Vec128<T, N> Shl(hwy::SignedTag /*tag*/, Vec128<T, N> v, + Vec128<T, N> bits) { + const DFromV<decltype(v)> di; + const RebindToUnsigned<decltype(di)> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template <typename T, size_t N> +HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, Vec128<T, N> bits) { + return detail::Shl(hwy::TypeTag<T>(), v, bits); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for SSSE3/SSE4 or 16-bit. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +template <size_t N> +HWY_API Vec128<uint16_t, N> operator>>(const Vec128<uint16_t, N> in, + const Vec128<uint16_t, N> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<uint16_t, N>{_mm_srlv_epi16(in.raw, bits.raw)}; +#else + const Simd<uint16_t, N, 0> d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} +HWY_API Vec128<uint16_t, 1> operator>>(const Vec128<uint16_t, 1> in, + const Vec128<uint16_t, 1> bits) { + return Vec128<uint16_t, 1>{_mm_srl_epi16(in.raw, bits.raw)}; +} + +template <size_t N> +HWY_API Vec128<uint32_t, N> operator>>(const Vec128<uint32_t, N> in, + const Vec128<uint32_t, N> bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const Simd<uint32_t, N, 0> d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128<uint32_t, N> in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128<uint32_t, N> mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = BitCast(d32, MulEven(in31, mul31)); // 3 ? 1 ? + const Vec128<uint32_t, N> out = OddEven(out31, BitCast(d32, out20)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128<uint32_t, N>{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} +HWY_API Vec128<uint32_t, 1> operator>>(const Vec128<uint32_t, 1> in, + const Vec128<uint32_t, 1> bits) { + return Vec128<uint32_t, 1>{_mm_srl_epi32(in.raw, bits.raw)}; +} + +HWY_API Vec128<uint64_t> operator>>(const Vec128<uint64_t> v, + const Vec128<uint64_t> bits) { +#if HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 + // Individual shifts and combine + const Vec128<uint64_t> out0{_mm_srl_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128<uint64_t> out1{_mm_srl_epi64(v.raw, bits1)}; + return ConcatUpperLower(Full128<uint64_t>(), out1, out0); +#else + return Vec128<uint64_t>{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64<uint64_t> operator>>(const Vec64<uint64_t> v, + const Vec64<uint64_t> bits) { + return Vec64<uint64_t>{_mm_srl_epi64(v.raw, bits.raw)}; +} + +#if HWY_TARGET > HWY_AVX3 // AVX2 or older +namespace detail { + +// Also used in x86_256-inl.h. +template <class DI, class V> +HWY_INLINE V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned<DI> du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +template <size_t N> +HWY_API Vec128<int16_t, N> operator>>(const Vec128<int16_t, N> v, + const Vec128<int16_t, N> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int16_t, N>{_mm_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd<int16_t, N, 0>(), v, bits); +#endif +} +HWY_API Vec128<int16_t, 1> operator>>(const Vec128<int16_t, 1> v, + const Vec128<int16_t, 1> bits) { + return Vec128<int16_t, 1>{_mm_sra_epi16(v.raw, bits.raw)}; +} + +template <size_t N> +HWY_API Vec128<int32_t, N> operator>>(const Vec128<int32_t, N> v, + const Vec128<int32_t, N> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int32_t, N>{_mm_srav_epi32(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd<int32_t, N, 0>(), v, bits); +#endif +} +HWY_API Vec128<int32_t, 1> operator>>(const Vec128<int32_t, 1> v, + const Vec128<int32_t, 1> bits) { + return Vec128<int32_t, 1>{_mm_sra_epi32(v.raw, bits.raw)}; +} + +template <size_t N> +HWY_API Vec128<int64_t, N> operator>>(const Vec128<int64_t, N> v, + const Vec128<int64_t, N> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<int64_t, N>{_mm_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Simd<int64_t, N, 0>(), v, bits); +#endif +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +HWY_INLINE Vec128<uint64_t> MulEven(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + alignas(16) uint64_t mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128<uint64_t>(), mul); +} + +HWY_INLINE Vec128<uint64_t> MulOdd(const Vec128<uint64_t> a, + const Vec128<uint64_t> b) { + alignas(16) uint64_t mul[2]; + const Half<Full128<uint64_t>> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128<uint64_t>(), mul); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template <class V, size_t N, class D16 = Simd<bfloat16_t, 2 * N, 0>> +HWY_API V ReorderWidenMulAccumulate(Simd<float, N, 0> df32, VFromD<D16> a, + VFromD<D16> b, const V sum0, V& sum1) { + // TODO(janwas): _mm_dpbf16_ps when available + const RebindToUnsigned<decltype(df32)> du32; + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo. Using shift/and instead of Zip + // leads to the odd/even order that RearrangeToOddPlusEven prefers. + using VU32 = VFromD<decltype(du32)>; + const VU32 odd = Set(du32, 0xFFFF0000u); + const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); + const VU32 ao = And(BitCast(du32, a), odd); + const VU32 be = ShiftLeft<16>(BitCast(du32, b)); + const VU32 bo = And(BitCast(du32, b), odd); + sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); + return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); +} + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template <size_t N> +HWY_API Vec128<int32_t, N> ReorderWidenMulAccumulate( + Simd<int32_t, N, 0> /*d32*/, Vec128<int16_t, 2 * N> a, + Vec128<int16_t, 2 * N> b, const Vec128<int32_t, N> sum0, + Vec128<int32_t, N>& /*sum1*/) { + return sum0 + Vec128<int32_t, N>{_mm_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +template <size_t N> +HWY_API Vec128<int32_t, N> RearrangeToOddPlusEven(const Vec128<int32_t, N> sum0, + Vec128<int32_t, N> /*sum1*/) { + return sum0; // invariant already holds +} + +template <class VW> +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template <size_t N> +HWY_API Vec128<uint16_t, N> PromoteTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return Vec128<uint16_t, N>{_mm_unpacklo_epi8(v.raw, zero)}; +#else + return Vec128<uint16_t, N>{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint32_t, N> PromoteTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128<uint32_t, N>{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return Vec128<uint32_t, N>{_mm_cvtepu16_epi32(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint64_t, N> PromoteTo(Simd<uint64_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + return Vec128<uint64_t, N>{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; +#else + return Vec128<uint64_t, N>{_mm_cvtepu32_epi64(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<uint32_t, N> PromoteTo(Simd<uint32_t, N, 0> /* tag */, + const Vec128<uint8_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return Vec128<uint32_t, N>{_mm_unpacklo_epi16(u16, zero)}; +#else + return Vec128<uint32_t, N>{_mm_cvtepu8_epi32(v.raw)}; +#endif +} + +// Unsigned to signed: same plus cast. +template <size_t N> +HWY_API Vec128<int16_t, N> PromoteTo(Simd<int16_t, N, 0> di, + const Vec128<uint8_t, N> v) { + return BitCast(di, PromoteTo(Simd<uint16_t, N, 0>(), v)); +} +template <size_t N> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> di, + const Vec128<uint16_t, N> v) { + return BitCast(di, PromoteTo(Simd<uint32_t, N, 0>(), v)); +} +template <size_t N> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> di, + const Vec128<uint8_t, N> v) { + return BitCast(di, PromoteTo(Simd<uint32_t, N, 0>(), v)); +} + +// Signed: replicate sign bit. +template <size_t N> +HWY_API Vec128<int16_t, N> PromoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<int8_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<8>(Vec128<int16_t, N>{_mm_unpacklo_epi8(v.raw, v.raw)}); +#else + return Vec128<int16_t, N>{_mm_cvtepi8_epi16(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<16>(Vec128<int32_t, N>{_mm_unpacklo_epi16(v.raw, v.raw)}); +#else + return Vec128<int32_t, N>{_mm_cvtepi16_epi32(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int64_t, N> PromoteTo(Simd<int64_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + return ShiftRight<32>(Vec128<int64_t, N>{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return Vec128<int64_t, N>{_mm_cvtepi32_epi64(v.raw)}; +#endif +} +template <size_t N> +HWY_API Vec128<int32_t, N> PromoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<int8_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(Vec128<int32_t, N>{x4}); +#else + return Vec128<int32_t, N>{_mm_cvtepi8_epi32(v.raw)}; +#endif +} + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template <size_t N> +HWY_INLINE_F16 Vec128<float, N> PromoteTo(Simd<float, N, 0> df32, + const Vec128<float16_t, N> v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToSigned<decltype(df32)> di32; + const RebindToUnsigned<decltype(df32)> du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128<uint16_t, N>{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec128<float, N>{_mm_cvtph_ps(v.raw)}; +#endif +} + +template <size_t N> +HWY_API Vec128<float, N> PromoteTo(Simd<float, N, 0> df32, + const Vec128<bfloat16_t, N> v) { + const Rebind<uint16_t, decltype(df32)> du16; + const RebindToSigned<decltype(df32)> di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template <size_t N> +HWY_API Vec128<double, N> PromoteTo(Simd<double, N, 0> /* tag */, + const Vec128<float, N> v) { + return Vec128<double, N>{_mm_cvtps_pd(v.raw)}; +} + +template <size_t N> +HWY_API Vec128<double, N> PromoteTo(Simd<double, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<double, N>{_mm_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template <size_t N> +HWY_API Vec128<uint16_t, N> DemoteTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { +#if HWY_TARGET == HWY_SSSE3 + const Simd<int32_t, N, 0> di32; + const Simd<uint16_t, N * 2, 0> du16; + const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); + const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); + const auto clamped = Or(zero_if_neg, too_big); + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16, kLower2Bytes); + return Vec128<uint16_t, N>{TableLookupBytes(BitCast(du16, clamped), lo2).raw}; +#else + return Vec128<uint16_t, N>{_mm_packus_epi32(v.raw, v.raw)}; +#endif +} + +template <size_t N> +HWY_API Vec128<int16_t, N> DemoteTo(Simd<int16_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<int16_t, N>{_mm_packs_epi32(v.raw, v.raw)}; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128<uint8_t, N>{_mm_packus_epi16(i16, i16)}; +} + +template <size_t N> +HWY_API Vec128<uint8_t, N> DemoteTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<uint8_t, N>{_mm_packus_epi16(v.raw, v.raw)}; +} + +template <size_t N> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return Vec128<int8_t, N>{_mm_packs_epi16(i16, i16)}; +} + +template <size_t N> +HWY_API Vec128<int8_t, N> DemoteTo(Simd<int8_t, N, 0> /* tag */, + const Vec128<int16_t, N> v) { + return Vec128<int8_t, N>{_mm_packs_epi16(v.raw, v.raw)}; +} + +// Work around MSVC warning for _mm_cvtps_ph (8 is actually a valid immediate). +// clang-cl requires a non-empty string, so we 'ignore' the irrelevant -Wmain. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") + +template <size_t N> +HWY_API Vec128<float16_t, N> DemoteTo(Simd<float16_t, N, 0> df16, + const Vec128<float, N> v) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_F16C) + const RebindToUnsigned<decltype(df16)> du16; + const Rebind<uint32_t, decltype(df16)> du; + const RebindToSigned<decltype(du)> di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128<float16_t, N>{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +template <size_t N> +HWY_API Vec128<bfloat16_t, N> DemoteTo(Simd<bfloat16_t, N, 0> dbf16, + const Vec128<float, N> v) { + // TODO(janwas): _mm_cvtneps_pbh once we have avx512bf16. + const Rebind<int32_t, decltype(dbf16)> di32; + const Rebind<uint32_t, decltype(dbf16)> du32; // for logical shift right + const Rebind<uint16_t, decltype(dbf16)> du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +template <size_t N> +HWY_API Vec128<bfloat16_t, 2 * N> ReorderDemote2To( + Simd<bfloat16_t, 2 * N, 0> dbf16, Vec128<float, N> a, Vec128<float, N> b) { + // TODO(janwas): _mm_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned<decltype(dbf16)> du16; + const Repartition<uint32_t, decltype(dbf16)> du32; + const Vec128<uint32_t, N> b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +HWY_API Vec128<int16_t, 2> ReorderDemote2To(Simd<int16_t, 2, 0> dn, + Vec128<int32_t, 1> a, + Vec128<int32_t, 1> b) { + const Half<decltype(dn)> dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128<int16_t, 2> an{DemoteTo(dnh, a).raw}; + const Vec128<int16_t, 2> bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128<int16_t, 4> ReorderDemote2To(Simd<int16_t, 4, 0> dn, + Vec128<int32_t, 2> a, + Vec128<int32_t, 2> b) { + const Half<decltype(dn)> dnh; + // Pretend the result has twice as many lanes so we can InterleaveLower. + const Vec128<int16_t, 4> an{DemoteTo(dnh, a).raw}; + const Vec128<int16_t, 4> bn{DemoteTo(dnh, b).raw}; + return InterleaveLower(an, bn); +} +HWY_API Vec128<int16_t> ReorderDemote2To(Full128<int16_t> /*d16*/, + Vec128<int32_t> a, Vec128<int32_t> b) { + return Vec128<int16_t>{_mm_packs_epi32(a.raw, b.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> DemoteTo(Simd<float, N, 0> /* tag */, + const Vec128<double, N> v) { + return Vec128<float, N>{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// For well-defined float->int demotion in all x86_*-inl.h. + +template <size_t N> +HWY_INLINE auto ClampF64ToI32Max(Simd<double, N, 0> d, decltype(Zero(d)) v) + -> decltype(Zero(d)) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +// For ConvertTo float->int of same size, clamping before conversion would +// change the result because the max integer value is not exactly representable. +// Instead detect the overflow result after conversion and fix it. +template <class DI, class DF = RebindToFloat<DI>> +HWY_INLINE auto FixConversionOverflow(DI di, VFromD<DF> original, + decltype(Zero(di).raw) converted_raw) + -> VFromD<DI> { + // Combinations of original and output sign: + // --: normal <0 or -huge_val to 80..00: OK + // -+: -0 to 0 : OK + // +-: +huge_val to 80..00 : xor with FF..FF to get 7F..FF + // ++: normal >0 : OK + const auto converted = VFromD<DI>{converted_raw}; + const auto sign_wrong = AndNot(BitCast(di, original), converted); +#if HWY_COMPILER_GCC_ACTUAL + // Critical GCC 11 compiler bug (possibly also GCC 10): omits the Xor; also + // Add() if using that instead. Work around with one more instruction. + const RebindToUnsigned<DI> du; + const VFromD<DI> mask = BroadcastSignBit(sign_wrong); + const VFromD<DI> max = BitCast(di, ShiftRight<1>(BitCast(du, mask))); + return IfVecThenElse(mask, max, converted); +#else + return Xor(converted, BroadcastSignBit(sign_wrong)); +#endif +} + +} // namespace detail + +template <size_t N> +HWY_API Vec128<int32_t, N> DemoteTo(Simd<int32_t, N, 0> /* tag */, + const Vec128<double, N> v) { + const auto clamped = detail::ClampF64ToI32Max(Simd<double, N, 0>(), v); + return Vec128<int32_t, N>{_mm_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +template <size_t N> +HWY_API Vec128<uint8_t, N> U8FromU32(const Vec128<uint32_t, N> v) { + const Simd<uint32_t, N, 0> d32; + const Simd<uint8_t, N * 4, 0> d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +} + +// ------------------------------ Truncations + +template <typename From, typename To, + hwy::EnableIf<(sizeof(To) < sizeof(From))>* = nullptr> +HWY_API Vec128<To, 1> TruncateTo(Simd<To, 1, 0> /* tag */, + const Vec128<From, 1> v) { + static_assert(!IsSigned<To>() && !IsSigned<From>(), "Unsigned only"); + const Repartition<To, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + return Vec128<To, 1>{v1.raw}; +} + +HWY_API Vec128<uint8_t, 2> TruncateTo(Simd<uint8_t, 2, 0> /* tag */, + const Vec128<uint64_t, 2> v) { + const Full128<uint8_t> d8; + alignas(16) static constexpr uint8_t kMap[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + return LowerHalf(LowerHalf(LowerHalf(TableLookupBytes(v, Load(d8, kMap))))); +} + +HWY_API Vec128<uint16_t, 2> TruncateTo(Simd<uint16_t, 2, 0> /* tag */, + const Vec128<uint64_t, 2> v) { + const Full128<uint16_t> d16; + alignas(16) static constexpr uint16_t kMap[8] = { + 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d16, kMap)))); +} + +HWY_API Vec128<uint32_t, 2> TruncateTo(Simd<uint32_t, 2, 0> /* tag */, + const Vec128<uint64_t, 2> v) { + return Vec128<uint32_t, 2>{_mm_shuffle_epi32(v.raw, 0x88)}; +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + const Repartition<uint8_t, DFromV<decltype(v)>> d; + alignas(16) static constexpr uint8_t kMap[16] = { + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu, + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d, kMap)))); +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint16_t, N> TruncateTo(Simd<uint16_t, N, 0> /* tag */, + const Vec128<uint32_t, N> v) { + const Repartition<uint16_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template <size_t N, hwy::EnableIf<N >= 2>* = nullptr> +HWY_API Vec128<uint8_t, N> TruncateTo(Simd<uint8_t, N, 0> /* tag */, + const Vec128<uint16_t, N> v) { + const Repartition<uint8_t, DFromV<decltype(v)>> d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +template <size_t N> +HWY_API Vec128<float, N> ConvertTo(Simd<float, N, 0> /* tag */, + const Vec128<int32_t, N> v) { + return Vec128<float, N>{_mm_cvtepi32_ps(v.raw)}; +} + +template <size_t N> +HWY_API Vec128<float, N> ConvertTo(HWY_MAYBE_UNUSED Simd<float, N, 0> df, + const Vec128<uint32_t, N> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<float, N>{_mm_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned<decltype(df)> du32; + const RebindToSigned<decltype(df)> d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +template <size_t N> +HWY_API Vec128<double, N> ConvertTo(Simd<double, N, 0> dd, + const Vec128<int64_t, N> v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec128<double, N>{_mm_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition<uint32_t, decltype(dd)> d32; + const Repartition<uint64_t, decltype(dd)> d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +template <size_t N> +HWY_API Vec128<double, N> ConvertTo(HWY_MAYBE_UNUSED Simd<double, N, 0> dd, + const Vec128<uint64_t, N> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128<double, N>{_mm_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned<decltype(dd)> d64; + using VU = VFromD<decltype(d64)>; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double128_fast = [&dd](VU w) HWY_ATTR { + w = Or(w, VU{detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double128_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double128_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +template <size_t N> +HWY_API Vec128<int32_t, N> ConvertTo(const Simd<int32_t, N, 0> di, + const Vec128<float, N> v) { + return detail::FixConversionOverflow(di, v, _mm_cvttps_epi32(v.raw)); +} + +// Full (partial handled below) +HWY_API Vec128<int64_t> ConvertTo(Full128<int64_t> di, const Vec128<double> v) { +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 + return detail::FixConversionOverflow(di, v, _mm_cvttpd_epi64(v.raw)); +#elif HWY_ARCH_X86_64 + const __m128i i0 = _mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw)); + const Half<Full128<double>> dd2; + const __m128i i1 = _mm_cvtsi64_si128(_mm_cvttsd_si64(UpperHalf(dd2, v).raw)); + return detail::FixConversionOverflow(di, v, _mm_unpacklo_epi64(i0, i1)); +#else + using VI = VFromD<decltype(di)>; + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax<int64_t>()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} +HWY_API Vec64<int64_t> ConvertTo(Full64<int64_t> di, const Vec64<double> v) { + // Only need to specialize for non-AVX3, 64-bit (single scalar op) +#if HWY_TARGET > HWY_AVX3 && HWY_ARCH_X86_64 + const Vec64<int64_t> i0{_mm_cvtsi64_si128(_mm_cvttsd_si64(v.raw))}; + return detail::FixConversionOverflow(di, v, i0.raw); +#else + (void)di; + const auto full = ConvertTo(Full128<int64_t>(), Vec128<double>{v.raw}); + return Vec64<int64_t>{full.raw}; +#endif +} + +template <size_t N> +HWY_API Vec128<int32_t, N> NearestInt(const Vec128<float, N> v) { + const Simd<int32_t, N, 0> di; + return detail::FixConversionOverflow(di, v, _mm_cvtps_epi32(v.raw)); +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +#if HWY_TARGET == HWY_SSSE3 + +// Toward nearest integer, ties to even +template <typename T, size_t N> +HWY_API Vec128<T, N> Round(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + // Rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const Simd<T, N, 0> df; + const auto max = Set(df, MantissaEnd<T>()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template <typename T, size_t N> +HWY_INLINE Mask128<T, N> UseInt(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + return Abs(v) < Set(Simd<T, N, 0>(), MantissaEnd<T>()); +} + +} // namespace detail + +// Toward zero, aka truncate +template <typename T, size_t N> +HWY_API Vec128<T, N> Trunc(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template <typename T, size_t N> +HWY_API Vec128<T, N> Ceil(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +// Toward -infinity, aka floor +template <typename T, size_t N> +HWY_API Vec128<T, N> Floor(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> df; + const RebindToSigned<decltype(df)> di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#else + +// Toward nearest integer, ties to even +template <size_t N> +HWY_API Vec128<float, N> Round(const Vec128<float, N> v) { + return Vec128<float, N>{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template <size_t N> +HWY_API Vec128<double, N> Round(const Vec128<double, N> v) { + return Vec128<double, N>{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +template <size_t N> +HWY_API Vec128<float, N> Trunc(const Vec128<float, N> v) { + return Vec128<float, N>{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template <size_t N> +HWY_API Vec128<double, N> Trunc(const Vec128<double, N> v) { + return Vec128<double, N>{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +template <size_t N> +HWY_API Vec128<float, N> Ceil(const Vec128<float, N> v) { + return Vec128<float, N>{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template <size_t N> +HWY_API Vec128<double, N> Ceil(const Vec128<double, N> v) { + return Vec128<double, N>{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +template <size_t N> +HWY_API Vec128<float, N> Floor(const Vec128<float, N> v) { + return Vec128<float, N>{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template <size_t N> +HWY_API Vec128<double, N> Floor(const Vec128<double, N> v) { + return Vec128<double, N>{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +#endif // !HWY_SSSE3 + +// ------------------------------ Floating-point classification + +template <size_t N> +HWY_API Mask128<float, N> IsNaN(const Vec128<float, N> v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128<float, N>{_mm_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask128<float, N>{_mm_cmpunord_ps(v.raw, v.raw)}; +#endif +} +template <size_t N> +HWY_API Mask128<double, N> IsNaN(const Vec128<double, N> v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128<double, N>{_mm_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask128<double, N>{_mm_cmpunord_pd(v.raw, v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +template <size_t N> +HWY_API Mask128<float, N> IsInf(const Vec128<float, N> v) { + return Mask128<float, N>{_mm_fpclass_ps_mask(v.raw, 0x18)}; +} +template <size_t N> +HWY_API Mask128<double, N> IsInf(const Vec128<double, N> v) { + return Mask128<double, N>{_mm_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. +template <size_t N> +HWY_API Mask128<float, N> IsFinite(const Vec128<float, N> v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128<float, N>{_mm_fpclass_ps_mask(v.raw, 0x99)}); +} +template <size_t N> +HWY_API Mask128<double, N> IsFinite(const Vec128<double, N> v) { + return Not(Mask128<double, N>{_mm_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template <typename T, size_t N> +HWY_API Mask128<T, N> IsInf(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2<T>()))); +} + +// Returns whether normal/subnormal/zero. +template <typename T, size_t N> +HWY_API Mask128<T, N> IsFinite(const Vec128<T, N> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + const VFromD<decltype(du)> vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField<T>()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128<uint8_t> AESRound(Vec128<uint8_t> state, + Vec128<uint8_t> round_key) { + return Vec128<uint8_t>{_mm_aesenc_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128<uint8_t> AESLastRound(Vec128<uint8_t> state, + Vec128<uint8_t> round_key) { + return Vec128<uint8_t>{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + +template <size_t N, HWY_IF_LE128(uint64_t, N)> +HWY_API Vec128<uint64_t, N> CLMulLower(Vec128<uint64_t, N> a, + Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{_mm_clmulepi64_si128(a.raw, b.raw, 0x00)}; +} + +template <size_t N, HWY_IF_LE128(uint64_t, N)> +HWY_API Vec128<uint64_t, N> CLMulUpper(Vec128<uint64_t, N> a, + Vec128<uint64_t, N> b) { + return Vec128<uint64_t, N>{_mm_clmulepi64_si128(a.raw, b.raw, 0x11)}; +} + +#endif // !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET != HWY_SSSE3 + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const Vec128<T, N> vbits{_mm_cvtsi32_si128(static_cast<int>(mask_bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast<uint16_t>(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast<uint32_t>(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(16) constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T, size_t N, HWY_IF_LE128(T, N)> +HWY_API Mask128<T, N> LoadMaskBits(Simd<T, N, 0> d, + const uint8_t* HWY_RESTRICT bits) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask128<T, N>::FromBits(mask_bits); +#else + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits(d, mask_bits); +#endif +} + +template <typename T> +struct CompressIsPartition { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 supports native compress, but a table-based approach allows + // 'partitioning' (also moving mask=false lanes to the top), which helps + // vqsort. This is only feasible for eight or less lanes, i.e. sizeof(T) == 8 + // on AVX3. For simplicity, we only use tables for 64-bit lanes (not AVX3 + // u32x8 etc.). + enum { value = (sizeof(T) == 8) }; +#else + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +#endif +}; + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template <typename T, size_t N> +HWY_API size_t StoreMaskBits(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = (1 << N) - 1; + bits[0] = static_cast<uint8_t>(bits[0] & mask_bits); + } + + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +template <typename T, size_t N> +HWY_API size_t CountTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint64_t mask_bits = static_cast<uint64_t>(mask.raw) & ((1u << N) - 1); + return PopCount(mask_bits); +} + +template <typename T, size_t N> +HWY_API size_t FindKnownFirstTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint32_t mask_bits = static_cast<uint32_t>(mask.raw) & ((1u << N) - 1); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template <typename T, size_t N> +HWY_API intptr_t FindFirstTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint32_t mask_bits = static_cast<uint32_t>(mask.raw) & ((1u << N) - 1); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template <typename T, size_t N> +HWY_API bool AllFalse(const Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + const uint64_t mask_bits = static_cast<uint64_t>(mask.raw) & ((1u << N) - 1); + return mask_bits == 0; +} + +template <typename T, size_t N> +HWY_API bool AllTrue(const Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + const uint64_t mask_bits = static_cast<uint64_t>(mask.raw) & ((1u << N) - 1); + // Cannot use _kortestc because we may have less than 8 mask bits. + return mask_bits == (1u << N) - 1; +} + +// ------------------------------ Compress + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> Compress(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +template <size_t N, HWY_IF_GE64(float, N)> +HWY_API Vec128<float, N> Compress(Vec128<float, N> v, Mask128<float, N> mask) { + return Vec128<float, N>{_mm_maskz_compress_ps(mask.raw, v.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> Compress(Vec128<T> v, Mask128<T> mask) { + HWY_DASSERT(mask.raw < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Full128<T> d; + const Repartition<uint8_t, decltype(d)> d8; + const auto index = Load(d8, u8_indices + 16 * mask.raw); + return BitCast(d, TableLookupBytes(BitCast(d8, v), index)); +} + +// ------------------------------ CompressNot (Compress) + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> CompressNot(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) { + // See CompressIsPartition, PrintCompressNot64x2NibbleTables + alignas(16) constexpr uint64_t packed_array[16] = {0x00000010, 0x00000001, + 0x00000010, 0x00000010}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm_permutexvar_epi64 will ignore the upper bits. + const Full128<T> d; + const RebindToUnsigned<decltype(d)> du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(16) constexpr uint64_t shifts[2] = {0, 4}; + const auto indices = Indices128<T>{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128<uint64_t> CompressBlocksNot(Vec128<uint64_t> v, + Mask128<uint64_t> /* m */) { + return v; +} + +// ------------------------------ CompressStore + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4)> +HWY_API size_t CompressStore(Vec128<T, N> v, Mask128<T, N> mask, + Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8)> +HWY_API size_t CompressStore(Vec128<T, N> v, Mask128<T, N> mask, + Simd<T, N, 0> /* tag */, + T* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <size_t N, HWY_IF_LE128(float, N)> +HWY_API size_t CompressStore(Vec128<float, N> v, Mask128<float, N> mask, + Simd<float, N, 0> /* tag */, + float* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <size_t N, HWY_IF_LE128(double, N)> +HWY_API size_t CompressStore(Vec128<double, N> v, Mask128<double, N> mask, + Simd<double, N, 0> /* tag */, + double* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & ((1ull << N) - 1)); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + // AVX-512 already does the blending at no extra cost (latency 11, + // rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { + // We're relying on the mask to blend. Clear the undefined upper bits. + if (N != 16 / sizeof(T)) { + m = And(m, FirstN(d, N)); + } + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + const Vec128<T, N> compressed = Compress(v, m); +#if HWY_MEM_OPS_MIGHT_FAULT + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(16) T buf[N]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 or below + +// ------------------------------ StoreMaskBits + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast<uint64_t>(static_cast<unsigned>(mask_bits)); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, + const Mask128<T, N> mask) { + const Simd<T, N, 0> d; + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, + const Mask128<T, N> mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return U64FromInt(_mm_movemask_epi8(sign_bits)); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, + const Mask128<T, N> mask) { + const Simd<T, N, 0> d; + const Simd<float, N, 0> df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_ps(sign_bits.raw)); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<8> /*tag*/, + const Mask128<T, N> mask) { + const Simd<T, N, 0> d; + const Simd<double, N, 0> df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return U64FromInt(_mm_movemask_pd(sign_bits.raw)); +} + +// Returns the lowest N of the _mm_movemask* bits. +template <typename T, size_t N> +constexpr uint64_t OnlyActive(uint64_t mask_bits) { + return ((N * sizeof(T)) == 16) ? mask_bits : mask_bits & ((1ull << N) - 1); +} + +template <typename T, size_t N> +HWY_INLINE uint64_t BitsFromMask(const Mask128<T, N> mask) { + return OnlyActive<T, N>(BitsFromMask(hwy::SizeTag<sizeof(T)>(), mask)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template <typename T, size_t N> +HWY_API size_t StoreMaskBits(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask, uint8_t* bits) { + constexpr size_t kNumBytes = (N + 7) / 8; + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes<kNumBytes>(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template <typename T, size_t N> +HWY_API bool AllFalse(const Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template <typename T, size_t N> +HWY_API bool AllTrue(const Simd<T, N, 0> /* tag */, const Mask128<T, N> mask) { + constexpr uint64_t kAllBits = + detail::OnlyActive<T, N>((1ull << (16 / sizeof(T))) - 1); + return detail::BitsFromMask(mask) == kAllBits; +} + +template <typename T, size_t N> +HWY_API size_t CountTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template <typename T, size_t N> +HWY_API size_t FindKnownFirstTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template <typename T, size_t N> +HWY_API intptr_t FindFirstTrue(const Simd<T, N, 0> /* tag */, + const Mask128<T, N> mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> IndicesFromBits(Simd<T, N, 0> d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind<uint8_t, decltype(d)> d8; + const Simd<uint16_t, N, 0> du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128<uint8_t, 2 * N> byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128<uint16_t, N> pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec128<T, N> IndicesFromNotBits(Simd<T, N, 0> d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind<uint8_t, decltype(d)> d8; + const Simd<uint16_t, N, 0> du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128<uint8_t, 2 * N> byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128<uint16_t, N> pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4), HWY_IF_LE128(T, N)> +HWY_INLINE Vec128<T, N> IndicesFromBits(Simd<T, N, 0> d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 4), HWY_IF_LE128(T, N)> +HWY_INLINE Vec128<T, N> IndicesFromNotBits(Simd<T, N, 0> d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8), HWY_IF_LE128(T, N)> +HWY_INLINE Vec128<T, N> IndicesFromBits(Simd<T, N, 0> d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_LANE_SIZE(T, 8), HWY_IF_LE128(T, N)> +HWY_INLINE Vec128<T, N> IndicesFromNotBits(Simd<T, N, 0> d, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> CompressBits(Vec128<T, N> v, uint64_t mask_bits) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> CompressNotBits(Vec128<T, N> v, uint64_t mask_bits) { + const Simd<T, N, 0> d; + const RebindToUnsigned<decltype(d)> du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> Compress(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +// Two lanes: conditional swap +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> Compress(Vec128<T> v, Mask128<T> mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128<T> d; + const Vec128<T> m = VecFromMask(d, mask); + const Vec128<T> maskL = DupEven(m); + const Vec128<T> maskH = DupOdd(m); + const Vec128<T> swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API Vec128<T, N> Compress(Vec128<T, N> v, Mask128<T, N> mask) { + return detail::CompressBits(v, detail::BitsFromMask(mask)); +} + +// Single lane: no-op +template <typename T> +HWY_API Vec128<T, 1> CompressNot(Vec128<T, 1> v, Mask128<T, 1> /*m*/) { + return v; +} + +// Two lanes: conditional swap +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128<T> d; + const Vec128<T> m = VecFromMask(d, mask); + const Vec128<T> maskL = DupEven(m); + const Vec128<T> maskH = DupOdd(m); + const Vec128<T> swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template <typename T, size_t N, HWY_IF_LANE_SIZE_ONE_OF(T, 0x14)> +HWY_API Vec128<T, N> CompressNot(Vec128<T, N> v, Mask128<T, N> mask) { + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, detail::BitsFromMask(Not(mask))); + } + return detail::CompressNotBits(v, detail::BitsFromMask(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128<uint64_t> CompressBlocksNot(Vec128<uint64_t> v, + Mask128<uint64_t> /* m */) { + return v; +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec128<T, N> CompressBits(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressStore(Vec128<T, N> v, Mask128<T, N> m, Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBlendedStore(Vec128<T, N> v, Mask128<T, N> m, + Simd<T, N, 0> d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; + + const uint64_t mask_bits = detail::BitsFromMask(m); + HWY_DASSERT(mask_bits < (1ull << N)); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <typename T, size_t N, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(Vec128<T, N> v, + const uint8_t* HWY_RESTRICT bits, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; + + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes<kNumBytes>(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Reductions + +namespace detail { + +// N=1 for any T: no-op +template <typename T> +HWY_INLINE Vec128<T, 1> SumOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} +template <typename T> +HWY_INLINE Vec128<T, 1> MinOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} +template <typename T> +HWY_INLINE Vec128<T, 1> MaxOfLanes(hwy::SizeTag<sizeof(T)> /* tag */, + const Vec128<T, 1> v) { + return v; +} + +// u32/i32/f32: + +// N=2 +template <typename T> +HWY_INLINE Vec128<T, 2> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return v10 + Shuffle2301(v10); +} +template <typename T> +HWY_INLINE Vec128<T, 2> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return Min(v10, Shuffle2301(v10)); +} +template <typename T> +HWY_INLINE Vec128<T, 2> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T, 2> v10) { + return Max(v10, Shuffle2301(v10)); +} + +// N=4 (full) +template <typename T> +HWY_INLINE Vec128<T> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = v3210 + v1032; + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template <typename T> +HWY_INLINE Vec128<T> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = Min(v3210, v1032); + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template <typename T> +HWY_INLINE Vec128<T> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec128<T> v3210) { + const Vec128<T> v1032 = Shuffle1032(v3210); + const Vec128<T> v31_20_31_20 = Max(v3210, v1032); + const Vec128<T> v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +// u64/i64/f64: + +// N=2 (full) +template <typename T> +HWY_INLINE Vec128<T> SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<T> v10) { + const Vec128<T> v01 = Shuffle01(v10); + return v10 + v01; +} +template <typename T> +HWY_INLINE Vec128<T> MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<T> v10) { + const Vec128<T> v01 = Shuffle01(v10); + return Min(v10, v01); +} +template <typename T> +HWY_INLINE Vec128<T> MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec128<T> v10) { + const Vec128<T> v01 = Shuffle01(v10); + return Max(v10, v01); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +// u8, N=8, N=16: +HWY_API Vec64<uint8_t> SumOfLanes(hwy::SizeTag<1> /* tag */, Vec64<uint8_t> v) { + const Full64<uint8_t> d; + return Set(d, static_cast<uint8_t>(GetLane(SumsOf8(v)) & 0xFF)); +} +HWY_API Vec128<uint8_t> SumOfLanes(hwy::SizeTag<1> /* tag */, + Vec128<uint8_t> v) { + const Full128<uint8_t> d; + Vec128<uint64_t> sums = SumOfLanes(hwy::SizeTag<8>(), SumsOf8(v)); + return Set(d, static_cast<uint8_t>(GetLane(sums) & 0xFF)); +} + +template <size_t N, HWY_IF_GE64(int8_t, N)> +HWY_API Vec128<int8_t, N> SumOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto is_neg = v < Zero(d); + + // Sum positive and negative lanes separately, then combine to get the result. + const auto positive = SumsOf8(BitCast(du, IfThenZeroElse(is_neg, v))); + const auto negative = SumsOf8(BitCast(du, IfThenElseZero(is_neg, Abs(v)))); + return Set(d, static_cast<int8_t>(GetLane( + SumOfLanes(hwy::SizeTag<8>(), positive - negative)) & + 0xFF)); +} + +#if HWY_TARGET <= HWY_SSE4 +HWY_API Vec128<uint16_t> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t> v) { + using V = decltype(v); + return Broadcast<0>(V{_mm_minpos_epu16(v.raw)}); +} +HWY_API Vec64<uint8_t> MinOfLanes(hwy::SizeTag<1> /* tag */, Vec64<uint8_t> v) { + const Full64<uint8_t> d; + const Full128<uint16_t> d16; + return TruncateTo(d, MinOfLanes(hwy::SizeTag<2>(), PromoteTo(d16, v))); +} +HWY_API Vec128<uint8_t> MinOfLanes(hwy::SizeTag<1> tag, + Vec128<uint8_t> v) { + const Half<DFromV<decltype(v)>> d; + Vec64<uint8_t> result = + Min(MinOfLanes(tag, UpperHalf(d, v)), MinOfLanes(tag, LowerHalf(d, v))); + return Combine(DFromV<decltype(v)>(), result, result); +} + +HWY_API Vec128<uint16_t> MaxOfLanes(hwy::SizeTag<2> tag, Vec128<uint16_t> v) { + const Vec128<uint16_t> m(Set(DFromV<decltype(v)>(), LimitsMax<uint16_t>())); + return m - MinOfLanes(tag, m - v); +} +HWY_API Vec64<uint8_t> MaxOfLanes(hwy::SizeTag<1> tag, Vec64<uint8_t> v) { + const Vec64<uint8_t> m(Set(DFromV<decltype(v)>(), LimitsMax<uint8_t>())); + return m - MinOfLanes(tag, m - v); +} +HWY_API Vec128<uint8_t> MaxOfLanes(hwy::SizeTag<1> tag, Vec128<uint8_t> v) { + const Vec128<uint8_t> m(Set(DFromV<decltype(v)>(), LimitsMax<uint8_t>())); + return m - MinOfLanes(tag, m - v); +} +#elif HWY_TARGET == HWY_SSSE3 +template <size_t N, HWY_IF_GE64(uint8_t, N)> +HWY_API Vec128<uint8_t, N> MaxOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128<uint8_t, N> v) { + const DFromV<decltype(v)> d; + const RepartitionToWide<decltype(d)> d16; + const RepartitionToWide<decltype(d16)> d32; + Vec128<uint8_t, N> vm = Max(v, Reverse2(d, v)); + vm = Max(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); + vm = Max(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); + if (N > 8) { + const RepartitionToWide<decltype(d32)> d64; + vm = Max(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); + } + return vm; +} + +template <size_t N, HWY_IF_GE64(uint8_t, N)> +HWY_API Vec128<uint8_t, N> MinOfLanes(hwy::SizeTag<1> /* tag */, + const Vec128<uint8_t, N> v) { + const DFromV<decltype(v)> d; + const RepartitionToWide<decltype(d)> d16; + const RepartitionToWide<decltype(d16)> d32; + Vec128<uint8_t, N> vm = Min(v, Reverse2(d, v)); + vm = Min(vm, BitCast(d, Reverse2(d16, BitCast(d16, vm)))); + vm = Min(vm, BitCast(d, Reverse2(d32, BitCast(d32, vm)))); + if (N > 8) { + const RepartitionToWide<decltype(d32)> d64; + vm = Min(vm, BitCast(d, Reverse2(d64, BitCast(d64, vm)))); + } + return vm; +} +#endif + +// Implement min/max of i8 in terms of u8 by toggling the sign bit. +template <size_t N, HWY_IF_GE64(int8_t, N)> +HWY_API Vec128<int8_t, N> MinOfLanes(hwy::SizeTag<1> tag, + const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto mask = SignBit(du); + const auto vu = Xor(BitCast(du, v), mask); + return BitCast(d, Xor(MinOfLanes(tag, vu), mask)); +} +template <size_t N, HWY_IF_GE64(int8_t, N)> +HWY_API Vec128<int8_t, N> MaxOfLanes(hwy::SizeTag<1> tag, + const Vec128<int8_t, N> v) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto mask = SignBit(du); + const auto vu = Xor(BitCast(du, v), mask); + return BitCast(d, Xor(MaxOfLanes(tag, vu), mask)); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +template <size_t N, HWY_IF_GE32(uint16_t, N)> +HWY_API Vec128<uint16_t, N> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<uint16_t, N> v) { + const Simd<uint16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +template <size_t N, HWY_IF_GE32(int16_t, N)> +HWY_API Vec128<int16_t, N> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec128<int16_t, N> v) { + const Simd<int16_t, N, 0> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for u/i/f 32/64. Returns the same value in each lane. +template <typename T, size_t N> +HWY_API Vec128<T, N> SumOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::SumOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MinOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::MinOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} +template <typename T, size_t N> +HWY_API Vec128<T, N> MaxOfLanes(Simd<T, N, 0> /* tag */, const Vec128<T, N> v) { + return detail::MaxOfLanes(hwy::SizeTag<sizeof(T)>(), v); +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Also used by x86_256/x86_512. +template <class D, class V = VFromD<D>> +HWY_INLINE V Lt128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Also used by x86_256/x86_512. +template <class D, class V = VFromD<D>> +HWY_INLINE V Eq128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template <class D, class V = VFromD<D>> +HWY_INLINE V Ne128Vec(const D d, const V a, const V b) { + static_assert(!IsSigned<TFromD<D>>() && sizeof(TFromD<D>) == 8, + "D must be u64"); + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template <class D, class V = VFromD<D>> +HWY_INLINE V Lt128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template <class D, class V = VFromD<D>> +HWY_INLINE V Eq128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template <class D, class V = VFromD<D>> +HWY_INLINE V Ne128UpperVec(const D d, const V a, const V b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const V neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template <class D, class V = VFromD<D>> +HWY_API MFromD<D> Lt128(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template <class D, class V = VFromD<D>> +HWY_API MFromD<D> Eq128(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template <class D, class V = VFromD<D>> +HWY_API MFromD<D> Ne128(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template <class D, class V = VFromD<D>> +HWY_API MFromD<D> Lt128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template <class D, class V = VFromD<D>> +HWY_API MFromD<D> Eq128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template <class D, class V = VFromD<D>> +HWY_API MFromD<D> Ne128Upper(D d, const V a, const V b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template <class D, class V = VFromD<D>> +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template <class D, class V = VFromD<D>> +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template <class D, class V = VFromD<D>> +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template <class D, class V = VFromD<D>> +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/highway/hwy/ops/x86_256-inl.h b/third_party/highway/hwy/ops/x86_256-inl.h new file mode 100644 index 0000000000..3539520adf --- /dev/null +++ b/third_party/highway/hwy/ops/x86_256-inl.h @@ -0,0 +1,5548 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when +// compiling for that target. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +// Must come before HWY_COMPILER_CLANGCL +#include <immintrin.h> // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including <immintrin.h> should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include <avxintrin.h> +// avxintrin defines __m256i and must come before avx2intrin. +#include <avx2intrin.h> +#include <bmi2intrin.h> // _pext_u64 +#include <f16cintrin.h> +#include <fmaintrin.h> +#include <smmintrin.h> +#endif // HWY_COMPILER_CLANGCL + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#if HWY_IS_MSAN +#include <sanitizer/msan_interface.h> +#endif + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template <typename T> +struct Raw256 { + using type = __m256i; +}; +template <> +struct Raw256<float> { + using type = __m256; +}; +template <> +struct Raw256<double> { + using type = __m256d; +}; + +} // namespace detail + +template <typename T> +class Vec256 { + using Raw = typename detail::Raw256<T>::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +#if HWY_TARGET <= HWY_AVX3 + +namespace detail { + +// Template arg: sizeof(lane type) +template <size_t size> +struct RawMask256 {}; +template <> +struct RawMask256<1> { + using type = __mmask32; +}; +template <> +struct RawMask256<2> { + using type = __mmask16; +}; +template <> +struct RawMask256<4> { + using type = __mmask8; +}; +template <> +struct RawMask256<8> { + using type = __mmask8; +}; + +} // namespace detail + +template <typename T> +struct Mask256 { + using Raw = typename detail::RawMask256<sizeof(T)>::type; + + static Mask256<T> FromBits(uint64_t mask_bits) { + return Mask256<T>{static_cast<Raw>(mask_bits)}; + } + + Raw raw; +}; + +#else // AVX2 + +// FF..FF or 0. +template <typename T> +struct Mask256 { + typename detail::Raw256<T>::type raw; +}; + +#endif // HWY_TARGET <= HWY_AVX3 + +template <typename T> +using Full256 = Simd<T, 32 / sizeof(T), 0>; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } +HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } +HWY_INLINE __m256i BitCastToInteger(__m256d v) { + return _mm256_castpd_si256(v); +} + +template <typename T> +HWY_INLINE Vec256<uint8_t> BitCastToByte(Vec256<T> v) { + return Vec256<uint8_t>{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template <typename T> +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +template <> +struct BitCastFromInteger256<float> { + HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } +}; +template <> +struct BitCastFromInteger256<double> { + HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } +}; + +template <typename T> +HWY_INLINE Vec256<T> BitCastFromByte(Full256<T> /* tag */, Vec256<uint8_t> v) { + return Vec256<T>{BitCastFromInteger256<T>()(v.raw)}; +} + +} // namespace detail + +template <typename T, typename FromT> +HWY_API Vec256<T> BitCast(Full256<T> d, Vec256<FromT> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template <typename T> +HWY_API Vec256<T> Zero(Full256<T> /* tag */) { + return Vec256<T>{_mm256_setzero_si256()}; +} +HWY_API Vec256<float> Zero(Full256<float> /* tag */) { + return Vec256<float>{_mm256_setzero_ps()}; +} +HWY_API Vec256<double> Zero(Full256<double> /* tag */) { + return Vec256<double>{_mm256_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec256<uint8_t> Set(Full256<uint8_t> /* tag */, const uint8_t t) { + return Vec256<uint8_t>{_mm256_set1_epi8(static_cast<char>(t))}; // NOLINT +} +HWY_API Vec256<uint16_t> Set(Full256<uint16_t> /* tag */, const uint16_t t) { + return Vec256<uint16_t>{_mm256_set1_epi16(static_cast<short>(t))}; // NOLINT +} +HWY_API Vec256<uint32_t> Set(Full256<uint32_t> /* tag */, const uint32_t t) { + return Vec256<uint32_t>{_mm256_set1_epi32(static_cast<int>(t))}; +} +HWY_API Vec256<uint64_t> Set(Full256<uint64_t> /* tag */, const uint64_t t) { + return Vec256<uint64_t>{ + _mm256_set1_epi64x(static_cast<long long>(t))}; // NOLINT +} +HWY_API Vec256<int8_t> Set(Full256<int8_t> /* tag */, const int8_t t) { + return Vec256<int8_t>{_mm256_set1_epi8(static_cast<char>(t))}; // NOLINT +} +HWY_API Vec256<int16_t> Set(Full256<int16_t> /* tag */, const int16_t t) { + return Vec256<int16_t>{_mm256_set1_epi16(static_cast<short>(t))}; // NOLINT +} +HWY_API Vec256<int32_t> Set(Full256<int32_t> /* tag */, const int32_t t) { + return Vec256<int32_t>{_mm256_set1_epi32(t)}; +} +HWY_API Vec256<int64_t> Set(Full256<int64_t> /* tag */, const int64_t t) { + return Vec256<int64_t>{ + _mm256_set1_epi64x(static_cast<long long>(t))}; // NOLINT +} +HWY_API Vec256<float> Set(Full256<float> /* tag */, const float t) { + return Vec256<float>{_mm256_set1_ps(t)}; +} +HWY_API Vec256<double> Set(Full256<double> /* tag */, const double t) { + return Vec256<double>{_mm256_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template <typename T> +HWY_API Vec256<T> Undefined(Full256<T> /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec256<T>{_mm256_undefined_si256()}; +} +HWY_API Vec256<float> Undefined(Full256<float> /* tag */) { + return Vec256<float>{_mm256_undefined_ps()}; +} +HWY_API Vec256<double> Undefined(Full256<double> /* tag */) { + return Vec256<double>{_mm256_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ And + +template <typename T> +HWY_API Vec256<T> And(Vec256<T> a, Vec256<T> b) { + return Vec256<T>{_mm256_and_si256(a.raw, b.raw)}; +} + +HWY_API Vec256<float> And(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_and_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> And(const Vec256<double> a, const Vec256<double> b) { + return Vec256<double>{_mm256_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template <typename T> +HWY_API Vec256<T> AndNot(Vec256<T> not_mask, Vec256<T> mask) { + return Vec256<T>{_mm256_andnot_si256(not_mask.raw, mask.raw)}; +} +HWY_API Vec256<float> AndNot(const Vec256<float> not_mask, + const Vec256<float> mask) { + return Vec256<float>{_mm256_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec256<double> AndNot(const Vec256<double> not_mask, + const Vec256<double> mask) { + return Vec256<double>{_mm256_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template <typename T> +HWY_API Vec256<T> Or(Vec256<T> a, Vec256<T> b) { + return Vec256<T>{_mm256_or_si256(a.raw, b.raw)}; +} + +HWY_API Vec256<float> Or(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_or_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> Or(const Vec256<double> a, const Vec256<double> b) { + return Vec256<double>{_mm256_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template <typename T> +HWY_API Vec256<T> Xor(Vec256<T> a, Vec256<T> b) { + return Vec256<T>{_mm256_xor_si256(a.raw, b.raw)}; +} + +HWY_API Vec256<float> Xor(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> Xor(const Vec256<double> a, const Vec256<double> b) { + return Vec256<double>{_mm256_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Not +template <typename T> +HWY_API Vec256<T> Not(const Vec256<T> v) { + using TU = MakeUnsigned<T>; +#if HWY_TARGET <= HWY_AVX3 + const __m256i vu = BitCast(Full256<TU>(), v).raw; + return BitCast(Full256<T>(), + Vec256<TU>{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); +#else + return Xor(v, BitCast(Full256<T>(), Vec256<TU>{_mm256_set1_epi32(-1)})); +#endif +} + +// ------------------------------ Xor3 +template <typename T> +HWY_API Vec256<T> Xor3(Vec256<T> x1, Vec256<T> x2, Vec256<T> x3) { +#if HWY_TARGET <= HWY_AVX3 + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +#else + return Xor(x1, Xor(x2, x3)); +#endif +} + +// ------------------------------ Or3 +template <typename T> +HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) { +#if HWY_TARGET <= HWY_AVX3 + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template <typename T> +HWY_API Vec256<T> OrAnd(Vec256<T> o, Vec256<T> a1, Vec256<T> a2) { +#if HWY_TARGET <= HWY_AVX3 + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template <typename T> +HWY_API Vec256<T> IfVecThenElse(Vec256<T> mask, Vec256<T> yes, Vec256<T> no) { +#if HWY_TARGET <= HWY_AVX3 + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ Operator overloads (internal-only if float) + +template <typename T> +HWY_API Vec256<T> operator&(const Vec256<T> a, const Vec256<T> b) { + return And(a, b); +} + +template <typename T> +HWY_API Vec256<T> operator|(const Vec256<T> a, const Vec256<T> b) { + return Or(a, b); +} + +template <typename T> +HWY_API Vec256<T> operator^(const Vec256<T> a, const Vec256<T> b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template <typename T> +HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec256<T> v) { + return Vec256<T>{_mm256_popcnt_epi8(v.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec256<T> v) { + return Vec256<T>{_mm256_popcnt_epi16(v.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec256<T> v) { + return Vec256<T>{_mm256_popcnt_epi32(v.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec256<T> v) { + return Vec256<T>{_mm256_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> PopulationCount(Vec256<T> v) { + return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ CopySign + +template <typename T> +HWY_API Vec256<T> CopySign(const Vec256<T> magn, const Vec256<T> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + + const Full256<T> d; + const auto msb = SignBit(d); + +#if HWY_TARGET <= HWY_AVX3 + const Rebind<MakeUnsigned<T>, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m256i out = _mm256_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +#else + return Or(AndNot(msb, magn), And(msb, sign)); +#endif +} + +template <typename T> +HWY_API Vec256<T> CopySignToAbs(const Vec256<T> abs, const Vec256<T> sign) { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +#else + return Or(abs, And(SignBit(Full256<T>()), sign)); +#endif +} + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template <typename T> +HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<1> /* tag */, Mask256<T> mask, + Vec256<T> yes, Vec256<T> no) { + return Vec256<T>{_mm256_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<2> /* tag */, Mask256<T> mask, + Vec256<T> yes, Vec256<T> no) { + return Vec256<T>{_mm256_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<4> /* tag */, Mask256<T> mask, + Vec256<T> yes, Vec256<T> no) { + return Vec256<T>{_mm256_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<8> /* tag */, Mask256<T> mask, + Vec256<T> yes, Vec256<T> no) { + return Vec256<T>{_mm256_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) { + return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); +} +HWY_API Vec256<float> IfThenElse(Mask256<float> mask, Vec256<float> yes, + Vec256<float> no) { + return Vec256<float>{_mm256_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +HWY_API Vec256<double> IfThenElse(Mask256<double> mask, Vec256<double> yes, + Vec256<double> no) { + return Vec256<double>{_mm256_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template <typename T> +HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256<T> mask, + Vec256<T> yes) { + return Vec256<T>{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256<T> mask, + Vec256<T> yes) { + return Vec256<T>{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256<T> mask, + Vec256<T> yes) { + return Vec256<T>{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256<T> mask, + Vec256<T> yes) { + return Vec256<T>{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) { + return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); +} +HWY_API Vec256<float> IfThenElseZero(Mask256<float> mask, Vec256<float> yes) { + return Vec256<float>{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec256<double> IfThenElseZero(Mask256<double> mask, + Vec256<double> yes) { + return Vec256<double>{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template <typename T> +HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256<T> mask, + Vec256<T> no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec256<T>{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256<T> mask, + Vec256<T> no) { + return Vec256<T>{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256<T> mask, + Vec256<T> no) { + return Vec256<T>{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T> +HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256<T> mask, + Vec256<T> no) { + return Vec256<T>{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) { + return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); +} +HWY_API Vec256<float> IfThenZeroElse(Mask256<float> mask, Vec256<float> no) { + return Vec256<float>{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec256<double> IfThenZeroElse(Mask256<double> mask, Vec256<double> no) { + return Vec256<double>{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template <typename T> +HWY_API Vec256<T> ZeroIfNegative(const Vec256<T> v) { + static_assert(IsSigned<T>(), "Only for float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ------------------------------ Mask logical + +namespace detail { + +template <typename T> +HWY_INLINE Mask256<T> And(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kand_mask32(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask32>(a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> And(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kand_mask16(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> And(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> And(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask32>(~a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask256<T> Or(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kor_mask32(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask32>(a.raw | b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> Or(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kor_mask16(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> Or(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> Or(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask256<T> Xor(hwy::SizeTag<1> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask32>(a.raw ^ b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> Xor(hwy::SizeTag<2> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> Xor(hwy::SizeTag<4> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> Xor(hwy::SizeTag<8> /*tag*/, const Mask256<T> a, + const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask256<T> a, const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask256<T> a, const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask256<T> a, const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask256<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template <typename T> +HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask256<T> a, const Mask256<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256<T>{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask256<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} + +} // namespace detail + +template <typename T> +HWY_API Mask256<T> And(const Mask256<T> a, Mask256<T> b) { + return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask256<T> AndNot(const Mask256<T> a, Mask256<T> b) { + return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask256<T> Or(const Mask256<T> a, Mask256<T> b) { + return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask256<T> Xor(const Mask256<T> a, Mask256<T> b) { + return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask256<T> Not(const Mask256<T> m) { + // Flip only the valid bits. + constexpr size_t N = 32 / sizeof(T); + return Xor(m, Mask256<T>::FromBits((1ull << N) - 1)); +} + +template <typename T> +HWY_API Mask256<T> ExclusiveNeither(const Mask256<T> a, Mask256<T> b) { + return detail::ExclusiveNeither(hwy::SizeTag<sizeof(T)>(), a, b); +} + +#else // AVX2 + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template <typename T> +HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) { + return Mask256<T>{v.raw}; +} + +template <typename T> +HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { + return Vec256<T>{v.raw}; +} + +template <typename T> +HWY_API Vec256<T> VecFromMask(Full256<T> /* tag */, const Mask256<T> v) { + return Vec256<T>{v.raw}; +} + +// ------------------------------ IfThenElse + +// mask ? yes : no +template <typename T> +HWY_API Vec256<T> IfThenElse(const Mask256<T> mask, const Vec256<T> yes, + const Vec256<T> no) { + return Vec256<T>{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256<float> IfThenElse(const Mask256<float> mask, + const Vec256<float> yes, + const Vec256<float> no) { + return Vec256<float>{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256<double> IfThenElse(const Mask256<double> mask, + const Vec256<double> yes, + const Vec256<double> no) { + return Vec256<double>{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template <typename T> +HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) { + return yes & VecFromMask(Full256<T>(), mask); +} + +// mask ? 0 : no +template <typename T> +HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) { + return AndNot(VecFromMask(Full256<T>(), mask), no); +} + +template <typename T> +HWY_API Vec256<T> ZeroIfNegative(Vec256<T> v) { + static_assert(IsSigned<T>(), "Only for float"); + const auto zero = Zero(Full256<T>()); + // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes + return IfThenElse(MaskFromVec(v), zero, v); +} + +// ------------------------------ Mask logical + +template <typename T> +HWY_API Mask256<T> Not(const Mask256<T> m) { + return MaskFromVec(Not(VecFromMask(Full256<T>(), m))); +} + +template <typename T> +HWY_API Mask256<T> And(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> AndNot(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> Or(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> Xor(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template <typename T> +HWY_API Mask256<T> ExclusiveNeither(const Mask256<T> a, Mask256<T> b) { + const Full256<T> d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template <typename TFrom, typename TTo> +HWY_API Mask256<TTo> RebindMask(Full256<TTo> /*tag*/, Mask256<TFrom> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask256<TTo>{m.raw}; +} + +namespace detail { + +template <typename T> +HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<1> /*tag*/, const Vec256<T> v, + const Vec256<T> bit) { + return Mask256<T>{_mm256_test_epi8_mask(v.raw, bit.raw)}; +} +template <typename T> +HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<2> /*tag*/, const Vec256<T> v, + const Vec256<T> bit) { + return Mask256<T>{_mm256_test_epi16_mask(v.raw, bit.raw)}; +} +template <typename T> +HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<4> /*tag*/, const Vec256<T> v, + const Vec256<T> bit) { + return Mask256<T>{_mm256_test_epi32_mask(v.raw, bit.raw)}; +} +template <typename T> +HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<8> /*tag*/, const Vec256<T> v, + const Vec256<T> bit) { + return Mask256<T>{_mm256_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Mask256<T> TestBit(const Vec256<T> v, const Vec256<T> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); +} + +// ------------------------------ Equality + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256<float> operator==(Vec256<float> a, Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256<double> operator==(Vec256<double> a, Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256<float> operator!=(Vec256<float> a, Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask256<double> operator!=(Vec256<double> a, Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask256<int8_t> operator>(Vec256<int8_t> a, Vec256<int8_t> b) { + return Mask256<int8_t>{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256<int16_t> operator>(Vec256<int16_t> a, Vec256<int16_t> b) { + return Mask256<int16_t>{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256<int32_t> operator>(Vec256<int32_t> a, Vec256<int32_t> b) { + return Mask256<int32_t>{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256<int64_t> operator>(Vec256<int64_t> a, Vec256<int64_t> b) { + return Mask256<int64_t>{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256<uint8_t> operator>(Vec256<uint8_t> a, Vec256<uint8_t> b) { + return Mask256<uint8_t>{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256<uint16_t> operator>(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Mask256<uint16_t>{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256<uint32_t> operator>(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Mask256<uint32_t>{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256<uint64_t> operator>(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + return Mask256<uint64_t>{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256<float> operator>(Vec256<float> a, Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256<double> operator>(Vec256<double> a, Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +HWY_API Mask256<float> operator>=(Vec256<float> a, Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256<double> operator>=(Vec256<double> a, Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Mask + +namespace detail { + +template <typename T> +HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256<T> v) { + return Mask256<T>{_mm256_movepi8_mask(v.raw)}; +} +template <typename T> +HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256<T> v) { + return Mask256<T>{_mm256_movepi16_mask(v.raw)}; +} +template <typename T> +HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256<T> v) { + return Mask256<T>{_mm256_movepi32_mask(v.raw)}; +} +template <typename T> +HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256<T> v) { + return Mask256<T>{_mm256_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) { + return detail::MaskFromVec(hwy::SizeTag<sizeof(T)>(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask256<float> MaskFromVec(const Vec256<float> v) { + return Mask256<float>{MaskFromVec(BitCast(Full256<int32_t>(), v)).raw}; +} +HWY_API Mask256<double> MaskFromVec(const Vec256<double> v) { + return Mask256<double>{MaskFromVec(BitCast(Full256<int64_t>(), v)).raw}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { + return Vec256<T>{_mm256_movm_epi8(v.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { + return Vec256<T>{_mm256_movm_epi16(v.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { + return Vec256<T>{_mm256_movm_epi32(v.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> VecFromMask(const Mask256<T> v) { + return Vec256<T>{_mm256_movm_epi64(v.raw)}; +} + +HWY_API Vec256<float> VecFromMask(const Mask256<float> v) { + return Vec256<float>{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; +} + +HWY_API Vec256<double> VecFromMask(const Mask256<double> v) { + return Vec256<double>{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; +} + +template <typename T> +HWY_API Vec256<T> VecFromMask(Full256<T> /* tag */, const Mask256<T> v) { + return VecFromMask(v); +} + +#else // AVX2 + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template <typename TFrom, typename TTo> +HWY_API Mask256<TTo> RebindMask(Full256<TTo> d_to, Mask256<TFrom> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MaskFromVec(BitCast(d_to, VecFromMask(Full256<TFrom>(), m))); +} + +template <typename T> +HWY_API Mask256<T> TestBit(const Vec256<T> v, const Vec256<T> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask256<T> operator==(const Vec256<T> a, const Vec256<T> b) { + return Mask256<T>{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +HWY_API Mask256<float> operator==(const Vec256<float> a, + const Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256<double> operator==(const Vec256<double> a, + const Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template <typename T> +HWY_API Mask256<T> operator!=(const Vec256<T> a, const Vec256<T> b) { + return Not(a == b); +} +HWY_API Mask256<float> operator!=(const Vec256<float> a, + const Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; +} +HWY_API Mask256<double> operator!=(const Vec256<double> a, + const Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 +// to perform an unsigned comparison instead of the intended signed. Workaround +// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy +#if HWY_COMPILER_GCC != 0 && HWY_COMPILER_GCC < 930 +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 +#else +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 +#endif + +HWY_API Mask256<int8_t> Gt(hwy::SignedTag /*tag*/, Vec256<int8_t> a, + Vec256<int8_t> b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256<int8_t>{static_cast<__m256i>(reinterpret_cast<i8x32>(a.raw) > + reinterpret_cast<i8x32>(b.raw))}; +#else + return Mask256<int8_t>{_mm256_cmpgt_epi8(a.raw, b.raw)}; +#endif +} +HWY_API Mask256<int16_t> Gt(hwy::SignedTag /*tag*/, Vec256<int16_t> a, + Vec256<int16_t> b) { + return Mask256<int16_t>{_mm256_cmpgt_epi16(a.raw, b.raw)}; +} +HWY_API Mask256<int32_t> Gt(hwy::SignedTag /*tag*/, Vec256<int32_t> a, + Vec256<int32_t> b) { + return Mask256<int32_t>{_mm256_cmpgt_epi32(a.raw, b.raw)}; +} +HWY_API Mask256<int64_t> Gt(hwy::SignedTag /*tag*/, Vec256<int64_t> a, + Vec256<int64_t> b) { + return Mask256<int64_t>{_mm256_cmpgt_epi64(a.raw, b.raw)}; +} + +template <typename T> +HWY_INLINE Mask256<T> Gt(hwy::UnsignedTag /*tag*/, Vec256<T> a, Vec256<T> b) { + const Full256<T> du; + const RebindToSigned<decltype(du)> di; + const Vec256<T> msb = Set(du, (LimitsMax<T>() >> 1) + 1); + return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); +} + +HWY_API Mask256<float> Gt(hwy::FloatTag /*tag*/, Vec256<float> a, + Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256<double> Gt(hwy::FloatTag /*tag*/, Vec256<double> a, + Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; +} + +} // namespace detail + +template <typename T> +HWY_API Mask256<T> operator>(Vec256<T> a, Vec256<T> b) { + return detail::Gt(hwy::TypeTag<T>(), a, b); +} + +// ------------------------------ Weak inequality + +HWY_API Mask256<float> operator>=(const Vec256<float> a, + const Vec256<float> b) { + return Mask256<float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256<double> operator>=(const Vec256<double> a, + const Vec256<double> b) { + return Mask256<double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template <typename T> +HWY_API Mask256<T> operator<(const Vec256<T> a, const Vec256<T> b) { + return b > a; +} + +template <typename T> +HWY_API Mask256<T> operator<=(const Vec256<T> a, const Vec256<T> b) { + return b >= a; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256<uint8_t> Min(const Vec256<uint8_t> a, const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> Min(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> Min(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> Min(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<uint64_t>{_mm256_min_epu64(a.raw, b.raw)}; +#else + const Full256<uint64_t> du; + const Full256<int64_t> di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +HWY_API Vec256<int8_t> Min(const Vec256<int8_t> a, const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> Min(const Vec256<int16_t> a, const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> Min(const Vec256<int32_t> a, const Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> Min(const Vec256<int64_t> a, const Vec256<int64_t> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int64_t>{_mm256_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +HWY_API Vec256<float> Min(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_min_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> Min(const Vec256<double> a, const Vec256<double> b) { + return Vec256<double>{_mm256_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256<uint8_t> Max(const Vec256<uint8_t> a, const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> Max(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> Max(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> Max(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<uint64_t>{_mm256_max_epu64(a.raw, b.raw)}; +#else + const Full256<uint64_t> du; + const Full256<int64_t> di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +HWY_API Vec256<int8_t> Max(const Vec256<int8_t> a, const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> Max(const Vec256<int16_t> a, const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> Max(const Vec256<int32_t> a, const Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> Max(const Vec256<int64_t> a, const Vec256<int64_t> b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int64_t>{_mm256_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +HWY_API Vec256<float> Max(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_max_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> Max(const Vec256<double> a, const Vec256<double> b) { + return Vec256<double>{_mm256_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ FirstN (Iota, Lt) + +template <typename T> +HWY_API Mask256<T> FirstN(const Full256<T> d, size_t n) { +#if HWY_TARGET <= HWY_AVX3 + (void)d; + constexpr size_t N = 32 / sizeof(T); +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << N) - 1; + // BZHI only looks at the lower 8 bits of n! + return Mask256<T>::FromBits((n > 255) ? all : _bzhi_u64(all, n)); +#else + const uint32_t all = static_cast<uint32_t>((1ull << N) - 1); + // BZHI only looks at the lower 8 bits of n! + return Mask256<T>::FromBits( + (n > 255) ? all : _bzhi_u32(all, static_cast<uint32_t>(n))); +#endif // HWY_ARCH_X86_64 +#else + const RebindToSigned<decltype(d)> di; // Signed comparisons are cheaper. + return RebindMask(d, Iota(di, 0) < Set(di, static_cast<MakeSigned<T>>(n))); +#endif +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256<uint8_t> operator+(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> operator+(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> operator+(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> operator+(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + return Vec256<uint64_t>{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256<int8_t> operator+(const Vec256<int8_t> a, + const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> operator+(const Vec256<int16_t> a, + const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> operator+(const Vec256<int32_t> a, + const Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> operator+(const Vec256<int64_t> a, + const Vec256<int64_t> b) { + return Vec256<int64_t>{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256<float> operator+(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_add_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> operator+(const Vec256<double> a, + const Vec256<double> b) { + return Vec256<double>{_mm256_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256<uint8_t> operator-(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> operator-(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> operator-(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> operator-(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + return Vec256<uint64_t>{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256<int8_t> operator-(const Vec256<int8_t> a, + const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> operator-(const Vec256<int16_t> a, + const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> operator-(const Vec256<int32_t> a, + const Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> operator-(const Vec256<int64_t> a, + const Vec256<int64_t> b) { + return Vec256<int64_t>{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256<float> operator-(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> operator-(const Vec256<double> a, + const Vec256<double> b) { + return Vec256<double>{_mm256_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256<uint64_t> SumsOf8(const Vec256<uint8_t> v) { + return Vec256<uint64_t>{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256<uint8_t> SaturatedAdd(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> SaturatedAdd(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256<int8_t> SaturatedAdd(const Vec256<int8_t> a, + const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> SaturatedAdd(const Vec256<int16_t> a, + const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256<uint8_t> SaturatedSub(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> SaturatedSub(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256<int8_t> SaturatedSub(const Vec256<int8_t> a, + const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> SaturatedSub(const Vec256<int16_t> a, + const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256<uint8_t> AverageRound(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> AverageRound(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256<int8_t> Abs(const Vec256<int8_t> v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (wrong result) + const auto zero = Zero(Full256<int8_t>()); + return Vec256<int8_t>{_mm256_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec256<int8_t>{_mm256_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec256<int16_t> Abs(const Vec256<int16_t> v) { + return Vec256<int16_t>{_mm256_abs_epi16(v.raw)}; +} +HWY_API Vec256<int32_t> Abs(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_abs_epi32(v.raw)}; +} +// i64 is implemented after BroadcastSignBit. + +HWY_API Vec256<float> Abs(const Vec256<float> v) { + const Vec256<int32_t> mask{_mm256_set1_epi32(0x7FFFFFFF)}; + return v & BitCast(Full256<float>(), mask); +} +HWY_API Vec256<double> Abs(const Vec256<double> v) { + const Vec256<int64_t> mask{_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; + return v & BitCast(Full256<double>(), mask); +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256<uint16_t> operator*(Vec256<uint16_t> a, Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> operator*(Vec256<uint32_t> a, Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256<int16_t> operator*(Vec256<int16_t> a, Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> operator*(Vec256<int32_t> a, Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_mullo_epi32(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256<uint16_t> MulHigh(Vec256<uint16_t> a, Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> MulHigh(Vec256<int16_t> a, Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec256<int16_t> MulFixedPoint15(Vec256<int16_t> a, Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256<int64_t> MulEven(Vec256<int32_t> a, Vec256<int32_t> b) { + return Vec256<int64_t>{_mm256_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> MulEven(Vec256<uint32_t> a, Vec256<uint32_t> b) { + return Vec256<uint64_t>{_mm256_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +template <int kBits> +HWY_API Vec256<uint16_t> ShiftLeft(const Vec256<uint16_t> v) { + return Vec256<uint16_t>{_mm256_slli_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<uint32_t> ShiftLeft(const Vec256<uint32_t> v) { + return Vec256<uint32_t>{_mm256_slli_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<uint64_t> ShiftLeft(const Vec256<uint64_t> v) { + return Vec256<uint64_t>{_mm256_slli_epi64(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<int16_t> ShiftLeft(const Vec256<int16_t> v) { + return Vec256<int16_t>{_mm256_slli_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<int32_t> ShiftLeft(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_slli_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<int64_t> ShiftLeft(const Vec256<int64_t> v) { + return Vec256<int64_t>{_mm256_slli_epi64(v.raw, kBits)}; +} + +template <int kBits, typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec256<T> ShiftLeft(const Vec256<T> v) { + const Full256<T> d8; + const RepartitionToWide<decltype(d8)> d16; + const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template <int kBits> +HWY_API Vec256<uint16_t> ShiftRight(const Vec256<uint16_t> v) { + return Vec256<uint16_t>{_mm256_srli_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<uint32_t> ShiftRight(const Vec256<uint32_t> v) { + return Vec256<uint32_t>{_mm256_srli_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<uint64_t> ShiftRight(const Vec256<uint64_t> v) { + return Vec256<uint64_t>{_mm256_srli_epi64(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<uint8_t> ShiftRight(const Vec256<uint8_t> v) { + const Full256<uint8_t> d8; + // Use raw instead of BitCast to support N=1. + const Vec256<uint8_t> shifted{ShiftRight<kBits>(Vec256<uint16_t>{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template <int kBits> +HWY_API Vec256<int16_t> ShiftRight(const Vec256<int16_t> v) { + return Vec256<int16_t>{_mm256_srai_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<int32_t> ShiftRight(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_srai_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec256<int8_t> ShiftRight(const Vec256<int8_t> v) { + const Full256<int8_t> di; + const Full256<uint8_t> du; + const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ RotateRight + +template <int kBits> +HWY_API Vec256<uint32_t> RotateRight(const Vec256<uint32_t> v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256<uint32_t>{_mm256_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(31, 32 - kBits)>(v)); +#endif +} + +template <int kBits> +HWY_API Vec256<uint64_t> RotateRight(const Vec256<uint64_t> v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256<uint64_t>{_mm256_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight<kBits>(v), ShiftLeft<HWY_MIN(63, 64 - kBits)>(v)); +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256<int8_t> BroadcastSignBit(const Vec256<int8_t> v) { + return VecFromMask(v < Zero(Full256<int8_t>())); +} + +HWY_API Vec256<int16_t> BroadcastSignBit(const Vec256<int16_t> v) { + return ShiftRight<15>(v); +} + +HWY_API Vec256<int32_t> BroadcastSignBit(const Vec256<int32_t> v) { + return ShiftRight<31>(v); +} + +HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) { +#if HWY_TARGET == HWY_AVX2 + return VecFromMask(v < Zero(Full256<int64_t>())); +#else + return Vec256<int64_t>{_mm256_srai_epi64(v.raw, 63)}; +#endif +} + +template <int kBits> +HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int64_t>{_mm256_srai_epi64(v.raw, kBits)}; +#else + const Full256<int64_t> di; + const Full256<uint64_t> du; + const auto right = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +HWY_API Vec256<int64_t> Abs(const Vec256<int64_t> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int64_t>{_mm256_abs_epi64(v.raw)}; +#else + const auto zero = Zero(Full256<int64_t>()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +HWY_API Vec256<int8_t> IfNegativeThenElse(Vec256<int8_t> v, Vec256<int8_t> yes, + Vec256<int8_t> no) { + // int8: AVX2 IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> IfNegativeThenElse(Vec256<T> v, Vec256<T> yes, Vec256<T> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + const Full256<T> d; + const RebindToSigned<decltype(d)> di; + + // 16-bit: no native blendv, so copy sign to lower byte's MSB. + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API Vec256<T> IfNegativeThenElse(Vec256<T> v, Vec256<T> yes, Vec256<T> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + const Full256<T> d; + const RebindToFloat<decltype(d)> df; + + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + const MFromD<decltype(df)> msb = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec256<uint16_t> ShiftLeftSame(const Vec256<uint16_t> v, + const int bits) { + return Vec256<uint16_t>{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256<uint32_t> ShiftLeftSame(const Vec256<uint32_t> v, + const int bits) { + return Vec256<uint32_t>{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256<uint64_t> ShiftLeftSame(const Vec256<uint64_t> v, + const int bits) { + return Vec256<uint64_t>{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256<int16_t> ShiftLeftSame(const Vec256<int16_t> v, const int bits) { + return Vec256<int16_t>{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256<int32_t> ShiftLeftSame(const Vec256<int32_t> v, const int bits) { + return Vec256<int32_t>{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256<int64_t> ShiftLeftSame(const Vec256<int64_t> v, const int bits) { + return Vec256<int64_t>{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec256<T> ShiftLeftSame(const Vec256<T> v, const int bits) { + const Full256<T> d8; + const RepartitionToWide<decltype(d8)> d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256<uint16_t> ShiftRightSame(const Vec256<uint16_t> v, + const int bits) { + return Vec256<uint16_t>{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256<uint32_t> ShiftRightSame(const Vec256<uint32_t> v, + const int bits) { + return Vec256<uint32_t>{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256<uint64_t> ShiftRightSame(const Vec256<uint64_t> v, + const int bits) { + return Vec256<uint64_t>{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256<uint8_t> ShiftRightSame(Vec256<uint8_t> v, const int bits) { + const Full256<uint8_t> d8; + const RepartitionToWide<decltype(d8)> d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits)); +} + +HWY_API Vec256<int16_t> ShiftRightSame(const Vec256<int16_t> v, + const int bits) { + return Vec256<int16_t>{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256<int32_t> ShiftRightSame(const Vec256<int32_t> v, + const int bits) { + return Vec256<int32_t>{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256<int64_t> ShiftRightSame(const Vec256<int64_t> v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int64_t>{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Full256<int64_t> di; + const Full256<uint64_t> du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +HWY_API Vec256<int8_t> ShiftRightSame(Vec256<int8_t> v, const int bits) { + const Full256<int8_t> di; + const Full256<uint8_t> du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Neg (Xor, Sub) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template <typename T> +HWY_INLINE Vec256<T> Neg(hwy::FloatTag /*tag*/, const Vec256<T> v) { + return Xor(v, SignBit(Full256<T>())); +} + +// Not floating-point +template <typename T> +HWY_INLINE Vec256<T> Neg(hwy::NonFloatTag /*tag*/, const Vec256<T> v) { + return Zero(Full256<T>()) - v; +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> Neg(const Vec256<T> v) { + return detail::Neg(hwy::IsFloatTag<T>(), v); +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec256<float> operator*(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> operator*(const Vec256<double> a, + const Vec256<double> b) { + return Vec256<double>{_mm256_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec256<float> operator/(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_div_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> operator/(const Vec256<double> a, + const Vec256<double> b) { + return Vec256<double>{_mm256_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec256<float> ApproximateReciprocal(const Vec256<float> v) { + return Vec256<float>{_mm256_rcp_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec256<float> AbsDiff(const Vec256<float> a, const Vec256<float> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec256<float> MulAdd(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256<float>{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256<double> MulAdd(const Vec256<double> mul, const Vec256<double> x, + const Vec256<double> add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256<double>{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +HWY_API Vec256<float> NegMulAdd(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256<float>{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256<double> NegMulAdd(const Vec256<double> mul, + const Vec256<double> x, + const Vec256<double> add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256<double>{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +HWY_API Vec256<float> MulSub(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256<float>{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256<double> MulSub(const Vec256<double> mul, const Vec256<double> x, + const Vec256<double> sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256<double>{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +HWY_API Vec256<float> NegMulSub(const Vec256<float> mul, const Vec256<float> x, + const Vec256<float> sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256<float>{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256<double> NegMulSub(const Vec256<double> mul, + const Vec256<double> x, + const Vec256<double> sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256<double>{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec256<float> Sqrt(const Vec256<float> v) { + return Vec256<float>{_mm256_sqrt_ps(v.raw)}; +} +HWY_API Vec256<double> Sqrt(const Vec256<double> v) { + return Vec256<double>{_mm256_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec256<float> ApproximateReciprocalSqrt(const Vec256<float> v) { + return Vec256<float>{_mm256_rsqrt_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +HWY_API Vec256<float> Round(const Vec256<float> v) { + return Vec256<float>{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256<double> Round(const Vec256<double> v) { + return Vec256<double>{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec256<float> Trunc(const Vec256<float> v) { + return Vec256<float>{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256<double> Trunc(const Vec256<double> v) { + return Vec256<double>{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256<float> Ceil(const Vec256<float> v) { + return Vec256<float>{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256<double> Ceil(const Vec256<double> v) { + return Vec256<double>{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec256<float> Floor(const Vec256<float> v) { + return Vec256<float>{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256<double> Floor(const Vec256<double> v) { + return Vec256<double>{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ------------------------------ Floating-point classification + +HWY_API Mask256<float> IsNaN(const Vec256<float> v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256<float>{_mm256_fpclass_ps_mask(v.raw, 0x81)}; +#else + return Mask256<float>{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} +HWY_API Mask256<double> IsNaN(const Vec256<double> v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256<double>{_mm256_fpclass_pd_mask(v.raw, 0x81)}; +#else + return Mask256<double>{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +HWY_API Mask256<float> IsInf(const Vec256<float> v) { + return Mask256<float>{_mm256_fpclass_ps_mask(v.raw, 0x18)}; +} +HWY_API Mask256<double> IsInf(const Vec256<double> v) { + return Mask256<double>{_mm256_fpclass_pd_mask(v.raw, 0x18)}; +} + +HWY_API Mask256<float> IsFinite(const Vec256<float> v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256<float>{_mm256_fpclass_ps_mask(v.raw, 0x99)}); +} +HWY_API Mask256<double> IsFinite(const Vec256<double> v) { + return Not(Mask256<double>{_mm256_fpclass_pd_mask(v.raw, 0x99)}); +} + +#else + +template <typename T> +HWY_API Mask256<T> IsInf(const Vec256<T> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Full256<T> d; + const RebindToSigned<decltype(d)> di; + const VFromD<decltype(di)> vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vi, vi), Set(di, hwy::MaxExponentTimes2<T>()))); +} + +// Returns whether normal/subnormal/zero. +template <typename T> +HWY_API Mask256<T> IsFinite(const Vec256<T> v) { + static_assert(IsFloat<T>(), "Only for float"); + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + const RebindToSigned<decltype(d)> di; // cheaper than unsigned comparison + const VFromD<decltype(du)> vu = BitCast(du, v); + // Shift left to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). MSVC seems to generate + // incorrect code if we instead add vu + vu. + const VFromD<decltype(di)> exp = + BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(ShiftLeft<1>(vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField<T>()))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== MEMORY + +// ------------------------------ Load + +template <typename T> +HWY_API Vec256<T> Load(Full256<T> /* tag */, const T* HWY_RESTRICT aligned) { + return Vec256<T>{ + _mm256_load_si256(reinterpret_cast<const __m256i*>(aligned))}; +} +HWY_API Vec256<float> Load(Full256<float> /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec256<float>{_mm256_load_ps(aligned)}; +} +HWY_API Vec256<double> Load(Full256<double> /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec256<double>{_mm256_load_pd(aligned)}; +} + +template <typename T> +HWY_API Vec256<T> LoadU(Full256<T> /* tag */, const T* HWY_RESTRICT p) { + return Vec256<T>{_mm256_loadu_si256(reinterpret_cast<const __m256i*>(p))}; +} +HWY_API Vec256<float> LoadU(Full256<float> /* tag */, + const float* HWY_RESTRICT p) { + return Vec256<float>{_mm256_loadu_ps(p)}; +} +HWY_API Vec256<double> LoadU(Full256<double> /* tag */, + const double* HWY_RESTRICT p) { + return Vec256<double>{_mm256_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec256<T>{_mm256_maskz_loadu_epi8(m.raw, p)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec256<T>{_mm256_maskz_loadu_epi16(m.raw, p)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec256<T>{_mm256_maskz_loadu_epi32(m.raw, p)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec256<T>{_mm256_maskz_loadu_epi64(m.raw, p)}; +} + +HWY_API Vec256<float> MaskedLoad(Mask256<float> m, Full256<float> /* tag */, + const float* HWY_RESTRICT p) { + return Vec256<float>{_mm256_maskz_loadu_ps(m.raw, p)}; +} + +HWY_API Vec256<double> MaskedLoad(Mask256<double> m, Full256<double> /* tag */, + const double* HWY_RESTRICT p) { + return Vec256<double>{_mm256_maskz_loadu_pd(m.raw, p)}; +} + +#else // AVX2 + +// There is no maskload_epi8/16, so blend instead. +template <typename T, hwy::EnableIf<sizeof(T) <= 2>* = nullptr> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> d, + const T* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> /* tag */, + const T* HWY_RESTRICT p) { + auto pi = reinterpret_cast<const int*>(p); // NOLINT + return Vec256<T>{_mm256_maskload_epi32(pi, m.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> MaskedLoad(Mask256<T> m, Full256<T> /* tag */, + const T* HWY_RESTRICT p) { + auto pi = reinterpret_cast<const long long*>(p); // NOLINT + return Vec256<T>{_mm256_maskload_epi64(pi, m.raw)}; +} + +HWY_API Vec256<float> MaskedLoad(Mask256<float> m, Full256<float> d, + const float* HWY_RESTRICT p) { + const Vec256<int32_t> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + return Vec256<float>{_mm256_maskload_ps(p, mi.raw)}; +} + +HWY_API Vec256<double> MaskedLoad(Mask256<double> m, Full256<double> d, + const double* HWY_RESTRICT p) { + const Vec256<int64_t> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + return Vec256<double>{_mm256_maskload_pd(p, mi.raw)}; +} + +#endif + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template <typename T> +HWY_API Vec256<T> LoadDup128(Full256<T> /* tag */, const T* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note + // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the + // upper half undefined) is fine because we're overwriting that anyway. + // This workaround seems in turn to generate incorrect code in MSVC 2022 + // (19.31), so use broadcastsi128 there. + const __m128i v128 = LoadU(Full128<T>(), p).raw; + return Vec256<T>{ + _mm256_inserti128_si256(_mm256_castsi128_si256(v128), v128, 1)}; +#else + return Vec256<T>{_mm256_broadcastsi128_si256(LoadU(Full128<T>(), p).raw)}; +#endif +} +HWY_API Vec256<float> LoadDup128(Full256<float> /* tag */, + const float* const HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const __m128 v128 = LoadU(Full128<float>(), p).raw; + return Vec256<float>{ + _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; +#else + return Vec256<float>{_mm256_broadcast_ps(reinterpret_cast<const __m128*>(p))}; +#endif +} +HWY_API Vec256<double> LoadDup128(Full256<double> /* tag */, + const double* const HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const __m128d v128 = LoadU(Full128<double>(), p).raw; + return Vec256<double>{ + _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; +#else + return Vec256<double>{ + _mm256_broadcast_pd(reinterpret_cast<const __m128d*>(p))}; +#endif +} + +// ------------------------------ Store + +template <typename T> +HWY_API void Store(Vec256<T> v, Full256<T> /* tag */, T* HWY_RESTRICT aligned) { + _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Store(const Vec256<float> v, Full256<float> /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec256<double> v, Full256<double> /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_store_pd(aligned, v.raw); +} + +template <typename T> +HWY_API void StoreU(Vec256<T> v, Full256<T> /* tag */, T* HWY_RESTRICT p) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); +} +HWY_API void StoreU(const Vec256<float> v, Full256<float> /* tag */, + float* HWY_RESTRICT p) { + _mm256_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec256<double> v, Full256<double> /* tag */, + double* HWY_RESTRICT p) { + _mm256_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +#if HWY_TARGET <= HWY_AVX3 + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi8(p, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi16(p, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi32(p, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> /* tag */, + T* HWY_RESTRICT p) { + _mm256_mask_storeu_epi64(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256<float> v, Mask256<float> m, + Full256<float> /* tag */, float* HWY_RESTRICT p) { + _mm256_mask_storeu_ps(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256<double> v, Mask256<double> m, + Full256<double> /* tag */, double* HWY_RESTRICT p) { + _mm256_mask_storeu_pd(p, m.raw, v.raw); +} + +#else // AVX2 + +// Intel SDM says "No AC# reported for any mask bit combinations". However, AMD +// allows AC# if "Alignment checking enabled and: 256-bit memory operand not +// 32-byte aligned". Fortunately AC# is not enabled by default and requires both +// OS support (CR0) and the application to set rflags.AC. We assume these remain +// disabled because x86/x64 code and compiler output often contain misaligned +// scalar accesses, which would also fault. +// +// Caveat: these are slow on AMD Jaguar/Bulldozer. + +template <typename T, hwy::EnableIf<sizeof(T) <= 2>* = nullptr> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> d, + T* HWY_RESTRICT p) { + // There is no maskload_epi8/16. Blending is also unsafe because loading a + // full vector that crosses the array end causes asan faults. Resort to scalar + // code; the caller should instead use memcpy, assuming m is FirstN(d, n). + const RebindToUnsigned<decltype(d)> du; + using TU = TFromD<decltype(du)>; + alignas(32) TU buf[32 / sizeof(T)]; + alignas(32) TU mask[32 / sizeof(T)]; + Store(BitCast(du, v), du, buf); + Store(BitCast(du, VecFromMask(d, m)), du, mask); + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> /* tag */, + T* HWY_RESTRICT p) { + auto pi = reinterpret_cast<int*>(p); // NOLINT + _mm256_maskstore_epi32(pi, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void BlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> /* tag */, + T* HWY_RESTRICT p) { + auto pi = reinterpret_cast<long long*>(p); // NOLINT + _mm256_maskstore_epi64(pi, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256<float> v, Mask256<float> m, Full256<float> d, + float* HWY_RESTRICT p) { + const Vec256<int32_t> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + _mm256_maskstore_ps(p, mi.raw, v.raw); +} + +HWY_API void BlendedStore(Vec256<double> v, Mask256<double> m, + Full256<double> d, double* HWY_RESTRICT p) { + const Vec256<int64_t> mi = + BitCast(RebindToSigned<decltype(d)>(), VecFromMask(d, m)); + _mm256_maskstore_pd(p, mi.raw, v.raw); +} + +#endif + +// ------------------------------ Non-temporal stores + +template <typename T> +HWY_API void Stream(Vec256<T> v, Full256<T> /* tag */, + T* HWY_RESTRICT aligned) { + _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec256<float> v, Full256<float> /* tag */, + float* HWY_RESTRICT aligned) { + _mm256_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec256<double> v, Full256<double> /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template <typename T> +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec256<T> v, + Full256<T> /* tag */, T* HWY_RESTRICT base, + const Vec256<int32_t> offset) { + _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template <typename T> +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec256<T> v, + Full256<T> /* tag */, T* HWY_RESTRICT base, + const Vec256<int32_t> index) { + _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template <typename T> +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec256<T> v, + Full256<T> /* tag */, T* HWY_RESTRICT base, + const Vec256<int64_t> offset) { + _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template <typename T> +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec256<T> v, + Full256<T> /* tag */, T* HWY_RESTRICT base, + const Vec256<int64_t> index) { + _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template <typename T, typename Offset> +HWY_API void ScatterOffset(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT base, + const Vec256<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag<sizeof(T)>(), v, d, base, offset); +} +template <typename T, typename Index> +HWY_API void ScatterIndex(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT base, + const Vec256<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag<sizeof(T)>(), v, d, base, index); +} + +HWY_API void ScatterOffset(Vec256<float> v, Full256<float> /* tag */, + float* HWY_RESTRICT base, + const Vec256<int32_t> offset) { + _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec256<float> v, Full256<float> /* tag */, + float* HWY_RESTRICT base, + const Vec256<int32_t> index) { + _mm256_i32scatter_ps(base, index.raw, v.raw, 4); +} + +HWY_API void ScatterOffset(Vec256<double> v, Full256<double> /* tag */, + double* HWY_RESTRICT base, + const Vec256<int64_t> offset) { + _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec256<double> v, Full256<double> /* tag */, + double* HWY_RESTRICT base, + const Vec256<int64_t> index) { + _mm256_i64scatter_pd(base, index.raw, v.raw, 8); +} + +#else + +template <typename T, typename Offset> +HWY_API void ScatterOffset(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT base, + const Vec256<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Offset offset_lanes[N]; + Store(offset, Full256<Offset>(), offset_lanes); + + uint8_t* base_bytes = reinterpret_cast<uint8_t*>(base); + for (size_t i = 0; i < N; ++i) { + CopyBytes<sizeof(T)>(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template <typename T, typename Index> +HWY_API void ScatterIndex(Vec256<T> v, Full256<T> d, T* HWY_RESTRICT base, + const Vec256<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + + constexpr size_t N = 32 / sizeof(T); + alignas(32) T lanes[N]; + Store(v, d, lanes); + + alignas(32) Index index_lanes[N]; + Store(index, Full256<Index>(), index_lanes); + + for (size_t i = 0; i < N; ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +#endif + +// ------------------------------ Gather + +namespace detail { + +template <typename T> +HWY_INLINE Vec256<T> GatherOffset(hwy::SizeTag<4> /* tag */, + Full256<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec256<int32_t> offset) { + return Vec256<T>{_mm256_i32gather_epi32( + reinterpret_cast<const int32_t*>(base), offset.raw, 1)}; +} +template <typename T> +HWY_INLINE Vec256<T> GatherIndex(hwy::SizeTag<4> /* tag */, + Full256<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec256<int32_t> index) { + return Vec256<T>{_mm256_i32gather_epi32( + reinterpret_cast<const int32_t*>(base), index.raw, 4)}; +} + +template <typename T> +HWY_INLINE Vec256<T> GatherOffset(hwy::SizeTag<8> /* tag */, + Full256<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec256<int64_t> offset) { + return Vec256<T>{_mm256_i64gather_epi64( + reinterpret_cast<const GatherIndex64*>(base), offset.raw, 1)}; +} +template <typename T> +HWY_INLINE Vec256<T> GatherIndex(hwy::SizeTag<8> /* tag */, + Full256<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec256<int64_t> index) { + return Vec256<T>{_mm256_i64gather_epi64( + reinterpret_cast<const GatherIndex64*>(base), index.raw, 8)}; +} + +} // namespace detail + +template <typename T, typename Offset> +HWY_API Vec256<T> GatherOffset(Full256<T> d, const T* HWY_RESTRICT base, + const Vec256<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag<sizeof(T)>(), d, base, offset); +} +template <typename T, typename Index> +HWY_API Vec256<T> GatherIndex(Full256<T> d, const T* HWY_RESTRICT base, + const Vec256<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag<sizeof(T)>(), d, base, index); +} + +HWY_API Vec256<float> GatherOffset(Full256<float> /* tag */, + const float* HWY_RESTRICT base, + const Vec256<int32_t> offset) { + return Vec256<float>{_mm256_i32gather_ps(base, offset.raw, 1)}; +} +HWY_API Vec256<float> GatherIndex(Full256<float> /* tag */, + const float* HWY_RESTRICT base, + const Vec256<int32_t> index) { + return Vec256<float>{_mm256_i32gather_ps(base, index.raw, 4)}; +} + +HWY_API Vec256<double> GatherOffset(Full256<double> /* tag */, + const double* HWY_RESTRICT base, + const Vec256<int64_t> offset) { + return Vec256<double>{_mm256_i64gather_pd(base, offset.raw, 1)}; +} +HWY_API Vec256<double> GatherIndex(Full256<double> /* tag */, + const double* HWY_RESTRICT base, + const Vec256<int64_t> index) { + return Vec256<double>{_mm256_i64gather_pd(base, index.raw, 8)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template <typename T> +HWY_API Vec128<T> LowerHalf(Full128<T> /* tag */, Vec256<T> v) { + return Vec128<T>{_mm256_castsi256_si128(v.raw)}; +} +HWY_API Vec128<float> LowerHalf(Full128<float> /* tag */, Vec256<float> v) { + return Vec128<float>{_mm256_castps256_ps128(v.raw)}; +} +HWY_API Vec128<double> LowerHalf(Full128<double> /* tag */, Vec256<double> v) { + return Vec128<double>{_mm256_castpd256_pd128(v.raw)}; +} + +template <typename T> +HWY_API Vec128<T> LowerHalf(Vec256<T> v) { + return LowerHalf(Full128<T>(), v); +} + +// ------------------------------ UpperHalf + +template <typename T> +HWY_API Vec128<T> UpperHalf(Full128<T> /* tag */, Vec256<T> v) { + return Vec128<T>{_mm256_extracti128_si256(v.raw, 1)}; +} +HWY_API Vec128<float> UpperHalf(Full128<float> /* tag */, Vec256<float> v) { + return Vec128<float>{_mm256_extractf128_ps(v.raw, 1)}; +} +HWY_API Vec128<double> UpperHalf(Full128<double> /* tag */, Vec256<double> v) { + return Vec128<double>{_mm256_extractf128_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template <typename T> +HWY_API T ExtractLane(const Vec256<T> v, size_t i) { + const Full256<T> d; + HWY_DASSERT(i < Lanes(d)); + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template <typename T> +HWY_API Vec256<T> InsertLane(const Vec256<T> v, size_t i, T t) { + const Full256<T> d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ GetLane (LowerHalf) +template <typename T> +HWY_API T GetLane(const Vec256<T> v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (VEX encoded +// 128-bit instructions zero the upper lanes to avoid large penalties), a +// compiler could decide to optimize out code that relies on this. +// +// The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the +// zeroing, but it is not available on MSVC until 15.7 nor GCC until 10.1. For +// older GCC, we can still obtain the desired code thanks to pattern +// recognition; note that the expensive insert instruction is not actually +// generated, see https://gcc.godbolt.org/z/1MKGaP. + +#if !defined(HWY_HAVE_ZEXT) +#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC >= 1915) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) +#define HWY_HAVE_ZEXT 1 +#else +#define HWY_HAVE_ZEXT 0 +#endif +#endif // defined(HWY_HAVE_ZEXT) + +template <typename T> +HWY_API Vec256<T> ZeroExtendVector(Full256<T> /* tag */, Vec128<T> lo) { +#if HWY_HAVE_ZEXT +return Vec256<T>{_mm256_zextsi128_si256(lo.raw)}; +#else + return Vec256<T>{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; +#endif +} +HWY_API Vec256<float> ZeroExtendVector(Full256<float> /* tag */, + Vec128<float> lo) { +#if HWY_HAVE_ZEXT + return Vec256<float>{_mm256_zextps128_ps256(lo.raw)}; +#else + return Vec256<float>{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; +#endif +} +HWY_API Vec256<double> ZeroExtendVector(Full256<double> /* tag */, + Vec128<double> lo) { +#if HWY_HAVE_ZEXT + return Vec256<double>{_mm256_zextpd128_pd256(lo.raw)}; +#else + return Vec256<double>{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ Combine + +template <typename T> +HWY_API Vec256<T> Combine(Full256<T> d, Vec128<T> hi, Vec128<T> lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256<T>{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; +} +HWY_API Vec256<float> Combine(Full256<float> d, Vec128<float> hi, + Vec128<float> lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256<float>{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; +} +HWY_API Vec256<double> Combine(Full256<double> d, Vec128<double> hi, + Vec128<double> lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256<double>{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, typename T> +HWY_API Vec256<T> ShiftLeftBytes(Full256<T> /* tag */, const Vec256<T> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bslli_epi128. + return Vec256<T>{_mm256_slli_si256(v.raw, kBytes)}; +} + +template <int kBytes, typename T> +HWY_API Vec256<T> ShiftLeftBytes(const Vec256<T> v) { + return ShiftLeftBytes<kBytes>(Full256<T>(), v); +} + +// ------------------------------ ShiftLeftLanes + +template <int kLanes, typename T> +HWY_API Vec256<T> ShiftLeftLanes(Full256<T> d, const Vec256<T> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T> +HWY_API Vec256<T> ShiftLeftLanes(const Vec256<T> v) { + return ShiftLeftLanes<kLanes>(Full256<T>(), v); +} + +// ------------------------------ ShiftRightBytes + +template <int kBytes, typename T> +HWY_API Vec256<T> ShiftRightBytes(Full256<T> /* tag */, const Vec256<T> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bsrli_epi128. + return Vec256<T>{_mm256_srli_si256(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template <int kLanes, typename T> +HWY_API Vec256<T> ShiftRightLanes(Full256<T> d, const Vec256<T> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// ------------------------------ CombineShiftRightBytes + +// Extracts 128 bits from <hi, lo> by skipping the least-significant kBytes. +template <int kBytes, typename T, class V = Vec256<T>> +HWY_API V CombineShiftRightBytes(Full256<T> d, V hi, V lo) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Vec256<uint8_t>{_mm256_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template <int kLane> +HWY_API Vec256<uint16_t> Broadcast(const Vec256<uint16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256<uint16_t>{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256<uint16_t>{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template <int kLane> +HWY_API Vec256<uint32_t> Broadcast(const Vec256<uint32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template <int kLane> +HWY_API Vec256<uint64_t> Broadcast(const Vec256<uint64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256<uint64_t>{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Signed +template <int kLane> +HWY_API Vec256<int16_t> Broadcast(const Vec256<int16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec256<int16_t>{_mm256_unpacklo_epi64(lo, lo)}; + } else { + const __m256i hi = + _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec256<int16_t>{_mm256_unpackhi_epi64(hi, hi)}; + } +} +template <int kLane> +HWY_API Vec256<int32_t> Broadcast(const Vec256<int32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} +template <int kLane> +HWY_API Vec256<int64_t> Broadcast(const Vec256<int64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256<int64_t>{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +// Float +template <int kLane> +HWY_API Vec256<float> Broadcast(Vec256<float> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} +template <int kLane> +HWY_API Vec256<double> Broadcast(const Vec256<double> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256<double>{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle2301(const Vec256<T> v) { + return Vec256<T>{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256<float> Shuffle2301(const Vec256<float> v) { + return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// Used by generic_ops-inl.h +namespace detail { + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle2301(const Vec256<T> a, const Vec256<T> b) { + const Full256<T> d; + const RebindToFloat<decltype(d)> df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec256<float>{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle1230(const Vec256<T> a, const Vec256<T> b) { + const Full256<T> d; + const RebindToFloat<decltype(d)> df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec256<float>{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Shuffle3012(const Vec256<T> a, const Vec256<T> b) { + const Full256<T> d; + const RebindToFloat<decltype(d)> df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec256<float>{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec256<uint32_t> Shuffle1032(const Vec256<uint32_t> v) { + return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256<int32_t> Shuffle1032(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256<float> Shuffle1032(const Vec256<float> v) { + // Shorter encoding than _mm256_permute_ps. + return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec256<uint64_t> Shuffle01(const Vec256<uint64_t> v) { + return Vec256<uint64_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256<int64_t> Shuffle01(const Vec256<int64_t> v) { + return Vec256<int64_t>{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256<double> Shuffle01(const Vec256<double> v) { + // Shorter encoding than _mm256_permute_pd. + return Vec256<double>{_mm256_shuffle_pd(v.raw, v.raw, 5)}; +} + +// Rotate right 32 bits +HWY_API Vec256<uint32_t> Shuffle0321(const Vec256<uint32_t> v) { + return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256<int32_t> Shuffle0321(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256<float> Shuffle0321(const Vec256<float> v) { + return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec256<uint32_t> Shuffle2103(const Vec256<uint32_t> v) { + return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256<int32_t> Shuffle2103(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256<float> Shuffle2103(const Vec256<float> v) { + return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec256<uint32_t> Shuffle0123(const Vec256<uint32_t> v) { + return Vec256<uint32_t>{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256<int32_t> Shuffle0123(const Vec256<int32_t> v) { + return Vec256<int32_t>{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256<float> Shuffle0123(const Vec256<float> v) { + return Vec256<float>{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template <typename T> +struct Indices256 { + __m256i raw; +}; + +// Native 8x32 instruction: indices remain unchanged +template <typename T, typename TI, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Indices256<T> IndicesFromVec(Full256<T> /* tag */, Vec256<TI> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256<TI> di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast<TI>(32 / sizeof(T)))))); +#endif + return Indices256<T>{vec.raw}; +} + +// 64-bit lanes: convert indices to 8x32 unless AVX3 is available +template <typename T, typename TI, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Indices256<T> IndicesFromVec(Full256<T> d, Vec256<TI> idx64) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + const Rebind<TI, decltype(d)> di; + (void)di; // potentially unused +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && + AllTrue(di, Lt(idx64, Set(di, static_cast<TI>(32 / sizeof(T)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Indices256<T>{idx64.raw}; +#else + const Repartition<float, decltype(d)> df; // 32-bit! + // Replicate 64-bit index into upper 32 bits + const Vec256<TI> dup = + BitCast(di, Vec256<float>{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); + // For each idx64 i, idx32 are 2*i and 2*i+1. + const Vec256<TI> idx32 = dup + dup + Set(di, TI(1) << 32); + return Indices256<T>{idx32.raw}; +#endif +} + +template <typename T, typename TI> +HWY_API Indices256<T> SetTableIndices(const Full256<T> d, const TI* idx) { + const Rebind<TI, decltype(d)> di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) { + return Vec256<T>{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<T>{_mm256_permutexvar_epi64(idx.raw, v.raw)}; +#else + return Vec256<T>{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +#endif +} + +HWY_API Vec256<float> TableLookupLanes(const Vec256<float> v, + const Indices256<float> idx) { + return Vec256<float>{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; +} + +HWY_API Vec256<double> TableLookupLanes(const Vec256<double> v, + const Indices256<double> idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<double>{_mm256_permutexvar_pd(idx.raw, v.raw)}; +#else + const Full256<double> df; + const Full256<uint64_t> du; + return BitCast(df, Vec256<uint64_t>{_mm256_permutevar8x32_epi32( + BitCast(du, v).raw, idx.raw)}); +#endif +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T> +HWY_API Vec256<T> SwapAdjacentBlocks(Vec256<T> v) { + return Vec256<T>{_mm256_permute2x128_si256(v.raw, v.raw, 0x01)}; +} + +HWY_API Vec256<float> SwapAdjacentBlocks(Vec256<float> v) { + return Vec256<float>{_mm256_permute2f128_ps(v.raw, v.raw, 0x01)}; +} + +HWY_API Vec256<double> SwapAdjacentBlocks(Vec256<double> v) { + return Vec256<double>{_mm256_permute2f128_pd(v.raw, v.raw, 0x01)}; +} + +// ------------------------------ Reverse (RotateRight) + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Reverse(Full256<T> d, const Vec256<T> v) { + alignas(32) constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse(Full256<T> d, const Vec256<T> v) { + alignas(32) constexpr int64_t kReverse[4] = {3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> Reverse(Full256<T> d, const Vec256<T> v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned<decltype(d)> di; + alignas(32) constexpr int16_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec256<int16_t> idx = Load(di, kReverse); + return BitCast(d, Vec256<int16_t>{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide<RebindToUnsigned<decltype(d)>> du32; + const Vec256<uint32_t> rev32 = Reverse(du32, BitCast(du32, v)); + return BitCast(d, RotateRight<16>(rev32)); +#endif +} + +// ------------------------------ Reverse2 + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> Reverse2(Full256<T> d, const Vec256<T> v) { + const Full256<uint32_t> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Reverse2(Full256<T> /* tag */, const Vec256<T> v) { + return Shuffle2301(v); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse2(Full256<T> /* tag */, const Vec256<T> v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 (SwapAdjacentBlocks) + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> Reverse4(Full256<T> d, const Vec256<T> v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned<decltype(d)> di; + alignas(32) constexpr int16_t kReverse4[16] = {3, 2, 1, 0, 7, 6, 5, 4, + 11, 10, 9, 8, 15, 14, 13, 12}; + const Vec256<int16_t> idx = Load(di, kReverse4); + return BitCast(d, Vec256<int16_t>{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide<decltype(d)> dw; + return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Reverse4(Full256<T> /* tag */, const Vec256<T> v) { + return Shuffle0123(v); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse4(Full256<T> /* tag */, const Vec256<T> v) { + // Could also use _mm256_permute4x64_epi64. + return SwapAdjacentBlocks(Shuffle01(v)); +} + +// ------------------------------ Reverse8 + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> Reverse8(Full256<T> d, const Vec256<T> v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned<decltype(d)> di; + alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec256<int16_t> idx = Load(di, kReverse8); + return BitCast(d, Vec256<int16_t>{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide<decltype(d)> dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Reverse8(Full256<T> d, const Vec256<T> v) { + return Reverse(d, v); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Reverse8(Full256<T> /* tag */, const Vec256<T> /* v */) { + HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec256<uint8_t> InterleaveLower(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> InterleaveLower(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> InterleaveLower(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> InterleaveLower(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + return Vec256<uint64_t>{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256<int8_t> InterleaveLower(const Vec256<int8_t> a, + const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> InterleaveLower(const Vec256<int16_t> a, + const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> InterleaveLower(const Vec256<int32_t> a, + const Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> InterleaveLower(const Vec256<int64_t> a, + const Vec256<int64_t> b) { + return Vec256<int64_t>{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256<float> InterleaveLower(const Vec256<float> a, + const Vec256<float> b) { + return Vec256<float>{_mm256_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> InterleaveLower(const Vec256<double> a, + const Vec256<double> b) { + return Vec256<double>{_mm256_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec256<uint8_t> InterleaveUpper(const Vec256<uint8_t> a, + const Vec256<uint8_t> b) { + return Vec256<uint8_t>{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<uint16_t> InterleaveUpper(const Vec256<uint16_t> a, + const Vec256<uint16_t> b) { + return Vec256<uint16_t>{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<uint32_t> InterleaveUpper(const Vec256<uint32_t> a, + const Vec256<uint32_t> b) { + return Vec256<uint32_t>{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> InterleaveUpper(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + return Vec256<uint64_t>{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256<int8_t> InterleaveUpper(const Vec256<int8_t> a, + const Vec256<int8_t> b) { + return Vec256<int8_t>{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec256<int16_t> InterleaveUpper(const Vec256<int16_t> a, + const Vec256<int16_t> b) { + return Vec256<int16_t>{_mm256_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec256<int32_t> InterleaveUpper(const Vec256<int32_t> a, + const Vec256<int32_t> b) { + return Vec256<int32_t>{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> InterleaveUpper(const Vec256<int64_t> a, + const Vec256<int64_t> b) { + return Vec256<int64_t>{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256<float> InterleaveUpper(const Vec256<float> a, + const Vec256<float> b) { + return Vec256<float>{_mm256_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec256<double> InterleaveUpper(const Vec256<double> a, + const Vec256<double> b) { + return Vec256<double>{_mm256_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template <typename T, class V = Vec256<T>> +HWY_API V InterleaveUpper(Full256<T> /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <typename T, typename TW = MakeWide<T>> +HWY_API Vec256<TW> ZipLower(Vec256<T> a, Vec256<T> b) { + return BitCast(Full256<TW>(), InterleaveLower(a, b)); +} +template <typename T, typename TW = MakeWide<T>> +HWY_API Vec256<TW> ZipLower(Full256<TW> dw, Vec256<T> a, Vec256<T> b) { + return BitCast(dw, InterleaveLower(a, b)); +} + +template <typename T, typename TW = MakeWide<T>> +HWY_API Vec256<TW> ZipUpper(Full256<TW> dw, Vec256<T> a, Vec256<T> b) { + return BitCast(dw, InterleaveUpper(Full256<T>(), a, b)); +} + +// ------------------------------ Blocks (LowerHalf, ZeroExtendVector) + +// _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. +// _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no +// extra cost) for LowerLower and UpperLower. + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template <typename T> +HWY_API Vec256<T> ConcatLowerLower(Full256<T> d, const Vec256<T> hi, + const Vec256<T> lo) { + const Half<decltype(d)> d2; + return Vec256<T>{_mm256_inserti128_si256(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +HWY_API Vec256<float> ConcatLowerLower(Full256<float> d, const Vec256<float> hi, + const Vec256<float> lo) { + const Half<decltype(d)> d2; + return Vec256<float>{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +HWY_API Vec256<double> ConcatLowerLower(Full256<double> d, + const Vec256<double> hi, + const Vec256<double> lo) { + const Half<decltype(d)> d2; + return Vec256<double>{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template <typename T> +HWY_API Vec256<T> ConcatLowerUpper(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + return Vec256<T>{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; +} +HWY_API Vec256<float> ConcatLowerUpper(Full256<float> /* tag */, + const Vec256<float> hi, + const Vec256<float> lo) { + return Vec256<float>{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; +} +HWY_API Vec256<double> ConcatLowerUpper(Full256<double> /* tag */, + const Vec256<double> hi, + const Vec256<double> lo) { + return Vec256<double>{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template <typename T> +HWY_API Vec256<T> ConcatUpperLower(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + return Vec256<T>{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; +} +HWY_API Vec256<float> ConcatUpperLower(Full256<float> /* tag */, + const Vec256<float> hi, + const Vec256<float> lo) { + return Vec256<float>{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; +} +HWY_API Vec256<double> ConcatUpperLower(Full256<double> /* tag */, + const Vec256<double> hi, + const Vec256<double> lo) { + return Vec256<double>{_mm256_blend_pd(hi.raw, lo.raw, 3)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template <typename T> +HWY_API Vec256<T> ConcatUpperUpper(Full256<T> /* tag */, const Vec256<T> hi, + const Vec256<T> lo) { + return Vec256<T>{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x31)}; +} +HWY_API Vec256<float> ConcatUpperUpper(Full256<float> /* tag */, + const Vec256<float> hi, + const Vec256<float> lo) { + return Vec256<float>{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; +} +HWY_API Vec256<double> ConcatUpperUpper(Full256<double> /* tag */, + const Vec256<double> hi, + const Vec256<double> lo) { + return Vec256<double>{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; +} + +// ------------------------------ ConcatOdd + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec256<T> ConcatOdd(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(32) constexpr uint8_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast(d, Vec256<uint16_t>{_mm256_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide<decltype(du)> dw; + // Unsigned 8-bit shift so we can pack. + const Vec256<uint16_t> uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec256<uint16_t> uL = ShiftRight<8>(BitCast(dw, lo)); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return Vec256<T>{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> ConcatOdd(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint16_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast(d, Vec256<uint16_t>{_mm256_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide<decltype(du)> dw; + // Unsigned 16-bit shift so we can pack. + const Vec256<uint32_t> uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec256<uint32_t> uL = ShiftRight<16>(BitCast(dw, lo)); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return Vec256<T>{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> ConcatOdd(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast(d, Vec256<uint32_t>{_mm256_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat<decltype(d)> df; + const Vec256<float> v3131{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return Vec256<T>{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +HWY_API Vec256<float> ConcatOdd(Full256<float> d, Vec256<float> hi, + Vec256<float> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return Vec256<float>{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + const Vec256<float> v3131{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast(d, Vec256<uint32_t>{_mm256_permute4x64_epi64( + BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> ConcatOdd(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return BitCast(d, Vec256<uint64_t>{_mm256_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat<decltype(d)> df; + const Vec256<double> v31{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; + return Vec256<T>{ + _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +HWY_API Vec256<double> ConcatOdd(Full256<double> d, Vec256<double> hi, + Vec256<double> lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return Vec256<double>{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + (void)d; + const Vec256<double> v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; + return Vec256<double>{ + _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ ConcatEven + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec256<T> ConcatEven(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast(d, Vec256<uint32_t>{_mm256_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide<decltype(du)> dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec256<uint16_t> mask = Set(dw, 0x00FF); + const Vec256<uint16_t> uH = And(BitCast(dw, hi), mask); + const Vec256<uint16_t> uL = And(BitCast(dw, lo), mask); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return Vec256<T>{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec256<T> ConcatEven(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint16_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast(d, Vec256<uint32_t>{_mm256_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide<decltype(du)> dw; + // Isolate lower 16 bits per u32 so we can pack. + const Vec256<uint32_t> mask = Set(dw, 0x0000FFFF); + const Vec256<uint32_t> uH = And(BitCast(dw, hi), mask); + const Vec256<uint32_t> uL = And(BitCast(dw, lo), mask); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return Vec256<T>{_mm256_permute4x64_epi64(u16, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> ConcatEven(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast(d, Vec256<uint32_t>{_mm256_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat<decltype(d)> df; + const Vec256<float> v2020{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return Vec256<T>{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +HWY_API Vec256<float> ConcatEven(Full256<float> d, Vec256<float> hi, + Vec256<float> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return Vec256<float>{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + const Vec256<float> v2020{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast(d, Vec256<uint32_t>{_mm256_permute4x64_epi64( + BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> ConcatEven(Full256<T> d, Vec256<T> hi, Vec256<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return BitCast(d, Vec256<uint64_t>{_mm256_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +#else + const RebindToFloat<decltype(d)> df; + const Vec256<double> v20{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; + return Vec256<T>{ + _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +HWY_API Vec256<double> ConcatEven(Full256<double> d, Vec256<double> hi, + Vec256<double> lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return Vec256<double>{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +#else + (void)d; + const Vec256<double> v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; + return Vec256<double>{ + _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ DupEven (InterleaveLower) + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> DupEven(Vec256<T> v) { + return Vec256<T>{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +HWY_API Vec256<float> DupEven(Vec256<float> v) { + return Vec256<float>{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> DupEven(const Vec256<T> v) { + return InterleaveLower(Full256<T>(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> DupOdd(Vec256<T> v) { + return Vec256<T>{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +HWY_API Vec256<float> DupOdd(Vec256<float> v) { + return Vec256<float>{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> DupOdd(const Vec256<T> v) { + return InterleaveUpper(Full256<T>(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template <typename T> +HWY_INLINE Vec256<T> OddEven(hwy::SizeTag<1> /* tag */, const Vec256<T> a, + const Vec256<T> b) { + const Full256<T> d; + const Full256<uint8_t> d8; + alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); +} +template <typename T> +HWY_INLINE Vec256<T> OddEven(hwy::SizeTag<2> /* tag */, const Vec256<T> a, + const Vec256<T> b) { + return Vec256<T>{_mm256_blend_epi16(a.raw, b.raw, 0x55)}; +} +template <typename T> +HWY_INLINE Vec256<T> OddEven(hwy::SizeTag<4> /* tag */, const Vec256<T> a, + const Vec256<T> b) { + return Vec256<T>{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; +} +template <typename T> +HWY_INLINE Vec256<T> OddEven(hwy::SizeTag<8> /* tag */, const Vec256<T> a, + const Vec256<T> b) { + return Vec256<T>{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> OddEven(const Vec256<T> a, const Vec256<T> b) { + return detail::OddEven(hwy::SizeTag<sizeof(T)>(), a, b); +} +HWY_API Vec256<float> OddEven(const Vec256<float> a, const Vec256<float> b) { + return Vec256<float>{_mm256_blend_ps(a.raw, b.raw, 0x55)}; +} + +HWY_API Vec256<double> OddEven(const Vec256<double> a, const Vec256<double> b) { + return Vec256<double>{_mm256_blend_pd(a.raw, b.raw, 5)}; +} + +// ------------------------------ OddEvenBlocks + +template <typename T> +Vec256<T> OddEvenBlocks(Vec256<T> odd, Vec256<T> even) { + return Vec256<T>{_mm256_blend_epi32(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256<float> OddEvenBlocks(Vec256<float> odd, Vec256<float> even) { + return Vec256<float>{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256<double> OddEvenBlocks(Vec256<double> odd, Vec256<double> even) { + return Vec256<double>{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; +} + +// ------------------------------ ReverseBlocks (ConcatLowerUpper) + +template <typename T> +HWY_API Vec256<T> ReverseBlocks(Full256<T> d, Vec256<T> v) { + return ConcatLowerUpper(d, v, v); +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template <typename T, typename TI> +HWY_API Vec256<TI> TableLookupBytes(const Vec256<T> bytes, + const Vec256<TI> from) { + return Vec256<TI>{_mm256_shuffle_epi8(bytes.raw, from.raw)}; +} + +// Partial index vector +template <typename T, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytes(const Vec256<T> bytes, + const Vec128<TI, NI> from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256<TI>(), Vec128<TI>{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128<TI, NI>{LowerHalf(Full128<TI>(), tbl_full).raw}; +} + +// Partial table vector +template <typename T, size_t N, typename TI> +HWY_API Vec256<TI> TableLookupBytes(const Vec128<T, N> bytes, + const Vec256<TI> from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256<T>(), Vec128<T>{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by x86_128. + +// ------------------------------ Shl (Mul, ZipLower) + +namespace detail { + +#if HWY_TARGET > HWY_AVX3 && !HWY_IDE // AVX2 or older + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template <typename T> +HWY_INLINE Vec256<MakeUnsigned<T>> Pow2(const Vec256<T> v) { + static_assert(sizeof(T) == 2, "Only for 16-bit"); + const Full256<T> d; + const RepartitionToWide<decltype(d)> dw; + const Rebind<float, decltype(dw)> df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // Do not use ConvertTo because it checks for overflow, which is redundant + // because we only care about v in [0, 16). + const Vec256<int32_t> bits0{_mm256_cvttps_epi32(BitCast(df, f0).raw)}; + const Vec256<int32_t> bits1{_mm256_cvttps_epi32(BitCast(df, f1).raw)}; + return Vec256<MakeUnsigned<T>>{_mm256_packus_epi32(bits0.raw, bits1.raw)}; +} + +#endif // HWY_TARGET > HWY_AVX3 + +HWY_INLINE Vec256<uint16_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint16_t> v, + Vec256<uint16_t> bits) { +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + return Vec256<uint16_t>{_mm256_sllv_epi16(v.raw, bits.raw)}; +#else + return v * Pow2(bits); +#endif +} + +HWY_INLINE Vec256<uint32_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint32_t> v, + Vec256<uint32_t> bits) { + return Vec256<uint32_t>{_mm256_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256<uint64_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint64_t> v, + Vec256<uint64_t> bits) { + return Vec256<uint64_t>{_mm256_sllv_epi64(v.raw, bits.raw)}; +} + +template <typename T> +HWY_INLINE Vec256<T> Shl(hwy::SignedTag /*tag*/, Vec256<T> v, Vec256<T> bits) { + // Signed left shifts are the same as unsigned. + const Full256<T> di; + const Full256<MakeUnsigned<T>> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template <typename T> +HWY_API Vec256<T> operator<<(Vec256<T> v, Vec256<T> bits) { + return detail::Shl(hwy::TypeTag<T>(), v, bits); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +HWY_API Vec256<uint16_t> operator>>(Vec256<uint16_t> v, Vec256<uint16_t> bits) { +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + return Vec256<uint16_t>{_mm256_srlv_epi16(v.raw, bits.raw)}; +#else + Full256<uint16_t> d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + auto out = MulHigh(v, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), v, out); +#endif +} + +HWY_API Vec256<uint32_t> operator>>(Vec256<uint32_t> v, Vec256<uint32_t> bits) { + return Vec256<uint32_t>{_mm256_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256<uint64_t> operator>>(Vec256<uint64_t> v, Vec256<uint64_t> bits) { + return Vec256<uint64_t>{_mm256_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec256<int16_t> operator>>(Vec256<int16_t> v, Vec256<int16_t> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int16_t>{_mm256_srav_epi16(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256<int16_t>(), v, bits); +#endif +} + +HWY_API Vec256<int32_t> operator>>(Vec256<int32_t> v, Vec256<int32_t> bits) { + return Vec256<int32_t>{_mm256_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256<int64_t> operator>>(Vec256<int64_t> v, Vec256<int64_t> bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<int64_t>{_mm256_srav_epi64(v.raw, bits.raw)}; +#else + return detail::SignedShr(Full256<int64_t>(), v, bits); +#endif +} + +HWY_INLINE Vec256<uint64_t> MulEven(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + const Full256<uint64_t> du64; + const RepartitionToNarrow<decltype(du64)> du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveLower(mulL, mulH); +} + +HWY_INLINE Vec256<uint64_t> MulOdd(const Vec256<uint64_t> a, + const Vec256<uint64_t> b) { + const Full256<uint64_t> du64; + const RepartitionToNarrow<decltype(du64)> du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need bits [95:64] (= upper half of input) + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Same as above, but we're using the odd results (upper 64 bits per block). + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveUpper(du64, mulL, mulH); +} + +// ------------------------------ ReorderWidenMulAccumulate +HWY_API Vec256<int32_t> ReorderWidenMulAccumulate(Full256<int32_t> /*d32*/, + Vec256<int16_t> a, + Vec256<int16_t> b, + const Vec256<int32_t> sum0, + Vec256<int32_t>& /*sum1*/) { + return sum0 + Vec256<int32_t>{_mm256_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +HWY_API Vec256<int32_t> RearrangeToOddPlusEven(const Vec256<int32_t> sum0, + Vec256<int32_t> /*sum1*/) { + return sum0; // invariant already holds +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +HWY_API Vec256<double> PromoteTo(Full256<double> /* tag */, + const Vec128<float, 4> v) { + return Vec256<double>{_mm256_cvtps_pd(v.raw)}; +} + +HWY_API Vec256<double> PromoteTo(Full256<double> /* tag */, + const Vec128<int32_t, 4> v) { + return Vec256<double>{_mm256_cvtepi32_pd(v.raw)}; +} + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec256<uint16_t> PromoteTo(Full256<uint16_t> /* tag */, + Vec128<uint8_t> v) { + return Vec256<uint16_t>{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256<uint32_t> PromoteTo(Full256<uint32_t> /* tag */, + Vec128<uint8_t, 8> v) { + return Vec256<uint32_t>{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256<int16_t> PromoteTo(Full256<int16_t> /* tag */, + Vec128<uint8_t> v) { + return Vec256<int16_t>{_mm256_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec256<int32_t> PromoteTo(Full256<int32_t> /* tag */, + Vec128<uint8_t, 8> v) { + return Vec256<int32_t>{_mm256_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec256<uint32_t> PromoteTo(Full256<uint32_t> /* tag */, + Vec128<uint16_t> v) { + return Vec256<uint32_t>{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256<int32_t> PromoteTo(Full256<int32_t> /* tag */, + Vec128<uint16_t> v) { + return Vec256<int32_t>{_mm256_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec256<uint64_t> PromoteTo(Full256<uint64_t> /* tag */, + Vec128<uint32_t> v) { + return Vec256<uint64_t>{_mm256_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec256<int16_t> PromoteTo(Full256<int16_t> /* tag */, + Vec128<int8_t> v) { + return Vec256<int16_t>{_mm256_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec256<int32_t> PromoteTo(Full256<int32_t> /* tag */, + Vec128<int8_t, 8> v) { + return Vec256<int32_t>{_mm256_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec256<int32_t> PromoteTo(Full256<int32_t> /* tag */, + Vec128<int16_t> v) { + return Vec256<int32_t>{_mm256_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec256<int64_t> PromoteTo(Full256<int64_t> /* tag */, + Vec128<int32_t> v) { + return Vec256<int64_t>{_mm256_cvtepi32_epi64(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec128<uint16_t> DemoteTo(Full128<uint16_t> /* tag */, + const Vec256<int32_t> v) { + const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); + // Concatenating lower halves of both 128-bit blocks afterward is more + // efficient than an extra input with low block = high block of v. + return Vec128<uint16_t>{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; +} + +HWY_API Vec128<int16_t> DemoteTo(Full128<int16_t> /* tag */, + const Vec256<int32_t> v) { + const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); + return Vec128<int16_t>{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; +} + +HWY_API Vec128<uint8_t, 8> DemoteTo(Full64<uint8_t> /* tag */, + const Vec256<int32_t> v) { + const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i u16_concat = _mm256_permute4x64_epi64(u16_blocks, 0x88); + const __m128i u16 = _mm256_castsi256_si128(u16_concat); + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); + return Vec128<uint8_t, 8>{_mm_packus_epi16(i16, i16)}; +} + +HWY_API Vec128<uint8_t> DemoteTo(Full128<uint8_t> /* tag */, + const Vec256<int16_t> v) { + const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); + return Vec128<uint8_t>{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; +} + +HWY_API Vec128<int8_t, 8> DemoteTo(Full64<int8_t> /* tag */, + const Vec256<int32_t> v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return Vec128<int8_t, 8>{_mm_packs_epi16(i16, i16)}; +} + +HWY_API Vec128<int8_t> DemoteTo(Full128<int8_t> /* tag */, + const Vec256<int16_t> v) { + const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); + return Vec128<int8_t>{ + _mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; +} + + // Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". + // 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") + +HWY_API Vec128<float16_t> DemoteTo(Full128<float16_t> df16, + const Vec256<float> v) { +#ifdef HWY_DISABLE_F16C + const RebindToUnsigned<decltype(df16)> du16; + const Rebind<uint32_t, decltype(df16)> du; + const RebindToSigned<decltype(du)> di; + const auto bits32 = BitCast(du, v); + const auto sign = ShiftRight<31>(bits32); + const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); + const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); + + const auto k15 = Set(di, 15); + const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); + const auto is_tiny = exp < Set(di, -24); + + const auto is_subnormal = exp < Set(di, -14); + const auto biased_exp16 = + BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); + const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) + const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + + (mantissa32 >> (Set(du, 13) + sub_exp)); + const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, + ShiftRight<13>(mantissa32)); // <1024 + + const auto sign16 = ShiftLeft<15>(sign); + const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; + const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); + return BitCast(df16, DemoteTo(du16, bits16)); +#else + (void)df16; + return Vec128<float16_t>{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; +#endif +} + +HWY_DIAGNOSTICS(pop) + +HWY_API Vec128<bfloat16_t> DemoteTo(Full128<bfloat16_t> dbf16, + const Vec256<float> v) { + // TODO(janwas): _mm256_cvtneps_pbh once we have avx512bf16. + const Rebind<int32_t, decltype(dbf16)> di32; + const Rebind<uint32_t, decltype(dbf16)> du32; // for logical shift right + const Rebind<uint16_t, decltype(dbf16)> du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec256<bfloat16_t> ReorderDemote2To(Full256<bfloat16_t> dbf16, + Vec256<float> a, Vec256<float> b) { + // TODO(janwas): _mm256_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned<decltype(dbf16)> du16; + const Repartition<uint32_t, decltype(dbf16)> du32; + const Vec256<uint32_t> b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec256<int16_t> ReorderDemote2To(Full256<int16_t> /*d16*/, + Vec256<int32_t> a, Vec256<int32_t> b) { + return Vec256<int16_t>{_mm256_packs_epi32(a.raw, b.raw)}; +} + +HWY_API Vec128<float> DemoteTo(Full128<float> /* tag */, + const Vec256<double> v) { + return Vec128<float>{_mm256_cvtpd_ps(v.raw)}; +} + +HWY_API Vec128<int32_t> DemoteTo(Full128<int32_t> /* tag */, + const Vec256<double> v) { + const auto clamped = detail::ClampF64ToI32Max(Full256<double>(), v); + return Vec128<int32_t>{_mm256_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128<uint8_t, 8> U8FromU32(const Vec256<uint32_t> v) { + const Full256<uint32_t> d32; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Full128<uint32_t>(), quad); + const auto pair = LowerHalf(lo | hi); + return BitCast(Full64<uint8_t>(), pair); +} + +// ------------------------------ Truncations + +namespace detail { + +// LO and HI each hold four indices of bytes within a 128-bit block. +template <uint32_t LO, uint32_t HI, typename T> +HWY_INLINE Vec128<uint32_t> LookupAndConcatHalves(Vec256<T> v) { + const Full256<uint32_t> d32; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint32_t kMap[8] = { + LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); +#else + alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, + ~0u, ~0u, LO, HI}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); + // Possible alternative: + // const auto lo = LowerHalf(quad); + // const auto hi = UpperHalf(Full128<uint32_t>(), quad); + // const auto result = lo | hi; +#endif + + return Vec128<uint32_t>{_mm256_castsi256_si128(result)}; +} + +// LO and HI each hold two indices of bytes within a 128-bit block. +template <uint16_t LO, uint16_t HI, typename T> +HWY_INLINE Vec128<uint32_t, 2> LookupAndConcatQuarters(Vec256<T> v) { + const Full256<uint16_t> d16; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint16_t kMap[16] = { + LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d16, kMap).raw); + return LowerHalf(Vec128<uint32_t>{_mm256_castsi256_si128(result)}); +#else + constexpr uint16_t ff = static_cast<uint16_t>(~0u); + alignas(32) static constexpr uint16_t kMap[16] = { + LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; + const auto quad = TableLookupBytes(v, Load(d16, kMap)); + const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); + const auto half = _mm256_castsi256_si128(mixed); + return LowerHalf(Vec128<uint32_t>{_mm_packus_epi32(half, half)}); +#endif +} + +} // namespace detail + +HWY_API Vec128<uint8_t, 4> TruncateTo(Simd<uint8_t, 4, 0> /* tag */, + const Vec256<uint64_t> v) { + const Full256<uint32_t> d32; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(v.raw, Load(d32, kMap).raw); + return LowerHalf(LowerHalf(LowerHalf(Vec256<uint8_t>{result}))); +#else + alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, + 0x0800FFFFu, ~0u, ~0u, ~0u}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Full128<uint32_t>(), quad); + const auto result = lo | hi; + return LowerHalf(LowerHalf(Vec128<uint8_t>{result.raw})); +#endif +} + +HWY_API Vec128<uint16_t, 4> TruncateTo(Simd<uint16_t, 4, 0> /* tag */, + const Vec256<uint64_t> v) { + const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); + return Vec128<uint16_t, 4>{result.raw}; +} + +HWY_API Vec128<uint32_t> TruncateTo(Simd<uint32_t, 4, 0> /* tag */, + const Vec256<uint64_t> v) { + const Full256<uint32_t> d32; + alignas(32) constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto v32 = + TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); + return LowerHalf(Vec256<uint32_t>{v32.raw}); +} + +HWY_API Vec128<uint8_t, 8> TruncateTo(Simd<uint8_t, 8, 0> /* tag */, + const Vec256<uint32_t> v) { + const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); + return Vec128<uint8_t, 8>{full.raw}; +} + +HWY_API Vec128<uint16_t> TruncateTo(Simd<uint16_t, 8, 0> /* tag */, + const Vec256<uint32_t> v) { + const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); + return Vec128<uint16_t>{full.raw}; +} + +HWY_API Vec128<uint8_t> TruncateTo(Simd<uint8_t, 16, 0> /* tag */, + const Vec256<uint16_t> v) { + const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); + return Vec128<uint8_t>{full.raw}; +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +HWY_API Vec256<float> ConvertTo(Full256<float> /* tag */, + const Vec256<int32_t> v) { + return Vec256<float>{_mm256_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec256<double> ConvertTo(Full256<double> dd, const Vec256<int64_t> v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dd; + return Vec256<double>{_mm256_cvtepi64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition<uint32_t, decltype(dd)> d32; + const Repartition<uint64_t, decltype(dd)> d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +#endif +} + +HWY_API Vec256<float> ConvertTo(HWY_MAYBE_UNUSED Full256<float> df, + const Vec256<uint32_t> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<float>{_mm256_cvtepu32_ps(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned<decltype(df)> du32; + const RebindToSigned<decltype(df)> d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +#endif +} + +HWY_API Vec256<double> ConvertTo(HWY_MAYBE_UNUSED Full256<double> dd, + const Vec256<uint64_t> v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256<double>{_mm256_cvtepu64_pd(v.raw)}; +#else + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned<decltype(dd)> d64; + using VU = VFromD<decltype(d64)>; + + const VU msk_lo = Set(d64, 0xFFFFFFFFULL); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + auto uint64_to_double256_fast = [&dd](Vec256<uint64_t> w) HWY_ATTR { + w = Or(w, Vec256<uint64_t>{ + detail::BitCastToInteger(Set(dd, 0x0010000000000000).raw)}); + return BitCast(dd, w) - Set(dd, 0x0010000000000000); + }; + + const auto v_lo_dbl = uint64_to_double256_fast(v_lo); + return MulAdd(cnst2_32_dbl, uint64_to_double256_fast(v_hi), v_lo_dbl); +#endif +} + +// Truncates (rounds toward zero). +HWY_API Vec256<int32_t> ConvertTo(Full256<int32_t> d, const Vec256<float> v) { + return detail::FixConversionOverflow(d, v, _mm256_cvttps_epi32(v.raw)); +} + +HWY_API Vec256<int64_t> ConvertTo(Full256<int64_t> di, const Vec256<double> v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::FixConversionOverflow(di, v, _mm256_cvttpd_epi64(v.raw)); +#else + using VI = decltype(Zero(di)); + const VI k0 = Zero(di); + const VI k1 = Set(di, 1); + const VI k51 = Set(di, 51); + + // Exponent indicates whether the number can be represented as int64_t. + const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); + const VI exp = biased_exp - Set(di, 0x3FF); + const auto in_range = exp < Set(di, 63); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + const VI shift_mnt = Max(k51 - exp, k0); + const VI shift_int = Max(exp - k51, k0); + const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); + // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. + const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); + // For inputs larger than 2^52, insert zeros at the bottom. + const VI shifted = int52 << shift_int; + // Restore the one bit lost when shifting in the implicit 1-bit. + const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax<int64_t>()) - sign_mask; + const VI magnitude = IfThenElse(in_range, restored, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +#endif +} + +HWY_API Vec256<int32_t> NearestInt(const Vec256<float> v) { + const Full256<int32_t> di; + return detail::FixConversionOverflow(di, v, _mm256_cvtps_epi32(v.raw)); +} + + +HWY_API Vec256<float> PromoteTo(Full256<float> df32, + const Vec128<float16_t> v) { +#ifdef HWY_DISABLE_F16C + const RebindToSigned<decltype(df32)> di32; + const RebindToUnsigned<decltype(df32)> du32; + // Expand to u32 so we can shift. + const auto bits16 = PromoteTo(du32, Vec128<uint16_t>{v.raw}); + const auto sign = ShiftRight<15>(bits16); + const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); + const auto mantissa = bits16 & Set(du32, 0x3FF); + const auto subnormal = + BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * + Set(df32, 1.0f / 16384 / 1024)); + + const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); + const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); + const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; + const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); + return BitCast(df32, ShiftLeft<31>(sign) | bits32); +#else + (void)df32; + return Vec256<float>{_mm256_cvtph_ps(v.raw)}; +#endif +} + +HWY_API Vec256<float> PromoteTo(Full256<float> df32, + const Vec128<bfloat16_t> v) { + const Rebind<uint16_t, decltype(df32)> du16; + const RebindToSigned<decltype(df32)> di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec256<uint8_t> AESRound(Vec256<uint8_t> state, + Vec256<uint8_t> round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256<uint8_t>{_mm256_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full256<uint8_t> d; + const Half<decltype(d)> d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256<uint8_t> AESLastRound(Vec256<uint8_t> state, + Vec256<uint8_t> round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256<uint8_t>{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full256<uint8_t> d; + const Half<decltype(d)> d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256<uint64_t> CLMulLower(Vec256<uint64_t> a, Vec256<uint64_t> b) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256<uint64_t>{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; +#else + const Full256<uint64_t> d; + const Half<decltype(d)> d2; + return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulLower(LowerHalf(a), LowerHalf(b))); +#endif +} + +HWY_API Vec256<uint64_t> CLMulUpper(Vec256<uint64_t> a, Vec256<uint64_t> b) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec256<uint64_t>{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; +#else + const Full256<uint64_t> d; + const Half<decltype(d)> d2; + return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulUpper(LowerHalf(a), LowerHalf(b))); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template <typename T, typename T2> +HWY_API Vec256<T> Iota(const Full256<T> d, const T2 first) { + HWY_ALIGN T lanes[32 / sizeof(T)]; + for (size_t i = 0; i < 32 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag<T>(), static_cast<T>(first), i); + } + return Load(d, lanes); +} + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadMaskBits + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T> +HWY_API Mask256<T> LoadMaskBits(const Full256<T> /* tag */, + const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes<kNumBytes>(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return Mask256<T>::FromBits(mask_bits); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template <typename T> +HWY_API size_t StoreMaskBits(const Full256<T> /* tag */, const Mask256<T> mask, + uint8_t* bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + CopyBytes<kNumBytes>(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (N < 8) { + const int mask_bits = static_cast<int>((1ull << N) - 1); + bits[0] = static_cast<uint8_t>(bits[0] & mask_bits); + } + return kNumBytes; +} + +// ------------------------------ Mask testing + +template <typename T> +HWY_API size_t CountTrue(const Full256<T> /* tag */, const Mask256<T> mask) { + return PopCount(static_cast<uint64_t>(mask.raw)); +} + +template <typename T> +HWY_API size_t FindKnownFirstTrue(const Full256<T> /* tag */, + const Mask256<T> mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template <typename T> +HWY_API intptr_t FindFirstTrue(const Full256<T> d, const Mask256<T> mask) { + return mask.raw ? static_cast<intptr_t>(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256<T> mask) { + return (uint64_t{mask.raw} & 0xF) == 0; +} + +} // namespace detail + +template <typename T> +HWY_API bool AllFalse(const Full256<T> /* tag */, const Mask256<T> mask) { + return detail::AllFalse(hwy::SizeTag<sizeof(T)>(), mask); +} + +namespace detail { + +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFu; +#endif +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFu; +#endif +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFu; +#endif +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256<T> mask) { + // Cannot use _kortestc because we have less than 8 mask bits. + return mask.raw == 0xFu; +} + +} // namespace detail + +template <typename T> +HWY_API bool AllTrue(const Full256<T> /* tag */, const Mask256<T> mask) { + return detail::AllTrue(hwy::SizeTag<sizeof(T)>(), mask); +} + +// ------------------------------ Compress + +// 16-bit is defined in x86_512 so we can use 512-bit vectors. + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> mask) { + return Vec256<T>{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec256<float> Compress(Vec256<float> v, Mask256<float> mask) { + return Vec256<float>{_mm256_maskz_compress_ps(mask.raw, v.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> mask) { + // See CompressIsPartition. + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintCompress64x4NibbleTables + 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, + 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, + 0x00001032, 0x00001320, 0x00000321, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256<T>{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot (Compress) + +// Implemented in x86_512 for lane size != 8. + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec256<T> CompressNot(Vec256<T> v, Mask256<T> mask) { + // See CompressIsPartition. + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintCompressNot64x4NibbleTables + 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, + 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, + 0x00003210, 0x00003201, 0x00003210, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(32) constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256<T>{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressStore + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API size_t CompressStore(Vec256<T> v, Mask256<T> mask, Full256<T> /* tag */, + T* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API size_t CompressStore(Vec256<T> v, Mask256<T> mask, Full256<T> /* tag */, + T* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +HWY_API size_t CompressStore(Vec256<float> v, Mask256<float> mask, + Full256<float> /* tag */, + float* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +HWY_API size_t CompressStore(Vec256<double> v, Mask256<double> mask, + Full256<double> /* tag */, + double* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw} & 0xFull); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore (CompressStore) + +template <typename T> +HWY_API size_t CompressBlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> d, + T* HWY_RESTRICT unaligned) { + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) > 2) { + // Native (32 or 64-bit) AVX-512 instruction already does the blending at no + // extra cost (latency 11, rthroughput 2 - same as compress plus store). + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(Vec256<T> v, const uint8_t* HWY_RESTRICT bits, + Full256<T> d, T* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#else // AVX2 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_LE128 there. +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Mask256<T> LoadMaskBits256(Full256<T> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + const Repartition<uint32_t, decltype(d)> du32; + const auto vbits = BitCast(du, Set(du32, static_cast<uint32_t>(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + const Repartition<uint64_t, decltype(d)> du64; + alignas(32) constexpr uint64_t kRep8[4] = { + 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, + 0x0303030303030303ull}; + const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); + + alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Mask256<T> LoadMaskBits256(Full256<T> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(32) constexpr uint16_t kBit[16] = { + 1, 2, 4, 8, 16, 32, 64, 128, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + const auto vmask_bits = Set(du, static_cast<uint16_t>(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Mask256<T> LoadMaskBits256(Full256<T> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast<uint32_t>(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Mask256<T> LoadMaskBits256(Full256<T> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> du; + alignas(32) constexpr uint64_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T> +HWY_API Mask256<T> LoadMaskBits(Full256<T> d, + const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes<kNumBytes>(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::LoadMaskBits256(d, mask_bits); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE uint64_t BitsFromMask(const Mask256<T> mask) { + const Full256<T> d; + const Full256<uint8_t> d8; + const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; + // Prevent sign-extension of 32-bit masks because the intrinsic returns int. + return static_cast<uint32_t>(_mm256_movemask_epi8(sign_bits)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE uint64_t BitsFromMask(const Mask256<T> mask) { +#if HWY_ARCH_X86_64 + const Full256<T> d; + const Full256<uint8_t> d8; + const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + const uint64_t sign_bits8 = BitsFromMask(mask8); + // Skip the bits from the lower byte of each u16 (better not to use the + // same packs_epi16 as SSE4, because that requires an extra swizzle here). + return _pext_u64(sign_bits8, 0xAAAAAAAAull); +#else + // Slow workaround for 32-bit builds, which lack _pext_u64. + // Remove useless lower half of each u16 while preserving the sign bit. + // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. + const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); + // Move odd qwords (value zero) to top so they don't affect the mask value. + const auto compressed = + _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0)); + return static_cast<unsigned>(_mm256_movemask_epi8(compressed)); +#endif // HWY_ARCH_X86_64 +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE uint64_t BitsFromMask(const Mask256<T> mask) { + const Full256<T> d; + const Full256<float> df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast<unsigned>(_mm256_movemask_ps(sign_bits)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE uint64_t BitsFromMask(const Mask256<T> mask) { + const Full256<T> d; + const Full256<double> df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast<unsigned>(_mm256_movemask_pd(sign_bits)); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template <typename T> +HWY_API size_t StoreMaskBits(const Full256<T> /* tag */, const Mask256<T> mask, + uint8_t* bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + const uint64_t mask_bits = detail::BitsFromMask(mask); + CopyBytes<kNumBytes>(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask +// lane is 0 or ~0. +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API bool AllFalse(const Full256<T> d, const Mask256<T> mask) { + const Repartition<uint8_t, decltype(d)> d8; + const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return detail::BitsFromMask(mask8) == 0; +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API bool AllFalse(const Full256<T> /* tag */, const Mask256<T> mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return detail::BitsFromMask(mask) == 0; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API bool AllTrue(const Full256<T> d, const Mask256<T> mask) { + const Repartition<uint8_t, decltype(d)> d8; + const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return detail::BitsFromMask(mask8) == (1ull << 32) - 1; +} +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API bool AllTrue(const Full256<T> /* tag */, const Mask256<T> mask) { + constexpr uint64_t kAllBits = (1ull << (32 / sizeof(T))) - 1; + return detail::BitsFromMask(mask) == kAllBits; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API size_t CountTrue(const Full256<T> d, const Mask256<T> mask) { + const Repartition<uint8_t, decltype(d)> d8; + const Mask256<uint8_t> mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return PopCount(detail::BitsFromMask(mask8)) >> 1; +} +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_API size_t CountTrue(const Full256<T> /* tag */, const Mask256<T> mask) { + return PopCount(detail::BitsFromMask(mask)); +} + +template <typename T> +HWY_API size_t FindKnownFirstTrue(const Full256<T> /* tag */, + const Mask256<T> mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return Num0BitsBelowLS1Bit_Nonzero64(mask_bits); +} + +template <typename T> +HWY_API intptr_t FindFirstTrue(const Full256<T> /* tag */, + const Mask256<T> mask) { + const uint64_t mask_bits = detail::BitsFromMask(mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec256<uint32_t> IndicesFromBits(Full256<T> d, uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintCompress32x8Tables + 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, + 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, + 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, + 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, + 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, + 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, + 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, + 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, + 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, + 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, + 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, + 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, + 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, + 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, + 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, + 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, + 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, + 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, + 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, + 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, + 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, + 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, + 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, + 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, + 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, + 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, + 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, + 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, + 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, + 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, + 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, + 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, + 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, + 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, + 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, + 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, + 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, + 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, + 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, + 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, + 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, + 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, + 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; + + // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec256<uint32_t> IndicesFromBits(Full256<T> d, uint64_t mask_bits) { + const Repartition<uint32_t, decltype(d)> d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t u32_indices[128] = { + // PrintCompress64x4PairTables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, + 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, + 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, + 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, + 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, + 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, + 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, + 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_INLINE Vec256<uint32_t> IndicesFromNotBits(Full256<T> d, + uint64_t mask_bits) { + const RebindToUnsigned<decltype(d)> d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintCompressNot32x8Tables + 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, + 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, + 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, + 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, + 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, + 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, + 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, + 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, + 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, + 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, + 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, + 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, + 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, + 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, + 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, + 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, + 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, + 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, + 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, + 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, + 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, + 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, + 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, + 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, + 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, + 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, + 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, + 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, + 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, + 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, + 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, + 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, + 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, + 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, + 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, + 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, + 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, + 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, + 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, + 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, + 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, + 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, + 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; + + // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_INLINE Vec256<uint32_t> IndicesFromNotBits(Full256<T> d, + uint64_t mask_bits) { + const Repartition<uint32_t, decltype(d)> d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) constexpr uint32_t u32_indices[128] = { + // PrintCompressNot64x4PairTables + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, + 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, + 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, + 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, + 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 2)> +HWY_INLINE Vec256<T> Compress(Vec256<T> v, const uint64_t mask_bits) { + const Full256<T> d; + const Repartition<uint32_t, decltype(d)> du32; + + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256<uint32_t> indices{IndicesFromBits(d, mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec256<T> Compress(Vec256<T> v, const uint64_t mask_bits) { + const Full256<T> d; + const RebindToUnsigned<decltype(d)> du; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Half<decltype(du)> duh; + const auto half0 = LowerHalf(duh, vu16); + const auto half1 = UpperHalf(duh, vu16); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = detail::CompressBits(half0, mask_bits0); + const auto compressed1 = detail::CompressBits(half1, mask_bits1); + + alignas(32) uint16_t all_true[16] = {}; + // Store mask=true lanes, left to right. + const size_t num_true0 = PopCount(mask_bits0); + Store(compressed0, duh, all_true); + StoreU(compressed1, duh, all_true + num_true0); + + if (hwy::HWY_NAMESPACE::CompressIsPartition<T>::value) { + // Store mask=false lanes, right to left. The second vector fills the upper + // half with right-aligned false lanes. The first vector is shifted + // rightwards to overwrite the true lanes of the second. + alignas(32) uint16_t all_false[16] = {}; + const size_t num_true1 = PopCount(mask_bits1); + Store(compressed1, duh, all_false + 8); + StoreU(compressed0, duh, all_false + num_true1); + + const auto mask = FirstN(du, num_true0 + num_true1); + return BitCast(d, + IfThenElse(mask, Load(du, all_true), Load(du, all_false))); + } else { + // Only care about the mask=true lanes. + return BitCast(d, Load(du, all_true)); + } +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x110)> // 4 or 8 bytes +HWY_INLINE Vec256<T> CompressNot(Vec256<T> v, const uint64_t mask_bits) { + const Full256<T> d; + const Repartition<uint32_t, decltype(d)> du32; + + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256<uint32_t> indices{IndicesFromNotBits(d, mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_INLINE Vec256<T> CompressNot(Vec256<T> v, const uint64_t mask_bits) { + // Compress ensures only the lower 16 bits are set, so flip those. + return Compress(v, mask_bits ^ 0xFFFF); +} + +} // namespace detail + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> m) { + return detail::Compress(v, detail::BitsFromMask(m)); +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec256<T> CompressNot(Vec256<T> v, Mask256<T> m) { + return detail::CompressNot(v, detail::BitsFromMask(m)); +} + +HWY_API Vec256<uint64_t> CompressBlocksNot(Vec256<uint64_t> v, + Mask256<uint64_t> mask) { + return CompressNot(v, mask); +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API Vec256<T> CompressBits(Vec256<T> v, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes<kNumBytes>(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressStore(Vec256<T> v, Mask256<T> m, Full256<T> d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <typename T, HWY_IF_LANE_SIZE_ONE_OF(T, 0x110)> // 4 or 8 bytes +HWY_API size_t CompressBlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + + const Repartition<uint32_t, decltype(d)> du32; + HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). Nibble MSB encodes FirstN. + const Vec256<uint32_t> idx_and_mask = detail::IndicesFromBits(d, mask_bits); + // Shift nibble MSB into MSB + const Mask256<uint32_t> mask32 = MaskFromVec(ShiftLeft<28>(idx_and_mask)); + // First cast to unsigned (RebindMask cannot change lane size) + const Mask256<MakeUnsigned<T>> mask_u{mask32.raw}; + const Mask256<T> mask = RebindMask(d, mask_u); + const Vec256<T> compressed = + BitCast(d, TableLookupLanes(BitCast(du32, v), + Indices256<uint32_t>{idx_and_mask.raw})); + + BlendedStore(compressed, mask, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API size_t CompressBlendedStore(Vec256<T> v, Mask256<T> m, Full256<T> d, + T* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = detail::BitsFromMask(m); + const size_t count = PopCount(mask_bits); + const Vec256<T> compressed = detail::Compress(v, mask_bits); + +#if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(32) T buf[16]; + Store(compressed, d, buf); + memcpy(unaligned, buf, count * sizeof(T)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t CompressBitsStore(Vec256<T> v, const uint8_t* HWY_RESTRICT bits, + Full256<T> d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes<kNumBytes>(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template <typename T> +HWY_API void LoadTransposedBlocks3(Full256<T> d, + const T* HWY_RESTRICT unaligned, + Vec256<T>& A, Vec256<T>& B, Vec256<T>& C) { + constexpr size_t N = 32 / sizeof(T); + const Vec256<T> v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const Vec256<T> v32 = LoadU(d, unaligned + 1 * N); + const Vec256<T> v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template <typename T> +HWY_API void LoadTransposedBlocks4(Full256<T> d, + const T* HWY_RESTRICT unaligned, + Vec256<T>& A, Vec256<T>& B, Vec256<T>& C, + Vec256<T>& D) { + constexpr size_t N = 32 / sizeof(T); + const Vec256<T> v10 = LoadU(d, unaligned + 0 * N); + const Vec256<T> v32 = LoadU(d, unaligned + 1 * N); + const Vec256<T> v54 = LoadU(d, unaligned + 2 * N); + const Vec256<T> v76 = LoadU(d, unaligned + 3 * N); + + A = ConcatLowerLower(d, v54, v10); + B = ConcatUpperUpper(d, v54, v10); + C = ConcatLowerLower(d, v76, v32); + D = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template <typename T> +HWY_API void StoreTransposedBlocks2(const Vec256<T> i, const Vec256<T> j, + const Full256<T> d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template <typename T> +HWY_API void StoreTransposedBlocks3(const Vec256<T> i, const Vec256<T> j, + const Vec256<T> k, Full256<T> d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template <typename T> +HWY_API void StoreTransposedBlocks4(const Vec256<T> i, const Vec256<T> j, + const Vec256<T> k, const Vec256<T> l, + Full256<T> d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 32 / sizeof(T); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ Reductions + +namespace detail { + +// Returns sum{lane[i]} in each lane. "v3210" is a replicated 128-bit block. +// Same logic as x86/128.h, but with Vec256 arguments. +template <typename T> +HWY_INLINE Vec256<T> SumOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256<T> v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = v3210 + v1032; + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return v20_31_20_31 + v31_20_31_20; +} +template <typename T> +HWY_INLINE Vec256<T> MinOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256<T> v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Min(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Min(v20_31_20_31, v31_20_31_20); +} +template <typename T> +HWY_INLINE Vec256<T> MaxOfLanes(hwy::SizeTag<4> /* tag */, + const Vec256<T> v3210) { + const auto v1032 = Shuffle1032(v3210); + const auto v31_20_31_20 = Max(v3210, v1032); + const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); + return Max(v20_31_20_31, v31_20_31_20); +} + +template <typename T> +HWY_INLINE Vec256<T> SumOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256<T> v10) { + const auto v01 = Shuffle01(v10); + return v10 + v01; +} +template <typename T> +HWY_INLINE Vec256<T> MinOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256<T> v10) { + const auto v01 = Shuffle01(v10); + return Min(v10, v01); +} +template <typename T> +HWY_INLINE Vec256<T> MaxOfLanes(hwy::SizeTag<8> /* tag */, + const Vec256<T> v10) { + const auto v01 = Shuffle01(v10); + return Max(v10, v01); +} + +HWY_API Vec256<uint16_t> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec256<uint16_t> v) { + const Full256<uint16_t> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +HWY_API Vec256<int16_t> SumOfLanes(hwy::SizeTag<2> /* tag */, + Vec256<int16_t> v) { + const Full256<int16_t> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(hwy::SizeTag<4>(), even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +HWY_API Vec256<uint16_t> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec256<uint16_t> v) { + const Full256<uint16_t> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec256<int16_t> MinOfLanes(hwy::SizeTag<2> /* tag */, + Vec256<int16_t> v) { + const Full256<int16_t> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(hwy::SizeTag<4>(), Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +HWY_API Vec256<uint16_t> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec256<uint16_t> v) { + const Full256<uint16_t> d; + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec256<int16_t> MaxOfLanes(hwy::SizeTag<2> /* tag */, + Vec256<int16_t> v) { + const Full256<int16_t> d; + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(hwy::SizeTag<4>(), Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +} // namespace detail + +// Supported for {uif}{32,64},{ui}16. Returns the broadcasted result. +template <typename T> +HWY_API Vec256<T> SumOfLanes(Full256<T> d, const Vec256<T> vHL) { + const Vec256<T> vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::SumOfLanes(hwy::SizeTag<sizeof(T)>(), vLH + vHL); +} +template <typename T> +HWY_API Vec256<T> MinOfLanes(Full256<T> d, const Vec256<T> vHL) { + const Vec256<T> vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::MinOfLanes(hwy::SizeTag<sizeof(T)>(), Min(vLH, vHL)); +} +template <typename T> +HWY_API Vec256<T> MaxOfLanes(Full256<T> d, const Vec256<T> vHL) { + const Vec256<T> vLH = ConcatLowerUpper(d, vHL, vHL); + return detail::MaxOfLanes(hwy::SizeTag<sizeof(T)>(), Max(vLH, vHL)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/highway/hwy/ops/x86_512-inl.h b/third_party/highway/hwy/ops/x86_512-inl.h new file mode 100644 index 0000000000..5f3b34c357 --- /dev/null +++ b/third_party/highway/hwy/ops/x86_512-inl.h @@ -0,0 +1,4605 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 512-bit AVX512 vectors and operations. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4703 6001 26494, ignored "-Wmaybe-uninitialized") +#endif + +#include <immintrin.h> // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including <immintrin.h> should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +// clang-format off +#include <smmintrin.h> + +#include <avxintrin.h> +#include <avx2intrin.h> +#include <f16cintrin.h> +#include <fmaintrin.h> + +#include <avx512fintrin.h> +#include <avx512vlintrin.h> +#include <avx512bwintrin.h> +#include <avx512dqintrin.h> +#include <avx512vlbwintrin.h> +#include <avx512vldqintrin.h> +#include <avx512bitalgintrin.h> +#include <avx512vlbitalgintrin.h> +#include <avx512vpopcntdqintrin.h> +#include <avx512vpopcntdqvlintrin.h> +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +#include <stddef.h> +#include <stdint.h> + +#if HWY_IS_MSAN +#include <sanitizer/msan_interface.h> +#endif + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_256-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { + +template <typename T> +struct Raw512 { + using type = __m512i; +}; +template <> +struct Raw512<float> { + using type = __m512; +}; +template <> +struct Raw512<double> { + using type = __m512d; +}; + +// Template arg: sizeof(lane type) +template <size_t size> +struct RawMask512 {}; +template <> +struct RawMask512<1> { + using type = __mmask64; +}; +template <> +struct RawMask512<2> { + using type = __mmask32; +}; +template <> +struct RawMask512<4> { + using type = __mmask16; +}; +template <> +struct RawMask512<8> { + using type = __mmask8; +}; + +} // namespace detail + +template <typename T> +class Vec512 { + using Raw = typename detail::Raw512<T>::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec512& operator*=(const Vec512 other) { + return *this = (*this * other); + } + HWY_INLINE Vec512& operator/=(const Vec512 other) { + return *this = (*this / other); + } + HWY_INLINE Vec512& operator+=(const Vec512 other) { + return *this = (*this + other); + } + HWY_INLINE Vec512& operator-=(const Vec512 other) { + return *this = (*this - other); + } + HWY_INLINE Vec512& operator&=(const Vec512 other) { + return *this = (*this & other); + } + HWY_INLINE Vec512& operator|=(const Vec512 other) { + return *this = (*this | other); + } + HWY_INLINE Vec512& operator^=(const Vec512 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Mask register: one bit per lane. +template <typename T> +struct Mask512 { + using Raw = typename detail::RawMask512<sizeof(T)>::type; + Raw raw; +}; + +template <typename T> +using Full512 = Simd<T, 64 / sizeof(T), 0>; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } +HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } +HWY_INLINE __m512i BitCastToInteger(__m512d v) { + return _mm512_castpd_si512(v); +} + +template <typename T> +HWY_INLINE Vec512<uint8_t> BitCastToByte(Vec512<T> v) { + return Vec512<uint8_t>{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template <typename T> +struct BitCastFromInteger512 { + HWY_INLINE __m512i operator()(__m512i v) { return v; } +}; +template <> +struct BitCastFromInteger512<float> { + HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } +}; +template <> +struct BitCastFromInteger512<double> { + HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } +}; + +template <typename T> +HWY_INLINE Vec512<T> BitCastFromByte(Full512<T> /* tag */, Vec512<uint8_t> v) { + return Vec512<T>{BitCastFromInteger512<T>()(v.raw)}; +} + +} // namespace detail + +template <typename T, typename FromT> +HWY_API Vec512<T> BitCast(Full512<T> d, Vec512<FromT> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +// Returns an all-zero vector. +template <typename T> +HWY_API Vec512<T> Zero(Full512<T> /* tag */) { + return Vec512<T>{_mm512_setzero_si512()}; +} +HWY_API Vec512<float> Zero(Full512<float> /* tag */) { + return Vec512<float>{_mm512_setzero_ps()}; +} +HWY_API Vec512<double> Zero(Full512<double> /* tag */) { + return Vec512<double>{_mm512_setzero_pd()}; +} + +// Returns a vector with all lanes set to "t". +HWY_API Vec512<uint8_t> Set(Full512<uint8_t> /* tag */, const uint8_t t) { + return Vec512<uint8_t>{_mm512_set1_epi8(static_cast<char>(t))}; // NOLINT +} +HWY_API Vec512<uint16_t> Set(Full512<uint16_t> /* tag */, const uint16_t t) { + return Vec512<uint16_t>{_mm512_set1_epi16(static_cast<short>(t))}; // NOLINT +} +HWY_API Vec512<uint32_t> Set(Full512<uint32_t> /* tag */, const uint32_t t) { + return Vec512<uint32_t>{_mm512_set1_epi32(static_cast<int>(t))}; +} +HWY_API Vec512<uint64_t> Set(Full512<uint64_t> /* tag */, const uint64_t t) { + return Vec512<uint64_t>{ + _mm512_set1_epi64(static_cast<long long>(t))}; // NOLINT +} +HWY_API Vec512<int8_t> Set(Full512<int8_t> /* tag */, const int8_t t) { + return Vec512<int8_t>{_mm512_set1_epi8(static_cast<char>(t))}; // NOLINT +} +HWY_API Vec512<int16_t> Set(Full512<int16_t> /* tag */, const int16_t t) { + return Vec512<int16_t>{_mm512_set1_epi16(static_cast<short>(t))}; // NOLINT +} +HWY_API Vec512<int32_t> Set(Full512<int32_t> /* tag */, const int32_t t) { + return Vec512<int32_t>{_mm512_set1_epi32(t)}; +} +HWY_API Vec512<int64_t> Set(Full512<int64_t> /* tag */, const int64_t t) { + return Vec512<int64_t>{ + _mm512_set1_epi64(static_cast<long long>(t))}; // NOLINT +} +HWY_API Vec512<float> Set(Full512<float> /* tag */, const float t) { + return Vec512<float>{_mm512_set1_ps(t)}; +} +HWY_API Vec512<double> Set(Full512<double> /* tag */, const double t) { + return Vec512<double>{_mm512_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template <typename T> +HWY_API Vec512<T> Undefined(Full512<T> /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec512<T>{_mm512_undefined_epi32()}; +} +HWY_API Vec512<float> Undefined(Full512<float> /* tag */) { + return Vec512<float>{_mm512_undefined_ps()}; +} +HWY_API Vec512<double> Undefined(Full512<double> /* tag */) { + return Vec512<double>{_mm512_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== LOGICAL + +// ------------------------------ Not + +template <typename T> +HWY_API Vec512<T> Not(const Vec512<T> v) { + using TU = MakeUnsigned<T>; + const __m512i vu = BitCast(Full512<TU>(), v).raw; + return BitCast(Full512<T>(), + Vec512<TU>{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); +} + +// ------------------------------ And + +template <typename T> +HWY_API Vec512<T> And(const Vec512<T> a, const Vec512<T> b) { + return Vec512<T>{_mm512_and_si512(a.raw, b.raw)}; +} + +HWY_API Vec512<float> And(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_and_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> And(const Vec512<double> a, const Vec512<double> b) { + return Vec512<double>{_mm512_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template <typename T> +HWY_API Vec512<T> AndNot(const Vec512<T> not_mask, const Vec512<T> mask) { + return Vec512<T>{_mm512_andnot_si512(not_mask.raw, mask.raw)}; +} +HWY_API Vec512<float> AndNot(const Vec512<float> not_mask, + const Vec512<float> mask) { + return Vec512<float>{_mm512_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec512<double> AndNot(const Vec512<double> not_mask, + const Vec512<double> mask) { + return Vec512<double>{_mm512_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template <typename T> +HWY_API Vec512<T> Or(const Vec512<T> a, const Vec512<T> b) { + return Vec512<T>{_mm512_or_si512(a.raw, b.raw)}; +} + +HWY_API Vec512<float> Or(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_or_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> Or(const Vec512<double> a, const Vec512<double> b) { + return Vec512<double>{_mm512_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template <typename T> +HWY_API Vec512<T> Xor(const Vec512<T> a, const Vec512<T> b) { + return Vec512<T>{_mm512_xor_si512(a.raw, b.raw)}; +} + +HWY_API Vec512<float> Xor(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> Xor(const Vec512<double> a, const Vec512<double> b) { + return Vec512<double>{_mm512_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor3 +template <typename T> +HWY_API Vec512<T> Xor3(Vec512<T> x1, Vec512<T> x2, Vec512<T> x3) { + const Full512<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96); + return BitCast(d, VU{ret}); +} + +// ------------------------------ Or3 +template <typename T> +HWY_API Vec512<T> Or3(Vec512<T> o1, Vec512<T> o2, Vec512<T> o3) { + const Full512<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE); + return BitCast(d, VU{ret}); +} + +// ------------------------------ OrAnd +template <typename T> +HWY_API Vec512<T> OrAnd(Vec512<T> o, Vec512<T> a1, Vec512<T> a2) { + const Full512<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); + return BitCast(d, VU{ret}); +} + +// ------------------------------ IfVecThenElse +template <typename T> +HWY_API Vec512<T> IfVecThenElse(Vec512<T> mask, Vec512<T> yes, Vec512<T> no) { + const Full512<T> d; + const RebindToUnsigned<decltype(d)> du; + using VU = VFromD<decltype(du)>; + return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, + BitCast(du, yes).raw, + BitCast(du, no).raw, 0xCA)}); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template <typename T> +HWY_API Vec512<T> operator&(const Vec512<T> a, const Vec512<T> b) { + return And(a, b); +} + +template <typename T> +HWY_API Vec512<T> operator|(const Vec512<T> a, const Vec512<T> b) { + return Or(a, b); +} + +template <typename T> +HWY_API Vec512<T> operator^(const Vec512<T> a, const Vec512<T> b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET == HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template <typename T> +HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec512<T> v) { + return Vec512<T>{_mm512_popcnt_epi8(v.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec512<T> v) { + return Vec512<T>{_mm512_popcnt_epi16(v.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec512<T> v) { + return Vec512<T>{_mm512_popcnt_epi32(v.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec512<T> v) { + return Vec512<T>{_mm512_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec512<T> PopulationCount(Vec512<T> v) { + return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ CopySign + +template <typename T> +HWY_API Vec512<T> CopySign(const Vec512<T> magn, const Vec512<T> sign) { + static_assert(IsFloat<T>(), "Only makes sense for floating-point"); + + const Full512<T> d; + const auto msb = SignBit(d); + + const Rebind<MakeUnsigned<T>, decltype(d)> du; + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + // The lane size does not matter because we are not using predication. + const __m512i out = _mm512_ternarylogic_epi32( + BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); + return BitCast(d, decltype(Zero(du)){out}); +} + +template <typename T> +HWY_API Vec512<T> CopySignToAbs(const Vec512<T> abs, const Vec512<T> sign) { + // AVX3 can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== MASK + +// ------------------------------ FirstN + +// Possibilities for constructing a bitmask of N ones: +// - kshift* only consider the lowest byte of the shift count, so they would +// not correctly handle large n. +// - Scalar shifts >= 64 are UB. +// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, +// we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. + +#if HWY_ARCH_X86_32 +namespace detail { + +// 32 bit mask is sufficient for lane size >= 2. +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_INLINE Mask512<T> FirstN(size_t n) { + Mask512<T> m; + const uint32_t all = ~uint32_t{0}; + // BZHI only looks at the lower 8 bits of n! + m.raw = static_cast<decltype(m.raw)>((n > 255) ? all : _bzhi_u32(all, n)); + return m; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_INLINE Mask512<T> FirstN(size_t n) { + const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; + return Mask512<T>{static_cast<__mmask64>(bits)}; +} + +} // namespace detail +#endif // HWY_ARCH_X86_32 + +template <typename T> +HWY_API Mask512<T> FirstN(const Full512<T> /*tag*/, size_t n) { +#if HWY_ARCH_X86_64 + Mask512<T> m; + const uint64_t all = ~uint64_t{0}; + // BZHI only looks at the lower 8 bits of n! + m.raw = static_cast<decltype(m.raw)>((n > 255) ? all : _bzhi_u64(all, n)); + return m; +#else + return detail::FirstN<T>(n); +#endif // HWY_ARCH_X86_64 +} + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template <typename T> +HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<1> /* tag */, + const Mask512<T> mask, const Vec512<T> yes, + const Vec512<T> no) { + return Vec512<T>{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<2> /* tag */, + const Mask512<T> mask, const Vec512<T> yes, + const Vec512<T> no) { + return Vec512<T>{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<4> /* tag */, + const Mask512<T> mask, const Vec512<T> yes, + const Vec512<T> no) { + return Vec512<T>{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<8> /* tag */, + const Mask512<T> mask, const Vec512<T> yes, + const Vec512<T> no) { + return Vec512<T>{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec512<T> IfThenElse(const Mask512<T> mask, const Vec512<T> yes, + const Vec512<T> no) { + return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no); +} +HWY_API Vec512<float> IfThenElse(const Mask512<float> mask, + const Vec512<float> yes, + const Vec512<float> no) { + return Vec512<float>{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; +} +HWY_API Vec512<double> IfThenElse(const Mask512<double> mask, + const Vec512<double> yes, + const Vec512<double> no) { + return Vec512<double>{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; +} + +namespace detail { + +template <typename T> +HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<1> /* tag */, + const Mask512<T> mask, + const Vec512<T> yes) { + return Vec512<T>{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<2> /* tag */, + const Mask512<T> mask, + const Vec512<T> yes) { + return Vec512<T>{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<4> /* tag */, + const Mask512<T> mask, + const Vec512<T> yes) { + return Vec512<T>{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<8> /* tag */, + const Mask512<T> mask, + const Vec512<T> yes) { + return Vec512<T>{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec512<T> IfThenElseZero(const Mask512<T> mask, const Vec512<T> yes) { + return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes); +} +HWY_API Vec512<float> IfThenElseZero(const Mask512<float> mask, + const Vec512<float> yes) { + return Vec512<float>{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec512<double> IfThenElseZero(const Mask512<double> mask, + const Vec512<double> yes) { + return Vec512<double>{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template <typename T> +HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */, + const Mask512<T> mask, const Vec512<T> no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec512<T>{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */, + const Mask512<T> mask, const Vec512<T> no) { + return Vec512<T>{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */, + const Mask512<T> mask, const Vec512<T> no) { + return Vec512<T>{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template <typename T> +HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */, + const Mask512<T> mask, const Vec512<T> no) { + return Vec512<T>{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Vec512<T> IfThenZeroElse(const Mask512<T> mask, const Vec512<T> no) { + return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no); +} +HWY_API Vec512<float> IfThenZeroElse(const Mask512<float> mask, + const Vec512<float> no) { + return Vec512<float>{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec512<double> IfThenZeroElse(const Mask512<double> mask, + const Vec512<double> no) { + return Vec512<double>{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template <typename T> +HWY_API Vec512<T> IfNegativeThenElse(Vec512<T> v, Vec512<T> yes, Vec512<T> no) { + static_assert(IsSigned<T>(), "Only works for signed/float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenElse(MaskFromVec(v), yes, no); +} + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec512<T> ZeroIfNegative(const Vec512<T> v) { + // AVX3 MaskFromVec only looks at the MSB + return IfThenZeroElse(MaskFromVec(v), v); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec512<uint8_t> operator+(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> operator+(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> operator+(const Vec512<uint32_t> a, + const Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> operator+(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512<int8_t> operator+(const Vec512<int8_t> a, + const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> operator+(const Vec512<int16_t> a, + const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> operator+(const Vec512<int32_t> a, + const Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> operator+(const Vec512<int64_t> a, + const Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512<float> operator+(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_add_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> operator+(const Vec512<double> a, + const Vec512<double> b) { + return Vec512<double>{_mm512_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec512<uint8_t> operator-(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> operator-(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> operator-(const Vec512<uint32_t> a, + const Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> operator-(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512<int8_t> operator-(const Vec512<int8_t> a, + const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> operator-(const Vec512<int16_t> a, + const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> operator-(const Vec512<int32_t> a, + const Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> operator-(const Vec512<int64_t> a, + const Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512<float> operator-(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> operator-(const Vec512<double> a, + const Vec512<double> b) { + return Vec512<double>{_mm512_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec512<uint64_t> SumsOf8(const Vec512<uint8_t> v) { + return Vec512<uint64_t>{_mm512_sad_epu8(v.raw, _mm512_setzero_si512())}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec512<uint8_t> SaturatedAdd(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> SaturatedAdd(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512<int8_t> SaturatedAdd(const Vec512<int8_t> a, + const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> SaturatedAdd(const Vec512<int16_t> a, + const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec512<uint8_t> SaturatedSub(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> SaturatedSub(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512<int8_t> SaturatedSub(const Vec512<int8_t> a, + const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> SaturatedSub(const Vec512<int16_t> a, + const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec512<uint8_t> AverageRound(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> AverageRound(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec512<int8_t> Abs(const Vec512<int8_t> v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (untested due to internal compiler error) + const auto zero = Zero(Full512<int8_t>()); + return Vec512<int8_t>{_mm512_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec512<int8_t>{_mm512_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec512<int16_t> Abs(const Vec512<int16_t> v) { + return Vec512<int16_t>{_mm512_abs_epi16(v.raw)}; +} +HWY_API Vec512<int32_t> Abs(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_abs_epi32(v.raw)}; +} +HWY_API Vec512<int64_t> Abs(const Vec512<int64_t> v) { + return Vec512<int64_t>{_mm512_abs_epi64(v.raw)}; +} + +// These aren't native instructions, they also involve AND with constant. +HWY_API Vec512<float> Abs(const Vec512<float> v) { + return Vec512<float>{_mm512_abs_ps(v.raw)}; +} +HWY_API Vec512<double> Abs(const Vec512<double> v) { + return Vec512<double>{_mm512_abs_pd(v.raw)}; +} +// ------------------------------ ShiftLeft + +template <int kBits> +HWY_API Vec512<uint16_t> ShiftLeft(const Vec512<uint16_t> v) { + return Vec512<uint16_t>{_mm512_slli_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<uint32_t> ShiftLeft(const Vec512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_slli_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<uint64_t> ShiftLeft(const Vec512<uint64_t> v) { + return Vec512<uint64_t>{_mm512_slli_epi64(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<int16_t> ShiftLeft(const Vec512<int16_t> v) { + return Vec512<int16_t>{_mm512_slli_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<int32_t> ShiftLeft(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_slli_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<int64_t> ShiftLeft(const Vec512<int64_t> v) { + return Vec512<int64_t>{_mm512_slli_epi64(v.raw, kBits)}; +} + +template <int kBits, typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec512<T> ShiftLeft(const Vec512<T> v) { + const Full512<T> d8; + const RepartitionToWide<decltype(d8)> d16; + const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF))); +} + +// ------------------------------ ShiftRight + +template <int kBits> +HWY_API Vec512<uint16_t> ShiftRight(const Vec512<uint16_t> v) { + return Vec512<uint16_t>{_mm512_srli_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<uint32_t> ShiftRight(const Vec512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_srli_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<uint64_t> ShiftRight(const Vec512<uint64_t> v) { + return Vec512<uint64_t>{_mm512_srli_epi64(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<uint8_t> ShiftRight(const Vec512<uint8_t> v) { + const Full512<uint8_t> d8; + // Use raw instead of BitCast to support N=1. + const Vec512<uint8_t> shifted{ShiftRight<kBits>(Vec512<uint16_t>{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template <int kBits> +HWY_API Vec512<int16_t> ShiftRight(const Vec512<int16_t> v) { + return Vec512<int16_t>{_mm512_srai_epi16(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<int32_t> ShiftRight(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_srai_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<int64_t> ShiftRight(const Vec512<int64_t> v) { + return Vec512<int64_t>{_mm512_srai_epi64(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<int8_t> ShiftRight(const Vec512<int8_t> v) { + const Full512<int8_t> di; + const Full512<uint8_t> du; + const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight + +template <int kBits> +HWY_API Vec512<uint32_t> RotateRight(const Vec512<uint32_t> v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + return Vec512<uint32_t>{_mm512_ror_epi32(v.raw, kBits)}; +} + +template <int kBits> +HWY_API Vec512<uint64_t> RotateRight(const Vec512<uint64_t> v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + return Vec512<uint64_t>{_mm512_ror_epi64(v.raw, kBits)}; +} + +// ------------------------------ ShiftLeftSame + +HWY_API Vec512<uint16_t> ShiftLeftSame(const Vec512<uint16_t> v, + const int bits) { + return Vec512<uint16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512<uint32_t> ShiftLeftSame(const Vec512<uint32_t> v, + const int bits) { + return Vec512<uint32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512<uint64_t> ShiftLeftSame(const Vec512<uint64_t> v, + const int bits) { + return Vec512<uint64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512<int16_t> ShiftLeftSame(const Vec512<int16_t> v, const int bits) { + return Vec512<int16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512<int32_t> ShiftLeftSame(const Vec512<int32_t> v, const int bits) { + return Vec512<int32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512<int64_t> ShiftLeftSame(const Vec512<int64_t> v, const int bits) { + return Vec512<int64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec512<T> ShiftLeftSame(const Vec512<T> v, const int bits) { + const Full512<T> d8; + const RepartitionToWide<decltype(d8)> d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame + +HWY_API Vec512<uint16_t> ShiftRightSame(const Vec512<uint16_t> v, + const int bits) { + return Vec512<uint16_t>{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512<uint32_t> ShiftRightSame(const Vec512<uint32_t> v, + const int bits) { + return Vec512<uint32_t>{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512<uint64_t> ShiftRightSame(const Vec512<uint64_t> v, + const int bits) { + return Vec512<uint64_t>{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512<uint8_t> ShiftRightSame(Vec512<uint8_t> v, const int bits) { + const Full512<uint8_t> d8; + const RepartitionToWide<decltype(d8)> d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits)); +} + +HWY_API Vec512<int16_t> ShiftRightSame(const Vec512<int16_t> v, + const int bits) { + return Vec512<int16_t>{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512<int32_t> ShiftRightSame(const Vec512<int32_t> v, + const int bits) { + return Vec512<int32_t>{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512<int64_t> ShiftRightSame(const Vec512<int64_t> v, + const int bits) { + return Vec512<int64_t>{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) { + const Full512<int8_t> di; + const Full512<uint8_t> du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Shl + +HWY_API Vec512<uint16_t> operator<<(const Vec512<uint16_t> v, + const Vec512<uint16_t> bits) { + return Vec512<uint16_t>{_mm512_sllv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512<uint32_t> operator<<(const Vec512<uint32_t> v, + const Vec512<uint32_t> bits) { + return Vec512<uint32_t>{_mm512_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512<uint64_t> operator<<(const Vec512<uint64_t> v, + const Vec512<uint64_t> bits) { + return Vec512<uint64_t>{_mm512_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template <typename T, HWY_IF_SIGNED(T)> +HWY_API Vec512<T> operator<<(const Vec512<T> v, const Vec512<T> bits) { + const Full512<T> di; + const Full512<MakeUnsigned<T>> du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr + +HWY_API Vec512<uint16_t> operator>>(const Vec512<uint16_t> v, + const Vec512<uint16_t> bits) { + return Vec512<uint16_t>{_mm512_srlv_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512<uint32_t> operator>>(const Vec512<uint32_t> v, + const Vec512<uint32_t> bits) { + return Vec512<uint32_t>{_mm512_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512<uint64_t> operator>>(const Vec512<uint64_t> v, + const Vec512<uint64_t> bits) { + return Vec512<uint64_t>{_mm512_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec512<int16_t> operator>>(const Vec512<int16_t> v, + const Vec512<int16_t> bits) { + return Vec512<int16_t>{_mm512_srav_epi16(v.raw, bits.raw)}; +} + +HWY_API Vec512<int32_t> operator>>(const Vec512<int32_t> v, + const Vec512<int32_t> bits) { + return Vec512<int32_t>{_mm512_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512<int64_t> operator>>(const Vec512<int64_t> v, + const Vec512<int64_t> bits) { + return Vec512<int64_t>{_mm512_srav_epi64(v.raw, bits.raw)}; +} + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec512<uint8_t> Min(const Vec512<uint8_t> a, const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> Min(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> Min(const Vec512<uint32_t> a, + const Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> Min(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_min_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512<int8_t> Min(const Vec512<int8_t> a, const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> Min(const Vec512<int16_t> a, const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> Min(const Vec512<int32_t> a, const Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> Min(const Vec512<int64_t> a, const Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_min_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512<float> Min(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_min_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> Min(const Vec512<double> a, const Vec512<double> b) { + return Vec512<double>{_mm512_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec512<uint8_t> Max(const Vec512<uint8_t> a, const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> Max(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> Max(const Vec512<uint32_t> a, + const Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> Max(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_max_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512<int8_t> Max(const Vec512<int8_t> a, const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> Max(const Vec512<int16_t> a, const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> Max(const Vec512<int32_t> a, const Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> Max(const Vec512<int64_t> a, const Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_max_epi64(a.raw, b.raw)}; +} + +// Float +HWY_API Vec512<float> Max(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_max_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> Max(const Vec512<double> a, const Vec512<double> b) { + return Vec512<double>{_mm512_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec512<uint16_t> operator*(Vec512<uint16_t> a, Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> operator*(Vec512<uint32_t> a, Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> operator*(Vec512<uint64_t> a, Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec256<uint64_t> operator*(Vec256<uint64_t> a, Vec256<uint64_t> b) { + return Vec256<uint64_t>{_mm256_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec128<uint64_t> operator*(Vec128<uint64_t> a, Vec128<uint64_t> b) { + return Vec128<uint64_t>{_mm_mullo_epi64(a.raw, b.raw)}; +} + +// Per-target flag to prevent generic_ops-inl.h from defining i64 operator*. +#ifdef HWY_NATIVE_I64MULLO +#undef HWY_NATIVE_I64MULLO +#else +#define HWY_NATIVE_I64MULLO +#endif + +// Signed +HWY_API Vec512<int16_t> operator*(Vec512<int16_t> a, Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> operator*(Vec512<int32_t> a, Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> operator*(Vec512<int64_t> a, Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec256<int64_t> operator*(Vec256<int64_t> a, Vec256<int64_t> b) { + return Vec256<int64_t>{_mm256_mullo_epi64(a.raw, b.raw)}; +} +HWY_API Vec128<int64_t> operator*(Vec128<int64_t> a, Vec128<int64_t> b) { + return Vec128<int64_t>{_mm_mullo_epi64(a.raw, b.raw)}; +} +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec512<uint16_t> MulHigh(Vec512<uint16_t> a, Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> MulHigh(Vec512<int16_t> a, Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512<int16_t> MulFixedPoint15(Vec512<int16_t> a, Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec512<int64_t> MulEven(Vec512<int32_t> a, Vec512<int32_t> b) { + return Vec512<int64_t>{_mm512_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> MulEven(Vec512<uint32_t> a, Vec512<uint32_t> b) { + return Vec512<uint64_t>{_mm512_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ Neg (Sub) + +template <typename T, HWY_IF_FLOAT(T)> +HWY_API Vec512<T> Neg(const Vec512<T> v) { + return Xor(v, SignBit(Full512<T>())); +} + +template <typename T, HWY_IF_NOT_FLOAT(T)> +HWY_API Vec512<T> Neg(const Vec512<T> v) { + return Zero(Full512<T>()) - v; +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec512<float> operator*(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> operator*(const Vec512<double> a, + const Vec512<double> b) { + return Vec512<double>{_mm512_mul_pd(a.raw, b.raw)}; +} + +HWY_API Vec512<float> operator/(const Vec512<float> a, const Vec512<float> b) { + return Vec512<float>{_mm512_div_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> operator/(const Vec512<double> a, + const Vec512<double> b) { + return Vec512<double>{_mm512_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +HWY_API Vec512<float> ApproximateReciprocal(const Vec512<float> v) { + return Vec512<float>{_mm512_rcp14_ps(v.raw)}; +} + +// Absolute value of difference. +HWY_API Vec512<float> AbsDiff(const Vec512<float> a, const Vec512<float> b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +HWY_API Vec512<float> MulAdd(const Vec512<float> mul, const Vec512<float> x, + const Vec512<float> add) { + return Vec512<float>{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512<double> MulAdd(const Vec512<double> mul, const Vec512<double> x, + const Vec512<double> add) { + return Vec512<double>{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +HWY_API Vec512<float> NegMulAdd(const Vec512<float> mul, const Vec512<float> x, + const Vec512<float> add) { + return Vec512<float>{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512<double> NegMulAdd(const Vec512<double> mul, + const Vec512<double> x, + const Vec512<double> add) { + return Vec512<double>{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +HWY_API Vec512<float> MulSub(const Vec512<float> mul, const Vec512<float> x, + const Vec512<float> sub) { + return Vec512<float>{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512<double> MulSub(const Vec512<double> mul, const Vec512<double> x, + const Vec512<double> sub) { + return Vec512<double>{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +HWY_API Vec512<float> NegMulSub(const Vec512<float> mul, const Vec512<float> x, + const Vec512<float> sub) { + return Vec512<float>{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512<double> NegMulSub(const Vec512<double> mul, + const Vec512<double> x, + const Vec512<double> sub) { + return Vec512<double>{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec512<float> Sqrt(const Vec512<float> v) { + return Vec512<float>{_mm512_sqrt_ps(v.raw)}; +} +HWY_API Vec512<double> Sqrt(const Vec512<double> v) { + return Vec512<double>{_mm512_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec512<float> ApproximateReciprocalSqrt(const Vec512<float> v) { + return Vec512<float>{_mm512_rsqrt14_ps(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Toward nearest integer, tie to even +HWY_API Vec512<float> Round(const Vec512<float> v) { + return Vec512<float>{_mm512_roundscale_ps( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512<double> Round(const Vec512<double> v) { + return Vec512<double>{_mm512_roundscale_pd( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +HWY_API Vec512<float> Trunc(const Vec512<float> v) { + return Vec512<float>{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512<double> Trunc(const Vec512<double> v) { + return Vec512<double>{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec512<float> Ceil(const Vec512<float> v) { + return Vec512<float>{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512<double> Ceil(const Vec512<double> v) { + return Vec512<double>{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +HWY_API Vec512<float> Floor(const Vec512<float> v) { + return Vec512<float>{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512<double> Floor(const Vec512<double> v) { + return Vec512<double>{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template <typename TFrom, typename TTo> +HWY_API Mask512<TTo> RebindMask(Full512<TTo> /*tag*/, Mask512<TFrom> m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask512<TTo>{m.raw}; +} + +namespace detail { + +template <typename T> +HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<1> /*tag*/, const Vec512<T> v, + const Vec512<T> bit) { + return Mask512<T>{_mm512_test_epi8_mask(v.raw, bit.raw)}; +} +template <typename T> +HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<2> /*tag*/, const Vec512<T> v, + const Vec512<T> bit) { + return Mask512<T>{_mm512_test_epi16_mask(v.raw, bit.raw)}; +} +template <typename T> +HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<4> /*tag*/, const Vec512<T> v, + const Vec512<T> bit) { + return Mask512<T>{_mm512_test_epi32_mask(v.raw, bit.raw)}; +} +template <typename T> +HWY_INLINE Mask512<T> TestBit(hwy::SizeTag<8> /*tag*/, const Vec512<T> v, + const Vec512<T> bit) { + return Mask512<T>{_mm512_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Mask512<T> TestBit(const Vec512<T> v, const Vec512<T> bit) { + static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag<sizeof(T)>(), v, bit); +} + +// ------------------------------ Equality + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask512<T> operator==(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512<float> operator==(Vec512<float> a, Vec512<float> b) { + return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask512<double> operator==(Vec512<double> a, Vec512<double> b) { + return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Mask512<T> operator!=(Vec512<T> a, Vec512<T> b) { + return Mask512<T>{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512<float> operator!=(Vec512<float> a, Vec512<float> b) { + return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask512<double> operator!=(Vec512<double> a, Vec512<double> b) { + return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask512<uint8_t> operator>(Vec512<uint8_t> a, Vec512<uint8_t> b) { + return Mask512<uint8_t>{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512<uint16_t> operator>(Vec512<uint16_t> a, Vec512<uint16_t> b) { + return Mask512<uint16_t>{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512<uint32_t> operator>(Vec512<uint32_t> a, Vec512<uint32_t> b) { + return Mask512<uint32_t>{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512<uint64_t> operator>(Vec512<uint64_t> a, Vec512<uint64_t> b) { + return Mask512<uint64_t>{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512<int8_t> operator>(Vec512<int8_t> a, Vec512<int8_t> b) { + return Mask512<int8_t>{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512<int16_t> operator>(Vec512<int16_t> a, Vec512<int16_t> b) { + return Mask512<int16_t>{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512<int32_t> operator>(Vec512<int32_t> a, Vec512<int32_t> b) { + return Mask512<int32_t>{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512<int64_t> operator>(Vec512<int64_t> a, Vec512<int64_t> b) { + return Mask512<int64_t>{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512<float> operator>(Vec512<float> a, Vec512<float> b) { + return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask512<double> operator>(Vec512<double> a, Vec512<double> b) { + return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +HWY_API Mask512<float> operator>=(Vec512<float> a, Vec512<float> b) { + return Mask512<float>{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask512<double> operator>=(Vec512<double> a, Vec512<double> b) { + return Mask512<double>{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +// ------------------------------ Reversed comparisons + +template <typename T> +HWY_API Mask512<T> operator<(Vec512<T> a, Vec512<T> b) { + return b > a; +} + +template <typename T> +HWY_API Mask512<T> operator<=(Vec512<T> a, Vec512<T> b) { + return b >= a; +} + +// ------------------------------ Mask + +namespace detail { + +template <typename T> +HWY_INLINE Mask512<T> MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512<T> v) { + return Mask512<T>{_mm512_movepi8_mask(v.raw)}; +} +template <typename T> +HWY_INLINE Mask512<T> MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512<T> v) { + return Mask512<T>{_mm512_movepi16_mask(v.raw)}; +} +template <typename T> +HWY_INLINE Mask512<T> MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512<T> v) { + return Mask512<T>{_mm512_movepi32_mask(v.raw)}; +} +template <typename T> +HWY_INLINE Mask512<T> MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512<T> v) { + return Mask512<T>{_mm512_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template <typename T> +HWY_API Mask512<T> MaskFromVec(const Vec512<T> v) { + return detail::MaskFromVec(hwy::SizeTag<sizeof(T)>(), v); +} +// There do not seem to be native floating-point versions of these instructions. +HWY_API Mask512<float> MaskFromVec(const Vec512<float> v) { + return Mask512<float>{MaskFromVec(BitCast(Full512<int32_t>(), v)).raw}; +} +HWY_API Mask512<double> MaskFromVec(const Vec512<double> v) { + return Mask512<double>{MaskFromVec(BitCast(Full512<int64_t>(), v)).raw}; +} + +HWY_API Vec512<uint8_t> VecFromMask(const Mask512<uint8_t> v) { + return Vec512<uint8_t>{_mm512_movm_epi8(v.raw)}; +} +HWY_API Vec512<int8_t> VecFromMask(const Mask512<int8_t> v) { + return Vec512<int8_t>{_mm512_movm_epi8(v.raw)}; +} + +HWY_API Vec512<uint16_t> VecFromMask(const Mask512<uint16_t> v) { + return Vec512<uint16_t>{_mm512_movm_epi16(v.raw)}; +} +HWY_API Vec512<int16_t> VecFromMask(const Mask512<int16_t> v) { + return Vec512<int16_t>{_mm512_movm_epi16(v.raw)}; +} + +HWY_API Vec512<uint32_t> VecFromMask(const Mask512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512<int32_t> VecFromMask(const Mask512<int32_t> v) { + return Vec512<int32_t>{_mm512_movm_epi32(v.raw)}; +} +HWY_API Vec512<float> VecFromMask(const Mask512<float> v) { + return Vec512<float>{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; +} + +HWY_API Vec512<uint64_t> VecFromMask(const Mask512<uint64_t> v) { + return Vec512<uint64_t>{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512<int64_t> VecFromMask(const Mask512<int64_t> v) { + return Vec512<int64_t>{_mm512_movm_epi64(v.raw)}; +} +HWY_API Vec512<double> VecFromMask(const Mask512<double> v) { + return Vec512<double>{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; +} + +template <typename T> +HWY_API Vec512<T> VecFromMask(Full512<T> /* tag */, const Mask512<T> v) { + return VecFromMask(v); +} + +// ------------------------------ Mask logical + +namespace detail { + +template <typename T> +HWY_INLINE Mask512<T> Not(hwy::SizeTag<1> /*tag*/, const Mask512<T> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_knot_mask64(m.raw)}; +#else + return Mask512<T>{~m.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Not(hwy::SizeTag<2> /*tag*/, const Mask512<T> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_knot_mask32(m.raw)}; +#else + return Mask512<T>{~m.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Not(hwy::SizeTag<4> /*tag*/, const Mask512<T> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_knot_mask16(m.raw)}; +#else + return Mask512<T>{static_cast<uint16_t>(~m.raw & 0xFFFF)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Not(hwy::SizeTag<8> /*tag*/, const Mask512<T> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_knot_mask8(m.raw)}; +#else + return Mask512<T>{static_cast<uint8_t>(~m.raw & 0xFF)}; +#endif +} + +template <typename T> +HWY_INLINE Mask512<T> And(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kand_mask64(a.raw, b.raw)}; +#else + return Mask512<T>{a.raw & b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> And(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kand_mask32(a.raw, b.raw)}; +#else + return Mask512<T>{a.raw & b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> And(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kand_mask16(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint16_t>(a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> And(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kand_mask8(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint8_t>(a.raw & b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kandn_mask64(a.raw, b.raw)}; +#else + return Mask512<T>{~a.raw & b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask512<T>{~a.raw & b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint16_t>(~a.raw & b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> AndNot(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint8_t>(~a.raw & b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask512<T> Or(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kor_mask64(a.raw, b.raw)}; +#else + return Mask512<T>{a.raw | b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Or(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kor_mask32(a.raw, b.raw)}; +#else + return Mask512<T>{a.raw | b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Or(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kor_mask16(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint16_t>(a.raw | b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Or(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kor_mask8(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint8_t>(a.raw | b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask512<T> Xor(hwy::SizeTag<1> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxor_mask64(a.raw, b.raw)}; +#else + return Mask512<T>{a.raw ^ b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Xor(hwy::SizeTag<2> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask512<T>{a.raw ^ b.raw}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Xor(hwy::SizeTag<4> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint16_t>(a.raw ^ b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> Xor(hwy::SizeTag<8> /*tag*/, const Mask512<T> a, + const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<uint8_t>(a.raw ^ b.raw)}; +#endif +} + +template <typename T> +HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask512<T> a, const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxnor_mask64(a.raw, b.raw)}; +#else + return Mask512<T>{~(a.raw ^ b.raw)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask512<T> a, const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask512<T> a, const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template <typename T> +HWY_INLINE Mask512<T> ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask512<T> a, const Mask512<T> b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512<T>{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask512<T>{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} + +} // namespace detail + +template <typename T> +HWY_API Mask512<T> Not(const Mask512<T> m) { + return detail::Not(hwy::SizeTag<sizeof(T)>(), m); +} + +template <typename T> +HWY_API Mask512<T> And(const Mask512<T> a, Mask512<T> b) { + return detail::And(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask512<T> AndNot(const Mask512<T> a, Mask512<T> b) { + return detail::AndNot(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask512<T> Or(const Mask512<T> a, Mask512<T> b) { + return detail::Or(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask512<T> Xor(const Mask512<T> a, Mask512<T> b) { + return detail::Xor(hwy::SizeTag<sizeof(T)>(), a, b); +} + +template <typename T> +HWY_API Mask512<T> ExclusiveNeither(const Mask512<T> a, Mask512<T> b) { + return detail::ExclusiveNeither(hwy::SizeTag<sizeof(T)>(), a, b); +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec512<int8_t> BroadcastSignBit(const Vec512<int8_t> v) { + return VecFromMask(v < Zero(Full512<int8_t>())); +} + +HWY_API Vec512<int16_t> BroadcastSignBit(const Vec512<int16_t> v) { + return ShiftRight<15>(v); +} + +HWY_API Vec512<int32_t> BroadcastSignBit(const Vec512<int32_t> v) { + return ShiftRight<31>(v); +} + +HWY_API Vec512<int64_t> BroadcastSignBit(const Vec512<int64_t> v) { + return Vec512<int64_t>{_mm512_srai_epi64(v.raw, 63)}; +} + +// ------------------------------ Floating-point classification (Not) + +HWY_API Mask512<float> IsNaN(const Vec512<float> v) { + return Mask512<float>{_mm512_fpclass_ps_mask(v.raw, 0x81)}; +} +HWY_API Mask512<double> IsNaN(const Vec512<double> v) { + return Mask512<double>{_mm512_fpclass_pd_mask(v.raw, 0x81)}; +} + +HWY_API Mask512<float> IsInf(const Vec512<float> v) { + return Mask512<float>{_mm512_fpclass_ps_mask(v.raw, 0x18)}; +} +HWY_API Mask512<double> IsInf(const Vec512<double> v) { + return Mask512<double>{_mm512_fpclass_pd_mask(v.raw, 0x18)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512<float> IsFinite(const Vec512<float> v) { + return Not(Mask512<float>{_mm512_fpclass_ps_mask(v.raw, 0x99)}); +} +HWY_API Mask512<double> IsFinite(const Vec512<double> v) { + return Not(Mask512<double>{_mm512_fpclass_pd_mask(v.raw, 0x99)}); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template <typename T> +HWY_API Vec512<T> Load(Full512<T> /* tag */, const T* HWY_RESTRICT aligned) { + return Vec512<T>{_mm512_load_si512(aligned)}; +} +HWY_API Vec512<float> Load(Full512<float> /* tag */, + const float* HWY_RESTRICT aligned) { + return Vec512<float>{_mm512_load_ps(aligned)}; +} +HWY_API Vec512<double> Load(Full512<double> /* tag */, + const double* HWY_RESTRICT aligned) { + return Vec512<double>{_mm512_load_pd(aligned)}; +} + +template <typename T> +HWY_API Vec512<T> LoadU(Full512<T> /* tag */, const T* HWY_RESTRICT p) { + return Vec512<T>{_mm512_loadu_si512(p)}; +} +HWY_API Vec512<float> LoadU(Full512<float> /* tag */, + const float* HWY_RESTRICT p) { + return Vec512<float>{_mm512_loadu_ps(p)}; +} +HWY_API Vec512<double> LoadU(Full512<double> /* tag */, + const double* HWY_RESTRICT p) { + return Vec512<double>{_mm512_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec512<T> MaskedLoad(Mask512<T> m, Full512<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec512<T>{_mm512_maskz_loadu_epi8(m.raw, p)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> MaskedLoad(Mask512<T> m, Full512<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec512<T>{_mm512_maskz_loadu_epi16(m.raw, p)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> MaskedLoad(Mask512<T> m, Full512<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec512<T>{_mm512_maskz_loadu_epi32(m.raw, p)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> MaskedLoad(Mask512<T> m, Full512<T> /* tag */, + const T* HWY_RESTRICT p) { + return Vec512<T>{_mm512_maskz_loadu_epi64(m.raw, p)}; +} + +HWY_API Vec512<float> MaskedLoad(Mask512<float> m, Full512<float> /* tag */, + const float* HWY_RESTRICT p) { + return Vec512<float>{_mm512_maskz_loadu_ps(m.raw, p)}; +} + +HWY_API Vec512<double> MaskedLoad(Mask512<double> m, Full512<double> /* tag */, + const double* HWY_RESTRICT p) { + return Vec512<double>{_mm512_maskz_loadu_pd(m.raw, p)}; +} + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template <typename T> +HWY_API Vec512<T> LoadDup128(Full512<T> /* tag */, + const T* const HWY_RESTRICT p) { + const auto x4 = LoadU(Full128<T>(), p); + return Vec512<T>{_mm512_broadcast_i32x4(x4.raw)}; +} +HWY_API Vec512<float> LoadDup128(Full512<float> /* tag */, + const float* const HWY_RESTRICT p) { + const __m128 x4 = _mm_loadu_ps(p); + return Vec512<float>{_mm512_broadcast_f32x4(x4)}; +} + +HWY_API Vec512<double> LoadDup128(Full512<double> /* tag */, + const double* const HWY_RESTRICT p) { + const __m128d x2 = _mm_loadu_pd(p); + return Vec512<double>{_mm512_broadcast_f64x2(x2)}; +} + +// ------------------------------ Store + +template <typename T> +HWY_API void Store(const Vec512<T> v, Full512<T> /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Store(const Vec512<float> v, Full512<float> /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_store_ps(aligned, v.raw); +} +HWY_API void Store(const Vec512<double> v, Full512<double> /* tag */, + double* HWY_RESTRICT aligned) { + _mm512_store_pd(aligned, v.raw); +} + +template <typename T> +HWY_API void StoreU(const Vec512<T> v, Full512<T> /* tag */, + T* HWY_RESTRICT p) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); +} +HWY_API void StoreU(const Vec512<float> v, Full512<float> /* tag */, + float* HWY_RESTRICT p) { + _mm512_storeu_ps(p, v.raw); +} +HWY_API void StoreU(const Vec512<double> v, Full512<double>, + double* HWY_RESTRICT p) { + _mm512_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API void BlendedStore(Vec512<T> v, Mask512<T> m, Full512<T> /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi8(p, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API void BlendedStore(Vec512<T> v, Mask512<T> m, Full512<T> /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi16(p, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API void BlendedStore(Vec512<T> v, Mask512<T> m, Full512<T> /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi32(p, m.raw, v.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API void BlendedStore(Vec512<T> v, Mask512<T> m, Full512<T> /* tag */, + T* HWY_RESTRICT p) { + _mm512_mask_storeu_epi64(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec512<float> v, Mask512<float> m, + Full512<float> /* tag */, float* HWY_RESTRICT p) { + _mm512_mask_storeu_ps(p, m.raw, v.raw); +} + +HWY_API void BlendedStore(Vec512<double> v, Mask512<double> m, + Full512<double> /* tag */, double* HWY_RESTRICT p) { + _mm512_mask_storeu_pd(p, m.raw, v.raw); +} + +// ------------------------------ Non-temporal stores + +template <typename T> +HWY_API void Stream(const Vec512<T> v, Full512<T> /* tag */, + T* HWY_RESTRICT aligned) { + _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +HWY_API void Stream(const Vec512<float> v, Full512<float> /* tag */, + float* HWY_RESTRICT aligned) { + _mm512_stream_ps(aligned, v.raw); +} +HWY_API void Stream(const Vec512<double> v, Full512<double>, + double* HWY_RESTRICT aligned) { + _mm512_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +namespace detail { + +template <typename T> +HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512<T> v, + Full512<T> /* tag */, T* HWY_RESTRICT base, + const Vec512<int32_t> offset) { + _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); +} +template <typename T> +HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512<T> v, + Full512<T> /* tag */, T* HWY_RESTRICT base, + const Vec512<int32_t> index) { + _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template <typename T> +HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512<T> v, + Full512<T> /* tag */, T* HWY_RESTRICT base, + const Vec512<int64_t> offset) { + _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); +} +template <typename T> +HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512<T> v, + Full512<T> /* tag */, T* HWY_RESTRICT base, + const Vec512<int64_t> index) { + _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +} // namespace detail + +template <typename T, typename Offset> +HWY_API void ScatterOffset(Vec512<T> v, Full512<T> d, T* HWY_RESTRICT base, + const Vec512<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::ScatterOffset(hwy::SizeTag<sizeof(T)>(), v, d, base, offset); +} +template <typename T, typename Index> +HWY_API void ScatterIndex(Vec512<T> v, Full512<T> d, T* HWY_RESTRICT base, + const Vec512<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::ScatterIndex(hwy::SizeTag<sizeof(T)>(), v, d, base, index); +} + +HWY_API void ScatterOffset(Vec512<float> v, Full512<float> /* tag */, + float* HWY_RESTRICT base, + const Vec512<int32_t> offset) { + _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec512<float> v, Full512<float> /* tag */, + float* HWY_RESTRICT base, + const Vec512<int32_t> index) { + _mm512_i32scatter_ps(base, index.raw, v.raw, 4); +} + +HWY_API void ScatterOffset(Vec512<double> v, Full512<double> /* tag */, + double* HWY_RESTRICT base, + const Vec512<int64_t> offset) { + _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); +} +HWY_API void ScatterIndex(Vec512<double> v, Full512<double> /* tag */, + double* HWY_RESTRICT base, + const Vec512<int64_t> index) { + _mm512_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ Gather + +namespace detail { + +template <typename T> +HWY_INLINE Vec512<T> GatherOffset(hwy::SizeTag<4> /* tag */, + Full512<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec512<int32_t> offset) { + return Vec512<T>{_mm512_i32gather_epi32(offset.raw, base, 1)}; +} +template <typename T> +HWY_INLINE Vec512<T> GatherIndex(hwy::SizeTag<4> /* tag */, + Full512<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec512<int32_t> index) { + return Vec512<T>{_mm512_i32gather_epi32(index.raw, base, 4)}; +} + +template <typename T> +HWY_INLINE Vec512<T> GatherOffset(hwy::SizeTag<8> /* tag */, + Full512<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec512<int64_t> offset) { + return Vec512<T>{_mm512_i64gather_epi64(offset.raw, base, 1)}; +} +template <typename T> +HWY_INLINE Vec512<T> GatherIndex(hwy::SizeTag<8> /* tag */, + Full512<T> /* tag */, + const T* HWY_RESTRICT base, + const Vec512<int64_t> index) { + return Vec512<T>{_mm512_i64gather_epi64(index.raw, base, 8)}; +} + +} // namespace detail + +template <typename T, typename Offset> +HWY_API Vec512<T> GatherOffset(Full512<T> d, const T* HWY_RESTRICT base, + const Vec512<Offset> offset) { + static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); + return detail::GatherOffset(hwy::SizeTag<sizeof(T)>(), d, base, offset); +} +template <typename T, typename Index> +HWY_API Vec512<T> GatherIndex(Full512<T> d, const T* HWY_RESTRICT base, + const Vec512<Index> index) { + static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); + return detail::GatherIndex(hwy::SizeTag<sizeof(T)>(), d, base, index); +} + +HWY_API Vec512<float> GatherOffset(Full512<float> /* tag */, + const float* HWY_RESTRICT base, + const Vec512<int32_t> offset) { + return Vec512<float>{_mm512_i32gather_ps(offset.raw, base, 1)}; +} +HWY_API Vec512<float> GatherIndex(Full512<float> /* tag */, + const float* HWY_RESTRICT base, + const Vec512<int32_t> index) { + return Vec512<float>{_mm512_i32gather_ps(index.raw, base, 4)}; +} + +HWY_API Vec512<double> GatherOffset(Full512<double> /* tag */, + const double* HWY_RESTRICT base, + const Vec512<int64_t> offset) { + return Vec512<double>{_mm512_i64gather_pd(offset.raw, base, 1)}; +} +HWY_API Vec512<double> GatherIndex(Full512<double> /* tag */, + const double* HWY_RESTRICT base, + const Vec512<int64_t> index) { + return Vec512<double>{_mm512_i64gather_pd(index.raw, base, 8)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template <typename T> +HWY_API Vec256<T> LowerHalf(Full256<T> /* tag */, Vec512<T> v) { + return Vec256<T>{_mm512_castsi512_si256(v.raw)}; +} +HWY_API Vec256<float> LowerHalf(Full256<float> /* tag */, Vec512<float> v) { + return Vec256<float>{_mm512_castps512_ps256(v.raw)}; +} +HWY_API Vec256<double> LowerHalf(Full256<double> /* tag */, Vec512<double> v) { + return Vec256<double>{_mm512_castpd512_pd256(v.raw)}; +} + +template <typename T> +HWY_API Vec256<T> LowerHalf(Vec512<T> v) { + return LowerHalf(Full256<T>(), v); +} + +// ------------------------------ UpperHalf + +template <typename T> +HWY_API Vec256<T> UpperHalf(Full256<T> /* tag */, Vec512<T> v) { + return Vec256<T>{_mm512_extracti32x8_epi32(v.raw, 1)}; +} +HWY_API Vec256<float> UpperHalf(Full256<float> /* tag */, Vec512<float> v) { + return Vec256<float>{_mm512_extractf32x8_ps(v.raw, 1)}; +} +HWY_API Vec256<double> UpperHalf(Full256<double> /* tag */, Vec512<double> v) { + return Vec256<double>{_mm512_extractf64x4_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template <typename T> +HWY_API T ExtractLane(const Vec512<T> v, size_t i) { + const Full512<T> d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template <typename T> +HWY_API Vec512<T> InsertLane(const Vec512<T> v, size_t i, T t) { + const Full512<T> d; + HWY_DASSERT(i < Lanes(d)); + alignas(64) T lanes[64 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ GetLane (LowerHalf) +template <typename T> +HWY_API T GetLane(const Vec512<T> v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +template <typename T> +HWY_API Vec512<T> ZeroExtendVector(Full512<T> /* tag */, Vec256<T> lo) { +#if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. + return Vec512<T>{_mm512_zextsi256_si512(lo.raw)}; +#else + return Vec512<T>{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; +#endif +} +HWY_API Vec512<float> ZeroExtendVector(Full512<float> /* tag */, + Vec256<float> lo) { +#if HWY_HAVE_ZEXT + return Vec512<float>{_mm512_zextps256_ps512(lo.raw)}; +#else + return Vec512<float>{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; +#endif +} +HWY_API Vec512<double> ZeroExtendVector(Full512<double> /* tag */, + Vec256<double> lo) { +#if HWY_HAVE_ZEXT + return Vec512<double>{_mm512_zextpd256_pd512(lo.raw)}; +#else + return Vec512<double>{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ Combine + +template <typename T> +HWY_API Vec512<T> Combine(Full512<T> d, Vec256<T> hi, Vec256<T> lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512<T>{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; +} +HWY_API Vec512<float> Combine(Full512<float> d, Vec256<float> hi, + Vec256<float> lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512<float>{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; +} +HWY_API Vec512<double> Combine(Full512<double> d, Vec256<double> hi, + Vec256<double> lo) { + const auto lo512 = ZeroExtendVector(d, lo); + return Vec512<double>{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes + +template <int kBytes, typename T> +HWY_API Vec512<T> ShiftLeftBytes(Full512<T> /* tag */, const Vec512<T> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512<T>{_mm512_bslli_epi128(v.raw, kBytes)}; +} + +template <int kBytes, typename T> +HWY_API Vec512<T> ShiftLeftBytes(const Vec512<T> v) { + return ShiftLeftBytes<kBytes>(Full512<T>(), v); +} + +// ------------------------------ ShiftLeftLanes + +template <int kLanes, typename T> +HWY_API Vec512<T> ShiftLeftLanes(Full512<T> d, const Vec512<T> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftLeftBytes<kLanes * sizeof(T)>(BitCast(d8, v))); +} + +template <int kLanes, typename T> +HWY_API Vec512<T> ShiftLeftLanes(const Vec512<T> v) { + return ShiftLeftLanes<kLanes>(Full512<T>(), v); +} + +// ------------------------------ ShiftRightBytes +template <int kBytes, typename T> +HWY_API Vec512<T> ShiftRightBytes(Full512<T> /* tag */, const Vec512<T> v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return Vec512<T>{_mm512_bsrli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightLanes +template <int kLanes, typename T> +HWY_API Vec512<T> ShiftRightLanes(Full512<T> d, const Vec512<T> v) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, ShiftRightBytes<kLanes * sizeof(T)>(d8, BitCast(d8, v))); +} + +// ------------------------------ CombineShiftRightBytes + +template <int kBytes, typename T, class V = Vec512<T>> +HWY_API V CombineShiftRightBytes(Full512<T> d, V hi, V lo) { + const Repartition<uint8_t, decltype(d)> d8; + return BitCast(d, Vec512<uint8_t>{_mm512_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +// Unsigned +template <int kLane> +HWY_API Vec512<uint16_t> Broadcast(const Vec512<uint16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512<uint16_t>{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512<uint16_t>{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template <int kLane> +HWY_API Vec512<uint32_t> Broadcast(const Vec512<uint32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, perm)}; +} +template <int kLane> +HWY_API Vec512<uint64_t> Broadcast(const Vec512<uint64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512<uint64_t>{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Signed +template <int kLane> +HWY_API Vec512<int16_t> Broadcast(const Vec512<int16_t> v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); + return Vec512<int16_t>{_mm512_unpacklo_epi64(lo, lo)}; + } else { + const __m512i hi = + _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); + return Vec512<int16_t>{_mm512_unpackhi_epi64(hi, hi)}; + } +} +template <int kLane> +HWY_API Vec512<int32_t> Broadcast(const Vec512<int32_t> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, perm)}; +} +template <int kLane> +HWY_API Vec512<int64_t> Broadcast(const Vec512<int64_t> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512<int64_t>{_mm512_shuffle_epi32(v.raw, perm)}; +} + +// Float +template <int kLane> +HWY_API Vec512<float> Broadcast(const Vec512<float> v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, perm)}; +} +template <int kLane> +HWY_API Vec512<double> Broadcast(const Vec512<double> v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); + return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, perm)}; +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec512<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Shuffle2301(const Vec512<T> v) { + return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512<float> Shuffle2301(const Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +namespace detail { + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Shuffle2301(const Vec512<T> a, const Vec512<T> b) { + const Full512<T> d; + const RebindToFloat<decltype(d)> df; + return BitCast( + d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_CDAB)}); +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Shuffle1230(const Vec512<T> a, const Vec512<T> b) { + const Full512<T> d; + const RebindToFloat<decltype(d)> df; + return BitCast( + d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_BCDA)}); +} +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Shuffle3012(const Vec512<T> a, const Vec512<T> b) { + const Full512<T> d; + const RebindToFloat<decltype(d)> df; + return BitCast( + d, Vec512<float>{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_DABC)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec512<uint32_t> Shuffle1032(const Vec512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<int32_t> Shuffle1032(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<float> Shuffle1032(const Vec512<float> v) { + // Shorter encoding than _mm512_permute_ps. + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<uint64_t> Shuffle01(const Vec512<uint64_t> v) { + return Vec512<uint64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<int64_t> Shuffle01(const Vec512<int64_t> v) { + return Vec512<int64_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<double> Shuffle01(const Vec512<double> v) { + // Shorter encoding than _mm512_permute_pd. + return Vec512<double>{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; +} + +// Rotate right 32 bits +HWY_API Vec512<uint32_t> Shuffle0321(const Vec512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512<int32_t> Shuffle0321(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512<float> Shuffle0321(const Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; +} +// Rotate left 32 bits +HWY_API Vec512<uint32_t> Shuffle2103(const Vec512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512<int32_t> Shuffle2103(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512<float> Shuffle2103(const Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; +} + +// Reverse +HWY_API Vec512<uint32_t> Shuffle0123(const Vec512<uint32_t> v) { + return Vec512<uint32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512<int32_t> Shuffle0123(const Vec512<int32_t> v) { + return Vec512<int32_t>{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512<float> Shuffle0123(const Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template <typename T> +struct Indices512 { + __m512i raw; +}; + +template <typename T, typename TI> +HWY_API Indices512<T> IndicesFromVec(Full512<T> /* tag */, Vec512<TI> vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full512<TI> di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast<TI>(64 / sizeof(T)))))); +#endif + return Indices512<T>{vec.raw}; +} + +template <typename T, typename TI> +HWY_API Indices512<T> SetTableIndices(const Full512<T> d, const TI* idx) { + const Rebind<TI, decltype(d)> di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { + return Vec512<T>{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> TableLookupLanes(Vec512<T> v, Indices512<T> idx) { + return Vec512<T>{_mm512_permutexvar_epi64(idx.raw, v.raw)}; +} + +HWY_API Vec512<float> TableLookupLanes(Vec512<float> v, Indices512<float> idx) { + return Vec512<float>{_mm512_permutexvar_ps(idx.raw, v.raw)}; +} + +HWY_API Vec512<double> TableLookupLanes(Vec512<double> v, + Indices512<double> idx) { + return Vec512<double>{_mm512_permutexvar_pd(idx.raw, v.raw)}; +} + +// ------------------------------ Reverse + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> Reverse(Full512<T> d, const Vec512<T> v) { + const RebindToSigned<decltype(d)> di; + alignas(64) constexpr int16_t kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512<int16_t> idx = Load(di, kReverse); + return BitCast(d, Vec512<int16_t>{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Reverse(Full512<T> d, const Vec512<T> v) { + alignas(64) constexpr int32_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> Reverse(Full512<T> d, const Vec512<T> v) { + alignas(64) constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +// ------------------------------ Reverse2 + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> Reverse2(Full512<T> d, const Vec512<T> v) { + const Full512<uint32_t> du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Reverse2(Full512<T> /* tag */, const Vec512<T> v) { + return Shuffle2301(v); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> Reverse2(Full512<T> /* tag */, const Vec512<T> v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> Reverse4(Full512<T> d, const Vec512<T> v) { + const RebindToSigned<decltype(d)> di; + alignas(64) constexpr int16_t kReverse4[32] = { + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, + 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; + const Vec512<int16_t> idx = Load(di, kReverse4); + return BitCast(d, Vec512<int16_t>{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Reverse4(Full512<T> /* tag */, const Vec512<T> v) { + return Shuffle0123(v); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> Reverse4(Full512<T> /* tag */, const Vec512<T> v) { + return Vec512<T>{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +HWY_API Vec512<double> Reverse4(Full512<double> /* tag */, Vec512<double> v) { + return Vec512<double>{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> Reverse8(Full512<T> d, const Vec512<T> v) { + const RebindToSigned<decltype(d)> di; + alignas(64) constexpr int16_t kReverse8[32] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; + const Vec512<int16_t> idx = Load(di, kReverse8); + return BitCast(d, Vec512<int16_t>{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> Reverse8(Full512<T> d, const Vec512<T> v) { + const RebindToSigned<decltype(d)> di; + alignas(64) constexpr int32_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec512<int32_t> idx = Load(di, kReverse8); + return BitCast(d, Vec512<int32_t>{ + _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> Reverse8(Full512<T> d, const Vec512<T> v) { + return Reverse(d, v); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +HWY_API Vec512<uint8_t> InterleaveLower(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> InterleaveLower(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> InterleaveLower(const Vec512<uint32_t> a, + const Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> InterleaveLower(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512<int8_t> InterleaveLower(const Vec512<int8_t> a, + const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> InterleaveLower(const Vec512<int16_t> a, + const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_unpacklo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> InterleaveLower(const Vec512<int32_t> a, + const Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> InterleaveLower(const Vec512<int64_t> a, + const Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512<float> InterleaveLower(const Vec512<float> a, + const Vec512<float> b) { + return Vec512<float>{_mm512_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> InterleaveLower(const Vec512<double> a, + const Vec512<double> b) { + return Vec512<double>{_mm512_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +// All functions inside detail lack the required D parameter. +namespace detail { + +HWY_API Vec512<uint8_t> InterleaveUpper(const Vec512<uint8_t> a, + const Vec512<uint8_t> b) { + return Vec512<uint8_t>{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<uint16_t> InterleaveUpper(const Vec512<uint16_t> a, + const Vec512<uint16_t> b) { + return Vec512<uint16_t>{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<uint32_t> InterleaveUpper(const Vec512<uint32_t> a, + const Vec512<uint32_t> b) { + return Vec512<uint32_t>{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<uint64_t> InterleaveUpper(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + return Vec512<uint64_t>{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512<int8_t> InterleaveUpper(const Vec512<int8_t> a, + const Vec512<int8_t> b) { + return Vec512<int8_t>{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +HWY_API Vec512<int16_t> InterleaveUpper(const Vec512<int16_t> a, + const Vec512<int16_t> b) { + return Vec512<int16_t>{_mm512_unpackhi_epi16(a.raw, b.raw)}; +} +HWY_API Vec512<int32_t> InterleaveUpper(const Vec512<int32_t> a, + const Vec512<int32_t> b) { + return Vec512<int32_t>{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +HWY_API Vec512<int64_t> InterleaveUpper(const Vec512<int64_t> a, + const Vec512<int64_t> b) { + return Vec512<int64_t>{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +HWY_API Vec512<float> InterleaveUpper(const Vec512<float> a, + const Vec512<float> b) { + return Vec512<float>{_mm512_unpackhi_ps(a.raw, b.raw)}; +} +HWY_API Vec512<double> InterleaveUpper(const Vec512<double> a, + const Vec512<double> b) { + return Vec512<double>{_mm512_unpackhi_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template <typename T, class V = Vec512<T>> +HWY_API V InterleaveUpper(Full512<T> /* tag */, V a, V b) { + return detail::InterleaveUpper(a, b); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template <typename T, typename TW = MakeWide<T>> +HWY_API Vec512<TW> ZipLower(Vec512<T> a, Vec512<T> b) { + return BitCast(Full512<TW>(), InterleaveLower(a, b)); +} +template <typename T, typename TW = MakeWide<T>> +HWY_API Vec512<TW> ZipLower(Full512<TW> /* d */, Vec512<T> a, Vec512<T> b) { + return BitCast(Full512<TW>(), InterleaveLower(a, b)); +} + +template <typename T, typename TW = MakeWide<T>> +HWY_API Vec512<TW> ZipUpper(Full512<TW> d, Vec512<T> a, Vec512<T> b) { + return BitCast(Full512<TW>(), InterleaveUpper(d, a, b)); +} + +// ------------------------------ Concat* halves + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template <typename T> +HWY_API Vec512<T> ConcatLowerLower(Full512<T> /* tag */, const Vec512<T> hi, + const Vec512<T> lo) { + return Vec512<T>{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +HWY_API Vec512<float> ConcatLowerLower(Full512<float> /* tag */, + const Vec512<float> hi, + const Vec512<float> lo) { + return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +HWY_API Vec512<double> ConcatLowerLower(Full512<double> /* tag */, + const Vec512<double> hi, + const Vec512<double> lo) { + return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template <typename T> +HWY_API Vec512<T> ConcatUpperUpper(Full512<T> /* tag */, const Vec512<T> hi, + const Vec512<T> lo) { + return Vec512<T>{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +HWY_API Vec512<float> ConcatUpperUpper(Full512<float> /* tag */, + const Vec512<float> hi, + const Vec512<float> lo) { + return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +HWY_API Vec512<double> ConcatUpperUpper(Full512<double> /* tag */, + const Vec512<double> hi, + const Vec512<double> lo) { + return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template <typename T> +HWY_API Vec512<T> ConcatLowerUpper(Full512<T> /* tag */, const Vec512<T> hi, + const Vec512<T> lo) { + return Vec512<T>{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<float> ConcatLowerUpper(Full512<float> /* tag */, + const Vec512<float> hi, + const Vec512<float> lo) { + return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512<double> ConcatLowerUpper(Full512<double> /* tag */, + const Vec512<double> hi, + const Vec512<double> lo) { + return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template <typename T> +HWY_API Vec512<T> ConcatUpperLower(Full512<T> /* tag */, const Vec512<T> hi, + const Vec512<T> lo) { + // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks + // are efficiently loaded from 32-bit regs. + const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); + return Vec512<T>{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; +} +HWY_API Vec512<float> ConcatUpperLower(Full512<float> /* tag */, + const Vec512<float> hi, + const Vec512<float> lo) { + const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); + return Vec512<float>{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; +} +HWY_API Vec512<double> ConcatUpperLower(Full512<double> /* tag */, + const Vec512<double> hi, + const Vec512<double> lo) { + const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); + return Vec512<double>{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; +} + +// ------------------------------ ConcatOdd + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec512<T> ConcatOdd(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[64] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, + 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, + 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, + 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, + 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; + return BitCast(d, + Vec512<uint8_t>{_mm512_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide<decltype(du)> dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec512<uint16_t> uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec512<uint16_t> uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec512<uint64_t> u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512<uint64_t> du64; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> ConcatOdd(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint16_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast(d, Vec512<uint16_t>{_mm512_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> ConcatOdd(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast(d, Vec512<uint32_t>{_mm512_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +} + +HWY_API Vec512<float> ConcatOdd(Full512<float> d, Vec512<float> hi, + Vec512<float> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31}; + return Vec512<float>{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, hi.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> ConcatOdd(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast(d, Vec512<uint64_t>{_mm512_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +} + +HWY_API Vec512<double> ConcatOdd(Full512<double> d, Vec512<double> hi, + Vec512<double> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return Vec512<double>{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +} + +// ------------------------------ ConcatEven + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API Vec512<T> ConcatEven(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; +#if HWY_TARGET == HWY_AVX3_DL + alignas(64) constexpr uint8_t kIdx[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, + 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, + 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, + 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, + 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; + return BitCast(d, + Vec512<uint32_t>{_mm512_mask2_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask64{0xFFFFFFFFFFFFFFFFull}, BitCast(du, hi).raw)}); +#else + const RepartitionToWide<decltype(du)> dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec512<uint16_t> mask = Set(dw, 0x00FF); + const Vec512<uint16_t> uH = And(BitCast(dw, hi), mask); + const Vec512<uint16_t> uL = And(BitCast(dw, lo), mask); + const Vec512<uint64_t> u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512<uint64_t> du64; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template <typename T, HWY_IF_LANE_SIZE(T, 2)> +HWY_API Vec512<T> ConcatEven(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint16_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast(d, Vec512<uint32_t>{_mm512_mask2_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask32{0xFFFFFFFFu}, BitCast(du, hi).raw)}); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> ConcatEven(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast(d, Vec512<uint32_t>{_mm512_mask2_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, BitCast(du, hi).raw)}); +} + +HWY_API Vec512<float> ConcatEven(Full512<float> d, Vec512<float> hi, + Vec512<float> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + return Vec512<float>{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, + __mmask16{0xFFFF}, hi.raw)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> ConcatEven(Full512<T> d, Vec512<T> hi, Vec512<T> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast(d, Vec512<uint64_t>{_mm512_mask2_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, + BitCast(du, hi).raw)}); +} + +HWY_API Vec512<double> ConcatEven(Full512<double> d, Vec512<double> hi, + Vec512<double> lo) { + const RebindToUnsigned<decltype(d)> du; + alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return Vec512<double>{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, + __mmask8{0xFF}, hi.raw)}; +} + +// ------------------------------ DupEven (InterleaveLower) + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> DupEven(Vec512<T> v) { + return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; +} +HWY_API Vec512<float> DupEven(Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> DupEven(const Vec512<T> v) { + return InterleaveLower(Full512<T>(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template <typename T, HWY_IF_LANE_SIZE(T, 4)> +HWY_API Vec512<T> DupOdd(Vec512<T> v) { + return Vec512<T>{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; +} +HWY_API Vec512<float> DupOdd(Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> DupOdd(const Vec512<T> v) { + return InterleaveUpper(Full512<T>(), v, v); +} + +// ------------------------------ OddEven + +template <typename T> +HWY_API Vec512<T> OddEven(const Vec512<T> a, const Vec512<T> b) { + constexpr size_t s = sizeof(T); + constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; + return IfThenElse(Mask512<T>{0x5555555555555555ull >> shift}, b, a); +} + +// ------------------------------ OddEvenBlocks + +template <typename T> +HWY_API Vec512<T> OddEvenBlocks(Vec512<T> odd, Vec512<T> even) { + return Vec512<T>{_mm512_mask_blend_epi64(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +HWY_API Vec512<float> OddEvenBlocks(Vec512<float> odd, Vec512<float> even) { + return Vec512<float>{ + _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; +} + +HWY_API Vec512<double> OddEvenBlocks(Vec512<double> odd, Vec512<double> even) { + return Vec512<double>{ + _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +// ------------------------------ SwapAdjacentBlocks + +template <typename T> +HWY_API Vec512<T> SwapAdjacentBlocks(Vec512<T> v) { + return Vec512<T>{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512<float> SwapAdjacentBlocks(Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512<double> SwapAdjacentBlocks(Vec512<double> v) { + return Vec512<double>{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// ------------------------------ ReverseBlocks + +template <typename T> +HWY_API Vec512<T> ReverseBlocks(Full512<T> /* tag */, Vec512<T> v) { + return Vec512<T>{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512<float> ReverseBlocks(Full512<float> /* tag */, Vec512<float> v) { + return Vec512<float>{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512<double> ReverseBlocks(Full512<double> /* tag */, + Vec512<double> v) { + return Vec512<double>{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template <typename T, typename TI> +HWY_API Vec512<TI> TableLookupBytes(Vec512<T> bytes, Vec512<TI> indices) { + return Vec512<TI>{_mm512_shuffle_epi8(bytes.raw, indices.raw)}; +} + +// Partial index vector +template <typename T, typename TI, size_t NI> +HWY_API Vec128<TI, NI> TableLookupBytes(Vec512<T> bytes, Vec128<TI, NI> from) { + const Full512<TI> d512; + const Half<decltype(d512)> d256; + const Half<decltype(d256)> d128; + // First expand to full 128, then 256, then 512. + const Vec128<TI> from_full{from.raw}; + const auto from_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); + const auto tbl_full = TableLookupBytes(bytes, from_512); + // Shrink to 256, then 128, then partial. + return Vec128<TI, NI>{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; +} +template <typename T, typename TI> +HWY_API Vec256<TI> TableLookupBytes(Vec512<T> bytes, Vec256<TI> from) { + const auto from_512 = ZeroExtendVector(Full512<TI>(), from); + return LowerHalf(Full256<TI>(), TableLookupBytes(bytes, from_512)); +} + +// Partial table vector +template <typename T, size_t N, typename TI> +HWY_API Vec512<TI> TableLookupBytes(Vec128<T, N> bytes, Vec512<TI> from) { + const Full512<TI> d512; + const Half<decltype(d512)> d256; + const Half<decltype(d256)> d128; + // First expand to full 128, then 256, then 512. + const Vec128<T> bytes_full{bytes.raw}; + const auto bytes_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); + return TableLookupBytes(bytes_512, from); +} +template <typename T, typename TI> +HWY_API Vec512<TI> TableLookupBytes(Vec256<T> bytes, Vec512<TI> from) { + const auto bytes_512 = ZeroExtendVector(Full512<T>(), bytes); + return TableLookupBytes(bytes_512, from); +} + +// Partial both are handled by x86_128/256. + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +HWY_API Vec512<uint16_t> PromoteTo(Full512<uint16_t> /* tag */, + Vec256<uint8_t> v) { + return Vec512<uint16_t>{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512<uint32_t> PromoteTo(Full512<uint32_t> /* tag */, + Vec128<uint8_t> v) { + return Vec512<uint32_t>{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512<int16_t> PromoteTo(Full512<int16_t> /* tag */, + Vec256<uint8_t> v) { + return Vec512<int16_t>{_mm512_cvtepu8_epi16(v.raw)}; +} +HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, + Vec128<uint8_t> v) { + return Vec512<int32_t>{_mm512_cvtepu8_epi32(v.raw)}; +} +HWY_API Vec512<uint32_t> PromoteTo(Full512<uint32_t> /* tag */, + Vec256<uint16_t> v) { + return Vec512<uint32_t>{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, + Vec256<uint16_t> v) { + return Vec512<int32_t>{_mm512_cvtepu16_epi32(v.raw)}; +} +HWY_API Vec512<uint64_t> PromoteTo(Full512<uint64_t> /* tag */, + Vec256<uint32_t> v) { + return Vec512<uint64_t>{_mm512_cvtepu32_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +HWY_API Vec512<int16_t> PromoteTo(Full512<int16_t> /* tag */, + Vec256<int8_t> v) { + return Vec512<int16_t>{_mm512_cvtepi8_epi16(v.raw)}; +} +HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, + Vec128<int8_t> v) { + return Vec512<int32_t>{_mm512_cvtepi8_epi32(v.raw)}; +} +HWY_API Vec512<int32_t> PromoteTo(Full512<int32_t> /* tag */, + Vec256<int16_t> v) { + return Vec512<int32_t>{_mm512_cvtepi16_epi32(v.raw)}; +} +HWY_API Vec512<int64_t> PromoteTo(Full512<int64_t> /* tag */, + Vec256<int32_t> v) { + return Vec512<int64_t>{_mm512_cvtepi32_epi64(v.raw)}; +} + +// Float +HWY_API Vec512<float> PromoteTo(Full512<float> /* tag */, + const Vec256<float16_t> v) { + return Vec512<float>{_mm512_cvtph_ps(v.raw)}; +} + +HWY_API Vec512<float> PromoteTo(Full512<float> df32, + const Vec256<bfloat16_t> v) { + const Rebind<uint16_t, decltype(df32)> du16; + const RebindToSigned<decltype(df32)> di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +HWY_API Vec512<double> PromoteTo(Full512<double> /* tag */, Vec256<float> v) { + return Vec512<double>{_mm512_cvtps_pd(v.raw)}; +} + +HWY_API Vec512<double> PromoteTo(Full512<double> /* tag */, Vec256<int32_t> v) { + return Vec512<double>{_mm512_cvtepi32_pd(v.raw)}; +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +HWY_API Vec256<uint16_t> DemoteTo(Full256<uint16_t> /* tag */, + const Vec512<int32_t> v) { + const Vec512<uint16_t> u16{_mm512_packus_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512<uint64_t>(), kLanes); + const Vec512<uint16_t> even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256<int16_t> DemoteTo(Full256<int16_t> /* tag */, + const Vec512<int32_t> v) { + const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512<uint64_t>(), kLanes); + const Vec512<int16_t> even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128<uint8_t, 16> DemoteTo(Full128<uint8_t> /* tag */, + const Vec512<int32_t> v) { + const Vec512<uint16_t> u16{_mm512_packus_epi32(v.raw, v.raw)}; + // packus treats the input as signed; we want unsigned. Clear the MSB to get + // unsigned saturation to u8. + const Vec512<int16_t> i16{ + _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; + const Vec512<uint8_t> u8{_mm512_packus_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512<uint32_t>(), kLanes); + const Vec512<uint8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256<uint8_t> DemoteTo(Full256<uint8_t> /* tag */, + const Vec512<int16_t> v) { + const Vec512<uint8_t> u8{_mm512_packus_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512<uint64_t>(), kLanes); + const Vec512<uint8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128<int8_t, 16> DemoteTo(Full128<int8_t> /* tag */, + const Vec512<int32_t> v) { + const Vec512<int16_t> i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512<int8_t> i8{_mm512_packs_epi16(i16.raw, i16.raw)}; + + alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, + 0, 4, 8, 12, 0, 4, 8, 12}; + const auto idx32 = LoadDup128(Full512<uint32_t>(), kLanes); + const Vec512<int8_t> fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +HWY_API Vec256<int8_t> DemoteTo(Full256<int8_t> /* tag */, + const Vec512<int16_t> v) { + const Vec512<int8_t> u8{_mm512_packs_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(Full512<uint64_t>(), kLanes); + const Vec512<int8_t> even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +HWY_API Vec256<float16_t> DemoteTo(Full256<float16_t> /* tag */, + const Vec512<float> v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Vec256<float16_t>{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Vec256<bfloat16_t> DemoteTo(Full256<bfloat16_t> dbf16, + const Vec512<float> v) { + // TODO(janwas): _mm512_cvtneps_pbh once we have avx512bf16. + const Rebind<int32_t, decltype(dbf16)> di32; + const Rebind<uint32_t, decltype(dbf16)> du32; // for logical shift right + const Rebind<uint16_t, decltype(dbf16)> du16; + const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); + return BitCast(dbf16, DemoteTo(du16, bits_in_32)); +} + +HWY_API Vec512<bfloat16_t> ReorderDemote2To(Full512<bfloat16_t> dbf16, + Vec512<float> a, Vec512<float> b) { + // TODO(janwas): _mm512_cvtne2ps_pbh once we have avx512bf16. + const RebindToUnsigned<decltype(dbf16)> du16; + const Repartition<uint32_t, decltype(dbf16)> du32; + const Vec512<uint32_t> b_in_even = ShiftRight<16>(BitCast(du32, b)); + return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); +} + +HWY_API Vec512<int16_t> ReorderDemote2To(Full512<int16_t> /*d16*/, + Vec512<int32_t> a, Vec512<int32_t> b) { + return Vec512<int16_t>{_mm512_packs_epi32(a.raw, b.raw)}; +} + +HWY_API Vec256<float> DemoteTo(Full256<float> /* tag */, + const Vec512<double> v) { + return Vec256<float>{_mm512_cvtpd_ps(v.raw)}; +} + +HWY_API Vec256<int32_t> DemoteTo(Full256<int32_t> /* tag */, + const Vec512<double> v) { + const auto clamped = detail::ClampF64ToI32Max(Full512<double>(), v); + return Vec256<int32_t>{_mm512_cvttpd_epi32(clamped.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128<uint8_t, 16> U8FromU32(const Vec512<uint32_t> v) { + const Full512<uint32_t> d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512<uint8_t> bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +// ------------------------------ Truncations + +HWY_API Vec128<uint8_t, 8> TruncateTo(Simd<uint8_t, 8, 0> d, + const Vec512<uint64_t> v) { +#if HWY_TARGET == HWY_AVX3_DL + (void)d; + const Full512<uint8_t> d8; + alignas(16) static constexpr uint8_t k8From64[16] = { + 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56}; + const Vec512<uint8_t> bytes{ + _mm512_permutexvar_epi8(LoadDup128(d8, k8From64).raw, v.raw)}; + return LowerHalf(LowerHalf(LowerHalf(bytes))); +#else + const Full512<uint32_t> d32; + alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512<uint32_t> even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return TruncateTo(d, LowerHalf(even)); +#endif +} + +HWY_API Vec128<uint16_t, 8> TruncateTo(Simd<uint16_t, 8, 0> /* tag */, + const Vec512<uint64_t> v) { + const Full512<uint16_t> d16; + alignas(16) static constexpr uint16_t k16From64[8] = { + 0, 4, 8, 12, 16, 20, 24, 28}; + const Vec512<uint16_t> bytes{ + _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +HWY_API Vec256<uint32_t> TruncateTo(Simd<uint32_t, 8, 0> /* tag */, + const Vec512<uint64_t> v) { + const Full512<uint32_t> d32; + alignas(64) constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512<uint32_t> even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return LowerHalf(even); +} + +HWY_API Vec128<uint8_t, 16> TruncateTo(Simd<uint8_t, 16, 0> /* tag */, + const Vec512<uint32_t> v) { +#if HWY_TARGET == HWY_AVX3_DL + const Full512<uint8_t> d8; + alignas(16) static constexpr uint8_t k8From32[16] = { + 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + const Vec512<uint8_t> bytes{ + _mm512_permutexvar_epi32(LoadDup128(d8, k8From32).raw, v.raw)}; +#else + const Full512<uint32_t> d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, + ~0u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); + // Gather the lowest 4 bytes of 4 128-bit blocks. + alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; + const Vec512<uint8_t> bytes{ + _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(LowerHalf(bytes)); +} + +HWY_API Vec256<uint16_t> TruncateTo(Simd<uint16_t, 16, 0> /* tag */, + const Vec512<uint32_t> v) { + const Full512<uint16_t> d16; + alignas(64) static constexpr uint16_t k16From32[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + const Vec512<uint16_t> bytes{ + _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; + return LowerHalf(bytes); +} + +HWY_API Vec256<uint8_t> TruncateTo(Simd<uint8_t, 32, 0> /* tag */, + const Vec512<uint16_t> v) { +#if HWY_TARGET == HWY_AVX3_DL + const Full512<uint8_t> d8; + alignas(64) static constexpr uint8_t k8From16[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + const Vec512<uint8_t> bytes{ + _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; +#else + const Full512<uint32_t> d32; + alignas(16) static constexpr uint32_t k16From32[4] = { + 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u}; + const auto quads = TableLookupBytes(v, LoadDup128(d32, k16From32)); + alignas(64) static constexpr uint32_t kIndex32[16] = { + 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; + const Vec512<uint8_t> bytes{ + _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(bytes); +} + +// ------------------------------ Convert integer <=> floating point + +HWY_API Vec512<float> ConvertTo(Full512<float> /* tag */, + const Vec512<int32_t> v) { + return Vec512<float>{_mm512_cvtepi32_ps(v.raw)}; +} + +HWY_API Vec512<double> ConvertTo(Full512<double> /* tag */, + const Vec512<int64_t> v) { + return Vec512<double>{_mm512_cvtepi64_pd(v.raw)}; +} + +HWY_API Vec512<float> ConvertTo(Full512<float> /* tag*/, + const Vec512<uint32_t> v) { + return Vec512<float>{_mm512_cvtepu32_ps(v.raw)}; +} + +HWY_API Vec512<double> ConvertTo(Full512<double> /* tag*/, + const Vec512<uint64_t> v) { + return Vec512<double>{_mm512_cvtepu64_pd(v.raw)}; +} + +// Truncates (rounds toward zero). +HWY_API Vec512<int32_t> ConvertTo(Full512<int32_t> d, const Vec512<float> v) { + return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); +} +HWY_API Vec512<int64_t> ConvertTo(Full512<int64_t> di, const Vec512<double> v) { + return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); +} + +HWY_API Vec512<int32_t> NearestInt(const Vec512<float> v) { + const Full512<int32_t> di; + return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec512<uint8_t> AESRound(Vec512<uint8_t> state, + Vec512<uint8_t> round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512<uint8_t>{_mm512_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full512<uint8_t> d; + const Half<decltype(d)> d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512<uint8_t> AESLastRound(Vec512<uint8_t> state, + Vec512<uint8_t> round_key) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512<uint8_t>{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full512<uint8_t> d; + const Half<decltype(d)> d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512<uint64_t> CLMulLower(Vec512<uint64_t> va, Vec512<uint64_t> vb) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512<uint64_t>{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const Full512<uint64_t> d; + const Full128<uint64_t> d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +HWY_API Vec512<uint64_t> CLMulUpper(Vec512<uint64_t> va, Vec512<uint64_t> vb) { +#if HWY_TARGET == HWY_AVX3_DL + return Vec512<uint64_t>{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const Full512<uint64_t> d; + const Full128<uint64_t> d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// Returns a vector with lane i=[0, N) set to "first" + i. +template <typename T, typename T2> +Vec512<T> Iota(const Full512<T> d, const T2 first) { + HWY_ALIGN T lanes[64 / sizeof(T)]; + for (size_t i = 0; i < 64 / sizeof(T); ++i) { + lanes[i] = + AddWithWraparound(hwy::IsFloatTag<T>(), static_cast<T>(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template <typename T> +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} + +} // namespace detail + +template <typename T> +HWY_API bool AllFalse(const Full512<T> /* tag */, const Mask512<T> mask) { + return detail::AllFalse(hwy::SizeTag<sizeof(T)>(), mask); +} + +namespace detail { + +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFFFFFFFFFull; +#endif +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFull; +#endif +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFull; +#endif +} +template <typename T> +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512<T> mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFull; +#endif +} + +} // namespace detail + +template <typename T> +HWY_API bool AllTrue(const Full512<T> /* tag */, const Mask512<T> mask) { + return detail::AllTrue(hwy::SizeTag<sizeof(T)>(), mask); +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template <typename T> +HWY_API Mask512<T> LoadMaskBits(const Full512<T> /* tag */, + const uint8_t* HWY_RESTRICT bits) { + Mask512<T> mask; + CopyBytes<8 / sizeof(T)>(bits, &mask.raw); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return mask; +} + +// `p` points to at least 8 writable bytes. +template <typename T> +HWY_API size_t StoreMaskBits(const Full512<T> /* tag */, const Mask512<T> mask, + uint8_t* bits) { + const size_t kNumBytes = 8 / sizeof(T); + CopyBytes<kNumBytes>(&mask.raw, bits); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return kNumBytes; +} + +template <typename T> +HWY_API size_t CountTrue(const Full512<T> /* tag */, const Mask512<T> mask) { + return PopCount(static_cast<uint64_t>(mask.raw)); +} + +template <typename T, HWY_IF_NOT_LANE_SIZE(T, 1)> +HWY_API size_t FindKnownFirstTrue(const Full512<T> /* tag */, + const Mask512<T> mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 1)> +HWY_API size_t FindKnownFirstTrue(const Full512<T> /* tag */, + const Mask512<T> mask) { + return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); +} + +template <typename T> +HWY_API intptr_t FindFirstTrue(const Full512<T> d, const Mask512<T> mask) { + return mask.raw ? static_cast<intptr_t>(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +// ------------------------------ Compress + +// Always implement 8-bit here even if we lack VBMI2 because we can do better +// than generic_ops (8 at a time) via the native 32-bit compress (16 at a time). +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +namespace detail { + +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 +template <size_t N> +HWY_INLINE Vec128<uint8_t, N> NativeCompress(const Vec128<uint8_t, N> v, + const Mask128<uint8_t, N> mask) { + return Vec128<uint8_t, N>{_mm_maskz_compress_epi8(mask.raw, v.raw)}; +} +HWY_INLINE Vec256<uint8_t> NativeCompress(const Vec256<uint8_t> v, + const Mask256<uint8_t> mask) { + return Vec256<uint8_t>{_mm256_maskz_compress_epi8(mask.raw, v.raw)}; +} +HWY_INLINE Vec512<uint8_t> NativeCompress(const Vec512<uint8_t> v, + const Mask512<uint8_t> mask) { + return Vec512<uint8_t>{_mm512_maskz_compress_epi8(mask.raw, v.raw)}; +} + +template <size_t N> +HWY_INLINE Vec128<uint16_t, N> NativeCompress(const Vec128<uint16_t, N> v, + const Mask128<uint16_t, N> mask) { + return Vec128<uint16_t, N>{_mm_maskz_compress_epi16(mask.raw, v.raw)}; +} +HWY_INLINE Vec256<uint16_t> NativeCompress(const Vec256<uint16_t> v, + const Mask256<uint16_t> mask) { + return Vec256<uint16_t>{_mm256_maskz_compress_epi16(mask.raw, v.raw)}; +} +HWY_INLINE Vec512<uint16_t> NativeCompress(const Vec512<uint16_t> v, + const Mask512<uint16_t> mask) { + return Vec512<uint16_t>{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; +} + +template <size_t N> +HWY_INLINE void NativeCompressStore(Vec128<uint8_t, N> v, + Mask128<uint8_t, N> mask, + Simd<uint8_t, N, 0> /* d */, + uint8_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256<uint8_t> v, Mask256<uint8_t> mask, + Full256<uint8_t> /* d */, + uint8_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512<uint8_t> v, Mask512<uint8_t> mask, + Full512<uint8_t> /* d */, + uint8_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} + +template <size_t N> +HWY_INLINE void NativeCompressStore(Vec128<uint16_t, N> v, + Mask128<uint16_t, N> mask, + Simd<uint16_t, N, 0> /* d */, + uint16_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256<uint16_t> v, Mask256<uint16_t> mask, + Full256<uint16_t> /* d */, + uint16_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512<uint16_t> v, Mask512<uint16_t> mask, + Full512<uint16_t> /* d */, + uint16_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} + +#endif // HWY_TARGET == HWY_AVX3_DL + +template <size_t N> +HWY_INLINE Vec128<uint32_t, N> NativeCompress(const Vec128<uint32_t, N> v, + const Mask128<uint32_t, N> mask) { + return Vec128<uint32_t, N>{_mm_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_INLINE Vec256<uint32_t> NativeCompress(Vec256<uint32_t> v, + Mask256<uint32_t> mask) { + return Vec256<uint32_t>{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_INLINE Vec512<uint32_t> NativeCompress(Vec512<uint32_t> v, + Mask512<uint32_t> mask) { + return Vec512<uint32_t>{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} +// We use table-based compress for 64-bit lanes, see CompressIsPartition. + +template <size_t N> +HWY_INLINE void NativeCompressStore(Vec128<uint32_t, N> v, + Mask128<uint32_t, N> mask, + Simd<uint32_t, N, 0> /* d */, + uint32_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256<uint32_t> v, Mask256<uint32_t> mask, + Full256<uint32_t> /* d */, + uint32_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512<uint32_t> v, Mask512<uint32_t> mask, + Full512<uint32_t> /* d */, + uint32_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} + +template <size_t N> +HWY_INLINE void NativeCompressStore(Vec128<uint64_t, N> v, + Mask128<uint64_t, N> mask, + Simd<uint64_t, N, 0> /* d */, + uint64_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256<uint64_t> v, Mask256<uint64_t> mask, + Full256<uint64_t> /* d */, + uint64_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec512<uint64_t> v, Mask512<uint64_t> mask, + Full512<uint64_t> /* d */, + uint64_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} + +// For u8x16 and <= u16x16 we can avoid store+load for Compress because there is +// only a single compressed vector (u32x16). Other EmuCompress are implemented +// after the EmuCompressStore they build upon. +template <size_t N> +HWY_INLINE Vec128<uint8_t, N> EmuCompress(Vec128<uint8_t, N> v, + Mask128<uint8_t, N> mask) { + const Simd<uint8_t, N, 0> d; + const Rebind<uint32_t, decltype(d)> d32; + const auto v0 = PromoteTo(d32, v); + + const uint64_t mask_bits{mask.raw}; + // Mask type is __mmask16 if v is full 128, else __mmask8. + using M32 = MFromD<decltype(d32)>; + const M32 m0{static_cast<typename M32::Raw>(mask_bits)}; + return TruncateTo(d, Compress(v0, m0)); +} + +template <size_t N> +HWY_INLINE Vec128<uint16_t, N> EmuCompress(Vec128<uint16_t, N> v, + Mask128<uint16_t, N> mask) { + const Simd<uint16_t, N, 0> d; + const Rebind<int32_t, decltype(d)> di32; + const RebindToUnsigned<decltype(di32)> du32; + const MFromD<decltype(du32)> mask32{static_cast<__mmask8>(mask.raw)}; + // DemoteTo is 2 ops, but likely lower latency than TruncateTo on SKX. + // Only i32 -> u16 is supported, whereas NativeCompress expects u32. + const VFromD<decltype(du32)> v32 = BitCast(du32, PromoteTo(di32, v)); + return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); +} + +HWY_INLINE Vec256<uint16_t> EmuCompress(Vec256<uint16_t> v, + Mask256<uint16_t> mask) { + const Full256<uint16_t> d; + const Rebind<int32_t, decltype(d)> di32; + const RebindToUnsigned<decltype(di32)> du32; + const Mask512<uint32_t> mask32{static_cast<__mmask16>(mask.raw)}; + const Vec512<uint32_t> v32 = BitCast(du32, PromoteTo(di32, v)); + return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); +} + +// See above - small-vector EmuCompressStore are implemented via EmuCompress. +template <typename T, size_t N> +HWY_INLINE void EmuCompressStore(Vec128<T, N> v, Mask128<T, N> mask, + Simd<T, N, 0> d, T* HWY_RESTRICT unaligned) { + StoreU(EmuCompress(v, mask), d, unaligned); +} + +HWY_INLINE void EmuCompressStore(Vec256<uint16_t> v, Mask256<uint16_t> mask, + Full256<uint16_t> d, + uint16_t* HWY_RESTRICT unaligned) { + StoreU(EmuCompress(v, mask), d, unaligned); +} + +// Main emulation logic for wider vector, starting with EmuCompressStore because +// it is most convenient to merge pieces using memory (concatenating vectors at +// byte offsets is difficult). +HWY_INLINE void EmuCompressStore(Vec256<uint8_t> v, Mask256<uint8_t> mask, + Full256<uint8_t> d, + uint8_t* HWY_RESTRICT unaligned) { + const uint64_t mask_bits{mask.raw}; + const Half<decltype(d)> dh; + const Rebind<uint32_t, decltype(dh)> d32; + const Vec512<uint32_t> v0 = PromoteTo(d32, LowerHalf(v)); + const Vec512<uint32_t> v1 = PromoteTo(d32, UpperHalf(dh, v)); + const Mask512<uint32_t> m0{static_cast<__mmask16>(mask_bits & 0xFFFFu)}; + const Mask512<uint32_t> m1{static_cast<__mmask16>(mask_bits >> 16)}; + const Vec128<uint8_t> c0 = TruncateTo(dh, NativeCompress(v0, m0)); + const Vec128<uint8_t> c1 = TruncateTo(dh, NativeCompress(v1, m1)); + uint8_t* HWY_RESTRICT pos = unaligned; + StoreU(c0, dh, pos); + StoreU(c1, dh, pos + CountTrue(d32, m0)); +} + +HWY_INLINE void EmuCompressStore(Vec512<uint8_t> v, Mask512<uint8_t> mask, + Full512<uint8_t> d, + uint8_t* HWY_RESTRICT unaligned) { + const uint64_t mask_bits{mask.raw}; + const Half<Half<decltype(d)>> dq; + const Rebind<uint32_t, decltype(dq)> d32; + HWY_ALIGN uint8_t lanes[64]; + Store(v, d, lanes); + const Vec512<uint32_t> v0 = PromoteTo(d32, LowerHalf(LowerHalf(v))); + const Vec512<uint32_t> v1 = PromoteTo(d32, Load(dq, lanes + 16)); + const Vec512<uint32_t> v2 = PromoteTo(d32, Load(dq, lanes + 32)); + const Vec512<uint32_t> v3 = PromoteTo(d32, Load(dq, lanes + 48)); + const Mask512<uint32_t> m0{static_cast<__mmask16>(mask_bits & 0xFFFFu)}; + const Mask512<uint32_t> m1{ + static_cast<uint16_t>((mask_bits >> 16) & 0xFFFFu)}; + const Mask512<uint32_t> m2{ + static_cast<uint16_t>((mask_bits >> 32) & 0xFFFFu)}; + const Mask512<uint32_t> m3{static_cast<__mmask16>(mask_bits >> 48)}; + const Vec128<uint8_t> c0 = TruncateTo(dq, NativeCompress(v0, m0)); + const Vec128<uint8_t> c1 = TruncateTo(dq, NativeCompress(v1, m1)); + const Vec128<uint8_t> c2 = TruncateTo(dq, NativeCompress(v2, m2)); + const Vec128<uint8_t> c3 = TruncateTo(dq, NativeCompress(v3, m3)); + uint8_t* HWY_RESTRICT pos = unaligned; + StoreU(c0, dq, pos); + pos += CountTrue(d32, m0); + StoreU(c1, dq, pos); + pos += CountTrue(d32, m1); + StoreU(c2, dq, pos); + pos += CountTrue(d32, m2); + StoreU(c3, dq, pos); +} + +HWY_INLINE void EmuCompressStore(Vec512<uint16_t> v, Mask512<uint16_t> mask, + Full512<uint16_t> d, + uint16_t* HWY_RESTRICT unaligned) { + const Repartition<int32_t, decltype(d)> di32; + const RebindToUnsigned<decltype(di32)> du32; + const Half<decltype(d)> dh; + const Vec512<uint32_t> promoted0 = + BitCast(du32, PromoteTo(di32, LowerHalf(dh, v))); + const Vec512<uint32_t> promoted1 = + BitCast(du32, PromoteTo(di32, UpperHalf(dh, v))); + + const uint64_t mask_bits{mask.raw}; + const uint64_t maskL = mask_bits & 0xFFFF; + const uint64_t maskH = mask_bits >> 16; + const Mask512<uint32_t> mask0{static_cast<__mmask16>(maskL)}; + const Mask512<uint32_t> mask1{static_cast<__mmask16>(maskH)}; + const Vec512<uint32_t> compressed0 = NativeCompress(promoted0, mask0); + const Vec512<uint32_t> compressed1 = NativeCompress(promoted1, mask1); + + const Vec256<uint16_t> demoted0 = DemoteTo(dh, BitCast(di32, compressed0)); + const Vec256<uint16_t> demoted1 = DemoteTo(dh, BitCast(di32, compressed1)); + + // Store 256-bit halves + StoreU(demoted0, dh, unaligned); + StoreU(demoted1, dh, unaligned + PopCount(maskL)); +} + +// Finally, the remaining EmuCompress for wide vectors, using EmuCompressStore. +template <typename T> // 1 or 2 bytes +HWY_INLINE Vec512<T> EmuCompress(Vec512<T> v, Mask512<T> mask) { + const Full512<T> d; + HWY_ALIGN T buf[2 * 64 / sizeof(T)]; + EmuCompressStore(v, mask, d, buf); + return Load(d, buf); +} + +HWY_INLINE Vec256<uint8_t> EmuCompress(Vec256<uint8_t> v, + const Mask256<uint8_t> mask) { + const Full256<uint8_t> d; + HWY_ALIGN uint8_t buf[2 * 32 / sizeof(uint8_t)]; + EmuCompressStore(v, mask, d, buf); + return Load(d, buf); +} + +} // namespace detail + +template <class V, class M, HWY_IF_LANE_SIZE_ONE_OF_V(V, 0x6)> // 1 or 2 bytes +HWY_API V Compress(V v, const M mask) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto mu = RebindMask(du, mask); +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +#else + return BitCast(d, detail::EmuCompress(BitCast(du, v), mu)); +#endif +} + +template <class V, class M, HWY_IF_LANE_SIZE_V(V, 4)> +HWY_API V Compress(V v, const M mask) { + const DFromV<decltype(v)> d; + const RebindToUnsigned<decltype(d)> du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> Compress(Vec512<T> v, Mask512<T> mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) constexpr uint64_t packed_array[256] = { + // From PrintCompress32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, + 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, + 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, + 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, + 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, + 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, + 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, + 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, + 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, + 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, + 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, + 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, + 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, + 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, + 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, + 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, + 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, + 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, + 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, + 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, + 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, + 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, + 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, + 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, + 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, + 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, + 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, + 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, + 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, + 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, + 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, + 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, + 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, + 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, + 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, + 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, + 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, + 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, + 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, + 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, + 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, + 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, + 0x10765432, 0x17654320, 0x07654321, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const Full512<T> d; + const RebindToUnsigned<decltype(d)> du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + const auto indices = Indices512<T>{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot + +template <class V, class M, HWY_IF_NOT_LANE_SIZE_V(V, 8)> +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +template <typename T, HWY_IF_LANE_SIZE(T, 8)> +HWY_API Vec512<T> CompressNot(Vec512<T> v, Mask512<T> mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) constexpr uint64_t packed_array[256] = { + // From PrintCompressNot32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, + 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, + 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, + 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, + 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, + 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, + 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, + 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, + 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, + 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, + 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, + 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, + 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, + 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, + 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, + 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, + 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, + 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, + 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, + 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, + 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, + 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, + 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, + 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, + 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, + 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, + 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, + 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, + 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, + 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, + 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, + 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, + 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, + 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, + 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, + 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, + 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, + 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, + 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, + 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, + 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, + 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, + 0x76543210, 0x76543201, 0x76543210, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const Full512<T> d; + const RebindToUnsigned<decltype(d)> du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + const auto indices = Indices512<T>{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// uint64_t lanes. Only implement for 256 and 512-bit vectors because this is a +// no-op for 128-bit. +template <class V, class M, hwy::EnableIf<(sizeof(V) > 16)>* = nullptr> +HWY_API V CompressBlocksNot(V v, M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressBits +template <class V> +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV<V>(), bits)); +} + +// ------------------------------ CompressStore + +template <class V, class D, HWY_IF_LANE_SIZE_ONE_OF_V(V, 0x6)> // 1 or 2 bytes +HWY_API size_t CompressStore(V v, MFromD<D> mask, D d, + TFromD<D>* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; + const auto mu = RebindMask(du, mask); + auto pu = reinterpret_cast<TFromD<decltype(du)> * HWY_RESTRICT>(unaligned); +#if HWY_TARGET == HWY_AVX3_DL // VBMI2 + detail::NativeCompressStore(BitCast(du, v), mu, du, pu); +#else + detail::EmuCompressStore(BitCast(du, v), mu, du, pu); +#endif + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(pu, count); + return count; +} + +template <class V, class D, HWY_IF_LANE_SIZE_ONE_OF_V(V, 0x110)> // 4 or 8 +HWY_API size_t CompressStore(V v, MFromD<D> mask, D d, + TFromD<D>* HWY_RESTRICT unaligned) { + const RebindToUnsigned<decltype(d)> du; + const auto mu = RebindMask(du, mask); + using TU = TFromD<decltype(du)>; + TU* HWY_RESTRICT pu = reinterpret_cast<TU*>(unaligned); + detail::NativeCompressStore(BitCast(du, v), mu, du, pu); + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(pu, count); + return count; +} + +// Additional overloads to avoid casting to uint32_t (delay?). +HWY_API size_t CompressStore(Vec512<float> v, Mask512<float> mask, + Full512<float> /* tag */, + float* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +HWY_API size_t CompressStore(Vec512<double> v, Mask512<double> mask, + Full512<double> /* tag */, + double* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore +template <class D, typename T = TFromD<D>> +HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d, + T* HWY_RESTRICT unaligned) { + // Native CompressStore already does the blending at no extra cost (latency + // 11, rthroughput 2 - same as compress plus store). + if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) > 2) { + return CompressStore(v, m, d, unaligned); + } else { + const size_t count = CountTrue(d, m); + BlendedStore(Compress(v, m), FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore +template <class D> +HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD<D>* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ LoadInterleaved4 + +// Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. +namespace detail { + +// Type-safe wrapper. +template <_MM_PERM_ENUM kPerm, typename T> +Vec512<T> Shuffle128(const Vec512<T> lo, const Vec512<T> hi) { + return Vec512<T>{_mm512_shuffle_i64x2(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512<float> Shuffle128(const Vec512<float> lo, const Vec512<float> hi) { + return Vec512<float>{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512<double> Shuffle128(const Vec512<double> lo, const Vec512<double> hi) { + return Vec512<double>{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// Output: +// 9 6 3 0 (LSB of A) +// a 7 4 1 +// b 8 5 2 +template <typename T> +HWY_API void LoadTransposedBlocks3(Full512<T> d, + const T* HWY_RESTRICT unaligned, + Vec512<T>& A, Vec512<T>& B, Vec512<T>& C) { + constexpr size_t N = 64 / sizeof(T); + const Vec512<T> v3210 = LoadU(d, unaligned + 0 * N); + const Vec512<T> v7654 = LoadU(d, unaligned + 1 * N); + const Vec512<T> vba98 = LoadU(d, unaligned + 2 * N); + + const Vec512<T> v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); + const Vec512<T> va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); + + A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); + B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); + C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// f e d c +// Output: +// c 8 4 0 (LSB of A) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +template <typename T> +HWY_API void LoadTransposedBlocks4(Full512<T> d, + const T* HWY_RESTRICT unaligned, + Vec512<T>& A, Vec512<T>& B, Vec512<T>& C, + Vec512<T>& D) { + constexpr size_t N = 64 / sizeof(T); + const Vec512<T> v3210 = LoadU(d, unaligned + 0 * N); + const Vec512<T> v7654 = LoadU(d, unaligned + 1 * N); + const Vec512<T> vba98 = LoadU(d, unaligned + 2 * N); + const Vec512<T> vfedc = LoadU(d, unaligned + 3 * N); + + const Vec512<T> v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); + const Vec512<T> vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); + const Vec512<T> v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); + const Vec512<T> vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); + A = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); + B = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); + C = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); + D = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2 + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 6 4 2 0 (LSB of i) +// 7 5 3 1 +// Output: +// 3 2 1 0 +// 7 6 5 4 +template <typename T> +HWY_API void StoreTransposedBlocks2(const Vec512<T> i, const Vec512<T> j, + const Full512<T> d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const auto j1_i1_j0_i0 = + detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); + const auto j3_i3_j2_i2 = + detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); + StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); + StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 9 6 3 0 (LSB of i) +// a 7 4 1 +// b 8 5 2 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +template <typename T> +HWY_API void StoreTransposedBlocks3(const Vec512<T> i, const Vec512<T> j, + const Vec512<T> k, Full512<T> d, + T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const Vec512<T> j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); + const Vec512<T> i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); + const Vec512<T> j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); + + const Vec512<T> out0 = // i1 k0 j0 i0 + detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); + const Vec512<T> out1 = // j2 i2 k1 j1 + detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); + const Vec512<T> out2 = // k3 j3 i3 k2 + detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); + + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// c 8 4 0 (LSB of i) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +// f e d c +template <typename T> +HWY_API void StoreTransposedBlocks4(const Vec512<T> i, const Vec512<T> j, + const Vec512<T> k, const Vec512<T> l, + Full512<T> d, T* HWY_RESTRICT unaligned) { + constexpr size_t N = 64 / sizeof(T); + const Vec512<T> j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const Vec512<T> l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); + const Vec512<T> j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const Vec512<T> l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); + const Vec512<T> out0 = + detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); + const Vec512<T> out1 = + detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); + const Vec512<T> out2 = + detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); + const Vec512<T> out3 = + detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ MulEven/Odd (Shuffle2301, InterleaveLower) + +HWY_INLINE Vec512<uint64_t> MulEven(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + const Full512<uint64_t> du64; + const RepartitionToNarrow<decltype(du64)> du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveLower(mulL, mulH); +} + +HWY_INLINE Vec512<uint64_t> MulOdd(const Vec512<uint64_t> a, + const Vec512<uint64_t> b) { + const Full512<uint64_t> du64; + const RepartitionToNarrow<decltype(du64)> du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need bits [95:64] (= upper half of input) + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Same as above, but we're using the odd results (upper 64 bits per block). + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + const auto mulH = MulEven(aH, bH) + w1 + k; + const auto mulL = ShiftLeft<32>(t) + w3; + return InterleaveUpper(du64, mulL, mulH); +} + +// ------------------------------ ReorderWidenMulAccumulate +HWY_API Vec512<int32_t> ReorderWidenMulAccumulate(Full512<int32_t> /*d32*/, + Vec512<int16_t> a, + Vec512<int16_t> b, + const Vec512<int32_t> sum0, + Vec512<int32_t>& /*sum1*/) { + return sum0 + Vec512<int32_t>{_mm512_madd_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512<int32_t> RearrangeToOddPlusEven(const Vec512<int32_t> sum0, + Vec512<int32_t> /*sum1*/) { + return sum0; // invariant already holds +} + +// ------------------------------ Reductions + +// Returns the sum in each lane. +HWY_API Vec512<int32_t> SumOfLanes(Full512<int32_t> d, Vec512<int32_t> v) { + return Set(d, _mm512_reduce_add_epi32(v.raw)); +} +HWY_API Vec512<int64_t> SumOfLanes(Full512<int64_t> d, Vec512<int64_t> v) { + return Set(d, _mm512_reduce_add_epi64(v.raw)); +} +HWY_API Vec512<uint32_t> SumOfLanes(Full512<uint32_t> d, Vec512<uint32_t> v) { + return Set(d, static_cast<uint32_t>(_mm512_reduce_add_epi32(v.raw))); +} +HWY_API Vec512<uint64_t> SumOfLanes(Full512<uint64_t> d, Vec512<uint64_t> v) { + return Set(d, static_cast<uint64_t>(_mm512_reduce_add_epi64(v.raw))); +} +HWY_API Vec512<float> SumOfLanes(Full512<float> d, Vec512<float> v) { + return Set(d, _mm512_reduce_add_ps(v.raw)); +} +HWY_API Vec512<double> SumOfLanes(Full512<double> d, Vec512<double> v) { + return Set(d, _mm512_reduce_add_pd(v.raw)); +} +HWY_API Vec512<uint16_t> SumOfLanes(Full512<uint16_t> d, Vec512<uint16_t> v) { + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(d32, even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} +HWY_API Vec512<int16_t> SumOfLanes(Full512<int16_t> d, Vec512<int16_t> v) { + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto sum = SumOfLanes(d32, even + odd); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(sum)), BitCast(d, sum)); +} + +// Returns the minimum in each lane. +HWY_API Vec512<int32_t> MinOfLanes(Full512<int32_t> d, Vec512<int32_t> v) { + return Set(d, _mm512_reduce_min_epi32(v.raw)); +} +HWY_API Vec512<int64_t> MinOfLanes(Full512<int64_t> d, Vec512<int64_t> v) { + return Set(d, _mm512_reduce_min_epi64(v.raw)); +} +HWY_API Vec512<uint32_t> MinOfLanes(Full512<uint32_t> d, Vec512<uint32_t> v) { + return Set(d, _mm512_reduce_min_epu32(v.raw)); +} +HWY_API Vec512<uint64_t> MinOfLanes(Full512<uint64_t> d, Vec512<uint64_t> v) { + return Set(d, _mm512_reduce_min_epu64(v.raw)); +} +HWY_API Vec512<float> MinOfLanes(Full512<float> d, Vec512<float> v) { + return Set(d, _mm512_reduce_min_ps(v.raw)); +} +HWY_API Vec512<double> MinOfLanes(Full512<double> d, Vec512<double> v) { + return Set(d, _mm512_reduce_min_pd(v.raw)); +} +HWY_API Vec512<uint16_t> MinOfLanes(Full512<uint16_t> d, Vec512<uint16_t> v) { + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(d32, Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec512<int16_t> MinOfLanes(Full512<int16_t> d, Vec512<int16_t> v) { + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MinOfLanes(d32, Min(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// Returns the maximum in each lane. +HWY_API Vec512<int32_t> MaxOfLanes(Full512<int32_t> d, Vec512<int32_t> v) { + return Set(d, _mm512_reduce_max_epi32(v.raw)); +} +HWY_API Vec512<int64_t> MaxOfLanes(Full512<int64_t> d, Vec512<int64_t> v) { + return Set(d, _mm512_reduce_max_epi64(v.raw)); +} +HWY_API Vec512<uint32_t> MaxOfLanes(Full512<uint32_t> d, Vec512<uint32_t> v) { + return Set(d, _mm512_reduce_max_epu32(v.raw)); +} +HWY_API Vec512<uint64_t> MaxOfLanes(Full512<uint64_t> d, Vec512<uint64_t> v) { + return Set(d, _mm512_reduce_max_epu64(v.raw)); +} +HWY_API Vec512<float> MaxOfLanes(Full512<float> d, Vec512<float> v) { + return Set(d, _mm512_reduce_max_ps(v.raw)); +} +HWY_API Vec512<double> MaxOfLanes(Full512<double> d, Vec512<double> v) { + return Set(d, _mm512_reduce_max_pd(v.raw)); +} +HWY_API Vec512<uint16_t> MaxOfLanes(Full512<uint16_t> d, Vec512<uint16_t> v) { + const RepartitionToWide<decltype(d)> d32; + const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(d32, Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} +HWY_API Vec512<int16_t> MaxOfLanes(Full512<int16_t> d, Vec512<int16_t> v) { + const RepartitionToWide<decltype(d)> d32; + // Sign-extend + const auto even = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, v))); + const auto odd = ShiftRight<16>(BitCast(d32, v)); + const auto min = MaxOfLanes(d32, Max(even, odd)); + // Also broadcast into odd lanes. + return OddEven(BitCast(d, ShiftLeft<16>(min)), BitCast(d, min)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/third_party/highway/hwy/per_target.cc b/third_party/highway/hwy/per_target.cc new file mode 100644 index 0000000000..4cbf152328 --- /dev/null +++ b/third_party/highway/hwy/per_target.cc @@ -0,0 +1,50 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/per_target.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/per_target.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +// On SVE, Lanes rounds down to a power of two, but we want to know the actual +// size here. Otherwise, hypothetical SVE with 48 bytes would round down to 32 +// and we'd enable HWY_SVE_256, and then fail reverse_test because Reverse on +// HWY_SVE_256 requires the actual vector to be a power of two. +#if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE_256 +size_t GetVectorBytes() { return detail::AllHardwareLanes(hwy::SizeTag<1>()); } +#else +size_t GetVectorBytes() { return Lanes(ScalableTag<uint8_t>()); } +#endif +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(GetVectorBytes); // Local function. +} // namespace + +size_t VectorBytes() { return HWY_DYNAMIC_DISPATCH(GetVectorBytes)(); } + +} // namespace hwy +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/per_target.h b/third_party/highway/hwy/per_target.h new file mode 100644 index 0000000000..da85de3226 --- /dev/null +++ b/third_party/highway/hwy/per_target.h @@ -0,0 +1,37 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PER_TARGET_H_ +#define HIGHWAY_HWY_PER_TARGET_H_ + +#include <stddef.h> + +// Per-target functions. + +namespace hwy { + +// Returns size in bytes of a vector, i.e. `Lanes(ScalableTag<uint8_t>())`. +// +// Do not cache the result, which may change after calling DisableTargets, or +// if software requests a different vector size (e.g. when entering/exiting SME +// streaming mode). Instead call this right before the code that depends on the +// result, without any DisableTargets or SME transition in-between. Note that +// this involves an indirect call, so prefer not to call this frequently nor +// unnecessarily. +size_t VectorBytes(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_PER_TARGET_H_ diff --git a/third_party/highway/hwy/print-inl.h b/third_party/highway/hwy/print-inl.h new file mode 100644 index 0000000000..6490d90dd6 --- /dev/null +++ b/third_party/highway/hwy/print-inl.h @@ -0,0 +1,55 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Print() function + +#include <stdint.h> + +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" +#include "hwy/print.h" + +// Per-target include guard +#if defined(HIGHWAY_HWY_PRINT_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_PRINT_INL_H_ +#undef HIGHWAY_HWY_PRINT_INL_H_ +#else +#define HIGHWAY_HWY_PRINT_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Prints lanes around `lane`, in memory order. +template <class D, class V = VFromD<D>> +void Print(const D d, const char* caption, VecArg<V> v, size_t lane_u = 0, + size_t max_lanes = 7) { + const size_t N = Lanes(d); + using T = TFromD<D>; + auto lanes = AllocateAligned<T>(N); + Store(v, d, lanes.get()); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + hwy::detail::PrintArray(info, caption, lanes.get(), N, lane_u, max_lanes); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/third_party/highway/hwy/print.cc b/third_party/highway/hwy/print.cc new file mode 100644 index 0000000000..0b52cde1b9 --- /dev/null +++ b/third_party/highway/hwy/print.cc @@ -0,0 +1,110 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/print.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stddef.h> +#include <stdio.h> + +#include "hwy/base.h" + +namespace hwy { +namespace detail { + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100) { + const char prefix = info.is_float ? 'f' : (info.is_signed ? 'i' : 'u'); + // Omit the xN suffix for scalars. + if (N == 1) { + // NOLINTNEXTLINE + snprintf(string100, 64, "%c%d", prefix, + static_cast<int>(info.sizeof_t * 8)); + } else { + // NOLINTNEXTLINE + snprintf(string100, 64, "%c%dx%d", prefix, + static_cast<int>(info.sizeof_t * 8), static_cast<int>(N)); + } +} + +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100) { + if (info.sizeof_t == 1) { + uint8_t byte; + CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. + snprintf(string100, 100, "0x%02X", byte); // NOLINT + } else if (info.sizeof_t == 2) { + uint16_t bits; + CopyBytes<2>(ptr, &bits); + snprintf(string100, 100, "0x%04X", bits); // NOLINT + } else if (info.sizeof_t == 4) { + if (info.is_float) { + float value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%g", static_cast<double>(value)); // NOLINT + } else if (info.is_signed) { + int32_t value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%d", value); // NOLINT + } else { + uint32_t value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%u", value); // NOLINT + } + } else { + HWY_ASSERT(info.sizeof_t == 8); + if (info.is_float) { + double value; + CopyBytes<8>(ptr, &value); + snprintf(string100, 100, "%g", value); // NOLINT + } else if (info.is_signed) { + int64_t value; + CopyBytes<8>(ptr, &value); + snprintf(string100, 100, "%" PRIi64 "", value); // NOLINT + } else { + uint64_t value; + CopyBytes<8>(ptr, &value); + snprintf(string100, 100, "%" PRIu64 "", value); // NOLINT + } + } +} + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, size_t lane_u, + size_t max_lanes) { + const uint8_t* array_bytes = reinterpret_cast<const uint8_t*>(array_void); + + char type_name[100]; + TypeName(info, N, type_name); + + const intptr_t lane = intptr_t(lane_u); + const size_t begin = static_cast<size_t>(HWY_MAX(0, lane - 2)); + const size_t end = HWY_MIN(begin + max_lanes, N); + fprintf(stderr, "%s %s [%" PRIu64 "+ ->]:\n ", type_name, caption, + static_cast<uint64_t>(begin)); + for (size_t i = begin; i < end; ++i) { + const void* ptr = array_bytes + i * info.sizeof_t; + char str[100]; + ToString(info, ptr, str); + fprintf(stderr, "%s,", str); + } + if (begin >= end) fprintf(stderr, "(out of bounds)"); + fprintf(stderr, "\n"); +} + +} // namespace detail +} // namespace hwy diff --git a/third_party/highway/hwy/print.h b/third_party/highway/hwy/print.h new file mode 100644 index 0000000000..13792866a3 --- /dev/null +++ b/third_party/highway/hwy/print.h @@ -0,0 +1,73 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_PRINT_H_ +#define HWY_PRINT_H_ + +// Helpers for printing vector lanes. + +#include <stddef.h> +#include <stdio.h> + +#include "hwy/base.h" +#include "hwy/highway_export.h" + +namespace hwy { + +namespace detail { + +// For implementing value comparisons etc. as type-erased functions to reduce +// template bloat. +struct TypeInfo { + size_t sizeof_t; + bool is_float; + bool is_signed; +}; + +template <typename T> +HWY_INLINE TypeInfo MakeTypeInfo() { + TypeInfo info; + info.sizeof_t = sizeof(T); + info.is_float = IsFloat<T>(); + info.is_signed = IsSigned<T>(); + return info; +} + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100); +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100); + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, + size_t lane_u = 0, size_t max_lanes = 7); + +} // namespace detail + +template <typename T> +HWY_NOINLINE void PrintValue(T value) { + char str[100]; + detail::ToString(hwy::detail::MakeTypeInfo<T>(), &value, str); + fprintf(stderr, "%s,", str); +} + +template <typename T> +HWY_NOINLINE void PrintArray(const T* value, size_t count) { + detail::PrintArray(hwy::detail::MakeTypeInfo<T>(), "", value, count, 0, + count); +} + +} // namespace hwy + +#endif // HWY_PRINT_H_ diff --git a/third_party/highway/hwy/targets.cc b/third_party/highway/hwy/targets.cc new file mode 100644 index 0000000000..dc4217c8fe --- /dev/null +++ b/third_party/highway/hwy/targets.cc @@ -0,0 +1,433 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> // PRIx64 +#include <stdarg.h> +#include <stddef.h> +#include <stdint.h> +#include <stdio.h> + +#include <atomic> + +#include "hwy/per_target.h" // VectorBytes + +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif + +#include <stdlib.h> // abort / exit + +#if HWY_ARCH_X86 +#include <xmmintrin.h> +#if HWY_COMPILER_MSVC +#include <intrin.h> +#else // !HWY_COMPILER_MSVC +#include <cpuid.h> +#endif // HWY_COMPILER_MSVC + +#elif HWY_ARCH_ARM && HWY_OS_LINUX && !defined(TOOLCHAIN_MISS_SYS_AUXV_H) +#include <sys/auxv.h> +#endif // HWY_ARCH_* + +namespace hwy { +namespace { + +#if HWY_ARCH_X86 + +HWY_INLINE bool IsBitSet(const uint32_t reg, const int index) { + return (reg & (1U << index)) != 0; +} + +// Calls CPUID instruction with eax=level and ecx=count and returns the result +// in abcd array where abcd = {eax, ebx, ecx, edx} (hence the name abcd). +HWY_INLINE void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC + int regs[4]; + __cpuidex(regs, level, count); + for (int i = 0; i < 4; ++i) { + abcd[i] = regs[i]; + } +#else // HWY_COMPILER_MSVC + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif // HWY_COMPILER_MSVC +} + +// Returns the lower 32 bits of extended control register 0. +// Requires CPU support for "OSXSAVE" (see below). +uint32_t ReadXCR0() { +#if HWY_COMPILER_MSVC + return static_cast<uint32_t>(_xgetbv(0)); +#else // HWY_COMPILER_MSVC + uint32_t xcr0, xcr0_high; + const uint32_t index = 0; + asm volatile(".byte 0x0F, 0x01, 0xD0" + : "=a"(xcr0), "=d"(xcr0_high) + : "c"(index)); + return xcr0; +#endif // HWY_COMPILER_MSVC +} + +#endif // HWY_ARCH_X86 + +// When running tests, this value can be set to the mocked supported targets +// mask. Only written to from a single thread before the test starts. +int64_t supported_targets_for_test_ = 0; + +// Mask of targets disabled at runtime with DisableTargets. +int64_t supported_mask_ = LimitsMax<int64_t>(); + +#if HWY_ARCH_X86 +// Arbitrary bit indices indicating which instruction set extensions are +// supported. Use enum to ensure values are distinct. +enum class FeatureIndex : uint32_t { + kSSE = 0, + kSSE2, + kSSE3, + kSSSE3, + + kSSE41, + kSSE42, + kCLMUL, + kAES, + + kAVX, + kAVX2, + kF16C, + kFMA, + kLZCNT, + kBMI, + kBMI2, + + kAVX512F, + kAVX512VL, + kAVX512DQ, + kAVX512BW, + + kVNNI, + kVPCLMULQDQ, + kVBMI, + kVBMI2, + kVAES, + kPOPCNTDQ, + kBITALG, + + kSentinel +}; +static_assert(static_cast<size_t>(FeatureIndex::kSentinel) < 64, + "Too many bits for u64"); + +HWY_INLINE constexpr uint64_t Bit(FeatureIndex index) { + return 1ull << static_cast<size_t>(index); +} + +constexpr uint64_t kGroupSSSE3 = + Bit(FeatureIndex::kSSE) | Bit(FeatureIndex::kSSE2) | + Bit(FeatureIndex::kSSE3) | Bit(FeatureIndex::kSSSE3); + +constexpr uint64_t kGroupSSE4 = + Bit(FeatureIndex::kSSE41) | Bit(FeatureIndex::kSSE42) | + Bit(FeatureIndex::kCLMUL) | Bit(FeatureIndex::kAES) | kGroupSSSE3; + +// We normally assume BMI/BMI2/FMA are available if AVX2 is. This allows us to +// use BZHI and (compiler-generated) MULX. However, VirtualBox lacks them +// [https://www.virtualbox.org/ticket/15471]. Thus we provide the option of +// avoiding using and requiring these so AVX2 can still be used. +#ifdef HWY_DISABLE_BMI2_FMA +constexpr uint64_t kGroupBMI2_FMA = 0; +#else +constexpr uint64_t kGroupBMI2_FMA = Bit(FeatureIndex::kBMI) | + Bit(FeatureIndex::kBMI2) | + Bit(FeatureIndex::kFMA); +#endif + +#ifdef HWY_DISABLE_F16C +constexpr uint64_t kGroupF16C = 0; +#else +constexpr uint64_t kGroupF16C = Bit(FeatureIndex::kF16C); +#endif + +constexpr uint64_t kGroupAVX2 = + Bit(FeatureIndex::kAVX) | Bit(FeatureIndex::kAVX2) | + Bit(FeatureIndex::kLZCNT) | kGroupBMI2_FMA | kGroupF16C | kGroupSSE4; + +constexpr uint64_t kGroupAVX3 = + Bit(FeatureIndex::kAVX512F) | Bit(FeatureIndex::kAVX512VL) | + Bit(FeatureIndex::kAVX512DQ) | Bit(FeatureIndex::kAVX512BW) | kGroupAVX2; + +constexpr uint64_t kGroupAVX3_DL = + Bit(FeatureIndex::kVNNI) | Bit(FeatureIndex::kVPCLMULQDQ) | + Bit(FeatureIndex::kVBMI) | Bit(FeatureIndex::kVBMI2) | + Bit(FeatureIndex::kVAES) | Bit(FeatureIndex::kPOPCNTDQ) | + Bit(FeatureIndex::kBITALG) | kGroupAVX3; + +#endif // HWY_ARCH_X86 + +// Returns targets supported by the CPU, independently of DisableTargets. +// Factored out of SupportedTargets to make its structure more obvious. Note +// that x86 CPUID may take several hundred cycles. +int64_t DetectTargets() { + // Apps will use only one of these (the default is EMU128), but compile flags + // for this TU may differ from that of the app, so allow both. + int64_t bits = HWY_SCALAR | HWY_EMU128; + +#if HWY_ARCH_X86 + bool has_osxsave = false; + { // ensures we do not accidentally use flags outside this block + uint64_t flags = 0; + uint32_t abcd[4]; + + Cpuid(0, 0, abcd); + const uint32_t max_level = abcd[0]; + + // Standard feature flags + Cpuid(1, 0, abcd); + flags |= IsBitSet(abcd[3], 25) ? Bit(FeatureIndex::kSSE) : 0; + flags |= IsBitSet(abcd[3], 26) ? Bit(FeatureIndex::kSSE2) : 0; + flags |= IsBitSet(abcd[2], 0) ? Bit(FeatureIndex::kSSE3) : 0; + flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kCLMUL) : 0; + flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kSSSE3) : 0; + flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kFMA) : 0; + flags |= IsBitSet(abcd[2], 19) ? Bit(FeatureIndex::kSSE41) : 0; + flags |= IsBitSet(abcd[2], 20) ? Bit(FeatureIndex::kSSE42) : 0; + flags |= IsBitSet(abcd[2], 25) ? Bit(FeatureIndex::kAES) : 0; + flags |= IsBitSet(abcd[2], 28) ? Bit(FeatureIndex::kAVX) : 0; + flags |= IsBitSet(abcd[2], 29) ? Bit(FeatureIndex::kF16C) : 0; + has_osxsave = IsBitSet(abcd[2], 27); + + // Extended feature flags + Cpuid(0x80000001U, 0, abcd); + flags |= IsBitSet(abcd[2], 5) ? Bit(FeatureIndex::kLZCNT) : 0; + + // Extended features + if (max_level >= 7) { + Cpuid(7, 0, abcd); + flags |= IsBitSet(abcd[1], 3) ? Bit(FeatureIndex::kBMI) : 0; + flags |= IsBitSet(abcd[1], 5) ? Bit(FeatureIndex::kAVX2) : 0; + flags |= IsBitSet(abcd[1], 8) ? Bit(FeatureIndex::kBMI2) : 0; + + flags |= IsBitSet(abcd[1], 16) ? Bit(FeatureIndex::kAVX512F) : 0; + flags |= IsBitSet(abcd[1], 17) ? Bit(FeatureIndex::kAVX512DQ) : 0; + flags |= IsBitSet(abcd[1], 30) ? Bit(FeatureIndex::kAVX512BW) : 0; + flags |= IsBitSet(abcd[1], 31) ? Bit(FeatureIndex::kAVX512VL) : 0; + + flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kVBMI) : 0; + flags |= IsBitSet(abcd[2], 6) ? Bit(FeatureIndex::kVBMI2) : 0; + flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kVAES) : 0; + flags |= IsBitSet(abcd[2], 10) ? Bit(FeatureIndex::kVPCLMULQDQ) : 0; + flags |= IsBitSet(abcd[2], 11) ? Bit(FeatureIndex::kVNNI) : 0; + flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kBITALG) : 0; + flags |= IsBitSet(abcd[2], 14) ? Bit(FeatureIndex::kPOPCNTDQ) : 0; + } + + // Set target bit(s) if all their group's flags are all set. + if ((flags & kGroupAVX3_DL) == kGroupAVX3_DL) { + bits |= HWY_AVX3_DL; + } + if ((flags & kGroupAVX3) == kGroupAVX3) { + bits |= HWY_AVX3; + } + if ((flags & kGroupAVX2) == kGroupAVX2) { + bits |= HWY_AVX2; + } + if ((flags & kGroupSSE4) == kGroupSSE4) { + bits |= HWY_SSE4; + } + if ((flags & kGroupSSSE3) == kGroupSSSE3) { + bits |= HWY_SSSE3; + } + } + + // Clear bits if the OS does not support XSAVE - otherwise, registers + // are not preserved across context switches. + if (has_osxsave) { + const uint32_t xcr0 = ReadXCR0(); + const int64_t min_avx3 = HWY_AVX3 | HWY_AVX3_DL; + const int64_t min_avx2 = HWY_AVX2 | min_avx3; + // XMM + if (!IsBitSet(xcr0, 1)) { + bits &= ~(HWY_SSSE3 | HWY_SSE4 | min_avx2); + } + // YMM + if (!IsBitSet(xcr0, 2)) { + bits &= ~min_avx2; + } + // opmask, ZMM lo/hi + if (!IsBitSet(xcr0, 5) || !IsBitSet(xcr0, 6) || !IsBitSet(xcr0, 7)) { + bits &= ~min_avx3; + } + } + + if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { + fprintf(stderr, + "WARNING: CPU supports %" PRIx64 " but software requires %" PRIx64 + "\n", + bits, static_cast<int64_t>(HWY_ENABLED_BASELINE)); + } + +#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH + using CapBits = unsigned long; // NOLINT + const CapBits hw = getauxval(AT_HWCAP); + (void)hw; + +#if HWY_ARCH_ARM_A64 + +#if defined(HWCAP_AES) + // aarch64 always has NEON and VFPv4, but not necessarily AES, which we + // require and thus must still check for. + if (hw & HWCAP_AES) { + bits |= HWY_NEON; + } +#endif // HWCAP_AES + +#if defined(HWCAP_SVE) + if (hw & HWCAP_SVE) { + bits |= HWY_SVE; + } +#endif + +#if defined(HWCAP2_SVE2) && defined(HWCAP2_SVEAES) + const CapBits hw2 = getauxval(AT_HWCAP2); + if ((hw2 & HWCAP2_SVE2) && (hw2 & HWCAP2_SVEAES)) { + bits |= HWY_SVE2; + } +#endif + +#else // HWY_ARCH_ARM_A64 + +// Some old auxv.h / hwcap.h do not define these. If not, treat as unsupported. +// Note that AES has a different HWCAP bit compared to aarch64. +#if defined(HWCAP_NEON) && defined(HWCAP_VFPv4) + if ((hw & HWCAP_NEON) && (hw & HWCAP_VFPv4)) { + bits |= HWY_NEON; + } +#endif + +#endif // HWY_ARCH_ARM_A64 + if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { + fprintf(stderr, + "WARNING: CPU supports %" PRIx64 " but software requires %" PRIx64 + "\n", + bits, static_cast<int64_t>(HWY_ENABLED_BASELINE)); + } +#else // HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH + // TODO(janwas): detect for other platforms and check for baseline + // This file is typically compiled without HWY_IS_TEST, but targets_test has + // it set, and will expect all of its HWY_TARGETS (= all attainable) to be + // supported. + bits |= HWY_ENABLED_BASELINE; +#endif // HWY_ARCH_X86 + + return bits; +} + +} // namespace + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[2000]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); + +// If compiled with any sanitizer, they can also print a stack trace. +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + __sanitizer_print_stack_trace(); +#endif // HWY_IS_* + fflush(stderr); + +// Now terminate the program: +#if HWY_ARCH_RVV + exit(1); // trap/abort just freeze Spike. +#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC + // Facilitates breaking into a debugger, but don't use this in non-debug + // builds because it looks like "illegal instruction", which is misleading. + __builtin_trap(); +#else + abort(); // Compile error without this due to HWY_NORETURN. +#endif +} + +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets) { + supported_mask_ = static_cast<int64_t>(~disabled_targets); + // This will take effect on the next call to SupportedTargets, which is + // called right before GetChosenTarget::Update. However, calling Update here + // would make it appear that HWY_DYNAMIC_DISPATCH was called, which we want + // to check in tests. We instead de-initialize such that the next + // HWY_DYNAMIC_DISPATCH calls GetChosenTarget::Update via FunctionCache. + GetChosenTarget().DeInit(); +} + +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets) { + supported_targets_for_test_ = targets; + GetChosenTarget().DeInit(); // see comment above +} + +HWY_DLLEXPORT int64_t SupportedTargets() { + int64_t targets = supported_targets_for_test_; + if (HWY_LIKELY(targets == 0)) { + // Mock not active. Re-detect instead of caching just in case we're on a + // heterogeneous ISA (also requires some app support to pin threads). This + // is only reached on the first HWY_DYNAMIC_DISPATCH or after each call to + // DisableTargets or SetSupportedTargetsForTest. + targets = DetectTargets(); + + // VectorBytes invokes HWY_DYNAMIC_DISPATCH. To prevent infinite recursion, + // first set up ChosenTarget. No need to Update() again afterwards with the + // final targets - that will be done by a caller of this function. + GetChosenTarget().Update(targets); + + // Now that we can call VectorBytes, check for targets with specific sizes. + if (HWY_ARCH_ARM_A64) { + const size_t vec_bytes = VectorBytes(); // uncached, see declaration + if ((targets & HWY_SVE) && vec_bytes == 32) { + targets = static_cast<int64_t>(targets | HWY_SVE_256); + } else { + targets = static_cast<int64_t>(targets & ~HWY_SVE_256); + } + if ((targets & HWY_SVE2) && vec_bytes == 16) { + targets = static_cast<int64_t>(targets | HWY_SVE2_128); + } else { + targets = static_cast<int64_t>(targets & ~HWY_SVE2_128); + } + } // HWY_ARCH_ARM_A64 + } + + targets &= supported_mask_; + return targets == 0 ? HWY_STATIC_TARGET : targets; +} + +HWY_DLLEXPORT ChosenTarget& GetChosenTarget() { + static ChosenTarget chosen_target; + return chosen_target; +} + +} // namespace hwy diff --git a/third_party/highway/hwy/targets.h b/third_party/highway/hwy/targets.h new file mode 100644 index 0000000000..5dba12ae96 --- /dev/null +++ b/third_party/highway/hwy/targets.h @@ -0,0 +1,326 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TARGETS_H_ +#define HIGHWAY_HWY_TARGETS_H_ + +// Allows opting out of C++ standard library usage, which is not available in +// some Compiler Explorer environments. +#ifndef HWY_NO_LIBCXX +#include <vector> +#endif + +// For SIMD module implementations and their callers. Defines which targets to +// generate and call. + +#include "hwy/base.h" +#include "hwy/detect_targets.h" +#include "hwy/highway_export.h" + +#if !HWY_ARCH_RVV && !defined(HWY_NO_LIBCXX) +#include <atomic> +#endif + +namespace hwy { + +// Returns bitfield of enabled targets that are supported on this CPU; there is +// always at least one such target, hence the return value is never 0. The +// targets returned may change after calling DisableTargets. This function is +// always defined, but the HWY_SUPPORTED_TARGETS wrapper may allow eliding +// calls to it if there is only a single target enabled. +HWY_DLLEXPORT int64_t SupportedTargets(); + +// Evaluates to a function call, or literal if there is a single target. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) == 0 +#define HWY_SUPPORTED_TARGETS HWY_TARGETS +#else +#define HWY_SUPPORTED_TARGETS hwy::SupportedTargets() +#endif + +// Subsequent SupportedTargets will not return targets whose bit(s) are set in +// `disabled_targets`. Exception: if SupportedTargets would return 0, it will +// instead return HWY_STATIC_TARGET (there must always be one target to call). +// +// This function is useful for disabling targets known to be buggy, or if the +// best available target is undesirable (perhaps due to throttling or memory +// bandwidth limitations). Use SetSupportedTargetsForTest instead of this +// function for iteratively enabling specific targets for testing. +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets); + +// Subsequent SupportedTargets will return the given set of targets, except +// those disabled via DisableTargets. Call with a mask of 0 to disable the mock +// and return to the normal SupportedTargets behavior. Used to run tests for +// all targets. +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets); + +#ifndef HWY_NO_LIBCXX + +// Return the list of targets in HWY_TARGETS supported by the CPU as a list of +// individual HWY_* target macros such as HWY_SCALAR or HWY_NEON. This list +// is affected by the current SetSupportedTargetsForTest() mock if any. +HWY_INLINE std::vector<int64_t> SupportedAndGeneratedTargets() { + std::vector<int64_t> ret; + for (int64_t targets = SupportedTargets() & HWY_TARGETS; targets != 0; + targets = targets & (targets - 1)) { + int64_t current_target = targets & ~(targets - 1); + ret.push_back(current_target); + } + return ret; +} + +#endif // HWY_NO_LIBCXX + +static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { + switch (target) { +#if HWY_ARCH_X86 + case HWY_SSSE3: + return "SSSE3"; + case HWY_SSE4: + return "SSE4"; + case HWY_AVX2: + return "AVX2"; + case HWY_AVX3: + return "AVX3"; + case HWY_AVX3_DL: + return "AVX3_DL"; +#endif + +#if HWY_ARCH_ARM + case HWY_SVE2_128: + return "SVE2_128"; + case HWY_SVE_256: + return "SVE_256"; + case HWY_SVE2: + return "SVE2"; + case HWY_SVE: + return "SVE"; + case HWY_NEON: + return "NEON"; +#endif + +#if HWY_ARCH_PPC + case HWY_PPC8: + return "PPC8"; +#endif + +#if HWY_ARCH_WASM + case HWY_WASM: + return "WASM"; + case HWY_WASM_EMU256: + return "WASM_EMU256"; +#endif + +#if HWY_ARCH_RVV + case HWY_RVV: + return "RVV"; +#endif + + case HWY_EMU128: + return "EMU128"; + case HWY_SCALAR: + return "SCALAR"; + + default: + return "Unknown"; // must satisfy gtest IsValidParamName() + } +} + +// The maximum number of dynamic targets on any architecture is defined by +// HWY_MAX_DYNAMIC_TARGETS and depends on the arch. + +// For the ChosenTarget mask and index we use a different bit arrangement than +// in the HWY_TARGETS mask. Only the targets involved in the current +// architecture are used in this mask, and therefore only the least significant +// (HWY_MAX_DYNAMIC_TARGETS + 2) bits of the int64_t mask are used. The least +// significant bit is set when the mask is not initialized, the next +// HWY_MAX_DYNAMIC_TARGETS more significant bits are a range of bits from the +// HWY_TARGETS or SupportedTargets() mask for the given architecture shifted to +// that position and the next more significant bit is used for HWY_SCALAR (if +// HWY_COMPILE_ONLY_SCALAR is defined) or HWY_EMU128. Because of this we need to +// define equivalent values for HWY_TARGETS in this representation. +// This mask representation allows to use ctz() on this mask and obtain a small +// number that's used as an index of the table for dynamic dispatch. In this +// way the first entry is used when the mask is uninitialized, the following +// HWY_MAX_DYNAMIC_TARGETS are for dynamic dispatch and the last one is for +// scalar. + +// The HWY_SCALAR/HWY_EMU128 bit in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_SCALAR (1LL << (HWY_MAX_DYNAMIC_TARGETS + 1)) + +// Converts from a HWY_TARGETS mask to a ChosenTarget mask format for the +// current architecture. +#define HWY_CHOSEN_TARGET_SHIFT(X) \ + ((((X) >> (HWY_HIGHEST_TARGET_BIT + 1 - HWY_MAX_DYNAMIC_TARGETS)) & \ + ((1LL << HWY_MAX_DYNAMIC_TARGETS) - 1)) \ + << 1) + +// The HWY_TARGETS mask in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_TARGETS \ + (HWY_CHOSEN_TARGET_SHIFT(HWY_TARGETS) | HWY_CHOSEN_TARGET_MASK_SCALAR | 1LL) + +#if HWY_ARCH_X86 +// Maximum number of dynamic targets, changing this value is an ABI incompatible +// change +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_X86 +// These must match the order in which the HWY_TARGETS are defined +// starting by the least significant (HWY_HIGHEST_TARGET_BIT + 1 - +// HWY_MAX_DYNAMIC_TARGETS) bit. This list must contain exactly +// HWY_MAX_DYNAMIC_TARGETS elements and does not include SCALAR. The first entry +// corresponds to the best target. Don't include a "," at the end of the list. +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_AVX3_DL(func_name), /* AVX3_DL */ \ + HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \ + HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \ + nullptr, /* AVX */ \ + HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \ + HWY_CHOOSE_SSSE3(func_name), /* SSSE3 */ \ + nullptr , /* reserved - SSE3? */ \ + nullptr /* reserved - SSE2? */ + +#elif HWY_ARCH_ARM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_ARM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2_128(func_name), /* SVE2 128-bit */ \ + HWY_CHOOSE_SVE_256(func_name), /* SVE 256-bit */ \ + HWY_CHOOSE_SVE2(func_name), /* SVE2 */ \ + HWY_CHOOSE_SVE(func_name), /* SVE */ \ + HWY_CHOOSE_NEON(func_name), /* NEON */ \ + nullptr /* reserved - Helium? */ + +#elif HWY_ARCH_RVV +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_RVV +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_RVV(func_name), /* RVV */ \ + nullptr /* reserved */ + +#elif HWY_ARCH_PPC +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_PPC +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ + nullptr, /* reserved (VSX or AltiVec) */ \ + nullptr /* reserved (VSX or AltiVec) */ + +#elif HWY_ARCH_WASM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_WASM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_WASM_EMU256(func_name), /* WASM_EMU256 */ \ + HWY_CHOOSE_WASM(func_name), /* WASM */ \ + nullptr /* reserved */ + +#else +// Unknown architecture, will use HWY_SCALAR without dynamic dispatch, though +// still creating single-entry tables in HWY_EXPORT to ensure portability. +#define HWY_MAX_DYNAMIC_TARGETS 1 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_SCALAR +#endif + +// Bitfield of supported and enabled targets. The format differs from that of +// HWY_TARGETS; the lowest bit governs the first function pointer (which is +// special in that it calls FunctionCache, then Update, then dispatches to the +// actual implementation) in the tables created by HWY_EXPORT. Monostate (see +// GetChosenTarget), thread-safe except on RVV. +struct ChosenTarget { + public: + // Reset bits according to `targets` (typically the return value of + // SupportedTargets()). Postcondition: IsInitialized() == true. + void Update(int64_t targets) { + // These are `targets` shifted downwards, see above. Also include SCALAR + // (corresponds to the last entry in the function table) as fallback. + StoreMask(HWY_CHOSEN_TARGET_SHIFT(targets) | HWY_CHOSEN_TARGET_MASK_SCALAR); + } + + // Reset to the uninitialized state, so that FunctionCache will call Update + // during the next HWY_DYNAMIC_DISPATCH, and IsInitialized returns false. + void DeInit() { StoreMask(1); } + + // Whether Update was called. This indicates whether any HWY_DYNAMIC_DISPATCH + // function was called, which we check in tests. + bool IsInitialized() const { return LoadMask() != 1; } + + // Return the index in the dynamic dispatch table to be used by the current + // CPU. Note that this method must be in the header file so it uses the value + // of HWY_CHOSEN_TARGET_MASK_TARGETS defined in the translation unit that + // calls it, which may be different from others. This means we only enable + // those targets that were actually compiled in this module. + size_t HWY_INLINE GetIndex() const { + return hwy::Num0BitsBelowLS1Bit_Nonzero64( + static_cast<uint64_t>(LoadMask() & HWY_CHOSEN_TARGET_MASK_TARGETS)); + } + + private: + // TODO(janwas): remove RVV once <atomic> is available +#if HWY_ARCH_RVV || defined(HWY_NO_LIBCXX) + int64_t LoadMask() const { return mask_; } + void StoreMask(int64_t mask) { mask_ = mask; } + + int64_t mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#else + int64_t LoadMask() const { return mask_.load(); } + void StoreMask(int64_t mask) { mask_.store(mask); } + + std::atomic<int64_t> mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#endif // HWY_ARCH_RVV +}; + +// For internal use (e.g. by FunctionCache and DisableTargets). +HWY_DLLEXPORT ChosenTarget& GetChosenTarget(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_TARGETS_H_ diff --git a/third_party/highway/hwy/targets_test.cc b/third_party/highway/hwy/targets_test.cc new file mode 100644 index 0000000000..f00b24546d --- /dev/null +++ b/third_party/highway/hwy/targets_test.cc @@ -0,0 +1,137 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#include "hwy/tests/test_util-inl.h" + +namespace fake { + +#define DECLARE_FUNCTION(TGT) \ + namespace N_##TGT { \ + /* Function argument is just to ensure/demonstrate they are possible. */ \ + int64_t FakeFunction(int) { return HWY_##TGT; } \ + } + +DECLARE_FUNCTION(AVX3_DL) +DECLARE_FUNCTION(AVX3) +DECLARE_FUNCTION(AVX2) +DECLARE_FUNCTION(SSE4) +DECLARE_FUNCTION(SSSE3) +DECLARE_FUNCTION(NEON) +DECLARE_FUNCTION(SVE) +DECLARE_FUNCTION(SVE2) +DECLARE_FUNCTION(SVE_256) +DECLARE_FUNCTION(SVE2_128) +DECLARE_FUNCTION(PPC8) +DECLARE_FUNCTION(WASM) +DECLARE_FUNCTION(WASM_EMU256) +DECLARE_FUNCTION(RVV) +DECLARE_FUNCTION(SCALAR) +DECLARE_FUNCTION(EMU128) + +HWY_EXPORT(FakeFunction); + +void CallFunctionForTarget(int64_t target, int line) { + if ((HWY_TARGETS & target) == 0) return; + hwy::SetSupportedTargetsForTest(target); + + // Call Update() first to make &HWY_DYNAMIC_DISPATCH() return + // the pointer to the already cached function. + hwy::GetChosenTarget().Update(hwy::SupportedTargets()); + + EXPECT_EQ(target, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; + + // Calling DeInit() will test that the initializer function + // also calls the right function. + hwy::GetChosenTarget().DeInit(); + +#if HWY_DISPATCH_WORKAROUND + EXPECT_EQ(HWY_STATIC_TARGET, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; +#else + EXPECT_EQ(target, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; +#endif + + // Second call uses the cached value from the previous call. + EXPECT_EQ(target, HWY_DYNAMIC_DISPATCH(FakeFunction)(42)) << line; +} + +void CheckFakeFunction() { + // When adding a target, also add to DECLARE_FUNCTION above. + CallFunctionForTarget(HWY_AVX3_DL, __LINE__); + CallFunctionForTarget(HWY_AVX3, __LINE__); + CallFunctionForTarget(HWY_AVX2, __LINE__); + CallFunctionForTarget(HWY_SSE4, __LINE__); + CallFunctionForTarget(HWY_SSSE3, __LINE__); + CallFunctionForTarget(HWY_NEON, __LINE__); + CallFunctionForTarget(HWY_SVE, __LINE__); + CallFunctionForTarget(HWY_SVE2, __LINE__); + CallFunctionForTarget(HWY_SVE_256, __LINE__); + CallFunctionForTarget(HWY_SVE2_128, __LINE__); + CallFunctionForTarget(HWY_PPC8, __LINE__); + CallFunctionForTarget(HWY_WASM, __LINE__); + CallFunctionForTarget(HWY_WASM_EMU256, __LINE__); + CallFunctionForTarget(HWY_RVV, __LINE__); + // The tables only have space for either HWY_SCALAR or HWY_EMU128; the former + // is opt-in only. +#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128 + CallFunctionForTarget(HWY_SCALAR, __LINE__); +#else + CallFunctionForTarget(HWY_EMU128, __LINE__); +#endif +} + +} // namespace fake + +namespace hwy { + +class HwyTargetsTest : public testing::Test { + protected: + void TearDown() override { + SetSupportedTargetsForTest(0); + DisableTargets(0); // Reset the mask. + } +}; + +// Test that the order in the HWY_EXPORT static array matches the expected +// value of the target bits. This is only checked for the targets that are +// enabled in the current compilation. +TEST_F(HwyTargetsTest, ChosenTargetOrderTest) { fake::CheckFakeFunction(); } + +TEST_F(HwyTargetsTest, DisabledTargetsTest) { + DisableTargets(~0LL); + // Check that disabling everything at least leaves the static target. + HWY_ASSERT(HWY_STATIC_TARGET == SupportedTargets()); + + DisableTargets(0); // Reset the mask. + const int64_t current_targets = SupportedTargets(); + const int64_t enabled_baseline = static_cast<int64_t>(HWY_ENABLED_BASELINE); + // Exclude these two because they are always returned by SupportedTargets. + const int64_t fallback = HWY_SCALAR | HWY_EMU128; + if ((current_targets & ~enabled_baseline & ~fallback) == 0) { + // We can't test anything else if the only compiled target is the baseline. + return; + } + + // Get the lowest bit in the mask (the best target) and disable that one. + const int64_t best_target = current_targets & (~current_targets + 1); + DisableTargets(best_target); + + // Check that the other targets are still enabled. + HWY_ASSERT((best_target ^ current_targets) == SupportedTargets()); + DisableTargets(0); // Reset the mask. +} + +} // namespace hwy diff --git a/third_party/highway/hwy/tests/arithmetic_test.cc b/third_party/highway/hwy/tests/arithmetic_test.cc new file mode 100644 index 0000000000..fa533228a0 --- /dev/null +++ b/third_party/highway/hwy/tests/arithmetic_test.cc @@ -0,0 +1,499 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/arithmetic_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestPlusMinus { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, T(2)); + const auto v3 = Iota(d, T(3)); + const auto v4 = Iota(d, T(4)); + + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast<T>((2 + i) + (3 + i)); + } + HWY_ASSERT_VEC_EQ(d, lanes.get(), Add(v2, v3)); + HWY_ASSERT_VEC_EQ(d, Set(d, 2), Sub(v4, v2)); + + for (size_t i = 0; i < N; ++i) { + lanes[i] = static_cast<T>((2 + i) + (4 + i)); + } + auto sum = v2; + sum = Add(sum, v4); // sum == 6,8.. + HWY_ASSERT_VEC_EQ(d, Load(d, lanes.get()), sum); + + sum = Sub(sum, v4); + HWY_ASSERT_VEC_EQ(d, v2, sum); + } +}; + +struct TestPlusMinusOverflow { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Iota(d, T(1)); + const auto vMax = Iota(d, LimitsMax<T>()); + const auto vMin = Iota(d, LimitsMin<T>()); + + // Check that no UB triggered. + // "assert" here is formal - to avoid compiler dropping calculations + HWY_ASSERT_VEC_EQ(d, Add(v1, vMax), Add(vMax, v1)); + HWY_ASSERT_VEC_EQ(d, Add(vMax, vMax), Add(vMax, vMax)); + HWY_ASSERT_VEC_EQ(d, Sub(vMin, v1), Sub(vMin, v1)); + HWY_ASSERT_VEC_EQ(d, Sub(vMin, vMax), Sub(vMin, vMax)); + } +}; + +HWY_NOINLINE void TestAllPlusMinus() { + ForAllTypes(ForPartialVectors<TestPlusMinus>()); + ForIntegerTypes(ForPartialVectors<TestPlusMinusOverflow>()); +} + +struct TestUnsignedSaturatingArithmetic { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 1); + const auto vm = Set(d, LimitsMax<T>()); + + HWY_ASSERT_VEC_EQ(d, Add(v0, v0), SaturatedAdd(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Add(v0, vi), SaturatedAdd(v0, vi)); + HWY_ASSERT_VEC_EQ(d, Add(v0, vm), SaturatedAdd(v0, vm)); + HWY_ASSERT_VEC_EQ(d, vm, SaturatedAdd(vi, vm)); + HWY_ASSERT_VEC_EQ(d, vm, SaturatedAdd(vm, vm)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(vi, vi)); + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(vi, vm)); + HWY_ASSERT_VEC_EQ(d, Sub(vm, vi), SaturatedSub(vm, vi)); + } +}; + +struct TestSignedSaturatingArithmetic { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vpm = Set(d, LimitsMax<T>()); + // Ensure all lanes are positive, even if Iota wraps around + const auto vi = Or(And(Iota(d, 0), vpm), Set(d, 1)); + const auto vn = Sub(v0, vi); + const auto vnm = Set(d, LimitsMin<T>()); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), Gt(vi, v0)); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), Lt(vn, v0)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedAdd(v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, SaturatedAdd(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(v0, vpm)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(vi, vpm)); + HWY_ASSERT_VEC_EQ(d, vpm, SaturatedAdd(vpm, vpm)); + + HWY_ASSERT_VEC_EQ(d, v0, SaturatedSub(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Sub(v0, vi), SaturatedSub(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vn, SaturatedSub(vn, v0)); + HWY_ASSERT_VEC_EQ(d, vnm, SaturatedSub(vnm, vi)); + HWY_ASSERT_VEC_EQ(d, vnm, SaturatedSub(vnm, vpm)); + } +}; + +struct TestSaturatingArithmeticOverflow { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Iota(d, T(1)); + const auto vMax = Iota(d, LimitsMax<T>()); + const auto vMin = Iota(d, LimitsMin<T>()); + + // Check that no UB triggered. + // "assert" here is formal - to avoid compiler dropping calculations + HWY_ASSERT_VEC_EQ(d, SaturatedAdd(v1, vMax), SaturatedAdd(vMax, v1)); + HWY_ASSERT_VEC_EQ(d, SaturatedAdd(vMax, vMax), SaturatedAdd(vMax, vMax)); + HWY_ASSERT_VEC_EQ(d, SaturatedAdd(vMin, vMax), SaturatedAdd(vMin, vMax)); + HWY_ASSERT_VEC_EQ(d, SaturatedAdd(vMin, vMin), SaturatedAdd(vMin, vMin)); + HWY_ASSERT_VEC_EQ(d, SaturatedSub(vMin, v1), SaturatedSub(vMin, v1)); + HWY_ASSERT_VEC_EQ(d, SaturatedSub(vMin, vMax), SaturatedSub(vMin, vMax)); + HWY_ASSERT_VEC_EQ(d, SaturatedSub(vMax, vMin), SaturatedSub(vMax, vMin)); + HWY_ASSERT_VEC_EQ(d, SaturatedSub(vMin, vMin), SaturatedSub(vMin, vMin)); + } +}; + +HWY_NOINLINE void TestAllSaturatingArithmetic() { + const ForPartialVectors<TestUnsignedSaturatingArithmetic> test_unsigned; + test_unsigned(uint8_t()); + test_unsigned(uint16_t()); + + const ForPartialVectors<TestSignedSaturatingArithmetic> test_signed; + test_signed(int8_t()); + test_signed(int16_t()); + + const ForPartialVectors<TestSaturatingArithmeticOverflow> test_overflow; + test_overflow(int8_t()); + test_overflow(uint8_t()); + test_overflow(int16_t()); + test_overflow(uint16_t()); +} + +struct TestAverage { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto v2 = Set(d, T(2)); + + HWY_ASSERT_VEC_EQ(d, v0, AverageRound(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v0, v1)); + HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v1, v1)); + HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v2, v2)); + } +}; + +HWY_NOINLINE void TestAllAverage() { + const ForPartialVectors<TestAverage> test; + test(uint8_t()); + test(uint16_t()); +} + +struct TestAbs { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp1 = Set(d, T(1)); + const auto vn1 = Set(d, T(-1)); + const auto vpm = Set(d, LimitsMax<T>()); + const auto vnm = Set(d, LimitsMin<T>()); + + HWY_ASSERT_VEC_EQ(d, v0, Abs(v0)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vp1)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vn1)); + HWY_ASSERT_VEC_EQ(d, vpm, Abs(vpm)); + HWY_ASSERT_VEC_EQ(d, vnm, Abs(vnm)); + } +}; + +struct TestFloatAbs { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp1 = Set(d, T(1)); + const auto vn1 = Set(d, T(-1)); + const auto vp2 = Set(d, T(0.01)); + const auto vn2 = Set(d, T(-0.01)); + + HWY_ASSERT_VEC_EQ(d, v0, Abs(v0)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vp1)); + HWY_ASSERT_VEC_EQ(d, vp1, Abs(vn1)); + HWY_ASSERT_VEC_EQ(d, vp2, Abs(vp2)); + HWY_ASSERT_VEC_EQ(d, vp2, Abs(vn2)); + } +}; + +HWY_NOINLINE void TestAllAbs() { + ForSignedTypes(ForPartialVectors<TestAbs>()); + ForFloatTypes(ForPartialVectors<TestFloatAbs>()); +} + +struct TestNeg { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vn = Set(d, T(-3)); + const auto vp = Set(d, T(3)); + HWY_ASSERT_VEC_EQ(d, v0, Neg(v0)); + HWY_ASSERT_VEC_EQ(d, vp, Neg(vn)); + HWY_ASSERT_VEC_EQ(d, vn, Neg(vp)); + } +}; + +struct TestNegOverflow { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto vn = Set(d, LimitsMin<T>()); + const auto vp = Set(d, LimitsMax<T>()); + HWY_ASSERT_VEC_EQ(d, Neg(vn), Neg(vn)); + HWY_ASSERT_VEC_EQ(d, Neg(vp), Neg(vp)); + } +}; + +HWY_NOINLINE void TestAllNeg() { + ForSignedTypes(ForPartialVectors<TestNeg>()); + ForFloatTypes(ForPartialVectors<TestNeg>()); + ForSignedTypes(ForPartialVectors<TestNegOverflow>()); +} + +struct TestUnsignedMinMax { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + // Leave headroom such that v1 < v2 even after wraparound. + const auto mod = And(Iota(d, 0), Set(d, LimitsMax<T>() >> 1)); + const auto v1 = Add(mod, Set(d, 1)); + const auto v2 = Add(mod, Set(d, 2)); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v0, Min(v1, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v0)); + + const auto vmin = Set(d, LimitsMin<T>()); + const auto vmax = Set(d, LimitsMax<T>()); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +struct TestSignedMinMax { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Leave headroom such that v1 < v2 even after wraparound. + const auto mod = And(Iota(d, 0), Set(d, LimitsMax<T>() >> 1)); + const auto v1 = Add(mod, Set(d, 1)); + const auto v2 = Add(mod, Set(d, 2)); + const auto v_neg = Sub(Zero(d), v1); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v_neg, Min(v1, v_neg)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v_neg)); + + const auto v0 = Zero(d); + const auto vmin = Set(d, LimitsMin<T>()); + const auto vmax = Set(d, LimitsMax<T>()); + HWY_ASSERT_VEC_EQ(d, vmin, Min(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Max(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, v0, Max(vmin, v0)); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +struct TestFloatMinMax { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + const auto v_neg = Iota(d, -T(Lanes(d))); + HWY_ASSERT_VEC_EQ(d, v1, Min(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, Max(v1, v2)); + HWY_ASSERT_VEC_EQ(d, v_neg, Min(v1, v_neg)); + HWY_ASSERT_VEC_EQ(d, v1, Max(v1, v_neg)); + + const auto v0 = Zero(d); + const auto vmin = Set(d, T(-1E30)); + const auto vmax = Set(d, T(1E30)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Max(v0, vmin)); + HWY_ASSERT_VEC_EQ(d, v0, Max(vmin, v0)); + + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmin, Min(vmax, vmin)); + + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmin, vmax)); + HWY_ASSERT_VEC_EQ(d, vmax, Max(vmax, vmin)); + } +}; + +HWY_NOINLINE void TestAllMinMax() { + ForUnsignedTypes(ForPartialVectors<TestUnsignedMinMax>()); + ForSignedTypes(ForPartialVectors<TestSignedMinMax>()); + ForFloatTypes(ForPartialVectors<TestFloatMinMax>()); +} + +template <class D> +static HWY_NOINLINE Vec<D> Make128(D d, uint64_t hi, uint64_t lo) { + alignas(16) uint64_t in[2]; + in[0] = lo; + in[1] = hi; + return LoadDup128(d, in); +} + +struct TestMinMax128 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const size_t N = Lanes(d); + auto a_lanes = AllocateAligned<T>(N); + auto b_lanes = AllocateAligned<T>(N); + auto min_lanes = AllocateAligned<T>(N); + auto max_lanes = AllocateAligned<T>(N); + RandomState rng; + + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + // Same arg + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Min128(d, v11, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Max128(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v11, v11)); + + // First arg less + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v10, v11)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v10, v11)); + + // Second arg less + HWY_ASSERT_VEC_EQ(d, v00, Min128(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v01, Max128(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Max128(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Max128(d, v11, v10)); + + // Also check 128-bit blocks are independent + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + a_lanes[i] = Random64(&rng); + b_lanes[i] = Random64(&rng); + } + const V a = Load(d, a_lanes.get()); + const V b = Load(d, b_lanes.get()); + for (size_t i = 0; i < N; i += 2) { + const bool lt = a_lanes[i + 1] == b_lanes[i + 1] + ? (a_lanes[i] < b_lanes[i]) + : (a_lanes[i + 1] < b_lanes[i + 1]); + min_lanes[i + 0] = lt ? a_lanes[i + 0] : b_lanes[i + 0]; + min_lanes[i + 1] = lt ? a_lanes[i + 1] : b_lanes[i + 1]; + max_lanes[i + 0] = lt ? b_lanes[i + 0] : a_lanes[i + 0]; + max_lanes[i + 1] = lt ? b_lanes[i + 1] : a_lanes[i + 1]; + } + HWY_ASSERT_VEC_EQ(d, min_lanes.get(), Min128(d, a, b)); + HWY_ASSERT_VEC_EQ(d, max_lanes.get(), Max128(d, a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMinMax128() { + ForGEVectors<128, TestMinMax128>()(uint64_t()); +} + +struct TestMinMax128Upper { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const size_t N = Lanes(d); + auto a_lanes = AllocateAligned<T>(N); + auto b_lanes = AllocateAligned<T>(N); + auto min_lanes = AllocateAligned<T>(N); + auto max_lanes = AllocateAligned<T>(N); + RandomState rng; + + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + // Same arg + HWY_ASSERT_VEC_EQ(d, v00, Min128Upper(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Min128Upper(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Min128Upper(d, v11, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Max128Upper(d, v00, v00)); + HWY_ASSERT_VEC_EQ(d, v01, Max128Upper(d, v01, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v10, v10)); + HWY_ASSERT_VEC_EQ(d, v11, Max128Upper(d, v11, v11)); + + // Equivalent but not equal (chooses second arg) + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Min128Upper(d, v10, v11)); + HWY_ASSERT_VEC_EQ(d, v00, Min128Upper(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Min128Upper(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v00, Max128Upper(d, v01, v00)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v11, v10)); + HWY_ASSERT_VEC_EQ(d, v01, Max128Upper(d, v00, v01)); + HWY_ASSERT_VEC_EQ(d, v11, Max128Upper(d, v10, v11)); + + // First arg less + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v01, v10)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v01, v10)); + + // Second arg less + HWY_ASSERT_VEC_EQ(d, v01, Min128Upper(d, v10, v01)); + HWY_ASSERT_VEC_EQ(d, v10, Max128Upper(d, v10, v01)); + + // Also check 128-bit blocks are independent + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + a_lanes[i] = Random64(&rng); + b_lanes[i] = Random64(&rng); + } + const V a = Load(d, a_lanes.get()); + const V b = Load(d, b_lanes.get()); + for (size_t i = 0; i < N; i += 2) { + const bool lt = a_lanes[i + 1] < b_lanes[i + 1]; + min_lanes[i + 0] = lt ? a_lanes[i + 0] : b_lanes[i + 0]; + min_lanes[i + 1] = lt ? a_lanes[i + 1] : b_lanes[i + 1]; + max_lanes[i + 0] = lt ? b_lanes[i + 0] : a_lanes[i + 0]; + max_lanes[i + 1] = lt ? b_lanes[i + 1] : a_lanes[i + 1]; + } + HWY_ASSERT_VEC_EQ(d, min_lanes.get(), Min128Upper(d, a, b)); + HWY_ASSERT_VEC_EQ(d, max_lanes.get(), Max128Upper(d, a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMinMax128Upper() { + ForGEVectors<128, TestMinMax128Upper>()(uint64_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyArithmeticTest); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllPlusMinus); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllSaturatingArithmetic); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAverage); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllAbs); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllNeg); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax128); +HWY_EXPORT_AND_TEST_P(HwyArithmeticTest, TestAllMinMax128Upper); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/blockwise_shift_test.cc b/third_party/highway/hwy/tests/blockwise_shift_test.cc new file mode 100644 index 0000000000..4e5250841b --- /dev/null +++ b/third_party/highway/hwy/tests/blockwise_shift_test.cc @@ -0,0 +1,270 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#include <algorithm> // std::fill + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/blockwise_shift_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestShiftBytes { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Bytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const Repartition<uint8_t, D> du8; + const size_t N8 = Lanes(du8); + + // Zero remains zero + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(v0)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(d, v0)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightBytes<1>(d, v0)); + + // Zero after shifting out the high/low byte + auto bytes = AllocateAligned<uint8_t>(N8); + std::fill(bytes.get(), bytes.get() + N8, 0); + bytes[N8 - 1] = 0x7F; + const auto vhi = BitCast(d, Load(du8, bytes.get())); + bytes[N8 - 1] = 0; + bytes[0] = 0x7F; + const auto vlo = BitCast(d, Load(du8, bytes.get())); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(vhi)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftLeftBytes<1>(d, vhi)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightBytes<1>(d, vlo)); + + // Check expected result with Iota + const size_t N = Lanes(d); + auto in = AllocateAligned<T>(N); + const uint8_t* in_bytes = reinterpret_cast<const uint8_t*>(in.get()); + const auto v = BitCast(d, Iota(du8, 1)); + Store(v, d, in.get()); + + auto expected = AllocateAligned<T>(N); + uint8_t* expected_bytes = reinterpret_cast<uint8_t*>(expected.get()); + + const size_t block_size = HWY_MIN(N8, 16); + for (size_t block = 0; block < N8; block += block_size) { + expected_bytes[block] = 0; + memcpy(expected_bytes + block + 1, in_bytes + block, block_size - 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftBytes<1>(v)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftBytes<1>(d, v)); + + for (size_t block = 0; block < N8; block += block_size) { + memcpy(expected_bytes + block, in_bytes + block + 1, block_size - 1); + expected_bytes[block + block_size - 1] = 0; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightBytes<1>(d, v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllShiftBytes() { + ForIntegerTypes(ForPartialVectors<TestShiftBytes>()); +} + +struct TestShiftLeftLanes { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Lanes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const auto v = Iota(d, T(1)); + const size_t N = Lanes(d); + if (N == 1) return; + auto expected = AllocateAligned<T>(N); + + HWY_ASSERT_VEC_EQ(d, v, ShiftLeftLanes<0>(v)); + HWY_ASSERT_VEC_EQ(d, v, ShiftLeftLanes<0>(d, v)); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + + for (size_t i = 0; i < N; ++i) { + expected[i] = (i % kLanesPerBlock) == 0 ? T(0) : T(i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftLanes<1>(v)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftLanes<1>(d, v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +struct TestShiftRightLanes { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define Shift*Lanes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + const auto v = Iota(d, T(1)); + const size_t N = Lanes(d); + if (N == 1) return; + auto expected = AllocateAligned<T>(N); + + HWY_ASSERT_VEC_EQ(d, v, ShiftRightLanes<0>(d, v)); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + + for (size_t i = 0; i < N; ++i) { + const size_t mod = i % kLanesPerBlock; + expected[i] = mod == (kLanesPerBlock - 1) || i >= N - 1 ? T(0) : T(2 + i); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightLanes<1>(d, v)); +#else + (void)d; +#endif // #if HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllShiftLeftLanes() { + ForAllTypes(ForPartialVectors<TestShiftLeftLanes>()); +} + +HWY_NOINLINE void TestAllShiftRightLanes() { + ForAllTypes(ForPartialVectors<TestShiftRightLanes>()); +} + +// Scalar does not define CombineShiftRightBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template <int kBytes> +struct TestCombineShiftRightBytes { + template <class T, class D> + HWY_NOINLINE void operator()(T, D d) { + constexpr size_t kBlockSize = 16; + static_assert(kBytes < kBlockSize, "Shift count is per block"); + const Repartition<uint8_t, D> d8; + const size_t N8 = Lanes(d8); + if (N8 < 16) return; + auto hi_bytes = AllocateAligned<uint8_t>(N8); + auto lo_bytes = AllocateAligned<uint8_t>(N8); + auto expected_bytes = AllocateAligned<uint8_t>(N8); + uint8_t combined[2 * kBlockSize]; + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(100); ++rep) { + for (size_t i = 0; i < N8; ++i) { + hi_bytes[i] = static_cast<uint8_t>(Random64(&rng) & 0xFF); + lo_bytes[i] = static_cast<uint8_t>(Random64(&rng) & 0xFF); + } + for (size_t i = 0; i < N8; i += kBlockSize) { + // Arguments are not the same size. + CopyBytes<kBlockSize>(&lo_bytes[i], combined); + CopyBytes<kBlockSize>(&hi_bytes[i], combined + kBlockSize); + CopyBytes<kBlockSize>(combined + kBytes, &expected_bytes[i]); + } + + const auto hi = BitCast(d, Load(d8, hi_bytes.get())); + const auto lo = BitCast(d, Load(d8, lo_bytes.get())); + const auto expected = BitCast(d, Load(d8, expected_bytes.get())); + HWY_ASSERT_VEC_EQ(d, expected, CombineShiftRightBytes<kBytes>(d, hi, lo)); + } + } +}; + +template <int kLanes> +struct TestCombineShiftRightLanes { + template <class T, class D> + HWY_NOINLINE void operator()(T, D d) { + const Repartition<uint8_t, D> d8; + const size_t N8 = Lanes(d8); + if (N8 < 16) return; + + auto hi_bytes = AllocateAligned<uint8_t>(N8); + auto lo_bytes = AllocateAligned<uint8_t>(N8); + auto expected_bytes = AllocateAligned<uint8_t>(N8); + constexpr size_t kBlockSize = 16; + uint8_t combined[2 * kBlockSize]; + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(100); ++rep) { + for (size_t i = 0; i < N8; ++i) { + hi_bytes[i] = static_cast<uint8_t>(Random64(&rng) & 0xFF); + lo_bytes[i] = static_cast<uint8_t>(Random64(&rng) & 0xFF); + } + for (size_t i = 0; i < N8; i += kBlockSize) { + // Arguments are not the same size. + CopyBytes<kBlockSize>(&lo_bytes[i], combined); + CopyBytes<kBlockSize>(&hi_bytes[i], combined + kBlockSize); + CopyBytes<kBlockSize>(combined + kLanes * sizeof(T), + &expected_bytes[i]); + } + + const auto hi = BitCast(d, Load(d8, hi_bytes.get())); + const auto lo = BitCast(d, Load(d8, lo_bytes.get())); + const auto expected = BitCast(d, Load(d8, expected_bytes.get())); + HWY_ASSERT_VEC_EQ(d, expected, CombineShiftRightLanes<kLanes>(d, hi, lo)); + } + } +}; + +#endif // #if HWY_TARGET != HWY_SCALAR + +struct TestCombineShiftRight { + template <class T, class D> + HWY_NOINLINE void operator()(T t, D d) { +// Scalar does not define CombineShiftRightBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + constexpr int kMaxBytes = + HWY_MIN(16, static_cast<int>(MaxLanes(d) * sizeof(T))); + constexpr int kMaxLanes = kMaxBytes / static_cast<int>(sizeof(T)); + TestCombineShiftRightBytes<kMaxBytes - 1>()(t, d); + TestCombineShiftRightBytes<HWY_MAX(kMaxBytes / 2, 1)>()(t, d); + TestCombineShiftRightBytes<1>()(t, d); + + TestCombineShiftRightLanes<kMaxLanes - 1>()(t, d); + TestCombineShiftRightLanes<HWY_MAX(kMaxLanes / 2, -1)>()(t, d); + TestCombineShiftRightLanes<1>()(t, d); +#else + (void)t; + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllCombineShiftRight() { + // Need at least 2 lanes. + ForAllTypes(ForShrinkableVectors<TestCombineShiftRight>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyBlockwiseShiftTest); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllShiftBytes); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllShiftLeftLanes); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllShiftRightLanes); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseShiftTest, TestAllCombineShiftRight); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/blockwise_test.cc b/third_party/highway/hwy/tests/blockwise_test.cc new file mode 100644 index 0000000000..e5ac9ab362 --- /dev/null +++ b/third_party/highway/hwy/tests/blockwise_test.cc @@ -0,0 +1,454 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <algorithm> // std::fill + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/blockwise_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename D, int kLane> +struct TestBroadcastR { + HWY_NOINLINE void operator()() const { + using T = typename D::T; + const D d; + const size_t N = Lanes(d); + if (kLane >= N) return; + auto in_lanes = AllocateAligned<T>(N); + std::fill(in_lanes.get(), in_lanes.get() + N, T(0)); + const size_t blockN = HWY_MIN(N * sizeof(T), 16) / sizeof(T); + // Need to set within each 128-bit block + for (size_t block = 0; block < N; block += blockN) { + in_lanes[block + kLane] = static_cast<T>(block + 1); + } + const auto in = Load(d, in_lanes.get()); + auto expected = AllocateAligned<T>(N); + for (size_t block = 0; block < N; block += blockN) { + for (size_t i = 0; i < blockN; ++i) { + expected[block + i] = T(block + 1); + } + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Broadcast<kLane>(in)); + + TestBroadcastR<D, kLane - 1>()(); + } +}; + +template <class D> +struct TestBroadcastR<D, -1> { + void operator()() const {} +}; + +struct TestBroadcast { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + TestBroadcastR<D, HWY_MIN(MaxLanes(d), 16 / sizeof(T)) - 1>()(); + } +}; + +HWY_NOINLINE void TestAllBroadcast() { + const ForPartialVectors<TestBroadcast> test; + // No u/i8. + test(uint16_t()); + test(int16_t()); + ForUIF3264(test); +} + +template <bool kFull> +struct ChooseTableSize { + template <typename T, typename DIdx> + using type = DIdx; +}; +template <> +struct ChooseTableSize<true> { + template <typename T, typename DIdx> + using type = ScalableTag<T>; +}; + +template <bool kFull> +struct TestTableLookupBytes { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + RandomState rng; + + const typename ChooseTableSize<kFull>::template type<T, D> d_tbl; + const Repartition<uint8_t, decltype(d_tbl)> d_tbl8; + const size_t NT8 = Lanes(d_tbl8); + + const Repartition<uint8_t, D> d8; + const size_t N8 = Lanes(d8); + + // Random input bytes + auto in_bytes = AllocateAligned<uint8_t>(NT8); + for (size_t i = 0; i < NT8; ++i) { + in_bytes[i] = Random32(&rng) & 0xFF; + } + const auto in = BitCast(d_tbl, Load(d_tbl8, in_bytes.get())); + + // Enough test data; for larger vectors, upper lanes will be zero. + const uint8_t index_bytes_source[64] = { + // Same index as source, multiple outputs from same input, + // unused input (9), ascending/descending and nonconsecutive neighbors. + 0, 2, 1, 2, 15, 12, 13, 14, 6, 7, 8, 5, 4, 3, 10, 11, + 11, 10, 3, 4, 5, 8, 7, 6, 14, 13, 12, 15, 2, 1, 2, 0, + 4, 3, 2, 2, 5, 6, 7, 7, 15, 15, 15, 15, 15, 15, 0, 1}; + auto index_bytes = AllocateAligned<uint8_t>(N8); + const size_t max_index = HWY_MIN(NT8, 16) - 1; + for (size_t i = 0; i < N8; ++i) { + index_bytes[i] = (i < 64) ? index_bytes_source[i] : 0; + // Avoid asan error for partial vectors. + index_bytes[i] = static_cast<uint8_t>(HWY_MIN(index_bytes[i], max_index)); + } + const auto indices = Load(d, reinterpret_cast<const T*>(index_bytes.get())); + + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + uint8_t* expected_bytes = reinterpret_cast<uint8_t*>(expected.get()); + + for (size_t block = 0; block < N8; block += 16) { + for (size_t i = 0; i < 16 && (block + i) < N8; ++i) { + const uint8_t index = index_bytes[block + i]; + HWY_ASSERT(index <= max_index); + // Note that block + index may exceed NT8 on RVV, which is fine because + // the operation uses the larger of the table and index vector size. + HWY_ASSERT(block + index < HWY_MAX(N8, NT8)); + // For large vectors, the lane index may wrap around due to block, + // also wrap around after 8-bit overflow. + expected_bytes[block + i] = + in_bytes[(block + index) % HWY_MIN(NT8, 256)]; + } + } + HWY_ASSERT_VEC_EQ(d, expected.get(), TableLookupBytes(in, indices)); + + // Individually test zeroing each byte position. + for (size_t i = 0; i < N8; ++i) { + const uint8_t prev_expected = expected_bytes[i]; + const uint8_t prev_index = index_bytes[i]; + expected_bytes[i] = 0; + + const int idx = 0x80 + (static_cast<int>(Random32(&rng) & 7) << 4); + HWY_ASSERT(0x80 <= idx && idx < 256); + index_bytes[i] = static_cast<uint8_t>(idx); + + const auto indices = + Load(d, reinterpret_cast<const T*>(index_bytes.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), TableLookupBytesOr0(in, indices)); + expected_bytes[i] = prev_expected; + index_bytes[i] = prev_index; + } +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllTableLookupBytesSame() { + // Partial index, same-sized table. + ForIntegerTypes(ForPartialVectors<TestTableLookupBytes<false>>()); +} + +HWY_NOINLINE void TestAllTableLookupBytesMixed() { + // Partial index, full-size table. + ForIntegerTypes(ForPartialVectors<TestTableLookupBytes<true>>()); +} + +struct TestInterleaveLower { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned<T>; + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned<T>(N); + auto odd_lanes = AllocateAligned<T>(N); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast<T>(2 * i + 0); + odd_lanes[i] = static_cast<T>(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = HWY_MIN(16 / sizeof(T), N); + for (size_t i = 0; i < Lanes(d); ++i) { + const size_t block = i / blockN; + const size_t index = (i % blockN) + block * 2 * blockN; + expected[i] = static_cast<T>(index & LimitsMax<TU>()); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveLower(even, odd)); + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveLower(d, even, odd)); + } +}; + +struct TestInterleaveUpper { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + if (N == 1) return; + auto even_lanes = AllocateAligned<T>(N); + auto odd_lanes = AllocateAligned<T>(N); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast<T>(2 * i + 0); + odd_lanes[i] = static_cast<T>(2 * i + 1); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = HWY_MIN(16 / sizeof(T), N); + for (size_t i = 0; i < Lanes(d); ++i) { + const size_t block = i / blockN; + expected[i] = T((i % blockN) + block * 2 * blockN + blockN); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), InterleaveUpper(d, even, odd)); + } +}; + +HWY_NOINLINE void TestAllInterleave() { + // Not DemoteVectors because this cannot be supported by HWY_SCALAR. + ForAllTypes(ForShrinkableVectors<TestInterleaveLower>()); + ForAllTypes(ForShrinkableVectors<TestInterleaveUpper>()); +} + +struct TestZipLower { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using WideT = MakeWide<T>; + static_assert(sizeof(T) * 2 == sizeof(WideT), "Must be double-width"); + static_assert(IsSigned<T>() == IsSigned<WideT>(), "Must have same sign"); + const size_t N = Lanes(d); + auto even_lanes = AllocateAligned<T>(N); + auto odd_lanes = AllocateAligned<T>(N); + // At least 2 lanes for HWY_SCALAR + auto zip_lanes = AllocateAligned<T>(HWY_MAX(N, 2)); + const T kMaxT = LimitsMax<T>(); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast<T>((2 * i + 0) & kMaxT); + odd_lanes[i] = static_cast<T>((2 * i + 1) & kMaxT); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const Repartition<WideT, D> dw; +#if HWY_TARGET == HWY_SCALAR + // Safely handle big-endian + const auto expected = Set(dw, static_cast<WideT>(1ULL << (sizeof(T) * 8))); +#else + const size_t blockN = HWY_MIN(size_t(16) / sizeof(T), N); + for (size_t i = 0; i < N; i += 2) { + const size_t base = (i / blockN) * blockN; + const size_t mod = i % blockN; + zip_lanes[i + 0] = even_lanes[mod / 2 + base]; + zip_lanes[i + 1] = odd_lanes[mod / 2 + base]; + } + const auto expected = + Load(dw, reinterpret_cast<const WideT*>(zip_lanes.get())); +#endif // HWY_TARGET == HWY_SCALAR + HWY_ASSERT_VEC_EQ(dw, expected, ZipLower(even, odd)); + HWY_ASSERT_VEC_EQ(dw, expected, ZipLower(dw, even, odd)); + } +}; + +HWY_NOINLINE void TestAllZipLower() { + const ForDemoteVectors<TestZipLower> lower_unsigned; + lower_unsigned(uint8_t()); + lower_unsigned(uint16_t()); +#if HWY_HAVE_INTEGER64 + lower_unsigned(uint32_t()); // generates u64 +#endif + + const ForDemoteVectors<TestZipLower> lower_signed; + lower_signed(int8_t()); + lower_signed(int16_t()); +#if HWY_HAVE_INTEGER64 + lower_signed(int32_t()); // generates i64 +#endif + + // No float - concatenating f32 does not result in a f64 +} + +// Remove this test (so it does not show as having run) if the only target is +// HWY_SCALAR, which does not support this op. +#if HWY_TARGETS != HWY_SCALAR + +struct TestZipUpper { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET == HWY_SCALAR + (void)d; +#else + using WideT = MakeWide<T>; + static_assert(sizeof(T) * 2 == sizeof(WideT), "Must be double-width"); + static_assert(IsSigned<T>() == IsSigned<WideT>(), "Must have same sign"); + const size_t N = Lanes(d); + if (N < 16 / sizeof(T)) return; + auto even_lanes = AllocateAligned<T>(N); + auto odd_lanes = AllocateAligned<T>(N); + auto zip_lanes = AllocateAligned<T>(N); + const T kMaxT = LimitsMax<T>(); + for (size_t i = 0; i < N; ++i) { + even_lanes[i] = static_cast<T>((2 * i + 0) & kMaxT); + odd_lanes[i] = static_cast<T>((2 * i + 1) & kMaxT); + } + const auto even = Load(d, even_lanes.get()); + const auto odd = Load(d, odd_lanes.get()); + + const size_t blockN = HWY_MIN(size_t(16) / sizeof(T), N); + + for (size_t i = 0; i < N; i += 2) { + const size_t base = (i / blockN) * blockN + blockN / 2; + const size_t mod = i % blockN; + zip_lanes[i + 0] = even_lanes[mod / 2 + base]; + zip_lanes[i + 1] = odd_lanes[mod / 2 + base]; + } + const Repartition<WideT, D> dw; + const auto expected = + Load(dw, reinterpret_cast<const WideT*>(zip_lanes.get())); + HWY_ASSERT_VEC_EQ(dw, expected, ZipUpper(dw, even, odd)); +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllZipUpper() { + const ForShrinkableVectors<TestZipUpper> upper_unsigned; + upper_unsigned(uint8_t()); + upper_unsigned(uint16_t()); +#if HWY_HAVE_INTEGER64 + upper_unsigned(uint32_t()); // generates u64 +#endif + + const ForShrinkableVectors<TestZipUpper> upper_signed; + upper_signed(int8_t()); + upper_signed(int16_t()); +#if HWY_HAVE_INTEGER64 + upper_signed(int32_t()); // generates i64 +#endif + + // No float - concatenating f32 does not result in a f64 +} + +#endif // HWY_TARGETS != HWY_SCALAR + +class TestSpecialShuffle32 { + public: + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, 0); + VerifyLanes32(d, Shuffle2301(v), 2, 3, 0, 1, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle1032(v), 1, 0, 3, 2, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle0321(v), 0, 3, 2, 1, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle2103(v), 2, 1, 0, 3, __FILE__, __LINE__); + VerifyLanes32(d, Shuffle0123(v), 0, 1, 2, 3, __FILE__, __LINE__); + } + + private: + // HWY_INLINE works around a Clang SVE compiler bug where all but the first + // 128 bits (the NEON register) of actual are zero. + template <class D, class V> + HWY_INLINE void VerifyLanes32(D d, VecArg<V> actual, const size_t i3, + const size_t i2, const size_t i1, + const size_t i0, const char* filename, + const int line) { + using T = TFromD<D>; + constexpr size_t kBlockN = 16 / sizeof(T); + const size_t N = Lanes(d); + if (N < 4) return; + auto expected = AllocateAligned<T>(N); + for (size_t block = 0; block < N; block += kBlockN) { + expected[block + 3] = static_cast<T>(block + i3); + expected[block + 2] = static_cast<T>(block + i2); + expected[block + 1] = static_cast<T>(block + i1); + expected[block + 0] = static_cast<T>(block + i0); + } + AssertVecEqual(d, expected.get(), actual, filename, line); + } +}; + +class TestSpecialShuffle64 { + public: + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, 0); + VerifyLanes64(d, Shuffle01(v), 0, 1, __FILE__, __LINE__); + } + + private: + // HWY_INLINE works around a Clang SVE compiler bug where all but the first + // 128 bits (the NEON register) of actual are zero. + template <class D, class V> + HWY_INLINE void VerifyLanes64(D d, VecArg<V> actual, const size_t i1, + const size_t i0, const char* filename, + const int line) { + using T = TFromD<D>; + constexpr size_t kBlockN = 16 / sizeof(T); + const size_t N = Lanes(d); + if (N < 2) return; + auto expected = AllocateAligned<T>(N); + for (size_t block = 0; block < N; block += kBlockN) { + expected[block + 1] = static_cast<T>(block + i1); + expected[block + 0] = static_cast<T>(block + i0); + } + AssertVecEqual(d, expected.get(), actual, filename, line); + } +}; + +HWY_NOINLINE void TestAllSpecialShuffles() { + const ForGEVectors<128, TestSpecialShuffle32> test32; + test32(uint32_t()); + test32(int32_t()); + test32(float()); + +#if HWY_HAVE_INTEGER64 + const ForGEVectors<128, TestSpecialShuffle64> test64; + test64(uint64_t()); + test64(int64_t()); +#endif + +#if HWY_HAVE_FLOAT64 + const ForGEVectors<128, TestSpecialShuffle64> test_d; + test_d(double()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyBlockwiseTest); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllBroadcast); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllTableLookupBytesSame); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllTableLookupBytesMixed); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllInterleave); +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllZipLower); +#if HWY_TARGETS != HWY_SCALAR +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllZipUpper); +#endif +HWY_EXPORT_AND_TEST_P(HwyBlockwiseTest, TestAllSpecialShuffles); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/combine_test.cc b/third_party/highway/hwy/tests/combine_test.cc new file mode 100644 index 0000000000..e2f4cbeb00 --- /dev/null +++ b/third_party/highway/hwy/tests/combine_test.cc @@ -0,0 +1,275 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#include <algorithm> // std::fill + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/combine_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLowerHalf { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Half<D> d2; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + auto lanes2 = AllocateAligned<T>(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + std::fill(lanes2.get(), lanes2.get() + N, T(0)); + const auto v = Iota(d, 1); + Store(LowerHalf(d2, v), d2, lanes.get()); + Store(LowerHalf(v), d2, lanes2.get()); // optionally without D + size_t i = 0; + for (; i < Lanes(d2); ++i) { + HWY_ASSERT_EQ(T(1 + i), lanes[i]); + HWY_ASSERT_EQ(T(1 + i), lanes2[i]); + } + // Other half remains unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + HWY_ASSERT_EQ(T(0), lanes2[i]); + } + } +}; + +struct TestLowerQuarter { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Half<D> d2; + const Half<decltype(d2)> d4; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + auto lanes2 = AllocateAligned<T>(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + std::fill(lanes2.get(), lanes2.get() + N, T(0)); + const auto v = Iota(d, 1); + const auto lo = LowerHalf(d4, LowerHalf(d2, v)); + const auto lo2 = LowerHalf(LowerHalf(v)); // optionally without D + Store(lo, d4, lanes.get()); + Store(lo2, d4, lanes2.get()); + size_t i = 0; + for (; i < Lanes(d4); ++i) { + HWY_ASSERT_EQ(T(i + 1), lanes[i]); + HWY_ASSERT_EQ(T(i + 1), lanes2[i]); + } + // Upper 3/4 remain unchanged + for (; i < N; ++i) { + HWY_ASSERT_EQ(T(0), lanes[i]); + HWY_ASSERT_EQ(T(0), lanes2[i]); + } + } +}; + +HWY_NOINLINE void TestAllLowerHalf() { + ForAllTypes(ForHalfVectors<TestLowerHalf>()); + + // The minimum vector size is 128 bits, so there's no guarantee we can have + // quarters of 64-bit lanes, hence test 'all' other types. + ForHalfVectors<TestLowerQuarter, 2> test_quarter; + ForUI8(test_quarter); + ForUI16(test_quarter); // exclude float16_t - cannot compare + ForUIF32(test_quarter); +} + +struct TestUpperHalf { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define UpperHalf. +#if HWY_TARGET != HWY_SCALAR + const Half<D> d2; + const size_t N2 = Lanes(d2); + HWY_ASSERT(N2 * 2 == Lanes(d)); + auto expected = AllocateAligned<T>(N2); + size_t i = 0; + for (; i < N2; ++i) { + expected[i] = static_cast<T>(N2 + 1 + i); + } + HWY_ASSERT_VEC_EQ(d2, expected.get(), UpperHalf(d2, Iota(d, 1))); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllUpperHalf() { + ForAllTypes(ForHalfVectors<TestUpperHalf>()); +} + +struct TestZeroExtendVector { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Twice<D> d2; + + const auto v = Iota(d, 1); + const size_t N = Lanes(d); + const size_t N2 = Lanes(d2); + // If equal, then N was already MaxLanes(d) and it's not clear what + // Combine or ZeroExtendVector should return. + if (N2 == N) return; + HWY_ASSERT(N2 == 2 * N); + auto lanes = AllocateAligned<T>(N2); + Store(v, d, &lanes[0]); + Store(v, d, &lanes[N]); + + const auto ext = ZeroExtendVector(d2, v); + Store(ext, d2, lanes.get()); + + // Lower half is unchanged + HWY_ASSERT_VEC_EQ(d, v, Load(d, &lanes[0])); + // Upper half is zero + HWY_ASSERT_VEC_EQ(d, Zero(d), Load(d, &lanes[N])); + } +}; + +HWY_NOINLINE void TestAllZeroExtendVector() { + ForAllTypes(ForExtendableVectors<TestZeroExtendVector>()); +} + +struct TestCombine { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Twice<D> d2; + const size_t N2 = Lanes(d2); + auto lanes = AllocateAligned<T>(N2); + + const auto lo = Iota(d, 1); + const auto hi = Iota(d, static_cast<T>(N2 / 2 + 1)); + const auto combined = Combine(d2, hi, lo); + Store(combined, d2, lanes.get()); + + const auto expected = Iota(d2, 1); + HWY_ASSERT_VEC_EQ(d2, expected, combined); + } +}; + +HWY_NOINLINE void TestAllCombine() { + ForAllTypes(ForExtendableVectors<TestCombine>()); +} + +struct TestConcat { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + if (N == 1) return; + const size_t half_bytes = N * sizeof(T) / 2; + + auto hi = AllocateAligned<T>(N); + auto lo = AllocateAligned<T>(N); + auto expected = AllocateAligned<T>(N); + RandomState rng; + for (size_t rep = 0; rep < 10; ++rep) { + for (size_t i = 0; i < N; ++i) { + hi[i] = static_cast<T>(Random64(&rng) & 0xFF); + lo[i] = static_cast<T>(Random64(&rng) & 0xFF); + } + + { + memcpy(&expected[N / 2], &hi[N / 2], half_bytes); + memcpy(&expected[0], &lo[0], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatUpperLower(d, vhi, vlo)); + } + + { + memcpy(&expected[N / 2], &hi[N / 2], half_bytes); + memcpy(&expected[0], &lo[N / 2], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatUpperUpper(d, vhi, vlo)); + } + + { + memcpy(&expected[N / 2], &hi[0], half_bytes); + memcpy(&expected[0], &lo[N / 2], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatLowerUpper(d, vhi, vlo)); + } + + { + memcpy(&expected[N / 2], &hi[0], half_bytes); + memcpy(&expected[0], &lo[0], half_bytes); + const auto vhi = Load(d, hi.get()); + const auto vlo = Load(d, lo.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), ConcatLowerLower(d, vhi, vlo)); + } + } + } +}; + +HWY_NOINLINE void TestAllConcat() { + ForAllTypes(ForShrinkableVectors<TestConcat>()); +} + +struct TestConcatOddEven { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + const auto hi = Iota(d, static_cast<T>(N)); + const auto lo = Iota(d, 0); + const auto even = Add(Iota(d, 0), Iota(d, 0)); + const auto odd = Add(even, Set(d, 1)); + HWY_ASSERT_VEC_EQ(d, odd, ConcatOdd(d, hi, lo)); + HWY_ASSERT_VEC_EQ(d, even, ConcatEven(d, hi, lo)); + + // This test catches inadvertent saturation. + const auto min = Set(d, LowestValue<T>()); + const auto max = Set(d, HighestValue<T>()); + HWY_ASSERT_VEC_EQ(d, max, ConcatOdd(d, max, max)); + HWY_ASSERT_VEC_EQ(d, max, ConcatEven(d, max, max)); + HWY_ASSERT_VEC_EQ(d, min, ConcatOdd(d, min, min)); + HWY_ASSERT_VEC_EQ(d, min, ConcatEven(d, min, min)); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllConcatOddEven() { + ForAllTypes(ForShrinkableVectors<TestConcatOddEven>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCombineTest); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllLowerHalf); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllUpperHalf); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllZeroExtendVector); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllCombine); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllConcat); +HWY_EXPORT_AND_TEST_P(HwyCombineTest, TestAllConcatOddEven); +} // namespace hwy + +#endif // HWY_ONCE diff --git a/third_party/highway/hwy/tests/compare_test.cc b/third_party/highway/hwy/tests/compare_test.cc new file mode 100644 index 0000000000..a96e29fc62 --- /dev/null +++ b/third_party/highway/hwy/tests/compare_test.cc @@ -0,0 +1,509 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memset + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/compare_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// All types. +struct TestEquality { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, 2); + const auto v2b = Iota(d, 2); + const auto v3 = Iota(d, 3); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq(v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq(v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq(v2, v2b)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ne(v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne(v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne(v2, v2b)); + } +}; + +HWY_NOINLINE void TestAllEquality() { + ForAllTypes(ForPartialVectors<TestEquality>()); +} + +// a > b should be true, verify that for Gt/Lt and with swapped args. +template <class D> +void EnsureGreater(D d, TFromD<D> a, TFromD<D> b, const char* file, int line) { + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + const auto va = Set(d, a); + const auto vb = Set(d, b); + AssertMaskEqual(d, mask_true, Gt(va, vb), file, line); + AssertMaskEqual(d, mask_false, Lt(va, vb), file, line); + + // Swapped order + AssertMaskEqual(d, mask_false, Gt(vb, va), file, line); + AssertMaskEqual(d, mask_true, Lt(vb, va), file, line); + + // Also ensure irreflexive + AssertMaskEqual(d, mask_false, Gt(va, va), file, line); + AssertMaskEqual(d, mask_false, Gt(vb, vb), file, line); + AssertMaskEqual(d, mask_false, Lt(va, va), file, line); + AssertMaskEqual(d, mask_false, Lt(vb, vb), file, line); +} + +#define HWY_ENSURE_GREATER(d, a, b) EnsureGreater(d, a, b, __FILE__, __LINE__) + +struct TestStrictUnsigned { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T max = LimitsMax<T>(); + const auto v0 = Zero(d); + const auto v2 = And(Iota(d, T(2)), Set(d, 255)); // 0..255 + + const auto mask_false = MaskFalse(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 128, 127); + HWY_ENSURE_GREATER(d, max, max / 2); + HWY_ENSURE_GREATER(d, max, 1); + HWY_ENSURE_GREATER(d, max, 0); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + } +}; + +HWY_NOINLINE void TestAllStrictUnsigned() { + ForUnsignedTypes(ForPartialVectors<TestStrictUnsigned>()); +} + +struct TestStrictInt { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T min = LimitsMin<T>(); + const T max = LimitsMax<T>(); + const auto v0 = Zero(d); + const auto v2 = And(Iota(d, T(2)), Set(d, 127)); // 0..127 + const auto vn = Sub(Neg(v2), Set(d, 1)); // -1..-128 + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 0, -1); + HWY_ENSURE_GREATER(d, -1, -2); + HWY_ENSURE_GREATER(d, max, max / 2); + HWY_ENSURE_GREATER(d, max, 1); + HWY_ENSURE_GREATER(d, max, 0); + HWY_ENSURE_GREATER(d, max, -1); + HWY_ENSURE_GREATER(d, max, min); + HWY_ENSURE_GREATER(d, 0, min); + HWY_ENSURE_GREATER(d, min / 2, min); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_true, Gt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt(vn, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(vn, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, vn)); + } +}; + +// S-SSE3 bug (#795): same upper, differing MSB in lower +struct TestStrictInt64 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto m0 = MaskFalse(d); + const auto m1 = MaskTrue(d); + HWY_ASSERT_MASK_EQ(d, m0, Lt(Set(d, 0x380000000LL), Set(d, 0x300000001LL))); + HWY_ASSERT_MASK_EQ(d, m1, Lt(Set(d, 0xF00000000LL), Set(d, 0xF80000000LL))); + HWY_ASSERT_MASK_EQ(d, m1, Lt(Set(d, 0xF00000000LL), Set(d, 0xF80000001LL))); + } +}; + +HWY_NOINLINE void TestAllStrictInt() { + ForSignedTypes(ForPartialVectors<TestStrictInt>()); + ForPartialVectors<TestStrictInt64>()(int64_t()); +} + +struct TestStrictFloat { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const T huge_neg = T(-1E35); + const T huge_pos = T(1E36); + const auto v0 = Zero(d); + const auto v2 = Iota(d, T(2)); + const auto vn = Neg(v2); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + // Individual values of interest + HWY_ENSURE_GREATER(d, 2, 1); + HWY_ENSURE_GREATER(d, 1, 0); + HWY_ENSURE_GREATER(d, 0, -1); + HWY_ENSURE_GREATER(d, -1, -2); + HWY_ENSURE_GREATER(d, huge_pos, 1); + HWY_ENSURE_GREATER(d, huge_pos, 0); + HWY_ENSURE_GREATER(d, huge_pos, -1); + HWY_ENSURE_GREATER(d, huge_pos, huge_neg); + HWY_ENSURE_GREATER(d, 0, huge_neg); + + // Also use Iota to ensure lanes are independent + HWY_ASSERT_MASK_EQ(d, mask_true, Gt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt(vn, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt(vn, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, Gt(vn, vn)); + } +}; + +HWY_NOINLINE void TestAllStrictFloat() { + ForFloatTypes(ForPartialVectors<TestStrictFloat>()); +} + +struct TestWeakFloat { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v2 = Iota(d, T(2)); + const auto vn = Iota(d, -T(Lanes(d))); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ge(v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, Le(vn, vn)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ge(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_true, Le(vn, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Le(v2, vn)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ge(vn, v2)); + } +}; + +HWY_NOINLINE void TestAllWeakFloat() { + ForFloatTypes(ForPartialVectors<TestWeakFloat>()); +} + +template <class D> +static HWY_NOINLINE Vec<D> Make128(D d, uint64_t hi, uint64_t lo) { + alignas(16) uint64_t in[2]; + in[0] = lo; + in[1] = hi; + return LoadDup128(d, in); +} + +struct TestLt128 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax<T>(), LimitsMax<T>()); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllLt128() { ForGEVectors<128, TestLt128>()(uint64_t()); } + +struct TestLt128Upper { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax<T>(), LimitsMax<T>()); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Lt128Upper(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Lt128Upper(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllLt128Upper() { + ForGEVectors<128, TestLt128Upper>()(uint64_t()); +} + +struct TestEq128 { // Also Ne128 + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, v10, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v11, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, Add(iota, v10), iota)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax<T>(), LimitsMax<T>()); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128(d, vm, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128(d, v11, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllEq128() { ForGEVectors<128, TestEq128>()(uint64_t()); } + +struct TestEq128Upper { // Also Ne128Upper + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const V v00 = Zero(d); + const V v01 = Make128(d, 0, 1); + const V v10 = Make128(d, 1, 0); + const V v11 = Add(v01, v10); + + const auto mask_false = MaskFalse(d); + const auto mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v10, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v00, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v01, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v10, v10)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v00, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v00, v01)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v01, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v01, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v01, v11)); + + // Reversed order + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, v01, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, v01, v00)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v11, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v10, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v11, v01)); + + // Also check 128-bit blocks are independent + const V iota = Iota(d, 1); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, iota, Add(iota, v01))); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, iota, Add(iota, v10))); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, Add(iota, v01), iota)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, Add(iota, v10), iota)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, Add(iota, v10), iota)); + + // Max value + const V vm = Make128(d, LimitsMax<T>(), LimitsMax<T>()); + HWY_ASSERT_MASK_EQ(d, mask_true, Eq128Upper(d, vm, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Ne128Upper(d, vm, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_false, Eq128Upper(d, v11, vm)); + + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v00)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v01)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v10)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, vm, v11)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v00, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v01, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v10, vm)); + HWY_ASSERT_MASK_EQ(d, mask_true, Ne128Upper(d, v11, vm)); + } +}; + +HWY_NOINLINE void TestAllEq128Upper() { + ForGEVectors<128, TestEq128Upper>()(uint64_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCompareTest); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEquality); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictUnsigned); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictInt); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllStrictFloat); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllWeakFloat); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/compress_test.cc b/third_party/highway/hwy/tests/compress_test.cc new file mode 100644 index 0000000000..ae008b4dc4 --- /dev/null +++ b/third_party/highway/hwy/tests/compress_test.cc @@ -0,0 +1,833 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memset + +#include <array> // IWYU pragma: keep + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/compress_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Regenerate tables used in the implementation, instead of testing. +#define HWY_PRINT_TABLES 0 + +#if !HWY_PRINT_TABLES || HWY_IDE + +template <class D, class DI, typename T = TFromD<D>, typename TI = TFromD<DI>> +void CheckStored(D d, DI di, const char* op, size_t expected_pos, + size_t actual_pos, size_t num_to_check, + const AlignedFreeUniquePtr<T[]>& in, + const AlignedFreeUniquePtr<TI[]>& mask_lanes, + const AlignedFreeUniquePtr<T[]>& expected, const T* actual_u, + int line) { + if (expected_pos != actual_pos) { + hwy::Abort(__FILE__, line, + "%s: size mismatch for %s: expected %d, actual %d\n", op, + TypeName(T(), Lanes(d)).c_str(), static_cast<int>(expected_pos), + static_cast<int>(actual_pos)); + } + // Modified from AssertVecEqual - we may not be checking all lanes. + for (size_t i = 0; i < num_to_check; ++i) { + if (!IsEqual(expected[i], actual_u[i])) { + const size_t N = Lanes(d); + fprintf(stderr, "%s: mismatch at i=%d of %d, line %d:\n\n", op, + static_cast<int>(i), static_cast<int>(num_to_check), line); + Print(di, "mask", Load(di, mask_lanes.get()), 0, N); + Print(d, "in", Load(d, in.get()), 0, N); + Print(d, "expect", Load(d, expected.get()), 0, num_to_check); + Print(d, "actual", Load(d, actual_u), 0, num_to_check); + HWY_ASSERT(false); + } + } +} + +struct TestCompress { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned<T>; // For mask > 0 comparison + using TU = MakeUnsigned<T>; + const Rebind<TI, D> di; + const size_t N = Lanes(d); + + for (int frac : {0, 2, 3}) { + // For CompressStore + const size_t misalign = static_cast<size_t>(frac) * N / 4; + + auto in_lanes = AllocateAligned<T>(N); + auto mask_lanes = AllocateAligned<TI>(N); + auto garbage = AllocateAligned<TU>(N); + auto expected = AllocateAligned<T>(N); + auto actual_a = AllocateAligned<T>(misalign + N); + T* actual_u = actual_a.get() + misalign; + + const size_t bits_size = RoundUpTo((N + 7) / 8, 8); + auto bits = AllocateAligned<uint8_t>(bits_size); + memset(bits.get(), 0, bits_size); // for MSAN + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + size_t expected_pos = 0; + for (size_t i = 0; i < N; ++i) { + const uint64_t r = Random32(&rng); + in_lanes[i] = T(); // cannot initialize float16_t directly. + CopyBytes<sizeof(T)>(&r, &in_lanes[i]); // not same size + mask_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + if (mask_lanes[i] > 0) { + expected[expected_pos++] = in_lanes[i]; + } + garbage[i] = static_cast<TU>(Random64(&rng)); + } + size_t num_to_check; + if (CompressIsPartition<T>::value) { + // For non-native Compress, also check that mask=false lanes were + // moved to the back of the vector (highest indices). + size_t extra = expected_pos; + for (size_t i = 0; i < N; ++i) { + if (mask_lanes[i] == 0) { + expected[extra++] = in_lanes[i]; + } + } + HWY_ASSERT(extra == N); + num_to_check = N; + } else { + // For native Compress, only the mask=true lanes are defined. + num_to_check = expected_pos; + } + + const auto in = Load(d, in_lanes.get()); + const auto mask = + RebindMask(d, Gt(Load(di, mask_lanes.get()), Zero(di))); + StoreMaskBits(d, mask, bits.get()); + + // Compress + memset(actual_u, 0, N * sizeof(T)); + StoreU(Compress(in, mask), d, actual_u); + CheckStored(d, di, "Compress", expected_pos, expected_pos, num_to_check, + in_lanes, mask_lanes, expected, actual_u, __LINE__); + + // CompressNot + memset(actual_u, 0, N * sizeof(T)); + StoreU(CompressNot(in, Not(mask)), d, actual_u); + CheckStored(d, di, "CompressNot", expected_pos, expected_pos, + num_to_check, in_lanes, mask_lanes, expected, actual_u, + __LINE__); + + // CompressStore + memset(actual_u, 0, N * sizeof(T)); + const size_t size1 = CompressStore(in, mask, d, actual_u); + // expected_pos instead of num_to_check because this op is not + // affected by CompressIsPartition. + CheckStored(d, di, "CompressStore", expected_pos, size1, expected_pos, + in_lanes, mask_lanes, expected, actual_u, __LINE__); + + // CompressBlendedStore + memcpy(actual_u, garbage.get(), N * sizeof(T)); + const size_t size2 = CompressBlendedStore(in, mask, d, actual_u); + // expected_pos instead of num_to_check because this op only writes + // the mask=true lanes. + CheckStored(d, di, "CompressBlendedStore", expected_pos, size2, + expected_pos, in_lanes, mask_lanes, expected, actual_u, + __LINE__); + // Subsequent lanes are untouched. + for (size_t i = size2; i < N; ++i) { +#if HWY_COMPILER_MSVC && HWY_TARGET == HWY_AVX2 + // TODO(eustas): re-enable when compiler is fixed +#else + HWY_ASSERT_EQ(garbage[i], reinterpret_cast<TU*>(actual_u)[i]); +#endif + } + + // CompressBits + memset(actual_u, 0, N * sizeof(T)); + StoreU(CompressBits(in, bits.get()), d, actual_u); + CheckStored(d, di, "CompressBits", expected_pos, expected_pos, + num_to_check, in_lanes, mask_lanes, expected, actual_u, + __LINE__); + + // CompressBitsStore + memset(actual_u, 0, N * sizeof(T)); + const size_t size3 = CompressBitsStore(in, bits.get(), d, actual_u); + // expected_pos instead of num_to_check because this op is not + // affected by CompressIsPartition. + CheckStored(d, di, "CompressBitsStore", expected_pos, size3, + expected_pos, in_lanes, mask_lanes, expected, actual_u, + __LINE__); + } // rep + } // frac + } // operator() +}; + +HWY_NOINLINE void TestAllCompress() { + ForAllTypes(ForPartialVectors<TestCompress>()); +} + +struct TestCompressBlocks { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET == HWY_SCALAR + (void)d; +#else + static_assert(sizeof(T) == 8 && !IsSigned<T>(), "Should be u64"); + RandomState rng; + + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(d); + + auto in_lanes = AllocateAligned<T>(N); + auto mask_lanes = AllocateAligned<TI>(N); + auto expected = AllocateAligned<T>(N); + auto actual = AllocateAligned<T>(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + size_t expected_pos = 0; + for (size_t i = 0; i < N; i += 2) { + const uint64_t bits = Random32(&rng); + in_lanes[i + 1] = in_lanes[i] = T(); // cannot set float16_t directly. + CopyBytes<sizeof(T)>(&bits, &in_lanes[i]); // not same size + CopyBytes<sizeof(T)>(&bits, &in_lanes[i + 1]); // not same size + mask_lanes[i + 1] = mask_lanes[i] = TI{(Random32(&rng) & 8) ? 1 : 0}; + if (mask_lanes[i] > 0) { + expected[expected_pos++] = in_lanes[i]; + expected[expected_pos++] = in_lanes[i + 1]; + } + } + size_t num_to_check; + if (CompressIsPartition<T>::value) { + // For non-native Compress, also check that mask=false lanes were + // moved to the back of the vector (highest indices). + size_t extra = expected_pos; + for (size_t i = 0; i < N; ++i) { + if (mask_lanes[i] == 0) { + expected[extra++] = in_lanes[i]; + } + } + HWY_ASSERT(extra == N); + num_to_check = N; + } else { + // For native Compress, only the mask=true lanes are defined. + num_to_check = expected_pos; + } + + const auto in = Load(d, in_lanes.get()); + const auto mask = RebindMask(d, Gt(Load(di, mask_lanes.get()), Zero(di))); + + // CompressBlocksNot + memset(actual.get(), 0, N * sizeof(T)); + StoreU(CompressBlocksNot(in, Not(mask)), d, actual.get()); + CheckStored(d, di, "CompressBlocksNot", expected_pos, expected_pos, + num_to_check, in_lanes, mask_lanes, expected, actual.get(), + __LINE__); + } // rep +#endif // HWY_TARGET == HWY_SCALAR + } // operator() +}; + +HWY_NOINLINE void TestAllCompressBlocks() { + ForGE128Vectors<TestCompressBlocks>()(uint64_t()); +} + +#endif // !HWY_PRINT_TABLES + +#if HWY_PRINT_TABLES || HWY_IDE +namespace detail { // for code folding + +void PrintCompress8x8Tables() { + printf("======================================= 8x8\n"); + constexpr size_t N = 8; + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array<uint8_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + printf("%d,", indices[i]); + } + printf(code & 1 ? "//\n" : "/**/"); + } + printf("\n"); +} + +void PrintCompress16x8Tables() { + printf("======================================= 16x8\n"); + constexpr size_t N = 8; // 128-bit SIMD + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array<uint8_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Doubled (for converting lane to byte indices) + for (size_t i = 0; i < N; ++i) { + printf("%d,", 2 * indices[i]); + } + printf(code & 1 ? "//\n" : "/**/"); + } + printf("\n"); +} + +void PrintCompressNot16x8Tables() { + printf("======================================= Not 16x8\n"); + constexpr size_t N = 8; // 128-bit SIMD + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array<uint8_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Doubled (for converting lane to byte indices) + for (size_t i = 0; i < N; ++i) { + printf("%d,", 2 * indices[i]); + } + printf(not_code & 1 ? "//\n" : "/**/"); + } + printf("\n"); +} + +// Compressed to nibbles, unpacked via variable right shift. Also includes +// FirstN bits in the nibble MSB. +void PrintCompress32x8Tables() { + printf("======================================= 32/64x8\n"); + constexpr size_t N = 8; // AVX2 or 64-bit AVX3 + for (uint64_t code = 0; code < (1ull << N); ++code) { + const size_t count = PopCount(code); + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + if (i < count) { + indices[i] |= N; + HWY_ASSERT(indices[i] < 0x10); + } + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast<uint32_t>(packed)); + } + printf("\n"); +} + +void PrintCompressNot32x8Tables() { + printf("======================================= Not 32/64x8\n"); + constexpr size_t N = 8; // AVX2 or 64-bit AVX3 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + const size_t count = PopCount(code); + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + if (i < count) { + indices[i] |= N; + HWY_ASSERT(indices[i] < 0x10); + } + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast<uint32_t>(packed)); + } + printf("\n"); +} + +// Compressed to nibbles (for AVX3 64x4) +void PrintCompress64x4NibbleTables() { + printf("======================================= 64x4Nibble\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast<uint32_t>(packed)); + } + printf("\n"); +} + +void PrintCompressNot64x4NibbleTables() { + printf("======================================= Not 64x4Nibble\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast<uint32_t>(packed)); + } + printf("\n"); +} + +void PrintCompressNot64x2NibbleTables() { + printf("======================================= Not 64x2Nibble\n"); + constexpr size_t N = 2; // 128-bit + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Convert to nibbles + uint64_t packed = 0; + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT(indices[i] < N); + packed += indices[i] << (i * 4); + } + + HWY_ASSERT(packed < (1ull << (N * 4))); + printf("0x%08x,", static_cast<uint32_t>(packed)); + } + printf("\n"); +} + +void PrintCompress64x4Tables() { + printf("======================================= 64x4 uncompressed\n"); + constexpr size_t N = 4; // SVE_256 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array<size_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + printf("%d,", static_cast<int>(indices[i])); + } + } + printf("\n"); +} + +void PrintCompressNot64x4Tables() { + printf("======================================= Not 64x4 uncompressed\n"); + constexpr size_t N = 4; // SVE_256 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array<size_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + printf("%d,", static_cast<int>(indices[i])); + } + } + printf("\n"); +} + +// Same as above, but prints pairs of u32 indices (for AVX2). Also includes +// FirstN bits in the nibble MSB. +void PrintCompress64x4PairTables() { + printf("======================================= 64x4 u32 index\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t code = 0; code < (1ull << N); ++code) { + const size_t count = PopCount(code); + std::array<size_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + const int first_n_bit = i < count ? 8 : 0; + const int low = static_cast<int>(2 * indices[i]) + first_n_bit; + HWY_ASSERT(low < 0x10); + printf("%d, %d, ", low, low + 1); + } + } + printf("\n"); +} + +void PrintCompressNot64x4PairTables() { + printf("======================================= Not 64x4 u32 index\n"); + constexpr size_t N = 4; // AVX2 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + const size_t count = PopCount(code); + std::array<size_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + // Store uncompressed indices because SVE TBL returns 0 if an index is out + // of bounds. On AVX3 we simply variable-shift because permute indices are + // interpreted modulo N. Compression is not worth the extra shift+AND + // because the table is anyway only 512 bytes. + for (size_t i = 0; i < N; ++i) { + const int first_n_bit = i < count ? 8 : 0; + const int low = static_cast<int>(2 * indices[i]) + first_n_bit; + HWY_ASSERT(low < 0x10); + printf("%d, %d, ", low, low + 1); + } + } + printf("\n"); +} + +// 4-tuple of byte indices +void PrintCompress32x4Tables() { + printf("======================================= 32x4\n"); + using T = uint32_t; + constexpr size_t N = 4; // SSE4 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast<int>(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +void PrintCompressNot32x4Tables() { + printf("======================================= Not 32x4\n"); + using T = uint32_t; + constexpr size_t N = 4; // SSE4 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast<int>(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +// 8-tuple of byte indices +void PrintCompress64x2Tables() { + printf("======================================= 64x2\n"); + using T = uint64_t; + constexpr size_t N = 2; // SSE4 + for (uint64_t code = 0; code < (1ull << N); ++code) { + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast<int>(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +void PrintCompressNot64x2Tables() { + printf("======================================= Not 64x2\n"); + using T = uint64_t; + constexpr size_t N = 2; // SSE4 + for (uint64_t not_code = 0; not_code < (1ull << N); ++not_code) { + const uint64_t code = ~not_code; + std::array<uint32_t, N> indices{0}; + size_t pos = 0; + // All lanes where mask = true + for (size_t i = 0; i < N; ++i) { + if (code & (1ull << i)) { + indices[pos++] = i; + } + } + // All lanes where mask = false + for (size_t i = 0; i < N; ++i) { + if (!(code & (1ull << i))) { + indices[pos++] = i; + } + } + HWY_ASSERT(pos == N); + + for (size_t i = 0; i < N; ++i) { + for (size_t idx_byte = 0; idx_byte < sizeof(T); ++idx_byte) { + printf("%d,", static_cast<int>(sizeof(T) * indices[i] + idx_byte)); + } + } + } + printf("\n"); +} + +} // namespace detail + +HWY_NOINLINE void PrintTables() { + // Only print once. +#if HWY_TARGET == HWY_STATIC_TARGET + detail::PrintCompress32x8Tables(); + detail::PrintCompressNot32x8Tables(); + detail::PrintCompress64x4NibbleTables(); + detail::PrintCompressNot64x4NibbleTables(); + detail::PrintCompressNot64x2NibbleTables(); + detail::PrintCompress64x4Tables(); + detail::PrintCompressNot64x4Tables(); + detail::PrintCompress32x4Tables(); + detail::PrintCompressNot32x4Tables(); + detail::PrintCompress64x2Tables(); + detail::PrintCompressNot64x2Tables(); + detail::PrintCompress64x4PairTables(); + detail::PrintCompressNot64x4PairTables(); + detail::PrintCompress16x8Tables(); + detail::PrintCompress8x8Tables(); + detail::PrintCompressNot16x8Tables(); +#endif +} + +#endif // HWY_PRINT_TABLES + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCompressTest); +#if HWY_PRINT_TABLES +// Only print instead of running tests; this will be visible in the log. +HWY_EXPORT_AND_TEST_P(HwyCompressTest, PrintTables); +#else +HWY_EXPORT_AND_TEST_P(HwyCompressTest, TestAllCompress); +HWY_EXPORT_AND_TEST_P(HwyCompressTest, TestAllCompressBlocks); +#endif +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/convert_test.cc b/third_party/highway/hwy/tests/convert_test.cc new file mode 100644 index 0000000000..a7aea5fe9e --- /dev/null +++ b/third_party/highway/hwy/tests/convert_test.cc @@ -0,0 +1,643 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <cmath> // std::isfinite + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/convert_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Cast and ensure bytes are the same. Called directly from TestAllBitCast or +// via TestBitCastFrom. +template <typename ToT> +struct TestBitCast { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const Repartition<ToT, D> dto; + const size_t N = Lanes(d); + const size_t Nto = Lanes(dto); + if (N == 0 || Nto == 0) return; + HWY_ASSERT_EQ(N * sizeof(T), Nto * sizeof(ToT)); + const auto vf = Iota(d, 1); + const auto vt = BitCast(dto, vf); + // Must return the same bits + auto from_lanes = AllocateAligned<T>(Lanes(d)); + auto to_lanes = AllocateAligned<ToT>(Lanes(dto)); + Store(vf, d, from_lanes.get()); + Store(vt, dto, to_lanes.get()); + HWY_ASSERT( + BytesEqual(from_lanes.get(), to_lanes.get(), Lanes(d) * sizeof(T))); + } +}; + +// From D to all types. +struct TestBitCastFrom { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + TestBitCast<uint8_t>()(t, d); + TestBitCast<uint16_t>()(t, d); + TestBitCast<uint32_t>()(t, d); +#if HWY_HAVE_INTEGER64 + TestBitCast<uint64_t>()(t, d); +#endif + TestBitCast<int8_t>()(t, d); + TestBitCast<int16_t>()(t, d); + TestBitCast<int32_t>()(t, d); +#if HWY_HAVE_INTEGER64 + TestBitCast<int64_t>()(t, d); +#endif + TestBitCast<float>()(t, d); +#if HWY_HAVE_FLOAT64 + TestBitCast<double>()(t, d); +#endif + } +}; + +HWY_NOINLINE void TestAllBitCast() { + // For HWY_SCALAR and partial vectors, we can only cast to same-sized types: + // the former can't partition its single lane, and the latter can be smaller + // than a destination type. + const ForPartialVectors<TestBitCast<uint8_t>> to_u8; + to_u8(uint8_t()); + to_u8(int8_t()); + + const ForPartialVectors<TestBitCast<int8_t>> to_i8; + to_i8(uint8_t()); + to_i8(int8_t()); + + const ForPartialVectors<TestBitCast<uint16_t>> to_u16; + to_u16(uint16_t()); + to_u16(int16_t()); + + const ForPartialVectors<TestBitCast<int16_t>> to_i16; + to_i16(uint16_t()); + to_i16(int16_t()); + + const ForPartialVectors<TestBitCast<uint32_t>> to_u32; + to_u32(uint32_t()); + to_u32(int32_t()); + to_u32(float()); + + const ForPartialVectors<TestBitCast<int32_t>> to_i32; + to_i32(uint32_t()); + to_i32(int32_t()); + to_i32(float()); + +#if HWY_HAVE_INTEGER64 + const ForPartialVectors<TestBitCast<uint64_t>> to_u64; + to_u64(uint64_t()); + to_u64(int64_t()); +#if HWY_HAVE_FLOAT64 + to_u64(double()); +#endif + + const ForPartialVectors<TestBitCast<int64_t>> to_i64; + to_i64(uint64_t()); + to_i64(int64_t()); +#if HWY_HAVE_FLOAT64 + to_i64(double()); +#endif +#endif // HWY_HAVE_INTEGER64 + + const ForPartialVectors<TestBitCast<float>> to_float; + to_float(uint32_t()); + to_float(int32_t()); + to_float(float()); + +#if HWY_HAVE_FLOAT64 + const ForPartialVectors<TestBitCast<double>> to_double; + to_double(double()); +#if HWY_HAVE_INTEGER64 + to_double(uint64_t()); + to_double(int64_t()); +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_HAVE_FLOAT64 + +#if HWY_TARGET != HWY_SCALAR + // For non-scalar vectors, we can cast all types to all. + ForAllTypes(ForGEVectors<64, TestBitCastFrom>()); +#endif +} + +template <typename ToT> +struct TestPromoteTo { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(sizeof(T) < sizeof(ToT), "Input type must be narrower"); + const Rebind<ToT, D> to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned<T>(N); + auto expected = AllocateAligned<ToT>(N); + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + const uint64_t bits = rng(); + CopyBytes<sizeof(T)>(&bits, &from[i]); // not same size + expected[i] = from[i]; + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + PromoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllPromoteTo() { + const ForPromoteVectors<TestPromoteTo<uint16_t>, 1> to_u16div2; + to_u16div2(uint8_t()); + + const ForPromoteVectors<TestPromoteTo<uint32_t>, 2> to_u32div4; + to_u32div4(uint8_t()); + + const ForPromoteVectors<TestPromoteTo<uint32_t>, 1> to_u32div2; + to_u32div2(uint16_t()); + + const ForPromoteVectors<TestPromoteTo<int16_t>, 1> to_i16div2; + to_i16div2(uint8_t()); + to_i16div2(int8_t()); + + const ForPromoteVectors<TestPromoteTo<int32_t>, 1> to_i32div2; + to_i32div2(uint16_t()); + to_i32div2(int16_t()); + + const ForPromoteVectors<TestPromoteTo<int32_t>, 2> to_i32div4; + to_i32div4(uint8_t()); + to_i32div4(int8_t()); + + // Must test f16/bf16 separately because we can only load/store/convert them. + +#if HWY_HAVE_INTEGER64 + const ForPromoteVectors<TestPromoteTo<uint64_t>, 1> to_u64div2; + to_u64div2(uint32_t()); + + const ForPromoteVectors<TestPromoteTo<int64_t>, 1> to_i64div2; + to_i64div2(int32_t()); +#endif + +#if HWY_HAVE_FLOAT64 + const ForPromoteVectors<TestPromoteTo<double>, 1> to_f64div2; + to_f64div2(int32_t()); + to_f64div2(float()); +#endif +} + +template <typename T, HWY_IF_FLOAT(T)> +bool IsFinite(T t) { + return std::isfinite(t); +} +// Wrapper avoids calling std::isfinite for integer types (ambiguous). +template <typename T, HWY_IF_NOT_FLOAT(T)> +bool IsFinite(T /*unused*/) { + return true; +} + +template <class D> +AlignedFreeUniquePtr<float[]> F16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // +/- 1 + 1.0f, -1.0f, + // +/- 0 + 0.0f, -0.0f, + // near 0 + 0.25f, -0.25f, + // +/- integer + 4.0f, -32.0f, + // positive near limit + 65472.0f, 65504.0f, + // negative near limit + -65472.0f, -65504.0f, + // positive +/- delta + 2.00390625f, 3.99609375f, + // negative +/- delta + -2.00390625f, -3.99609375f, + // No infinity/NaN - implementation-defined due to ARM. + }; + constexpr size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + HWY_ASSERT(N != 0); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned<float>(padded); + auto expected = AllocateAligned<float>(padded); + size_t i = 0; + for (; i < kNumTestCases; ++i) { + in[i] = test_cases[i]; + } + for (; i < padded; ++i) { + in[i] = 0.0f; + } + return in; +} + +struct TestF16 { + template <typename TF32, class DF32> + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if HWY_HAVE_FLOAT16 + size_t padded; + const size_t N = Lanes(d32); // same count for f16 + HWY_ASSERT(N != 0); + auto in = F16TestCases(d32, padded); + using TF16 = float16_t; + const Rebind<TF16, DF32> d16; + auto temp16 = AllocateAligned<TF16>(N); + + for (size_t i = 0; i < padded; i += N) { + const auto loaded = Load(d32, &in[i]); + Store(DemoteTo(d16, loaded), d16, temp16.get()); + HWY_ASSERT_VEC_EQ(d32, loaded, PromoteTo(d32, Load(d16, temp16.get()))); + } +#else + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllF16() { ForDemoteVectors<TestF16>()(float()); } + +template <class D> +AlignedFreeUniquePtr<float[]> BF16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // +/- 1 + 1.0f, -1.0f, + // +/- 0 + 0.0f, -0.0f, + // near 0 + 0.25f, -0.25f, + // +/- integer + 4.0f, -32.0f, + // positive near limit + 3.389531389251535E38f, 1.99384199368e+38f, + // negative near limit + -3.389531389251535E38f, -1.99384199368e+38f, + // positive +/- delta + 2.015625f, 3.984375f, + // negative +/- delta + -2.015625f, -3.984375f, + }; + constexpr size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + HWY_ASSERT(N != 0); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned<float>(padded); + auto expected = AllocateAligned<float>(padded); + size_t i = 0; + for (; i < kNumTestCases; ++i) { + in[i] = test_cases[i]; + } + for (; i < padded; ++i) { + in[i] = 0.0f; + } + return in; +} + +struct TestBF16 { + template <typename TF32, class DF32> + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if !defined(HWY_EMULATE_SVE) + size_t padded; + auto in = BF16TestCases(d32, padded); + using TBF16 = bfloat16_t; +#if HWY_TARGET == HWY_SCALAR + const Rebind<TBF16, DF32> dbf16; // avoid 4/2 = 2 lanes +#else + const Repartition<TBF16, DF32> dbf16; +#endif + const Half<decltype(dbf16)> dbf16_half; + const size_t N = Lanes(d32); + HWY_ASSERT(Lanes(dbf16_half) <= N); + auto temp16 = AllocateAligned<TBF16>(N); + + for (size_t i = 0; i < padded; i += N) { + const auto loaded = Load(d32, &in[i]); + const auto v16 = DemoteTo(dbf16_half, loaded); + Store(v16, dbf16_half, temp16.get()); + const auto v16_loaded = Load(dbf16_half, temp16.get()); + HWY_ASSERT_VEC_EQ(d32, loaded, PromoteTo(d32, v16_loaded)); + } +#else + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllBF16() { ForShrinkableVectors<TestBF16>()(float()); } + +struct TestConvertU8 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, const D du32) { + const Rebind<uint8_t, D> du8; + const auto wrap = Set(du32, 0xFF); + HWY_ASSERT_VEC_EQ(du8, Iota(du8, 0), U8FromU32(And(Iota(du32, 0), wrap))); + HWY_ASSERT_VEC_EQ(du8, Iota(du8, 0x7F), + U8FromU32(And(Iota(du32, 0x7F), wrap))); + } +}; + +HWY_NOINLINE void TestAllConvertU8() { + ForDemoteVectors<TestConvertU8, 2>()(uint32_t()); +} + +template <typename From, typename To, class D> +constexpr bool IsSupportedTruncation() { + return (sizeof(To) < sizeof(From)) && + (Pow2(Rebind<To, D>()) + 3 >= static_cast<int>(CeilLog2(sizeof(To)))); +} + +struct TestTruncateTo { + template <typename From, typename To, class D, + hwy::EnableIf<!IsSupportedTruncation<From, To, D>()>* = nullptr> + HWY_NOINLINE void testTo(From, To, const D) { + // do nothing + } + + template <typename From, typename To, class D, + hwy::EnableIf<IsSupportedTruncation<From, To, D>()>* = nullptr> + HWY_NOINLINE void testTo(From, To, const D d) { + constexpr uint32_t base = 0xFA578D00; + const Rebind<To, D> dTo; + const auto src = Iota(d, static_cast<From>(base)); + const auto expected = Iota(dTo, static_cast<To>(base)); + const VFromD<decltype(dTo)> actual = TruncateTo(dTo, src); + HWY_ASSERT_VEC_EQ(dTo, expected, actual); + } + + template <typename T, class D> + HWY_NOINLINE void operator()(T from, const D d) { + testTo<T, uint8_t, D>(from, uint8_t(), d); + testTo<T, uint16_t, D>(from, uint16_t(), d); + testTo<T, uint32_t, D>(from, uint32_t(), d); + } +}; + +HWY_NOINLINE void TestAllTruncate() { + ForUnsignedTypes(ForPartialVectors<TestTruncateTo>()); +} + +// Separate function to attempt to work around a compiler bug on ARM: when this +// is merged with TestIntFromFloat, outputs match a previous Iota(-(N+1)) input. +struct TestIntFromFloatHuge { + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + // The ARMv7 manual says that float->int saturates, i.e. chooses the + // nearest representable value. This works correctly on armhf with GCC, but + // not with clang. For reasons unknown, MSVC also runs into an out-of-memory + // error here. +#if HWY_COMPILER_CLANG || HWY_COMPILER_MSVC + (void)df; +#else + using TI = MakeSigned<TF>; + const Rebind<TI, DF> di; + + // Workaround for incorrect 32-bit GCC codegen for SSSE3 - Print-ing + // the expected lvalue also seems to prevent the issue. + const size_t N = Lanes(df); + auto expected = AllocateAligned<TI>(N); + + // Huge positive + Store(Set(di, LimitsMax<TI>()), di, expected.get()); + HWY_ASSERT_VEC_EQ(di, expected.get(), ConvertTo(di, Set(df, TF(1E20)))); + + // Huge negative + Store(Set(di, LimitsMin<TI>()), di, expected.get()); + HWY_ASSERT_VEC_EQ(di, expected.get(), ConvertTo(di, Set(df, TF(-1E20)))); +#endif + } +}; + +class TestIntFromFloat { + template <typename TF, class DF> + static HWY_NOINLINE void TestPowers(TF /*unused*/, const DF df) { + using TI = MakeSigned<TF>; + const Rebind<TI, DF> di; + constexpr size_t kBits = sizeof(TF) * 8; + + // Powers of two, plus offsets to set some mantissa bits. + const int64_t ofs_table[3] = {0LL, 3LL << (kBits / 2), 1LL << (kBits - 15)}; + for (int sign = 0; sign < 2; ++sign) { + for (size_t shift = 0; shift < kBits - 1; ++shift) { + for (int64_t ofs : ofs_table) { + const int64_t mag = (int64_t{1} << shift) + ofs; + const int64_t val = sign ? mag : -mag; + HWY_ASSERT_VEC_EQ(di, Set(di, static_cast<TI>(val)), + ConvertTo(di, Set(df, static_cast<TF>(val)))); + } + } + } + } + + template <typename TF, class DF> + static HWY_NOINLINE void TestRandom(TF /*unused*/, const DF df) { + using TI = MakeSigned<TF>; + const Rebind<TI, DF> di; + const size_t N = Lanes(df); + + // TF does not have enough precision to represent TI. + const double min = static_cast<double>(LimitsMin<TI>()); + const double max = static_cast<double>(LimitsMax<TI>()); + + // Also check random values. + auto from = AllocateAligned<TF>(N); + auto expected = AllocateAligned<TI>(N); + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + CopyBytes<sizeof(TF)>(&bits, &from[i]); // not same size + } while (!std::isfinite(from[i])); + if (from[i] >= max) { + expected[i] = LimitsMax<TI>(); + } else if (from[i] <= min) { + expected[i] = LimitsMin<TI>(); + } else { + expected[i] = static_cast<TI>(from[i]); + } + } + + HWY_ASSERT_VEC_EQ(di, expected.get(), + ConvertTo(di, Load(df, from.get()))); + } + } + + public: + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF tf, const DF df) { + using TI = MakeSigned<TF>; + const Rebind<TI, DF> di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), ConvertTo(di, Iota(df, TF(4.0)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), ConvertTo(di, Iota(df, -TF(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), ConvertTo(di, Iota(df, TF(2.001)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), ConvertTo(di, Iota(df, TF(3.9999)))); + + const TF eps = static_cast<TF>(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + ConvertTo(di, Iota(df, -TF(N + 1) + eps))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + ConvertTo(di, Iota(df, -TF(N + 1) - eps))); + + TestPowers(tf, df); + TestRandom(tf, df); + } +}; + +HWY_NOINLINE void TestAllIntFromFloat() { + ForFloatTypes(ForPartialVectors<TestIntFromFloatHuge>()); + ForFloatTypes(ForPartialVectors<TestIntFromFloat>()); +} + +struct TestFloatFromInt { + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = MakeSigned<TF>; + const RebindToSigned<DF> di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), ConvertTo(df, Iota(di, TI(4)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), ConvertTo(df, Iota(di, -TI(N)))); + + // Max positive + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax<TI>())), + ConvertTo(df, Set(di, LimitsMax<TI>()))); + + // Min negative + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin<TI>())), + ConvertTo(df, Set(di, LimitsMin<TI>()))); + } +}; + +HWY_NOINLINE void TestAllFloatFromInt() { + ForFloatTypes(ForPartialVectors<TestFloatFromInt>()); +} + +struct TestFloatFromUint { + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TU = MakeUnsigned<TF>; + const RebindToUnsigned<DF> du; + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), ConvertTo(df, Iota(du, TU(4)))); + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(65535.0)), + ConvertTo(df, Iota(du, 65535))); // 2^16-1 + if (sizeof(TF) > 4) { + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4294967295.0)), + ConvertTo(df, Iota(du, 4294967295ULL))); // 2^32-1 + } + + // Max positive + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax<TU>())), + ConvertTo(df, Set(du, LimitsMax<TU>()))); + + // Zero + HWY_ASSERT_VEC_EQ(df, Zero(df), ConvertTo(df, Zero(du))); + } +}; + +HWY_NOINLINE void TestAllFloatFromUint() { + ForFloatTypes(ForPartialVectors<TestFloatFromUint>()); +} + +struct TestI32F64 { + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = int32_t; + const Rebind<TI, DF> di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(df, Iota(df, -TF(N)), PromoteTo(df, Iota(di, -TI(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(2.0)), PromoteTo(df, Iota(di, TI(2)))); + + // Below positive + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(4.0)), PromoteTo(df, Iota(di, TI(4)))); + + // Above negative + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-4.0)), PromoteTo(df, Iota(di, TI(-4)))); + + // Below negative + HWY_ASSERT_VEC_EQ(df, Iota(df, TF(-2.0)), PromoteTo(df, Iota(di, TI(-2)))); + + // Max positive int + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMax<TI>())), + PromoteTo(df, Set(di, LimitsMax<TI>()))); + + // Min negative int + HWY_ASSERT_VEC_EQ(df, Set(df, TF(LimitsMin<TI>())), + PromoteTo(df, Set(di, LimitsMin<TI>()))); + } +}; + +HWY_NOINLINE void TestAllI32F64() { +#if HWY_HAVE_FLOAT64 + ForDemoteVectors<TestI32F64>()(double()); +#endif +} + + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyConvertTest); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBitCast); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllPromoteTo); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllF16); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllBF16); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllConvertU8); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllTruncate); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllIntFromFloat); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllFloatFromInt); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllFloatFromUint); +HWY_EXPORT_AND_TEST_P(HwyConvertTest, TestAllI32F64); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/crypto_test.cc b/third_party/highway/hwy/tests/crypto_test.cc new file mode 100644 index 0000000000..b7dfb198a3 --- /dev/null +++ b/third_party/highway/hwy/tests/crypto_test.cc @@ -0,0 +1,553 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/crypto_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#define HWY_PRINT_CLMUL_GOLDEN 0 + +#if HWY_TARGET != HWY_SCALAR + +class TestAES { + template <typename T, class D> + HWY_NOINLINE void TestSBox(T /*unused*/, D d) { + // The generic implementation of the S-box is difficult to verify by + // inspection, so we add a white-box test that verifies it using enumeration + // (outputs for 0..255 vs. https://en.wikipedia.org/wiki/Rijndael_S-box). + const uint8_t sbox[256] = { + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, + 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, + 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, + 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, + 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, + 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, + 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, + 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, + 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, + 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, + 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, + 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, + 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, + 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, + 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, + 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, + 0xb0, 0x54, 0xbb, 0x16}; + + // Ensure it's safe to load an entire vector by padding. + const size_t N = Lanes(d); + const size_t padded = RoundUpTo(256, N); + auto expected = AllocateAligned<T>(padded); + // Must wrap around to match the input (Iota). + for (size_t pos = 0; pos < padded;) { + const size_t remaining = HWY_MIN(padded - pos, size_t(256)); + memcpy(expected.get() + pos, sbox, remaining); + pos += remaining; + } + + for (size_t i = 0; i < 256; i += N) { + const auto in = Iota(d, static_cast<T>(i)); + HWY_ASSERT_VEC_EQ(d, expected.get() + i, detail::SubBytes(in)); + } + } + + public: + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + // Test vector (after first KeyAddition) from + // https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/AES_Core128.pdf + alignas(16) constexpr uint8_t test_lanes[16] = { + 0x40, 0xBF, 0xAB, 0xF4, 0x06, 0xEE, 0x4D, 0x30, + 0x42, 0xCA, 0x6B, 0x99, 0x7A, 0x5C, 0x58, 0x16}; + const auto test = LoadDup128(d, test_lanes); + + // = ShiftRow result + alignas(16) constexpr uint8_t expected_sr_lanes[16] = { + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE}; + const auto expected_sr = LoadDup128(d, expected_sr_lanes); + + // = MixColumn result + alignas(16) constexpr uint8_t expected_mc_lanes[16] = { + 0x52, 0x9F, 0x16, 0xC2, 0x97, 0x86, 0x15, 0xCA, + 0xE0, 0x1A, 0xAE, 0x54, 0xBA, 0x1A, 0x26, 0x59}; + const auto expected_mc = LoadDup128(d, expected_mc_lanes); + + // = KeyAddition result + alignas(16) constexpr uint8_t expected_lanes[16] = { + 0xF2, 0x65, 0xE8, 0xD5, 0x1F, 0xD2, 0x39, 0x7B, + 0xC3, 0xB9, 0x97, 0x6D, 0x90, 0x76, 0x50, 0x5C}; + const auto expected = LoadDup128(d, expected_lanes); + + alignas(16) uint8_t key_lanes[16]; + for (size_t i = 0; i < 16; ++i) { + key_lanes[i] = expected_mc_lanes[i] ^ expected_lanes[i]; + } + const auto round_key = LoadDup128(d, key_lanes); + + HWY_ASSERT_VEC_EQ(d, expected_mc, AESRound(test, Zero(d))); + HWY_ASSERT_VEC_EQ(d, expected, AESRound(test, round_key)); + HWY_ASSERT_VEC_EQ(d, expected_sr, AESLastRound(test, Zero(d))); + HWY_ASSERT_VEC_EQ(d, Xor(expected_sr, round_key), + AESLastRound(test, round_key)); + + TestSBox(t, d); + } +}; +HWY_NOINLINE void TestAllAES() { ForGEVectors<128, TestAES>()(uint8_t()); } + +#else +HWY_NOINLINE void TestAllAES() {} +#endif // HWY_TARGET != HWY_SCALAR + +struct TestCLMul { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // needs 64 bit lanes and 128-bit result +#if HWY_TARGET != HWY_SCALAR && HWY_HAVE_INTEGER64 + const size_t N = Lanes(d); + if (N == 1) return; + + auto in1 = AllocateAligned<T>(N); + auto in2 = AllocateAligned<T>(N); + + constexpr size_t kCLMulNum = 512; + // Depends on rng! + static constexpr uint64_t kCLMulLower[kCLMulNum] = { + 0x24511d4ce34d6350ULL, 0x4ca582edde1236bbULL, 0x537e58f72dac25a8ULL, + 0x4e942d5e130b9225ULL, 0x75a906c519257a68ULL, 0x1df9f85126d96c5eULL, + 0x464e7c13f4ad286aULL, 0x138535ee35dabc40ULL, 0xb2f7477b892664ecULL, + 0x01557b077167c25dULL, 0xf32682490ee49624ULL, 0x0025bac603b9e140ULL, + 0xcaa86aca3e3daf40ULL, 0x1fbcfe4af73eb6c4ULL, 0x8ee8064dd0aae5dcULL, + 0x1248cb547858c213ULL, 0x37a55ee5b10fb34cULL, 0x6eb5c97b958f86e2ULL, + 0x4b1ab3eb655ea7cdULL, 0x1d66645a85627520ULL, 0xf8728e96daa36748ULL, + 0x38621043e6ff5e3bULL, 0xd1d28b5da5ffefb4ULL, 0x0a5cd65931546df7ULL, + 0x2a0639be3d844150ULL, 0x0e2d0f18c8d6f045ULL, 0xfacc770b963326c1ULL, + 0x19611b31ca2ef141ULL, 0xabea29510dd87518ULL, 0x18a7dc4b205f2768ULL, + 0x9d3975ea5612dc86ULL, 0x06319c139e374773ULL, 0x6641710400b4c390ULL, + 0x356c29b6001c3670ULL, 0xe9e04d851e040a00ULL, 0x21febe561222d79aULL, + 0xc071eaae6e148090ULL, 0x0eed351a0af94f5bULL, 0x04324eedb3c03688ULL, + 0x39e89b136e0d6ccdULL, 0x07d0fd2777a31600ULL, 0x44b8573827209822ULL, + 0x6d690229ea177d78ULL, 0x1b9749d960ba9f18ULL, 0x190945271c0fbb94ULL, + 0x189aea0e07d2c88eULL, 0xf18eab6b65a6beb2ULL, 0x57744b21c13d0d84ULL, + 0xf63050a613e95c2eULL, 0x12cd20d25f97102fULL, 0x5a5df0678dbcba60ULL, + 0x0b08fb80948bfafcULL, 0x44cf1cbe7c6fc3c8ULL, 0x166a470ef25da288ULL, + 0x2c498a609204e48cULL, 0x261b0a22585697ecULL, 0x737750574af7dde4ULL, + 0x4079959c60b01e0cULL, 0x06ed8aac13f782d6ULL, 0x019d454ba9b5ef20ULL, + 0xea1edbf96d49e858ULL, 0x17c2f3ebde9ac469ULL, 0x5cf72706e3d6f5e4ULL, + 0x16e856aa3c841516ULL, 0x256f7e3cef83368eULL, 0x47e17c8eb2774e77ULL, + 0x9b48ac150a804821ULL, 0x584523f61ccfdf22ULL, 0xedcb6a2a75d9e7f2ULL, + 0x1fe3d1838e537aa7ULL, 0x778872e9f64549caULL, 0x2f1cea6f0d3faf92ULL, + 0x0e8c4b6a9343f326ULL, 0x01902d1ba3048954ULL, 0xc5c1fd5269e91dc0ULL, + 0x0ef8a4707817eb9cULL, 0x1f696f09a5354ca4ULL, 0x369cd9de808b818cULL, + 0xf6917d1dd43fd784ULL, 0x7f4b76bf40dc166fULL, 0x4ce67698724ace12ULL, + 0x02c3bf60e6e9cd92ULL, 0xb8229e45b21458e8ULL, 0x415efd41e91adf49ULL, + 0x5edfcd516bb921cdULL, 0x5ff2c29429fd187eULL, 0x0af666b17103b3e0ULL, + 0x1f5e4ff8f54c9a5bULL, 0x429253d8a5544ba6ULL, 0x19de2fdf9f4d9dcaULL, + 0x29bf3d37ddc19a40ULL, 0x04d4513a879552baULL, 0x5cc7476cf71ee155ULL, + 0x40011f8c238784a5ULL, 0x1a3ae50b0fd2ee2bULL, 0x7db22f432ba462baULL, + 0x417290b0bee2284aULL, 0x055a6bd5bb853db2ULL, 0xaa667daeed8c2a34ULL, + 0x0d6b316bda7f3577ULL, 0x72d35598468e3d5dULL, 0x375b594804bfd33aULL, + 0x16ed3a319b540ae8ULL, 0x093bace4b4695afdULL, 0xc7118754ec2737ceULL, + 0x0fff361f0505c81aULL, 0x996e9e7291321af0ULL, 0x496b1d9b0b89ba8cULL, + 0x65a98b2e9181da9cULL, 0x70759c8dd45575dfULL, 0x3446fe727f5e2cbbULL, + 0x1121ae609d195e74ULL, 0x5ff5d68ce8a21018ULL, 0x0e27eca3825b60d6ULL, + 0x82f628bceca3d1daULL, 0x2756a0914e344047ULL, 0xa460406c1c708d50ULL, + 0x63ce32a0c083e491ULL, 0xc883e5a685c480e0ULL, 0x602c951891e600f9ULL, + 0x02ecb2e3911ca5f8ULL, 0x0d8675f4bb70781aULL, 0x43545cc3c78ea496ULL, + 0x04164b01d6b011c2ULL, 0x3acbb323dcab2c9bULL, 0x31c5ba4e22793082ULL, + 0x5a6484af5f7c2d10ULL, 0x1a929b16194e8078ULL, 0x7a6a75d03b313924ULL, + 0x0553c73a35b1d525ULL, 0xf18628c51142be34ULL, 0x1b51cf80d7efd8f5ULL, + 0x52e0ca4df63ee258ULL, 0x0e977099160650c9ULL, 0x6be1524e92024f70ULL, + 0x0ee2152625438b9dULL, 0xfa32af436f6d8eb4ULL, 0x5ecf49c2154287e5ULL, + 0x6b72f4ae3590569dULL, 0x086c5ee6e87bfb68ULL, 0x737a4f0dc04b6187ULL, + 0x08c3439280edea41ULL, 0x9547944f01636c5cULL, 0x6acfbfc2571cd71fULL, + 0x85d7842972449637ULL, 0x252ea5e5a7fad86aULL, 0x4e41468f99ba1632ULL, + 0x095e0c3ae63b25a2ULL, 0xb005ce88fd1c9425ULL, 0x748e668abbe09f03ULL, + 0xb2cfdf466b187d18ULL, 0x60b11e633d8fe845ULL, 0x07144c4d246db604ULL, + 0x139bcaac55e96125ULL, 0x118679b5a6176327ULL, 0x1cebe90fa4d9f83fULL, + 0x22244f52f0d312acULL, 0x669d4e17c9bfb713ULL, 0x96390e0b834bb0d0ULL, + 0x01f7f0e82ba08071ULL, 0x2dffeee31ca6d284ULL, 0x1f4738745ef039feULL, + 0x4ce0dd2b603b6420ULL, 0x0035fc905910a4d5ULL, 0x07df2b533df6fb04ULL, + 0x1cee2735c9b910ddULL, 0x2bc4af565f7809eaULL, 0x2f876c1f5cb1076cULL, + 0x33e079524099d056ULL, 0x169e0405d2f9efbaULL, 0x018643ab548a358cULL, + 0x1bb6fc4331cffe92ULL, 0x05111d3a04e92faaULL, 0x23c27ecf0d638b73ULL, + 0x1b79071dc1685d68ULL, 0x0662d20aba8e1e0cULL, 0xe7f6440277144c6fULL, + 0x4ca38b64c22196c0ULL, 0x43c05f6d1936fbeeULL, 0x0654199d4d1faf0fULL, + 0xf2014054e71c2d04ULL, 0x0a103e47e96b4c84ULL, 0x7986e691dd35b040ULL, + 0x4e1ebb53c306a341ULL, 0x2775bb3d75d65ba6ULL, 0x0562ab0adeff0f15ULL, + 0x3c2746ad5eba3eacULL, 0x1facdb5765680c60ULL, 0xb802a60027d81d00ULL, + 0x1191d0f6366ae3a9ULL, 0x81a97b5ae0ea5d14ULL, 0x06bee05b6178a770ULL, + 0xc7baeb2fe1d6aeb3ULL, 0x594cb5b867d04fdfULL, 0xf515a80138a4e350ULL, + 0x646417ad8073cf38ULL, 0x4a229a43373fb8d4ULL, 0x10fa6eafff1ca453ULL, + 0x9f060700895cc731ULL, 0x00521133d11d11f4ULL, 0xb940a2bb912a7a5cULL, + 0x3fab180670ad2a3cULL, 0x45a5f0e5b6fdb95dULL, 0x27c1baad6f946b15ULL, + 0x336c6bdbe527cf58ULL, 0x3b83aa602a5baea3ULL, 0xdf749153f9bcc376ULL, + 0x1a05513a6c0b4a90ULL, 0xb81e0b570a075c47ULL, 0x471fabb40bdc27ceULL, + 0x9dec9472f6853f60ULL, 0x361f71b88114193bULL, 0x3b550a8c4feeff00ULL, + 0x0f6cde5a68bc9bc0ULL, 0x3f50121a925703e0ULL, 0x6967ff66d6d343a9ULL, + 0xff6b5bd2ce7bc3ccULL, 0x05474cea08bf6cd8ULL, 0xf76eabbfaf108eb0ULL, + 0x067529be4fc6d981ULL, 0x4d766b137cf8a988ULL, 0x2f09c7395c5cfbbdULL, + 0x388793712da06228ULL, 0x02c9ff342c8f339aULL, 0x152c734139a860a3ULL, + 0x35776eb2b270c04dULL, 0x0f8d8b41f11c4608ULL, 0x0c2071665be6b288ULL, + 0xc034e212b3f71d88ULL, 0x071d961ef3276f99ULL, 0xf98598ee75b60773ULL, + 0x062062c58c6724e4ULL, 0xd156438e2125572cULL, 0x38552d59a7f0f7c8ULL, + 0x1a402178206e413cULL, 0x1f1f996c68293b26ULL, 0x8bce3cafe1730f7eULL, + 0x2d0480a0828f6bf5ULL, 0x6c99cffa171f92f6ULL, 0x0087f842bb0ac681ULL, + 0x11d7ed06e1e7fd3eULL, 0x07cb1186f2385dc6ULL, 0x5d7763ebff1e170fULL, + 0x2dacc870231ac292ULL, 0x8486317a9ffb390cULL, 0x1c3a6dd20c959ac6ULL, + 0x90dc96e3992e06b8ULL, 0x70d60bfa33e72b67ULL, 0x70c9bddd0985ee63ULL, + 0x012c9767b3673093ULL, 0xfcd3bc5580f6a88aULL, 0x0ac80017ef6308c3ULL, + 0xdb67d709ef4bba09ULL, 0x4c63e324f0e247ccULL, 0xa15481d3fe219d60ULL, + 0x094c4279cdccb501ULL, 0x965a28c72575cb82ULL, 0x022869db25e391ebULL, + 0x37f528c146023910ULL, 0x0c1290636917deceULL, 0x9aee25e96251ca9cULL, + 0x728ac5ba853b69c2ULL, 0x9f272c93c4be20c8ULL, 0x06c1aa6319d28124ULL, + 0x4324496b1ca8a4f7ULL, 0x0096ecfe7dfc0189ULL, 0x9e06131b19ae0020ULL, + 0x15278b15902f4597ULL, 0x2a9fece8c13842d8ULL, 0x1d4e6781f0e1355eULL, + 0x6855b712d3dbf7c0ULL, 0x06a07fad99be6f46ULL, 0x3ed9d7957e4d1d7cULL, + 0x0c326f7cbc248bb2ULL, 0xe6363ad2c537cf51ULL, 0x0e12eb1c40723f13ULL, + 0xf5c6ac850afba803ULL, 0x0322a79d615fa9f0ULL, 0x6116696ed97bd5f8ULL, + 0x0d438080fbbdc9f1ULL, 0x2e4dc42c38f1e243ULL, 0x64948e9104f3a5bfULL, + 0x9fd622371bdb5f00ULL, 0x0f12bf082b2a1b6eULL, 0x4b1f8d867d78031cULL, + 0x134392ea9f5ef832ULL, 0xf3d70472321bc23eULL, 0x05fcbe5e9eea268eULL, + 0x136dede7175a22cfULL, 0x1308f8baac2cbcccULL, 0xd691026f0915eb64ULL, + 0x0e49a668345c3a38ULL, 0x24ddbbe8bc96f331ULL, 0x4d2ec9479b640578ULL, + 0x450f0697327b359cULL, 0x32b45360f4488ee0ULL, 0x4f6d9ecec46a105aULL, + 0x5500c63401ae8e80ULL, 0x47dea495cf6f98baULL, 0x13dc9a2dfca80babULL, + 0xe6f8a93f7b24ca92ULL, 0x073f57a6d900a87fULL, 0x9ddb935fd3aa695aULL, + 0x101e98d24b39e8aaULL, 0x6b8d0eb95a507ddcULL, 0x45a908b3903d209bULL, + 0x6c96a3e119e617d4ULL, 0x2442787543d3be48ULL, 0xd3bc055c7544b364ULL, + 0x7693bb042ca8653eULL, 0xb95e3a4ea5d0101eULL, 0x116f0d459bb94a73ULL, + 0x841244b72cdc5e90ULL, 0x1271acced6cb34d3ULL, 0x07d289106524d638ULL, + 0x537c9cf49c01b5bbULL, 0x8a8e16706bb7a5daULL, 0x12e50a9c499dc3a9ULL, + 0x1cade520db2ba830ULL, 0x1add52f000d7db70ULL, 0x12cf15db2ce78e30ULL, + 0x0657eaf606bfc866ULL, 0x4026816d3b05b1d0ULL, 0x1ba0ebdf90128e4aULL, + 0xdfd649375996dd6eULL, 0x0f416e906c23d9aeULL, 0x384273cad0582a24ULL, + 0x2ff27b0378a46189ULL, 0xc4ecd18a2d7a7616ULL, 0x35cef0b5cd51d640ULL, + 0x7d582363643f48b7ULL, 0x0984ad746ad0ab7cULL, 0x2990a999835f9688ULL, + 0x2d4df66a97b19e05ULL, 0x592c79720af99aa2ULL, 0x052863c230602cd3ULL, + 0x5f5e2b15edcf2840ULL, 0x01dff1b694b978b0ULL, 0x14345a48b622025eULL, + 0x028fab3b6407f715ULL, 0x3455d188e6feca50ULL, 0x1d0d40288fb1b5fdULL, + 0x4685c5c2b6a1e5aeULL, 0x3a2077b1e5fe5adeULL, 0x1bc55d611445a0d8ULL, + 0x05480ae95f3f83feULL, 0xbbb59cfcf7e17fb6ULL, 0x13f7f10970bbb990ULL, + 0x6d00ac169425a352ULL, 0x7da0db397ef2d5d3ULL, 0x5b512a247f8d2479ULL, + 0x637eaa6a977c3c32ULL, 0x3720f0ae37cba89cULL, 0x443df6e6aa7f525bULL, + 0x28664c287dcef321ULL, 0x03c267c00cf35e49ULL, 0x690185572d4021deULL, + 0x2707ff2596e321c2ULL, 0xd865f5af7722c380ULL, 0x1ea285658e33aafbULL, + 0xc257c5e88755bef4ULL, 0x066f67275cfcc31eULL, 0xb09931945cc0fed0ULL, + 0x58c1dc38d6e3a03fULL, 0xf99489678fc94ee8ULL, 0x75045bb99be5758aULL, + 0x6c163bc34b40feefULL, 0x0420063ce7bdd3b4ULL, 0xf86ef10582bf2e28ULL, + 0x162c3449ca14858cULL, 0x94106aa61dfe3280ULL, 0x4073ae7a4e7e4941ULL, + 0x32b13fd179c250b4ULL, 0x0178fbb216a7e744ULL, 0xf840ae2f1cf92669ULL, + 0x18fc709acc80243dULL, 0x20ac2ebd69f4d558ULL, 0x6e580ad9c73ad46aULL, + 0x76d2b535b541c19dULL, 0x6c7a3fb9dd0ce0afULL, 0xc3481689b9754f28ULL, + 0x156e813b6557abdbULL, 0x6ee372e31276eb10ULL, 0x19cf37c038c8d381ULL, + 0x00d4d906c9ae3072ULL, 0x09f03cbb6dfbfd40ULL, 0x461ba31c4125f3cfULL, + 0x25b29fc63ad9f05bULL, 0x6808c95c2dddede9ULL, 0x0564224337066d9bULL, + 0xc87eb5f4a4d966f2ULL, 0x66fc66e1701f5847ULL, 0xc553a3559f74da28ULL, + 0x1dfd841be574df43ULL, 0x3ee2f100c3ebc082ULL, 0x1a2c4f9517b56e89ULL, + 0x502f65c4b535c8ffULL, 0x1da5663ab6f96ec0ULL, 0xba1f80b73988152cULL, + 0x364ff12182ac8dc1ULL, 0xe3457a3c4871db31ULL, 0x6ae9cadf92fd7e84ULL, + 0x9621ba3d6ca15186ULL, 0x00ff5af878c144ceULL, 0x918464dc130101a4ULL, + 0x036511e6b187efa6ULL, 0x06667d66550ff260ULL, 0x7fd18913f9b51bc1ULL, + 0x3740e6b27af77aa8ULL, 0x1f546c2fd358ff8aULL, 0x42f1424e3115c891ULL, + 0x03767db4e3a1bb33ULL, 0xa171a1c564345060ULL, 0x0afcf632fd7b1324ULL, + 0xb59508d933ffb7d0ULL, 0x57d766c42071be83ULL, 0x659f0447546114a2ULL, + 0x4070364481c460aeULL, 0xa2b9752280644d52ULL, 0x04ab884bea5771bdULL, + 0x87cd135602a232b4ULL, 0x15e54cd9a8155313ULL, 0x1e8005efaa3e1047ULL, + 0x696b93f4ab15d39fULL, 0x0855a8e540de863aULL, 0x0bb11799e79f9426ULL, + 0xeffa61e5c1b579baULL, 0x1e060a1d11808219ULL, 0x10e219205667c599ULL, + 0x2f7b206091c49498ULL, 0xb48854c820064860ULL, 0x21c4aaa3bfbe4a38ULL, + 0x8f4a032a3fa67e9cULL, 0x3146b3823401e2acULL, 0x3afee26f19d88400ULL, + 0x167087c485791d38ULL, 0xb67a1ed945b0fb4bULL, 0x02436eb17e27f1c0ULL, + 0xe05afce2ce2d2790ULL, 0x49c536fc6224cfebULL, 0x178865b3b862b856ULL, + 0x1ce530de26acde5bULL, 0x87312c0b30a06f38ULL, 0x03e653b578558d76ULL, + 0x4d3663c21d8b3accULL, 0x038003c23626914aULL, 0xd9d5a2c052a09451ULL, + 0x39b5acfe08a49384ULL, 0x40f349956d5800e4ULL, 0x0968b6950b1bd8feULL, + 0xd60b2ca030f3779cULL, 0x7c8bc11a23ce18edULL, 0xcc23374e27630bc2ULL, + 0x2e38fc2a8bb33210ULL, 0xe421357814ee5c44ULL, 0x315fb65ea71ec671ULL, + 0xfb1b0223f70ed290ULL, 0x30556c9f983eaf07ULL, 0x8dd438c3d0cd625aULL, + 0x05a8fd0c7ffde71bULL, 0x764d1313b5aeec7aULL, 0x2036af5de9622f47ULL, + 0x508a5bfadda292feULL, 0x3f77f04ba2830e90ULL, 0x9047cd9c66ca66d2ULL, + 0x1168b5318a54eb21ULL, 0xc93462d221da2e15ULL, 0x4c2c7cc54abc066eULL, + 0x767a56fec478240eULL, 0x095de72546595bd3ULL, 0xc9da535865158558ULL, + 0x1baccf36f33e73fbULL, 0xf3d7dbe64df77f18ULL, 0x1f8ebbb7be4850b8ULL, + 0x043c5ed77bce25a1ULL, 0x07d401041b2a178aULL, 0x9181ebb8bd8d5618ULL, + 0x078b935dc3e4034aULL, 0x7b59c08954214300ULL, 0x03570dc2a4f84421ULL, + 0xdd8715b82f6b4078ULL, 0x2bb49c8bb544163bULL, 0xc9eb125564d59686ULL, + 0x5fdc7a38f80b810aULL, 0x3a4a6d8fff686544ULL, 0x28360e2418627d3aULL, + 0x60874244c95ed992ULL, 0x2115cc1dd9c34ed3ULL, 0xfaa3ef61f55e9efcULL, + 0x27ac9b1ef1adc7e6ULL, 0x95ea00478fec3f54ULL, 0x5aea808b2d99ab43ULL, + 0xc8f79e51fe43a580ULL, 0x5dbccd714236ce25ULL, 0x783fa76ed0753458ULL, + 0x48cb290f19d84655ULL, 0xc86a832f7696099aULL, 0x52f30c6fec0e71d3ULL, + 0x77d4e91e8cdeb886ULL, 0x7169a703c6a79ccdULL, 0x98208145b9596f74ULL, + 0x0945695c761c0796ULL, 0x0be897830d17bae0ULL, 0x033ad3924caeeeb4ULL, + 0xedecb6cfa2d303a8ULL, 0x3f86b074818642e7ULL, 0xeefa7c878a8b03f4ULL, + 0x093c101b80922551ULL, 0xfb3b4e6c26ac0034ULL, 0x162bf87999b94f5eULL, + 0xeaedae76e975b17cULL, 0x1852aa090effe18eULL}; + + static constexpr uint64_t kCLMulUpper[kCLMulNum] = { + 0xbb41199b1d587c69ULL, 0x514d94d55894ee29ULL, 0xebc6cd4d2efd5d16ULL, + 0x042044ad2de477fdULL, 0xb865c8b0fcdf4b15ULL, 0x0724d7e551cc40f3ULL, + 0xb15a16f39edb0bccULL, 0x37d64419ede7a171ULL, 0x2aa01bb80c753401ULL, + 0x06ff3f8a95fdaf4dULL, 0x79898cc0838546deULL, 0x776acbd1b237c60aULL, + 0x4c1753be4f4e0064ULL, 0x0ba9243601206ed3ULL, 0xd567c3b1bf3ec557ULL, + 0x043fac7bcff61fb3ULL, 0x49356232b159fb2fULL, 0x3910c82038102d4dULL, + 0x30592fef753eb300ULL, 0x7b2660e0c92a9e9aULL, 0x8246c9248d671ef0ULL, + 0x5a0dcd95147af5faULL, 0x43fde953909cc0eaULL, 0x06147b972cb96e1bULL, + 0xd84193a6b2411d80ULL, 0x00cd7711b950196fULL, 0x1088f9f4ade7fa64ULL, + 0x05a13096ec113cfbULL, 0x958d816d53b00edcULL, 0x3846154a7cdba9cbULL, + 0x8af516db6b27d1e6ULL, 0x1a1d462ab8a33b13ULL, 0x4040b0ac1b2c754cULL, + 0x05127fe9af2fe1d6ULL, 0x9f96e79374321fa6ULL, 0x06ff64a4d9c326f3ULL, + 0x28709566e158ac15ULL, 0x301701d7111ca51cULL, 0x31e0445d1b9d9544ULL, + 0x0a95aff69bf1d03eULL, 0x7c298c8414ecb879ULL, 0x00801499b4143195ULL, + 0x91521a00dd676a5cULL, 0x2777526a14c2f723ULL, 0xfa26aac6a6357dddULL, + 0x1d265889b0187a4bULL, 0xcd6e70fa8ed283e4ULL, 0x18a815aa50ea92caULL, + 0xc01e082694a263c6ULL, 0x4b40163ba53daf25ULL, 0xbc658caff6501673ULL, + 0x3ba35359586b9652ULL, 0x74f96acc97a4936cULL, 0x3989dfdb0cf1d2cfULL, + 0x358a01eaa50dda32ULL, 0x01109a5ed8f0802bULL, 0x55b84922e63c2958ULL, + 0x55b14843d87551d5ULL, 0x1db8ec61b1b578d8ULL, 0x79a2d49ef8c3658fULL, + 0xa304516816b3fbe0ULL, 0x163ecc09cc7b82f9ULL, 0xab91e8d22aabef00ULL, + 0x0ed6b09262de8354ULL, 0xcfd47d34cf73f6f2ULL, 0x7dbd1db2390bc6c3ULL, + 0x5ae789d3875e7b00ULL, 0x1d60fd0e70fe8fa4ULL, 0x690bc15d5ae4f6f5ULL, + 0x121ef5565104fb44ULL, 0x6e98e89297353b54ULL, 0x42554949249d62edULL, + 0xd6d6d16b12df78d2ULL, 0x320b33549b74975dULL, 0xd2a0618763d22e00ULL, + 0x0808deb93cba2017ULL, 0x01bd3b2302a2cc70ULL, 0x0b7b8dd4d71c8dd6ULL, + 0x34d60a3382a0756cULL, 0x40984584c8219629ULL, 0xf1152cba10093a66ULL, + 0x068001c6b2159ccbULL, 0x3d70f13c6cda0800ULL, 0x0e6b6746a322b956ULL, + 0x83a494319d8c770bULL, 0x0faecf64a8553e9aULL, 0xa34919222c39b1bcULL, + 0x0c63850d89e71c6fULL, 0x585f0bee92e53dc8ULL, 0x10f222b13b4fa5deULL, + 0x61573114f94252f2ULL, 0x09d59c311fba6c27ULL, 0x014effa7da49ed4eULL, + 0x4a400a1bc1c31d26ULL, 0xc9091c047b484972ULL, 0x3989f341ec2230ccULL, + 0xdcb03a98b3aee41eULL, 0x4a54a676a33a95e1ULL, 0xe499b7753951ef7cULL, + 0x2f43b1d1061d8b48ULL, 0xc3313bdc68ceb146ULL, 0x5159f6bc0e99227fULL, + 0x98128e6d9c05efcaULL, 0x15ea32b27f77815bULL, 0xe882c054e2654eecULL, + 0x003d2cdb8faee8c6ULL, 0xb416dd333a9fe1dfULL, 0x73f6746aefcfc98bULL, + 0x93dc114c10a38d70ULL, 0x05055941657845eaULL, 0x2ed7351347349334ULL, + 0x26fb1ee2c69ae690ULL, 0xa4575d10dc5b28e0ULL, 0x3395b11295e485ebULL, + 0xe840f198a224551cULL, 0x78e6e5a431d941d4ULL, 0xa1fee3ceab27f391ULL, + 0x07d35b3c5698d0dcULL, 0x983c67fca9174a29ULL, 0x2bb6bbae72b5144aULL, + 0xa7730b8d13ce58efULL, 0x51b5272883de1998ULL, 0xb334e128bb55e260ULL, + 0x1cacf5fbbe1b9974ULL, 0x71a9df4bb743de60ULL, 0x5176fe545c2d0d7aULL, + 0xbe592ecf1a16d672ULL, 0x27aa8a30c3efe460ULL, 0x4c78a32f47991e06ULL, + 0x383459294312f26aULL, 0x97ba789127f1490cULL, 0x51c9aa8a3abd1ef1ULL, + 0xcc7355188121e50fULL, 0x0ecb3a178ae334c1ULL, 0x84879a5e574b7160ULL, + 0x0765298f6389e8f3ULL, 0x5c6750435539bb22ULL, 0x11a05cf056c937b5ULL, + 0xb5dc2172dbfb7662ULL, 0x3ffc17915d9f40e8ULL, 0xbc7904daf3b431b0ULL, + 0x71f2088490930a7cULL, 0xa89505fd9efb53c4ULL, 0x02e194afd61c5671ULL, + 0x99a97f4abf35fcecULL, 0x26830aad30fae96fULL, 0x4b2abc16b25cf0b0ULL, + 0x07ec6fffa1cafbdbULL, 0xf38188fde97a280cULL, 0x121335701afff64dULL, + 0xea5ef38b4e672a64ULL, 0x477edbcae3eabf03ULL, 0xa32813cc0e0d244dULL, + 0x13346d2af4972eefULL, 0xcbc18357af1cfa9aULL, 0x561b630316e73fa6ULL, + 0xe9dfb53249249305ULL, 0x5d2b9dd1479312eeULL, 0x3458008119b56d04ULL, + 0x50e6790b49801385ULL, 0x5bb9febe2349492bULL, 0x0c2813954299098fULL, + 0xf747b0c890a071d5ULL, 0x417e8f82cc028d77ULL, 0xa134fee611d804f8ULL, + 0x24c99ee9a0408761ULL, 0x3ebb224e727137f3ULL, 0x0686022073ceb846ULL, + 0xa05e901fb82ad7daULL, 0x0ece7dc43ab470fcULL, 0x2d334ecc58f7d6a3ULL, + 0x23166fadacc54e40ULL, 0x9c3a4472f839556eULL, 0x071717ab5267a4adULL, + 0xb6600ac351ba3ea0ULL, 0x30ec748313bb63d4ULL, 0xb5374e39287b23ccULL, + 0x074d75e784238aebULL, 0x77315879243914a4ULL, 0x3bbb1971490865f1ULL, + 0xa355c21f4fbe02d3ULL, 0x0027f4bb38c8f402ULL, 0xeef8708e652bc5f0ULL, + 0x7b9aa56cf9440050ULL, 0x113ac03c16cfc924ULL, 0x395db36d3e4bef9fULL, + 0x5d826fabcaa597aeULL, 0x2a77d3c58786d7e0ULL, 0x85996859a3ba19d4ULL, + 0x01e7e3c904c2d97fULL, 0x34f90b9b98d51fd0ULL, 0x243aa97fd2e99bb7ULL, + 0x40a0cebc4f65c1e8ULL, 0x46d3922ed4a5503eULL, 0x446e7ecaf1f9c0a4ULL, + 0x49dc11558bc2e6aeULL, 0xe7a9f20881793af8ULL, 0x5771cc4bc98103f1ULL, + 0x2446ea6e718fce90ULL, 0x25d14aca7f7da198ULL, 0x4347af186f9af964ULL, + 0x10cb44fc9146363aULL, 0x8a35587afce476b4ULL, 0x575144662fee3d3aULL, + 0x69f41177a6bc7a05ULL, 0x02ff8c38d6b3c898ULL, 0x57c73589a226ca40ULL, + 0x732f6b5baae66683ULL, 0x00c008bbedd4bb34ULL, 0x7412ff09524d6cadULL, + 0xb8fd0b5ad8c145a8ULL, 0x74bd9f94b6cdc7dfULL, 0x68233b317ca6c19cULL, + 0x314b9c2c08b15c54ULL, 0x5bd1ad72072ebd08ULL, 0x6610e6a6c07030e4ULL, + 0xa4fc38e885ead7ceULL, 0x36975d1ca439e034ULL, 0xa358f0fe358ffb1aULL, + 0x38e247ad663acf7dULL, 0x77daed3643b5deb8ULL, 0x5507c2aeae1ec3d0ULL, + 0xfdec226c73acf775ULL, 0x1b87ff5f5033492dULL, 0xa832dee545d9033fULL, + 0x1cee43a61e41783bULL, 0xdff82b2e2d822f69ULL, 0x2bbc9a376cb38cf2ULL, + 0x117b1cdaf765dc02ULL, 0x26a407f5682be270ULL, 0x8eb664cf5634af28ULL, + 0x17cb4513bec68551ULL, 0xb0df6527900cbfd0ULL, 0x335a2dc79c5afdfcULL, + 0xa2f0ca4cd38dca88ULL, 0x1c370713b81a2de1ULL, 0x849d5df654d1adfcULL, + 0x2fd1f7675ae14e44ULL, 0x4ff64dfc02247f7bULL, 0x3a2bcf40e395a48dULL, + 0x436248c821b187c1ULL, 0x29f4337b1c7104c0ULL, 0xfc317c46e6630ec4ULL, + 0x2774bccc4e3264c7ULL, 0x2d03218d9d5bee23ULL, 0x36a0ed04d659058aULL, + 0x452484461573cab6ULL, 0x0708edf87ed6272bULL, 0xf07960a1587446cbULL, + 0x3660167b067d84e0ULL, 0x65990a6993ddf8c4ULL, 0x0b197cd3d0b40b3fULL, + 0x1dcec4ab619f3a05ULL, 0x722ab223a84f9182ULL, 0x0822d61a81e7c38fULL, + 0x3d22ad75da563201ULL, 0x93cef6979fd35e0fULL, 0x05c3c25ae598b14cULL, + 0x1338df97dd496377ULL, 0x15bc324dc9c20acfULL, 0x96397c6127e6e8cfULL, + 0x004d01069ef2050fULL, 0x2fcf2e27893fdcbcULL, 0x072f77c3e44f4a5cULL, + 0x5eb1d80b3fe44918ULL, 0x1f59e7c28cc21f22ULL, 0x3390ce5df055c1f8ULL, + 0x4c0ef11df92cb6bfULL, 0x50f82f9e0848c900ULL, 0x08d0fde3ffc0ae38ULL, + 0xbd8d0089a3fbfb73ULL, 0x118ba5b0f311ef59ULL, 0x9be9a8407b926a61ULL, + 0x4ea04fbb21318f63ULL, 0xa1c8e7bb07b871ffULL, 0x1253a7262d5d3b02ULL, + 0x13e997a0512e5b29ULL, 0x54318460ce9055baULL, 0x4e1d8a4db0054798ULL, + 0x0b235226e2cade32ULL, 0x2588732c1476b315ULL, 0x16a378750ba8ac68ULL, + 0xba0b116c04448731ULL, 0x4dd02bd47694c2f1ULL, 0x16d6797b218b6b25ULL, + 0x769eb3709cfbf936ULL, 0x197746a0ce396f38ULL, 0x7d17ad8465961d6eULL, + 0xfe58f4998ae19bb4ULL, 0x36df24305233ce69ULL, 0xb88a4eb008f4ee72ULL, + 0x302b2eb923334787ULL, 0x15a4e3edbe13d448ULL, 0x39a4bf64dd7730ceULL, + 0xedf25421b31090c4ULL, 0x4d547fc131be3b69ULL, 0x2b316e120ca3b90eULL, + 0x0faf2357bf18a169ULL, 0x71f34b54ee2c1d62ULL, 0x18eaf6e5c93a3824ULL, + 0x7e168ba03c1b4c18ULL, 0x1a534dd586d9e871ULL, 0xa2cccd307f5f8c38ULL, + 0x2999a6fb4dce30f6ULL, 0x8f6d3b02c1d549a6ULL, 0x5cf7f90d817aac5aULL, + 0xd2a4ceefe66c8170ULL, 0x11560edc4ca959feULL, 0x89e517e6f0dc464dULL, + 0x75bb8972dddd2085ULL, 0x13859ed1e459d65aULL, 0x057114653326fa84ULL, + 0xe2e6f465173cc86cULL, 0x0ada4076497d7de4ULL, 0xa856fa10ec6dbf8aULL, + 0x41505d9a7c25d875ULL, 0x3091b6278382eccdULL, 0x055737185b2c3f13ULL, + 0x2f4df8ecd6f9c632ULL, 0x0633e89c33552d98ULL, 0xf7673724d16db440ULL, + 0x7331bd08e636c391ULL, 0x0252f29672fee426ULL, 0x1fc384946b6b9ddeULL, + 0x03460c12c901443aULL, 0x003a0792e10abcdaULL, 0x8dbec31f624e37d0ULL, + 0x667420d5bfe4dcbeULL, 0xfbfa30e874ed7641ULL, 0x46d1ae14db7ecef6ULL, + 0x216bd7e8f5448768ULL, 0x32bcd40d3d69cc88ULL, 0x2e991dbc39b65abeULL, + 0x0e8fb123a502f553ULL, 0x3d2d486b2c7560c0ULL, 0x09aba1db3079fe03ULL, + 0xcb540c59398c9bceULL, 0x363970e5339ed600ULL, 0x2caee457c28af00eULL, + 0x005e7d7ee47f41a0ULL, 0x69fad3eb10f44100ULL, 0x048109388c75beb3ULL, + 0x253dddf96c7a6fb8ULL, 0x4c47f705b9d47d09ULL, 0x6cec894228b5e978ULL, + 0x04044bb9f8ff45c2ULL, 0x079e75704d775caeULL, 0x073bd54d2a9e2c33ULL, + 0xcec7289270a364fbULL, 0x19e7486f19cd9e4eULL, 0xb50ac15b86b76608ULL, + 0x0620cf81f165c812ULL, 0x63eaaf13be7b11d4ULL, 0x0e0cf831948248c2ULL, + 0xf0412df8f46e7957ULL, 0x671c1fe752517e3fULL, 0x8841bfb04dd3f540ULL, + 0x122de4142249f353ULL, 0x40a4959fb0e76870ULL, 0x25cfd3d4b4bbc459ULL, + 0x78a07c82930c60d0ULL, 0x12c2de24d4cbc969ULL, 0x85d44866096ad7f4ULL, + 0x1fd917ca66b2007bULL, 0x01fbbb0751764764ULL, 0x3d2a4953c6fe0fdcULL, + 0xcc1489c5737afd94ULL, 0x1817c5b6a5346f41ULL, 0xe605a6a7e9985644ULL, + 0x3c50412328ff1946ULL, 0xd8c7fd65817f1291ULL, 0x0bd66975ab66339bULL, + 0x2baf8fa1c7d10fa9ULL, 0x24abdf06ddef848dULL, 0x14df0c9b2ea4f6c2ULL, + 0x2be950edfd2cb1f7ULL, 0x21911e21094178b6ULL, 0x0fa54d518a93b379ULL, + 0xb52508e0ac01ab42ULL, 0x0e035b5fd8cb79beULL, 0x1c1c6d1a3b3c8648ULL, + 0x286037b42ea9871cULL, 0xfe67bf311e48a340ULL, 0x02324131e932a472ULL, + 0x2486dc2dd919e2deULL, 0x008aec7f1da1d2ebULL, 0x63269ba0e8d3eb3aULL, + 0x23c0f11154adb62fULL, 0xc6052393ecd4c018ULL, 0x523585b7d2f5b9fcULL, + 0xf7e6f8c1e87564c9ULL, 0x09eb9fe5dd32c1a3ULL, 0x4d4f86886e055472ULL, + 0x67ea17b58a37966bULL, 0x3d3ce8c23b1ed1a8ULL, 0x0df97c5ac48857ceULL, + 0x9b6992623759eb12ULL, 0x275aa9551ae091f2ULL, 0x08855e19ac5e62e5ULL, + 0x1155fffe0ae083ccULL, 0xbc9c78db7c570240ULL, 0x074560c447dd2418ULL, + 0x3bf78d330bcf1e70ULL, 0x49867cd4b7ed134bULL, 0x8e6eee0cb4470accULL, + 0x1dabafdf59233dd6ULL, 0xea3a50d844fc3fb8ULL, 0x4f03f4454764cb87ULL, + 0x1f2f41cc36c9e6ecULL, 0x53cba4df42963441ULL, 0x10883b70a88d91fbULL, + 0x62b1fc77d4eb9481ULL, 0x893d8f2604b362e1ULL, 0x0933b7855368b440ULL, + 0x9351b545703b2fceULL, 0x59c1d489b9bdd3b4ULL, 0xe72a9c4311417b18ULL, + 0x5355df77e88eb226ULL, 0xe802c37aa963d7e1ULL, 0x381c3747bd6c3bc3ULL, + 0x378565573444258cULL, 0x37848b1e52b43c18ULL, 0x5da2cd32bdce12b6ULL, + 0x13166c5da615f6fdULL, 0xa51ef95efcc66ac8ULL, 0x640c95e473f1e541ULL, + 0x6ec68def1f217500ULL, 0x49ce3543c76a4079ULL, 0x5fc6fd3cddc706b5ULL, + 0x05c3c0f0f6a1fb0dULL, 0xe7820c0996ad1bddULL, 0x21f0d752a088f35cULL, + 0x755405b51d6fc4a0ULL, 0x7ec7649ca4b0e351ULL, 0x3d2b6a46a251f790ULL, + 0x23e1176b19f418adULL, 0x06056575efe8ac05ULL, 0x0f75981b6966e477ULL, + 0x06e87ec41ad437e4ULL, 0x43f6c255d5e1cb84ULL, 0xe4e67d1120ceb580ULL, + 0x2cd67b9e12c26d7bULL, 0xcd00b5ff7fd187f1ULL, 0x3f6cd40accdc4106ULL, + 0x3e895c835459b330ULL, 0x0814d53a217c0850ULL, 0xc9111fe78bc3a62dULL, + 0x719967e351473204ULL, 0xe757707d24282aa4ULL, 0x7226b7f5607f98e6ULL, + 0x7b268ffae3c08d96ULL, 0x16d3917c8b86020eULL, 0x5128bca51c49ea64ULL, + 0x345ffea02bb1698dULL, 0x9460f5111fe4fbc8ULL, 0x60dd1aa5762852cbULL, + 0xbb7440ed3c81667cULL, 0x0a4b12affa7f6f5cULL, 0x95cbcb0ae03861b6ULL, + 0x07ab3b0591db6070ULL, 0xc6476a4c3de78982ULL, 0x204e82e8623ad725ULL, + 0x569a5b4e8ac2a5ccULL, 0x425a1d77d72ebae2ULL, 0xcdaad5551ab33830ULL, + 0x0b7c68fd8422939eULL, 0x46d9a01f53ec3020ULL, 0x102871edbb29e852ULL, + 0x7a8e8084039075a5ULL, 0x40eaede8615e376aULL, 0x4dc67d757a1c751fULL, + 0x1176ef33063f9145ULL, 0x4ea230285b1c8156ULL, 0x6b2aa46ce0027392ULL, + 0x32b13230fba1b068ULL, 0x0e69796851bb984fULL, 0xb749f4542db698c0ULL, + 0x19ad0241ffffd49cULL, 0x2f41e92ef6caff52ULL, 0x4d0b068576747439ULL, + 0x14d607aef7463e00ULL, 0x1443d00d85fb440eULL, 0x529b43bf68688780ULL, + 0x21133a6bc3a3e378ULL, 0x865b6436dae0e7e5ULL, 0x6b4fe83dc1d6defcULL, + 0x03a5858a0ca0be46ULL, 0x1e841b187e67f312ULL, 0x61ee22ef40a66940ULL, + 0x0494bd2e9e741ef8ULL, 0x4eb59e323010e72cULL, 0x19f2abcfb749810eULL, + 0xb30f1e4f994ef9bcULL, 0x53cf6cdd51bd2d96ULL, 0x263943036497a514ULL, + 0x0d4b52170aa2edbaULL, 0x0c4758a1c7b4f758ULL, 0x178dadb1b502b51aULL, + 0x1ddbb20a602eb57aULL, 0x1fc2e2564a9f27fdULL, 0xd5f8c50a0e3d6f90ULL, + 0x0081da3bbe72ac09ULL, 0xcf140d002ccdb200ULL, 0x0ae8389f09b017feULL, + 0x17cc9ffdc03f4440ULL, 0x04eb921d704bcdddULL, 0x139a0ce4cdc521abULL, + 0x0bfce00c145cb0f0ULL, 0x99925ff132eff707ULL, 0x063f6e5da50c3d35ULL, + 0xa0c25dea3f0e6e29ULL, 0x0c7a9048cc8e040fULL, + }; + + const size_t padded = RoundUpTo(kCLMulNum, N); + auto expected_lower = AllocateAligned<T>(padded); + auto expected_upper = AllocateAligned<T>(padded); + CopyBytes<kCLMulNum * sizeof(T)>(kCLMulLower, expected_lower.get()); + CopyBytes<kCLMulNum * sizeof(T)>(kCLMulUpper, expected_upper.get()); + const size_t padding_size = (padded - kCLMulNum) * sizeof(T); + memset(expected_lower.get() + kCLMulNum, 0, padding_size); + memset(expected_upper.get() + kCLMulNum, 0, padding_size); + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < kCLMulNum / N; ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = Random64(&rng); + in2[i] = Random64(&rng); + } + + const auto a = Load(d, in1.get()); + const auto b = Load(d, in2.get()); +#if HWY_PRINT_CLMUL_GOLDEN + Store(CLMulLower(a, b), d, expected_lower.get() + rep * N); + Store(CLMulUpper(a, b), d, expected_upper.get() + rep * N); +#else + HWY_ASSERT_VEC_EQ(d, expected_lower.get() + rep * N, CLMulLower(a, b)); + HWY_ASSERT_VEC_EQ(d, expected_upper.get() + rep * N, CLMulUpper(a, b)); +#endif + } + +#if HWY_PRINT_CLMUL_GOLDEN + // RVV lacks PRIu64, so print 32-bit halves. + for (size_t i = 0; i < kCLMulNum; ++i) { + printf("0x%08x%08xULL,", static_cast<uint32_t>(expected_lower[i] >> 32), + static_cast<uint32_t>(expected_lower[i] & 0xFFFFFFFFU)); + } + printf("\n"); + for (size_t i = 0; i < kCLMulNum; ++i) { + printf("0x%08x%08xULL,", static_cast<uint32_t>(expected_upper[i] >> 32), + static_cast<uint32_t>(expected_upper[i] & 0xFFFFFFFFU)); + } +#endif // HWY_PRINT_CLMUL_GOLDEN +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllCLMul() { ForGEVectors<128, TestCLMul>()(uint64_t()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyCryptoTest); +HWY_EXPORT_AND_TEST_P(HwyCryptoTest, TestAllAES); +HWY_EXPORT_AND_TEST_P(HwyCryptoTest, TestAllCLMul); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/demote_test.cc b/third_party/highway/hwy/tests/demote_test.cc new file mode 100644 index 0000000000..22469113d5 --- /dev/null +++ b/third_party/highway/hwy/tests/demote_test.cc @@ -0,0 +1,328 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#include <cmath> // std::isfinite + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/demote_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +// Causes build timeout. +#if !HWY_IS_MSAN + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename T, HWY_IF_FLOAT(T)> +bool IsFiniteT(T t) { + return std::isfinite(t); +} +// Wrapper avoids calling std::isfinite for integer types (ambiguous). +template <typename T, HWY_IF_NOT_FLOAT(T)> +bool IsFiniteT(T /*unused*/) { + return true; +} + +template <typename ToT> +struct TestDemoteTo { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + static_assert(!IsFloat<ToT>(), "Use TestDemoteToFloat for float output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind<ToT, D> to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned<T>(N); + auto expected = AllocateAligned<ToT>(N); + + // Narrower range in the wider type, for clamping before we cast + const T min = LimitsMin<ToT>(); + const T max = LimitsMax<ToT>(); + + const auto value_ok = [&](T& value) { + if (!IsFiniteT(value)) return false; + return true; + }; + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + CopyBytes<sizeof(T)>(&bits, &from[i]); // not same size + } while (!value_ok(from[i])); + expected[i] = static_cast<ToT>(HWY_MIN(HWY_MAX(min, from[i]), max)); + } + + const auto in = Load(from_d, from.get()); + HWY_ASSERT_VEC_EQ(to_d, expected.get(), DemoteTo(to_d, in)); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToInt() { + ForDemoteVectors<TestDemoteTo<uint8_t>>()(int16_t()); + ForDemoteVectors<TestDemoteTo<uint8_t>, 2>()(int32_t()); + + ForDemoteVectors<TestDemoteTo<int8_t>>()(int16_t()); + ForDemoteVectors<TestDemoteTo<int8_t>, 2>()(int32_t()); + + const ForDemoteVectors<TestDemoteTo<uint16_t>> to_u16; + to_u16(int32_t()); + + const ForDemoteVectors<TestDemoteTo<int16_t>> to_i16; + to_i16(int32_t()); +} + +HWY_NOINLINE void TestAllDemoteToMixed() { +#if HWY_HAVE_FLOAT64 + const ForDemoteVectors<TestDemoteTo<int32_t>> to_i32; + to_i32(double()); +#endif +} + +template <typename ToT> +struct TestDemoteToFloat { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D from_d) { + // For floats, we clamp differently and cannot call LimitsMin. + static_assert(IsFloat<ToT>(), "Use TestDemoteTo for integer output"); + static_assert(sizeof(T) > sizeof(ToT), "Input type must be wider"); + const Rebind<ToT, D> to_d; + + const size_t N = Lanes(from_d); + auto from = AllocateAligned<T>(N); + auto expected = AllocateAligned<ToT>(N); + + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + do { + const uint64_t bits = rng(); + CopyBytes<sizeof(T)>(&bits, &from[i]); // not same size + } while (!IsFiniteT(from[i])); + const T magn = std::abs(from[i]); + const T max_abs = HighestValue<ToT>(); + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + const T clipped = copysign(HWY_MIN(magn, max_abs), from[i]); + expected[i] = static_cast<ToT>(clipped); + } + + HWY_ASSERT_VEC_EQ(to_d, expected.get(), + DemoteTo(to_d, Load(from_d, from.get()))); + } + } +}; + +HWY_NOINLINE void TestAllDemoteToFloat() { + // Must test f16 separately because we can only load/store/convert them. + +#if HWY_HAVE_FLOAT64 + const ForDemoteVectors<TestDemoteToFloat<float>, 1> to_float; + to_float(double()); +#endif +} + +template <class D> +AlignedFreeUniquePtr<float[]> ReorderBF16TestCases(D d, size_t& padded) { + const float test_cases[] = { + // Same as BF16TestCases: + // +/- 1 + 1.0f, + -1.0f, + // +/- 0 + 0.0f, + -0.0f, + // near 0 + 0.25f, + -0.25f, + // +/- integer + 4.0f, + -32.0f, + // positive +/- delta + 2.015625f, + 3.984375f, + // negative +/- delta + -2.015625f, + -3.984375f, + + // No huge values - would interfere with sum. But add more to fill 2 * N: + -2.0f, + -10.0f, + 0.03125f, + 1.03125f, + 1.5f, + 2.0f, + 4.0f, + 5.0f, + 6.0f, + 8.0f, + 10.0f, + 256.0f, + 448.0f, + 2080.0f, + }; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, 2 * N); // allow loading pairs of vectors + auto in = AllocateAligned<float>(padded); + auto expected = AllocateAligned<float>(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, 0.0f); + return in; +} + +class TestReorderDemote2To { + // In-place N^2 selection sort to avoid dependencies + void Sort(float* p, size_t count) { + for (size_t i = 0; i < count - 1; ++i) { + // Find min_element + size_t idx_min = i; + for (size_t j = i + 1; j < count; j++) { + if (p[j] < p[idx_min]) { + idx_min = j; + } + } + + // Swap with current + const float tmp = p[i]; + p[i] = p[idx_min]; + p[idx_min] = tmp; + } + } + + public: + template <typename TF32, class DF32> + HWY_NOINLINE void operator()(TF32 /*t*/, DF32 d32) { +#if HWY_TARGET != HWY_SCALAR + size_t padded; + auto in = ReorderBF16TestCases(d32, padded); + + using TBF16 = bfloat16_t; + const Repartition<TBF16, DF32> dbf16; + const Half<decltype(dbf16)> dbf16_half; + const size_t N = Lanes(d32); + auto temp16 = AllocateAligned<TBF16>(2 * N); + auto expected = AllocateAligned<float>(2 * N); + auto actual = AllocateAligned<float>(2 * N); + + for (size_t i = 0; i < padded; i += 2 * N) { + const auto f0 = Load(d32, &in[i + 0]); + const auto f1 = Load(d32, &in[i + N]); + const auto v16 = ReorderDemote2To(dbf16, f0, f1); + Store(v16, dbf16, temp16.get()); + const auto promoted0 = PromoteTo(d32, Load(dbf16_half, temp16.get() + 0)); + const auto promoted1 = PromoteTo(d32, Load(dbf16_half, temp16.get() + N)); + + // Smoke test: sum should be same (with tolerance for non-associativity) + const auto sum_expected = GetLane(SumOfLanes(d32, Add(f0, f1))); + const auto sum_actual = + GetLane(SumOfLanes(d32, Add(promoted0, promoted1))); + + HWY_ASSERT(sum_expected - 1E-4 <= sum_actual && + sum_actual <= sum_expected + 1E-4); + + // Ensure values are the same after sorting to undo the Reorder + Store(f0, d32, expected.get() + 0); + Store(f1, d32, expected.get() + N); + Store(promoted0, d32, actual.get() + 0); + Store(promoted1, d32, actual.get() + N); + Sort(expected.get(), 2 * N); + Sort(actual.get(), 2 * N); + HWY_ASSERT_VEC_EQ(d32, expected.get() + 0, Load(d32, actual.get() + 0)); + HWY_ASSERT_VEC_EQ(d32, expected.get() + N, Load(d32, actual.get() + N)); + } +#else // HWY_SCALAR + (void)d32; +#endif + } +}; + +HWY_NOINLINE void TestAllReorderDemote2To() { + ForShrinkableVectors<TestReorderDemote2To>()(float()); +} + +struct TestI32F64 { + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF /*unused*/, const DF df) { + using TI = int32_t; + const Rebind<TI, DF> di; + const size_t N = Lanes(df); + + // Integer positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(4)), DemoteTo(di, Iota(df, TF(4.0)))); + + // Integer negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), DemoteTo(di, Iota(df, -TF(N)))); + + // Above positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(2)), DemoteTo(di, Iota(df, TF(2.001)))); + + // Below positive + HWY_ASSERT_VEC_EQ(di, Iota(di, TI(3)), DemoteTo(di, Iota(df, TF(3.9999)))); + + const TF eps = static_cast<TF>(0.0001); + // Above negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N)), + DemoteTo(di, Iota(df, -TF(N + 1) + eps))); + + // Below negative + HWY_ASSERT_VEC_EQ(di, Iota(di, -TI(N + 1)), + DemoteTo(di, Iota(df, -TF(N + 1) - eps))); + + // Huge positive float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMax<TI>()), + DemoteTo(di, Set(df, TF(1E12)))); + + // Huge negative float + HWY_ASSERT_VEC_EQ(di, Set(di, LimitsMin<TI>()), + DemoteTo(di, Set(df, TF(-1E12)))); + } +}; + +HWY_NOINLINE void TestAllI32F64() { +#if HWY_HAVE_FLOAT64 + ForDemoteVectors<TestI32F64>()(double()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // !HWY_IS_MSAN + +#if HWY_ONCE + +namespace hwy { +#if !HWY_IS_MSAN +HWY_BEFORE_TEST(HwyDemoteTest); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToInt); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToMixed); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllDemoteToFloat); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllReorderDemote2To); +HWY_EXPORT_AND_TEST_P(HwyDemoteTest, TestAllI32F64); +#endif // !HWY_IS_MSAN +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/float_test.cc b/third_party/highway/hwy/tests/float_test.cc new file mode 100644 index 0000000000..bc6d9020e6 --- /dev/null +++ b/third_party/highway/hwy/tests/float_test.cc @@ -0,0 +1,350 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests some ops specific to floating-point types (Div, Round etc.) + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> // std::copy, std::fill +#include <limits> +#include <cmath> // std::abs, std::isnan, std::isinf, std::ceil, std::floor + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/float_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestDiv { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(-2)); + const auto v1 = Set(d, T(1)); + + // Unchanged after division by 1. + HWY_ASSERT_VEC_EQ(d, v, Div(v, v1)); + + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = (T(i) - 2) / T(2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Div(v, Set(d, T(2)))); + } +}; + +HWY_NOINLINE void TestAllDiv() { ForFloatTypes(ForPartialVectors<TestDiv>()); } + +struct TestApproximateReciprocal { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(-2)); + const auto nonzero = IfThenElse(Eq(v, Zero(d)), Set(d, T(1)), v); + const size_t N = Lanes(d); + auto input = AllocateAligned<T>(N); + Store(nonzero, d, input.get()); + + auto actual = AllocateAligned<T>(N); + Store(ApproximateReciprocal(nonzero), d, actual.get()); + + double max_l1 = 0.0; + double worst_expected = 0.0; + double worst_actual = 0.0; + for (size_t i = 0; i < N; ++i) { + const double expected = 1.0 / input[i]; + const double l1 = std::abs(expected - actual[i]); + if (l1 > max_l1) { + max_l1 = l1; + worst_expected = expected; + worst_actual = actual[i]; + } + } + const double abs_worst_expected = std::abs(worst_expected); + if (abs_worst_expected > 1E-5) { + const double max_rel = max_l1 / abs_worst_expected; + fprintf(stderr, "max l1 %f rel %f (%f vs %f)\n", max_l1, max_rel, + worst_expected, worst_actual); + HWY_ASSERT(max_rel < 0.004); + } + } +}; + +HWY_NOINLINE void TestAllApproximateReciprocal() { + ForPartialVectors<TestApproximateReciprocal>()(float()); +} + +struct TestSquareRoot { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto vi = Iota(d, 0); + HWY_ASSERT_VEC_EQ(d, vi, Sqrt(Mul(vi, vi))); + } +}; + +HWY_NOINLINE void TestAllSquareRoot() { + ForFloatTypes(ForPartialVectors<TestSquareRoot>()); +} + +struct TestReciprocalSquareRoot { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Set(d, 123.0f); + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + Store(ApproximateReciprocalSqrt(v), d, lanes.get()); + for (size_t i = 0; i < N; ++i) { + float err = lanes[i] - 0.090166f; + if (err < 0.0f) err = -err; + if (err >= 4E-4f) { + HWY_ABORT("Lane %d (%d): actual %f err %f\n", static_cast<int>(i), + static_cast<int>(N), lanes[i], err); + } + } + } +}; + +HWY_NOINLINE void TestAllReciprocalSquareRoot() { + ForPartialVectors<TestReciprocalSquareRoot>()(float()); +} + +template <typename T, class D> +AlignedFreeUniquePtr<T[]> RoundTestCases(T /*unused*/, D d, size_t& padded) { + const T eps = std::numeric_limits<T>::epsilon(); + const T test_cases[] = { + // +/- 1 + T(1), + T(-1), + // +/- 0 + T(0), + T(-0), + // near 0 + T(0.4), + T(-0.4), + // +/- integer + T(4), + T(-32), + // positive near limit + MantissaEnd<T>() - T(1.5), + MantissaEnd<T>() + T(1.5), + // negative near limit + -MantissaEnd<T>() - T(1.5), + -MantissaEnd<T>() + T(1.5), + // positive tiebreak + T(1.5), + T(2.5), + // negative tiebreak + T(-1.5), + T(-2.5), + // positive +/- delta + T(2.0001), + T(3.9999), + // negative +/- delta + T(-999.9999), + T(-998.0001), + // positive +/- epsilon + T(1) + eps, + T(1) - eps, + // negative +/- epsilon + T(-1) + eps, + T(-1) - eps, + // +/- huge (but still fits in float) + T(1E34), + T(-1E35), + // +/- infinity + std::numeric_limits<T>::infinity(), + -std::numeric_limits<T>::infinity(), + // qNaN + GetLane(NaN(d)) + }; + const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]); + const size_t N = Lanes(d); + padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors + auto in = AllocateAligned<T>(padded); + auto expected = AllocateAligned<T>(padded); + std::copy(test_cases, test_cases + kNumTestCases, in.get()); + std::fill(in.get() + kNumTestCases, in.get() + padded, T(0)); + return in; +} + +struct TestRound { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned<T>(padded); + + for (size_t i = 0; i < padded; ++i) { + // Avoid [std::]round, which does not round to nearest *even*. + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + expected[i] = static_cast<T>(nearbyint(in[i])); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Round(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllRound() { + ForFloatTypes(ForPartialVectors<TestRound>()); +} + +struct TestNearestInt { + template <typename TF, class DF> + HWY_NOINLINE void operator()(TF tf, const DF df) { + using TI = MakeSigned<TF>; + const RebindToSigned<DF> di; + + size_t padded; + auto in = RoundTestCases(tf, df, padded); + auto expected = AllocateAligned<TI>(padded); + + constexpr double max = static_cast<double>(LimitsMax<TI>()); + for (size_t i = 0; i < padded; ++i) { + if (std::isnan(in[i])) { + // We replace NaN with 0 below (no_nan) + expected[i] = 0; + } else if (std::isinf(in[i]) || double{std::abs(in[i])} >= max) { + // Avoid undefined result for lrintf + expected[i] = std::signbit(in[i]) ? LimitsMin<TI>() : LimitsMax<TI>(); + } else { + expected[i] = static_cast<TI>(lrintf(in[i])); + } + } + for (size_t i = 0; i < padded; i += Lanes(df)) { + const auto v = Load(df, &in[i]); + const auto no_nan = IfThenElse(Eq(v, v), v, Zero(df)); + HWY_ASSERT_VEC_EQ(di, &expected[i], NearestInt(no_nan)); + } + } +}; + +HWY_NOINLINE void TestAllNearestInt() { + ForPartialVectors<TestNearestInt>()(float()); +} + +struct TestTrunc { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned<T>(padded); + + for (size_t i = 0; i < padded; ++i) { + // NOTE: std:: version from C++11 cmath is not defined in RVV GCC, see + // https://lists.freebsd.org/pipermail/freebsd-current/2014-January/048130.html + expected[i] = static_cast<T>(trunc(in[i])); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Trunc(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllTrunc() { + ForFloatTypes(ForPartialVectors<TestTrunc>()); +} + +struct TestCeil { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned<T>(padded); + + for (size_t i = 0; i < padded; ++i) { + expected[i] = std::ceil(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Ceil(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllCeil() { + ForFloatTypes(ForPartialVectors<TestCeil>()); +} + +struct TestFloor { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + size_t padded; + auto in = RoundTestCases(t, d, padded); + auto expected = AllocateAligned<T>(padded); + + for (size_t i = 0; i < padded; ++i) { + expected[i] = std::floor(in[i]); + } + for (size_t i = 0; i < padded; i += Lanes(d)) { + HWY_ASSERT_VEC_EQ(d, &expected[i], Floor(Load(d, &in[i]))); + } + } +}; + +HWY_NOINLINE void TestAllFloor() { + ForFloatTypes(ForPartialVectors<TestFloor>()); +} + +struct TestAbsDiff { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes_a = AllocateAligned<T>(N); + auto in_lanes_b = AllocateAligned<T>(N); + auto out_lanes = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + in_lanes_a[i] = static_cast<T>((i ^ 1u) << i); + in_lanes_b[i] = static_cast<T>(i << i); + out_lanes[i] = std::abs(in_lanes_a[i] - in_lanes_b[i]); + } + const auto a = Load(d, in_lanes_a.get()); + const auto b = Load(d, in_lanes_b.get()); + const auto expected = Load(d, out_lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected, AbsDiff(a, b)); + HWY_ASSERT_VEC_EQ(d, expected, AbsDiff(b, a)); + } +}; + +HWY_NOINLINE void TestAllAbsDiff() { + ForPartialVectors<TestAbsDiff>()(float()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyFloatTest); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllDiv); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllApproximateReciprocal); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllSquareRoot); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllReciprocalSquareRoot); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllRound); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllNearestInt); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllTrunc); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllCeil); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllFloor); +HWY_EXPORT_AND_TEST_P(HwyFloatTest, TestAllAbsDiff); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/hwy_gtest.h b/third_party/highway/hwy/tests/hwy_gtest.h new file mode 100644 index 0000000000..a4c21cd171 --- /dev/null +++ b/third_party/highway/hwy/tests/hwy_gtest.h @@ -0,0 +1,157 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_TESTS_HWY_GTEST_H_ +#define HWY_TESTS_HWY_GTEST_H_ + +// Adapters for GUnit to run tests for all targets. + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <tuple> + +#include "gtest/gtest.h" +#include "hwy/highway.h" + +namespace hwy { + +// googletest before 1.10 didn't define INSTANTIATE_TEST_SUITE_P() but instead +// used INSTANTIATE_TEST_CASE_P which is now deprecated. +#ifdef INSTANTIATE_TEST_SUITE_P +#define HWY_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_SUITE_P +#else +#define HWY_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P +#endif + +// Helper class to run parametric tests using the hwy target as parameter. To +// use this define the following in your test: +// class MyTestSuite : public TestWithParamTarget { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P(MyTestSuite); +// TEST_P(MyTestSuite, MyTest) { ... } +class TestWithParamTarget : public testing::TestWithParam<int64_t> { + protected: + void SetUp() override { SetSupportedTargetsForTest(GetParam()); } + + void TearDown() override { + // Check that the parametric test calls SupportedTargets() when the source + // was compiled with more than one target. In the single-target case only + // static dispatch will be used anyway. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) != 0 + EXPECT_TRUE(GetChosenTarget().IsInitialized()) + << "This hwy target parametric test doesn't use dynamic-dispatch and " + "doesn't need to be parametric."; +#endif + SetSupportedTargetsForTest(0); + } +}; + +// Function to convert the test parameter of a TestWithParamTarget for +// displaying it in the gtest test name. +static inline std::string TestParamTargetName( + const testing::TestParamInfo<int64_t>& info) { + return TargetName(info.param); +} + +#define HWY_TARGET_INSTANTIATE_TEST_SUITE_P(suite) \ + HWY_GTEST_INSTANTIATE_TEST_SUITE_P( \ + suite##Group, suite, \ + testing::ValuesIn(::hwy::SupportedAndGeneratedTargets()), \ + ::hwy::TestParamTargetName) + +// Helper class similar to TestWithParamTarget to run parametric tests that +// depend on the target and another parametric test. If you need to use multiple +// extra parameters use a std::tuple<> of them and ::testing::Generate(...) as +// the generator. To use this class define the following in your test: +// class MyTestSuite : public TestWithParamTargetT<int> { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(MyTestSuite, ::testing::Range(0, 9)); +// TEST_P(MyTestSuite, MyTest) { ... GetParam() .... } +template <typename T> +class TestWithParamTargetAndT + : public ::testing::TestWithParam<std::tuple<int64_t, T>> { + public: + // Expose the parametric type here so it can be used by the + // HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T macro. + using HwyParamType = T; + + protected: + void SetUp() override { + SetSupportedTargetsForTest(std::get<0>( + ::testing::TestWithParam<std::tuple<int64_t, T>>::GetParam())); + } + + void TearDown() override { + // Check that the parametric test calls SupportedTargets() when the source + // was compiled with more than one target. In the single-target case only + // static dispatch will be used anyway. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) != 0 + EXPECT_TRUE(GetChosenTarget().IsInitialized()) + << "This hwy target parametric test doesn't use dynamic-dispatch and " + "doesn't need to be parametric."; +#endif + SetSupportedTargetsForTest(0); + } + + T GetParam() { + return std::get<1>( + ::testing::TestWithParam<std::tuple<int64_t, T>>::GetParam()); + } +}; + +template <typename T> +std::string TestParamTargetNameAndT( + const testing::TestParamInfo<std::tuple<int64_t, T>>& info) { + return std::string(TargetName(std::get<0>(info.param))) + "_" + + ::testing::PrintToString(std::get<1>(info.param)); +} + +#define HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(suite, generator) \ + HWY_GTEST_INSTANTIATE_TEST_SUITE_P( \ + suite##Group, suite, \ + ::testing::Combine( \ + testing::ValuesIn(::hwy::SupportedAndGeneratedTargets()), \ + generator), \ + ::hwy::TestParamTargetNameAndT<suite::HwyParamType>) + +// Helper macro to export a function and define a test that tests it. This is +// equivalent to do a HWY_EXPORT of a void(void) function and run it in a test: +// class MyTestSuite : public TestWithParamTarget { +// ... +// }; +// HWY_TARGET_INSTANTIATE_TEST_SUITE_P(MyTestSuite); +// HWY_EXPORT_AND_TEST_P(MyTestSuite, MyTest); +#define HWY_EXPORT_AND_TEST_P(suite, func_name) \ + HWY_EXPORT(func_name); \ + TEST_P(suite, func_name) { HWY_DYNAMIC_DISPATCH(func_name)(); } \ + static_assert(true, "For requiring trailing semicolon") + +#define HWY_EXPORT_AND_TEST_P_T(suite, func_name) \ + HWY_EXPORT(func_name); \ + TEST_P(suite, func_name) { HWY_DYNAMIC_DISPATCH(func_name)(GetParam()); } \ + static_assert(true, "For requiring trailing semicolon") + +#define HWY_BEFORE_TEST(suite) \ + class suite : public hwy::TestWithParamTarget {}; \ + HWY_TARGET_INSTANTIATE_TEST_SUITE_P(suite); \ + static_assert(true, "For requiring trailing semicolon") + +} // namespace hwy + +#endif // HWY_TESTS_HWY_GTEST_H_ diff --git a/third_party/highway/hwy/tests/if_test.cc b/third_party/highway/hwy/tests/if_test.cc new file mode 100644 index 0000000000..e44a878a0c --- /dev/null +++ b/third_party/highway/hwy/tests/if_test.cc @@ -0,0 +1,175 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/if_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestIfThenElse { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(d); + auto in1 = AllocateAligned<T>(N); + auto in2 = AllocateAligned<T>(N); + auto bool_lanes = AllocateAligned<TI>(N); + auto expected = AllocateAligned<T>(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast<T>(Random32(&rng)); + in2[i] = static_cast<T>(Random32(&rng)); + bool_lanes[i] = (Random32(&rng) & 16) ? TI(1) : TI(0); + } + + const auto v1 = Load(d, in1.get()); + const auto v2 = Load(d, in2.get()); + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + + for (size_t i = 0; i < N; ++i) { + expected[i] = bool_lanes[i] ? in1[i] : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenElse(mask, v1, v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = bool_lanes[i] ? in1[i] : T(0); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenElseZero(mask, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = bool_lanes[i] ? T(0) : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfThenZeroElse(mask, v2)); + } + } +}; + +HWY_NOINLINE void TestAllIfThenElse() { + ForAllTypes(ForPartialVectors<TestIfThenElse>()); +} + +struct TestIfVecThenElse { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TU = MakeUnsigned<T>; // For all-one mask + const Rebind<TU, D> du; + const size_t N = Lanes(d); + auto in1 = AllocateAligned<T>(N); + auto in2 = AllocateAligned<T>(N); + auto vec_lanes = AllocateAligned<TU>(N); + auto expected = AllocateAligned<T>(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast<T>(Random32(&rng)); + in2[i] = static_cast<T>(Random32(&rng)); + vec_lanes[i] = (Random32(&rng) & 16) ? static_cast<TU>(~TU(0)) : TU(0); + } + + const auto v1 = Load(d, in1.get()); + const auto v2 = Load(d, in2.get()); + const auto vec = BitCast(d, Load(du, vec_lanes.get())); + + for (size_t i = 0; i < N; ++i) { + expected[i] = vec_lanes[i] ? in1[i] : in2[i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), IfVecThenElse(vec, v1, v2)); + } + } +}; + +HWY_NOINLINE void TestAllIfVecThenElse() { + ForAllTypes(ForPartialVectors<TestIfVecThenElse>()); +} + +struct TestZeroIfNegative { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Iota(d, T(-1E5)); // assumes N < 10^5 + + // Zero and positive remain unchanged + HWY_ASSERT_VEC_EQ(d, v0, ZeroIfNegative(v0)); + HWY_ASSERT_VEC_EQ(d, vp, ZeroIfNegative(vp)); + + // Negative are all replaced with zero + HWY_ASSERT_VEC_EQ(d, v0, ZeroIfNegative(vn)); + } +}; + +HWY_NOINLINE void TestAllZeroIfNegative() { + ForFloatTypes(ForPartialVectors<TestZeroIfNegative>()); +} + +struct TestIfNegative { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Or(vp, SignBit(d)); + + // Zero and positive remain unchanged + HWY_ASSERT_VEC_EQ(d, v0, IfNegativeThenElse(v0, vn, v0)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(v0, v0, vn)); + HWY_ASSERT_VEC_EQ(d, vp, IfNegativeThenElse(vp, vn, vp)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(vp, vp, vn)); + + // Negative are replaced with 2nd arg + HWY_ASSERT_VEC_EQ(d, v0, IfNegativeThenElse(vn, v0, vp)); + HWY_ASSERT_VEC_EQ(d, vn, IfNegativeThenElse(vn, vn, v0)); + HWY_ASSERT_VEC_EQ(d, vp, IfNegativeThenElse(vn, vp, vn)); + } +}; + +HWY_NOINLINE void TestAllIfNegative() { + ForFloatTypes(ForPartialVectors<TestIfNegative>()); + ForSignedTypes(ForPartialVectors<TestIfNegative>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyIfTest); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllIfThenElse); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllIfVecThenElse); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllZeroIfNegative); +HWY_EXPORT_AND_TEST_P(HwyIfTest, TestAllIfNegative); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/interleaved_test.cc b/third_party/highway/hwy/tests/interleaved_test.cc new file mode 100644 index 0000000000..4d1fbd5ac5 --- /dev/null +++ b/third_party/highway/hwy/tests/interleaved_test.cc @@ -0,0 +1,256 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/interleaved_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLoadStoreInterleaved2 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned<T>(2 * N); + for (size_t i = 0; i < 2 * N; ++i) { + bytes[i] = static_cast<T>(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned<T>(3 * N); + auto actual_aligned = AllocateAligned<T>(3 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[2 * i + 0] = bytes[0 * N + i]; + expected[2 * i + 1] = bytes[1 * N + i]; + // Ensure we do not write more than 2*N bytes + expected[2 * N + i] = actual[2 * N + i] = 0; + } + StoreInterleaved2(in0, in1, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 3 * N * sizeof(T), &pos)) { + Print(d, "in0", in0, pos / 4); + Print(d, "in1", in1, pos / 4); + const size_t i = pos; + fprintf(stderr, "interleaved i=%d %f %f %f %f %f %f %f %f\n", + static_cast<int>(i), static_cast<double>(actual[i]), + static_cast<double>(actual[i + 1]), + static_cast<double>(actual[i + 2]), + static_cast<double>(actual[i + 3]), + static_cast<double>(actual[i + 4]), + static_cast<double>(actual[i + 5]), + static_cast<double>(actual[i + 6]), + static_cast<double>(actual[i + 7])); + HWY_ASSERT(false); + } + + Vec<D> out0, out1; + LoadInterleaved2(d, actual, out0, out1); + HWY_ASSERT_VEC_EQ(d, in0, out0); + HWY_ASSERT_VEC_EQ(d, in1, out1); + } + } +}; + +HWY_NOINLINE void TestAllLoadStoreInterleaved2() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors<TestLoadStoreInterleaved2, 2> test; +#else + const ForPartialVectors<TestLoadStoreInterleaved2> test; +#endif + ForAllTypes(test); +} + +// Workaround for build timeout on GCC 12 aarch64, see #776 +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && HWY_ARCH_ARM_A64 +#define HWY_BROKEN_LOAD34 1 +#else +#define HWY_BROKEN_LOAD34 0 +#endif + +#if !HWY_BROKEN_LOAD34 + +struct TestLoadStoreInterleaved3 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned<T>(3 * N); + for (size_t i = 0; i < 3 * N; ++i) { + bytes[i] = static_cast<T>(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + const auto in2 = Load(d, &bytes[2 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned<T>(4 * N); + auto actual_aligned = AllocateAligned<T>(4 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[3 * i + 0] = bytes[0 * N + i]; + expected[3 * i + 1] = bytes[1 * N + i]; + expected[3 * i + 2] = bytes[2 * N + i]; + // Ensure we do not write more than 3*N bytes + expected[3 * N + i] = actual[3 * N + i] = 0; + } + StoreInterleaved3(in0, in1, in2, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 4 * N * sizeof(T), &pos)) { + Print(d, "in0", in0, pos / 3, N); + Print(d, "in1", in1, pos / 3, N); + Print(d, "in2", in2, pos / 3, N); + const size_t i = pos; + fprintf(stderr, "interleaved i=%d %f %f %f %f %f %f\n", + static_cast<int>(i), static_cast<double>(actual[i]), + static_cast<double>(actual[i + 1]), + static_cast<double>(actual[i + 2]), + static_cast<double>(actual[i + 3]), + static_cast<double>(actual[i + 4]), + static_cast<double>(actual[i + 5])); + HWY_ASSERT(false); + } + + Vec<D> out0, out1, out2; + LoadInterleaved3(d, actual, out0, out1, out2); + HWY_ASSERT_VEC_EQ(d, in0, out0); + HWY_ASSERT_VEC_EQ(d, in1, out1); + HWY_ASSERT_VEC_EQ(d, in2, out2); + } + } +}; + +HWY_NOINLINE void TestAllLoadStoreInterleaved3() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors<TestLoadStoreInterleaved3, 2> test; +#else + const ForPartialVectors<TestLoadStoreInterleaved3> test; +#endif + ForAllTypes(test); +} + +struct TestLoadStoreInterleaved4 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + + RandomState rng; + + // Data to be interleaved + auto bytes = AllocateAligned<T>(4 * N); + + for (size_t i = 0; i < 4 * N; ++i) { + bytes[i] = static_cast<T>(Random32(&rng) & 0xFF); + } + const auto in0 = Load(d, &bytes[0 * N]); + const auto in1 = Load(d, &bytes[1 * N]); + const auto in2 = Load(d, &bytes[2 * N]); + const auto in3 = Load(d, &bytes[3 * N]); + + // Interleave here, ensure vector results match scalar + auto expected = AllocateAligned<T>(5 * N); + auto actual_aligned = AllocateAligned<T>(5 * N + 1); + T* actual = actual_aligned.get() + 1; + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + expected[4 * i + 0] = bytes[0 * N + i]; + expected[4 * i + 1] = bytes[1 * N + i]; + expected[4 * i + 2] = bytes[2 * N + i]; + expected[4 * i + 3] = bytes[3 * N + i]; + // Ensure we do not write more than 4*N bytes + expected[4 * N + i] = actual[4 * N + i] = 0; + } + StoreInterleaved4(in0, in1, in2, in3, d, actual); + size_t pos = 0; + if (!BytesEqual(expected.get(), actual, 5 * N * sizeof(T), &pos)) { + Print(d, "in0", in0, pos / 4); + Print(d, "in1", in1, pos / 4); + Print(d, "in2", in2, pos / 4); + Print(d, "in3", in3, pos / 4); + const size_t i = pos; + fprintf(stderr, "interleaved i=%d %f %f %f %f %f %f %f %f\n", + static_cast<int>(i), static_cast<double>(actual[i]), + static_cast<double>(actual[i + 1]), + static_cast<double>(actual[i + 2]), + static_cast<double>(actual[i + 3]), + static_cast<double>(actual[i + 4]), + static_cast<double>(actual[i + 5]), + static_cast<double>(actual[i + 6]), + static_cast<double>(actual[i + 7])); + HWY_ASSERT(false); + } + + Vec<D> out0, out1, out2, out3; + LoadInterleaved4(d, actual, out0, out1, out2, out3); + HWY_ASSERT_VEC_EQ(d, in0, out0); + HWY_ASSERT_VEC_EQ(d, in1, out1); + HWY_ASSERT_VEC_EQ(d, in2, out2); + HWY_ASSERT_VEC_EQ(d, in3, out3); + } + } +}; + +HWY_NOINLINE void TestAllLoadStoreInterleaved4() { +#if HWY_TARGET == HWY_RVV + // Segments are limited to 8 registers, so we can only go up to LMUL=2. + const ForExtendableVectors<TestLoadStoreInterleaved4, 2> test; +#else + const ForPartialVectors<TestLoadStoreInterleaved4> test; +#endif + ForAllTypes(test); +} + +#endif // !HWY_BROKEN_LOAD34 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyInterleavedTest); +HWY_EXPORT_AND_TEST_P(HwyInterleavedTest, TestAllLoadStoreInterleaved2); +#if !HWY_BROKEN_LOAD34 +HWY_EXPORT_AND_TEST_P(HwyInterleavedTest, TestAllLoadStoreInterleaved3); +HWY_EXPORT_AND_TEST_P(HwyInterleavedTest, TestAllLoadStoreInterleaved4); +#endif +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/list_targets.cc b/third_party/highway/hwy/tests/list_targets.cc new file mode 100644 index 0000000000..d09ee4fe86 --- /dev/null +++ b/third_party/highway/hwy/tests/list_targets.cc @@ -0,0 +1,71 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Simple tool to print the list of targets that were compiled in when building +// this tool. + +#include <stdio.h> + +#include "hwy/highway.h" + +void PrintTargets(const char* msg, int64_t targets) { + fprintf(stderr, "%s", msg); + // For each bit: + for (int64_t x = targets; x != 0; x = x & (x - 1)) { + // Extract value of least-significant bit. + fprintf(stderr, " %s", hwy::TargetName(x & (~x + 1))); + } + fprintf(stderr, "\n"); +} + +int main() { +#ifdef HWY_COMPILE_ONLY_EMU128 + const int only_emu128 = 1; +#else + const int only_emu128 = 0; +#endif +#ifdef HWY_COMPILE_ONLY_SCALAR + const int only_scalar = 1; +#else + const int only_scalar = 0; +#endif +#ifdef HWY_COMPILE_ONLY_STATIC + const int only_static = 1; +#else + const int only_static = 0; +#endif +#ifdef HWY_COMPILE_ALL_ATTAINABLE + const int all_attain = 1; +#else + const int all_attain = 0; +#endif +#ifdef HWY_IS_TEST + const int is_test = 1; +#else + const int is_test = 0; +#endif + + fprintf(stderr, + "Config: emu128:%d scalar:%d static:%d all_attain:%d is_test:%d\n", + only_emu128, only_scalar, only_static, all_attain, is_test); + PrintTargets("Compiled HWY_TARGETS: ", HWY_TARGETS); + PrintTargets("HWY_ATTAINABLE_TARGETS:", HWY_ATTAINABLE_TARGETS); + PrintTargets("HWY_BASELINE_TARGETS: ", HWY_BASELINE_TARGETS); + PrintTargets("HWY_STATIC_TARGET: ", HWY_STATIC_TARGET); + PrintTargets("HWY_BROKEN_TARGETS: ", HWY_BROKEN_TARGETS); + PrintTargets("HWY_DISABLED_TARGETS: ", HWY_DISABLED_TARGETS); + PrintTargets("Current CPU supports: ", hwy::SupportedTargets()); + return 0; +} diff --git a/third_party/highway/hwy/tests/logical_test.cc b/third_party/highway/hwy/tests/logical_test.cc new file mode 100644 index 0000000000..b646f5ff4b --- /dev/null +++ b/third_party/highway/hwy/tests/logical_test.cc @@ -0,0 +1,246 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcmp + +#include "hwy/aligned_allocator.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/logical_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestNot { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto ones = VecFromMask(d, Eq(v0, v0)); + const auto v1 = Set(d, 1); + const auto vnot1 = Set(d, T(~T(1))); + + HWY_ASSERT_VEC_EQ(d, v0, Not(ones)); + HWY_ASSERT_VEC_EQ(d, ones, Not(v0)); + HWY_ASSERT_VEC_EQ(d, v1, Not(vnot1)); + HWY_ASSERT_VEC_EQ(d, vnot1, Not(v1)); + } +}; + +HWY_NOINLINE void TestAllNot() { + ForIntegerTypes(ForPartialVectors<TestNot>()); +} + +struct TestLogical { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vi = Iota(d, 0); + + auto v = vi; + v = And(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = And(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + + v = Or(v, vi); + HWY_ASSERT_VEC_EQ(d, vi, v); + v = Or(v, v0); + HWY_ASSERT_VEC_EQ(d, vi, v); + + v = Xor(v, vi); + HWY_ASSERT_VEC_EQ(d, v0, v); + v = Xor(v, v0); + HWY_ASSERT_VEC_EQ(d, v0, v); + + HWY_ASSERT_VEC_EQ(d, v0, And(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, And(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, And(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Or(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, Xor(v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, vi, AndNot(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, AndNot(vi, vi)); + + HWY_ASSERT_VEC_EQ(d, v0, Or3(v0, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(v0, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(v0, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(v0, vi, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Or3(vi, vi, vi)); + + HWY_ASSERT_VEC_EQ(d, v0, Xor3(v0, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Xor3(v0, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, Xor3(v0, v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, Xor3(v0, vi, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor3(vi, v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor3(vi, vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, Xor3(vi, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Xor3(vi, vi, vi)); + + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, vi, v0)); + HWY_ASSERT_VEC_EQ(d, v0, OrAnd(v0, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(v0, vi, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, v0, v0)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, vi, v0)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, v0, vi)); + HWY_ASSERT_VEC_EQ(d, vi, OrAnd(vi, vi, vi)); + } +}; + +HWY_NOINLINE void TestAllLogical() { + ForAllTypes(ForPartialVectors<TestLogical>()); +} + +struct TestCopySign { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto vp = Iota(d, 1); + const auto vn = Iota(d, T(-1E5)); // assumes N < 10^5 + + // Zero remains zero regardless of sign + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, vp)); + HWY_ASSERT_VEC_EQ(d, v0, CopySign(v0, vn)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, vp)); + HWY_ASSERT_VEC_EQ(d, v0, CopySignToAbs(v0, vn)); + + // Positive input, positive sign => unchanged + HWY_ASSERT_VEC_EQ(d, vp, CopySign(vp, vp)); + HWY_ASSERT_VEC_EQ(d, vp, CopySignToAbs(vp, vp)); + + // Positive input, negative sign => negated + HWY_ASSERT_VEC_EQ(d, Neg(vp), CopySign(vp, vn)); + HWY_ASSERT_VEC_EQ(d, Neg(vp), CopySignToAbs(vp, vn)); + + // Negative input, negative sign => unchanged + HWY_ASSERT_VEC_EQ(d, vn, CopySign(vn, vn)); + + // Negative input, positive sign => negated + HWY_ASSERT_VEC_EQ(d, Neg(vn), CopySign(vn, vp)); + } +}; + +HWY_NOINLINE void TestAllCopySign() { + ForFloatTypes(ForPartialVectors<TestCopySign>()); +} + +struct TestBroadcastSignBit { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto s0 = Zero(d); + const auto s1 = Set(d, -1); // all bit set + const auto vpos = And(Iota(d, 0), Set(d, LimitsMax<T>())); + const auto vneg = Sub(s1, vpos); + + HWY_ASSERT_VEC_EQ(d, s0, BroadcastSignBit(vpos)); + HWY_ASSERT_VEC_EQ(d, s0, BroadcastSignBit(Set(d, LimitsMax<T>()))); + + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(vneg)); + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(Set(d, LimitsMin<T>()))); + HWY_ASSERT_VEC_EQ(d, s1, BroadcastSignBit(Set(d, LimitsMin<T>() / 2))); + } +}; + +HWY_NOINLINE void TestAllBroadcastSignBit() { + ForSignedTypes(ForPartialVectors<TestBroadcastSignBit>()); +} + +struct TestTestBit { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t kNumBits = sizeof(T) * 8; + for (size_t i = 0; i < kNumBits; ++i) { + const auto bit1 = Set(d, T(1ull << i)); + const auto bit2 = Set(d, T(1ull << ((i + 1) % kNumBits))); + const auto bit3 = Set(d, T(1ull << ((i + 2) % kNumBits))); + const auto bits12 = Or(bit1, bit2); + const auto bits23 = Or(bit2, bit3); + HWY_ASSERT(AllTrue(d, TestBit(bit1, bit1))); + HWY_ASSERT(AllTrue(d, TestBit(bits12, bit1))); + HWY_ASSERT(AllTrue(d, TestBit(bits12, bit2))); + + HWY_ASSERT(AllFalse(d, TestBit(bits12, bit3))); + HWY_ASSERT(AllFalse(d, TestBit(bits23, bit1))); + HWY_ASSERT(AllFalse(d, TestBit(bit1, bit2))); + HWY_ASSERT(AllFalse(d, TestBit(bit2, bit1))); + HWY_ASSERT(AllFalse(d, TestBit(bit1, bit3))); + HWY_ASSERT(AllFalse(d, TestBit(bit3, bit1))); + HWY_ASSERT(AllFalse(d, TestBit(bit2, bit3))); + HWY_ASSERT(AllFalse(d, TestBit(bit3, bit2))); + } + } +}; + +HWY_NOINLINE void TestAllTestBit() { + ForIntegerTypes(ForPartialVectors<TestTestBit>()); +} + +struct TestPopulationCount { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + size_t N = Lanes(d); + auto data = AllocateAligned<T>(N); + auto popcnt = AllocateAligned<T>(N); + for (size_t i = 0; i < AdjustedReps(1 << 18) / N; i++) { + for (size_t i = 0; i < N; i++) { + data[i] = static_cast<T>(rng()); + popcnt[i] = static_cast<T>(PopCount(data[i])); + } + HWY_ASSERT_VEC_EQ(d, popcnt.get(), PopulationCount(Load(d, data.get()))); + } + } +}; + +HWY_NOINLINE void TestAllPopulationCount() { + ForUnsignedTypes(ForPartialVectors<TestPopulationCount>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyLogicalTest); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllCopySign); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllBroadcastSignBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit); +HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllPopulationCount); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/mask_mem_test.cc b/third_party/highway/hwy/tests/mask_mem_test.cc new file mode 100644 index 0000000000..c44119dcd7 --- /dev/null +++ b/third_party/highway/hwy/tests/mask_mem_test.cc @@ -0,0 +1,197 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include <inttypes.h> +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcmp + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/mask_mem_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestMaskedLoad { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned<TI>(N); + + auto lanes = AllocateAligned<T>(N); + Store(Iota(d, T{1}), d, lanes.get()); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + const auto expected = IfThenElseZero(mask, Load(d, lanes.get())); + const auto actual = MaskedLoad(mask, d, lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected, actual); + } + } +}; + +HWY_NOINLINE void TestAllMaskedLoad() { + ForAllTypes(ForPartialVectors<TestMaskedLoad>()); +} + +struct TestBlendedStore { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned<TI>(N); + + const Vec<D> v = Iota(d, T{1}); + auto actual = AllocateAligned<T>(N); + auto expected = AllocateAligned<T>(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + // Re-initialize to something distinct from v[i]. + actual[i] = static_cast<T>(127 - (i & 127)); + expected[i] = bool_lanes[i] ? static_cast<T>(i + 1) : actual[i]; + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + BlendedStore(v, mask, d, actual.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), Load(d, actual.get())); + } + } +}; + +HWY_NOINLINE void TestAllBlendedStore() { + ForAllTypes(ForPartialVectors<TestBlendedStore>()); +} + +class TestStoreMaskBits { + public: + template <class T, class D> + HWY_NOINLINE void operator()(T /*t*/, D /*d*/) { + RandomState rng; + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned<TI>(N); + + const ScalableTag<uint8_t, -3> d_bits; + const size_t expected_num_bytes = (N + 7) / 8; + auto expected = AllocateAligned<uint8_t>(expected_num_bytes); + auto actual = AllocateAligned<uint8_t>(HWY_MAX(8, expected_num_bytes)); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + // Generate random mask pattern. + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = static_cast<TI>((rng() & 1024) ? 1 : 0); + } + const auto bools = Load(di, bool_lanes.get()); + const auto mask = Gt(bools, Zero(di)); + + // Requires at least 8 bytes, ensured above. + const size_t bytes_written = StoreMaskBits(di, mask, actual.get()); + if (bytes_written != expected_num_bytes) { + fprintf(stderr, "%s expected %" PRIu64 " bytes, actual %" PRIu64 "\n", + TypeName(T(), N).c_str(), + static_cast<uint64_t>(expected_num_bytes), + static_cast<uint64_t>(bytes_written)); + + HWY_ASSERT(false); + } + + // Requires at least 8 bytes, ensured above. + const auto mask2 = LoadMaskBits(di, actual.get()); + HWY_ASSERT_MASK_EQ(di, mask, mask2); + + memset(expected.get(), 0, expected_num_bytes); + for (size_t i = 0; i < N; ++i) { + expected[i / 8] = + static_cast<uint8_t>(expected[i / 8] | (bool_lanes[i] << (i % 8))); + } + + size_t i = 0; + // Stored bits must match original mask + for (; i < N; ++i) { + const TI is_set = (actual[i / 8] & (1 << (i % 8))) ? 1 : 0; + if (is_set != bool_lanes[i]) { + fprintf(stderr, "%s lane %" PRIu64 ": expected %d, actual %d\n", + TypeName(T(), N).c_str(), static_cast<uint64_t>(i), + static_cast<int>(bool_lanes[i]), static_cast<int>(is_set)); + Print(di, "bools", bools, 0, N); + Print(d_bits, "expected bytes", Load(d_bits, expected.get()), 0, + expected_num_bytes); + Print(d_bits, "actual bytes", Load(d_bits, actual.get()), 0, + expected_num_bytes); + + HWY_ASSERT(false); + } + } + // Any partial bits in the last byte must be zero + for (; i < 8 * bytes_written; ++i) { + const int bit = (actual[i / 8] & (1 << (i % 8))); + if (bit != 0) { + fprintf(stderr, "%s: bit #%" PRIu64 " should be zero\n", + TypeName(T(), N).c_str(), static_cast<uint64_t>(i)); + Print(di, "bools", bools, 0, N); + Print(d_bits, "expected bytes", Load(d_bits, expected.get()), 0, + expected_num_bytes); + Print(d_bits, "actual bytes", Load(d_bits, actual.get()), 0, + expected_num_bytes); + + HWY_ASSERT(false); + } + } + } + } +}; + +HWY_NOINLINE void TestAllStoreMaskBits() { + ForAllTypes(ForPartialVectors<TestStoreMaskBits>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMaskTest); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllMaskedLoad); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllBlendedStore); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllStoreMaskBits); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/mask_test.cc b/third_party/highway/hwy/tests/mask_test.cc new file mode 100644 index 0000000000..cf0d2d4ee8 --- /dev/null +++ b/third_party/highway/hwy/tests/mask_test.cc @@ -0,0 +1,295 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcmp + +#include <algorithm> // std::fill + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/mask_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// All types. +struct TestFromVec { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + + memset(lanes.get(), 0, N * sizeof(T)); + const auto actual_false = MaskFromVec(Load(d, lanes.get())); + HWY_ASSERT_MASK_EQ(d, MaskFalse(d), actual_false); + + memset(lanes.get(), 0xFF, N * sizeof(T)); + const auto actual_true = MaskFromVec(Load(d, lanes.get())); + HWY_ASSERT_MASK_EQ(d, MaskTrue(d), actual_true); + } +}; + +HWY_NOINLINE void TestAllFromVec() { + ForAllTypes(ForPartialVectors<TestFromVec>()); +} + +struct TestFirstN { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned<T>(N); + + using TN = SignedFromSize<HWY_MIN(sizeof(size_t), sizeof(T))>; + const size_t max_len = static_cast<size_t>(LimitsMax<TN>()); + + const size_t max_lanes = HWY_MIN(2 * N, AdjustedReps(512)); + for (size_t len = 0; len <= HWY_MIN(max_lanes, max_len); ++len) { + // Loop instead of Iota+Lt to avoid wraparound for 8-bit T. + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (i < len) ? T{1} : 0; + } + const auto expected = Eq(Load(d, bool_lanes.get()), Set(d, T{1})); + HWY_ASSERT_MASK_EQ(d, expected, FirstN(d, len)); + } + + // Also ensure huge values yield all-true (unless the vector is actually + // larger than max_len). + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (i < max_len) ? T{1} : 0; + } + const auto expected = Eq(Load(d, bool_lanes.get()), Set(d, T{1})); + HWY_ASSERT_MASK_EQ(d, expected, FirstN(d, max_len)); + } +}; + +HWY_NOINLINE void TestAllFirstN() { + ForAllTypes(ForPartialVectors<TestFirstN>()); +} + +struct TestMaskVec { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(d); + auto bool_lanes = AllocateAligned<TI>(N); + + // Each lane should have a chance of having mask=true. + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0); + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + HWY_ASSERT_MASK_EQ(d, mask, MaskFromVec(VecFromMask(d, mask))); + } + } +}; + +HWY_NOINLINE void TestAllMaskVec() { + const ForPartialVectors<TestMaskVec> test; + + test(uint16_t()); + test(int16_t()); + // TODO(janwas): float16_t - cannot compare yet + + ForUIF3264(test); +} + +struct TestAllTrueFalse { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto zero = Zero(d); + auto v = zero; + + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + std::fill(lanes.get(), lanes.get() + N, T(0)); + + HWY_ASSERT(AllTrue(d, Eq(v, zero))); + HWY_ASSERT(!AllFalse(d, Eq(v, zero))); + + // Single lane implies AllFalse = !AllTrue. Otherwise, there are multiple + // lanes and one is nonzero. + const bool expected_all_false = (N != 1); + + // Set each lane to nonzero and back to zero + for (size_t i = 0; i < N; ++i) { + lanes[i] = T(1); + v = Load(d, lanes.get()); + + HWY_ASSERT(!AllTrue(d, Eq(v, zero))); + + HWY_ASSERT(expected_all_false ^ AllFalse(d, Eq(v, zero))); + + lanes[i] = T(-1); + v = Load(d, lanes.get()); + HWY_ASSERT(!AllTrue(d, Eq(v, zero))); + HWY_ASSERT(expected_all_false ^ AllFalse(d, Eq(v, zero))); + + // Reset to all zero + lanes[i] = T(0); + v = Load(d, lanes.get()); + HWY_ASSERT(AllTrue(d, Eq(v, zero))); + HWY_ASSERT(!AllFalse(d, Eq(v, zero))); + } + } +}; + +HWY_NOINLINE void TestAllAllTrueFalse() { + ForAllTypes(ForPartialVectors<TestAllTrueFalse>()); +} + +struct TestCountTrue { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned<TI>(N); + memset(bool_lanes.get(), 0, N * sizeof(TI)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = HWY_MIN(N, size_t(10)); + + for (size_t code = 0; code < (1ull << max_lanes); ++code) { + // Number of zeros written = number of mask lanes that are true. + size_t expected = 0; + for (size_t i = 0; i < max_lanes; ++i) { + const bool is_true = (code & (1ull << i)) != 0; + bool_lanes[i] = is_true ? TI(1) : TI(0); + expected += is_true; + } + + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + const size_t actual = CountTrue(d, mask); + HWY_ASSERT_EQ(expected, actual); + } + } +}; + +HWY_NOINLINE void TestAllCountTrue() { + ForAllTypes(ForPartialVectors<TestCountTrue>()); +} + +struct TestFindFirstTrue { // Also FindKnownFirstTrue + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned<TI>(N); + memset(bool_lanes.get(), 0, N * sizeof(TI)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = AdjustedLog2Reps(HWY_MIN(N, size_t(9))); + + HWY_ASSERT_EQ(intptr_t(-1), FindFirstTrue(d, MaskFalse(d))); + HWY_ASSERT_EQ(intptr_t(0), FindFirstTrue(d, MaskTrue(d))); + HWY_ASSERT_EQ(size_t(0), FindKnownFirstTrue(d, MaskTrue(d))); + + for (size_t code = 1; code < (1ull << max_lanes); ++code) { + for (size_t i = 0; i < max_lanes; ++i) { + bool_lanes[i] = (code & (1ull << i)) ? TI(1) : TI(0); + } + + const size_t expected = + Num0BitsBelowLS1Bit_Nonzero32(static_cast<uint32_t>(code)); + const auto mask = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + HWY_ASSERT_EQ(static_cast<intptr_t>(expected), FindFirstTrue(d, mask)); + HWY_ASSERT_EQ(expected, FindKnownFirstTrue(d, mask)); + } + } +}; + +HWY_NOINLINE void TestAllFindFirstTrue() { + ForAllTypes(ForPartialVectors<TestFindFirstTrue>()); +} + +struct TestLogicalMask { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto m0 = MaskFalse(d); + const auto m_all = MaskTrue(d); + + using TI = MakeSigned<T>; // For mask > 0 comparison + const Rebind<TI, D> di; + const size_t N = Lanes(di); + auto bool_lanes = AllocateAligned<TI>(N); + memset(bool_lanes.get(), 0, N * sizeof(TI)); + + HWY_ASSERT_MASK_EQ(d, m0, Not(m_all)); + HWY_ASSERT_MASK_EQ(d, m_all, Not(m0)); + + Print(d, ".", VecFromMask(d, ExclusiveNeither(m0, m0))); + HWY_ASSERT_MASK_EQ(d, m_all, ExclusiveNeither(m0, m0)); + HWY_ASSERT_MASK_EQ(d, m0, ExclusiveNeither(m_all, m0)); + HWY_ASSERT_MASK_EQ(d, m0, ExclusiveNeither(m0, m_all)); + + // For all combinations of zero/nonzero state of subset of lanes: + const size_t max_lanes = AdjustedLog2Reps(HWY_MIN(N, size_t(6))); + for (size_t code = 0; code < (1ull << max_lanes); ++code) { + for (size_t i = 0; i < max_lanes; ++i) { + bool_lanes[i] = (code & (1ull << i)) ? TI(1) : TI(0); + } + + const auto m = RebindMask(d, Gt(Load(di, bool_lanes.get()), Zero(di))); + + HWY_ASSERT_MASK_EQ(d, m0, Xor(m, m)); + HWY_ASSERT_MASK_EQ(d, m0, AndNot(m, m)); + HWY_ASSERT_MASK_EQ(d, m0, AndNot(m_all, m)); + + HWY_ASSERT_MASK_EQ(d, m, Or(m, m)); + HWY_ASSERT_MASK_EQ(d, m, Or(m0, m)); + HWY_ASSERT_MASK_EQ(d, m, Or(m, m0)); + HWY_ASSERT_MASK_EQ(d, m, Xor(m0, m)); + HWY_ASSERT_MASK_EQ(d, m, Xor(m, m0)); + HWY_ASSERT_MASK_EQ(d, m, And(m, m)); + HWY_ASSERT_MASK_EQ(d, m, And(m_all, m)); + HWY_ASSERT_MASK_EQ(d, m, And(m, m_all)); + HWY_ASSERT_MASK_EQ(d, m, AndNot(m0, m)); + } + } +}; + +HWY_NOINLINE void TestAllLogicalMask() { + ForAllTypes(ForPartialVectors<TestLogicalMask>()); +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMaskTest); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFromVec); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFirstN); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllMaskVec); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllAllTrueFalse); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllCountTrue); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllFindFirstTrue); +HWY_EXPORT_AND_TEST_P(HwyMaskTest, TestAllLogicalMask); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/memory_test.cc b/third_party/highway/hwy/tests/memory_test.cc new file mode 100644 index 0000000000..d17addf544 --- /dev/null +++ b/third_party/highway/hwy/tests/memory_test.cc @@ -0,0 +1,343 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Ensure incompabilities with Windows macros (e.g. #define StoreFence) are +// detected. Must come before Highway headers. +#include "hwy/base.h" +#if defined(_WIN32) || defined(_WIN64) +#include <windows.h> +#endif + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> // std::fill + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/memory_test.cc" +#include "hwy/cache_control.h" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestLoadStore { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto hi = Iota(d, static_cast<T>(1 + N)); + const auto lo = Iota(d, 1); + auto lanes = AllocateAligned<T>(2 * N); + Store(hi, d, &lanes[N]); + Store(lo, d, &lanes[0]); + + // Aligned load + const auto lo2 = Load(d, &lanes[0]); + HWY_ASSERT_VEC_EQ(d, lo2, lo); + + // Aligned store + auto lanes2 = AllocateAligned<T>(2 * N); + Store(lo2, d, &lanes2[0]); + Store(hi, d, &lanes2[N]); + for (size_t i = 0; i < 2 * N; ++i) { + HWY_ASSERT_EQ(lanes[i], lanes2[i]); + } + + // Unaligned load + const auto vu = LoadU(d, &lanes[1]); + auto lanes3 = AllocateAligned<T>(N); + Store(vu, d, lanes3.get()); + for (size_t i = 0; i < N; ++i) { + HWY_ASSERT_EQ(T(i + 2), lanes3[i]); + } + + // Unaligned store + StoreU(lo2, d, &lanes2[N / 2]); + size_t i = 0; + for (; i < N / 2; ++i) { + HWY_ASSERT_EQ(lanes[i], lanes2[i]); + } + for (; i < 3 * N / 2; ++i) { + HWY_ASSERT_EQ(T(i - N / 2 + 1), lanes2[i]); + } + // Subsequent values remain unchanged. + for (; i < 2 * N; ++i) { + HWY_ASSERT_EQ(T(i + 1), lanes2[i]); + } + } +}; + +HWY_NOINLINE void TestAllLoadStore() { + ForAllTypes(ForPartialVectors<TestLoadStore>()); +} + +struct TestSafeCopyN { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto v = Iota(d, 1); + auto from = AllocateAligned<T>(N + 2); + auto to = AllocateAligned<T>(N + 2); + Store(v, d, from.get()); + + // 0: nothing changes + to[0] = T(); + SafeCopyN(0, d, from.get(), to.get()); + HWY_ASSERT_EQ(T(), to[0]); + + // 1: only first changes + to[1] = T(); + SafeCopyN(1, d, from.get(), to.get()); + HWY_ASSERT_EQ(static_cast<T>(1), to[0]); + HWY_ASSERT_EQ(T(), to[1]); + + // N-1: last does not change + to[N - 1] = T(); + SafeCopyN(N - 1, d, from.get(), to.get()); + HWY_ASSERT_EQ(T(), to[N - 1]); + // Also check preceding lanes + to[N - 1] = static_cast<T>(N); + HWY_ASSERT_VEC_EQ(d, to.get(), v); + + // N: all change + to[N] = T(); + SafeCopyN(N, d, from.get(), to.get()); + HWY_ASSERT_VEC_EQ(d, to.get(), v); + HWY_ASSERT_EQ(T(), to[N]); + + // N+1: subsequent lane does not change if using masked store + to[N + 1] = T(); + SafeCopyN(N + 1, d, from.get(), to.get()); + HWY_ASSERT_VEC_EQ(d, to.get(), v); +#if !HWY_MEM_OPS_MIGHT_FAULT + HWY_ASSERT_EQ(T(), to[N + 1]); +#endif + } +}; + +HWY_NOINLINE void TestAllSafeCopyN() { + ForAllTypes(ForPartialVectors<TestSafeCopyN>()); +} + +struct TestLoadDup128 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + // Scalar does not define LoadDup128. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + constexpr size_t N128 = 16 / sizeof(T); + alignas(16) T lanes[N128]; + for (size_t i = 0; i < N128; ++i) { + lanes[i] = static_cast<T>(1 + i); + } + + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>(i % N128 + 1); + } + + HWY_ASSERT_VEC_EQ(d, expected.get(), LoadDup128(d, lanes)); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllLoadDup128() { + ForAllTypes(ForGEVectors<128, TestLoadDup128>()); +} + +struct TestStream { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + const size_t affected_bytes = + (Lanes(d) * sizeof(T) + HWY_STREAM_MULTIPLE - 1) & + ~size_t(HWY_STREAM_MULTIPLE - 1); + const size_t affected_lanes = affected_bytes / sizeof(T); + auto out = AllocateAligned<T>(2 * affected_lanes); + std::fill(out.get(), out.get() + 2 * affected_lanes, T(0)); + + Stream(v, d, out.get()); + FlushStream(); + const auto actual = Load(d, out.get()); + HWY_ASSERT_VEC_EQ(d, v, actual); + // Ensure Stream didn't modify more memory than expected + for (size_t i = affected_lanes; i < 2 * affected_lanes; ++i) { + HWY_ASSERT_EQ(T(0), out[i]); + } + } +}; + +HWY_NOINLINE void TestAllStream() { + const ForPartialVectors<TestStream> test; + // No u8,u16. + test(uint32_t()); + test(uint64_t()); + // No i8,i16. + test(int32_t()); + test(int64_t()); + ForFloatTypes(test); +} + +// Assumes little-endian byte order! +struct TestScatter { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Offset = MakeSigned<T>; + + const size_t N = Lanes(d); + const size_t range = 4 * N; // number of items to scatter + const size_t max_bytes = range * sizeof(T); // upper bound on offset + + RandomState rng; + + // Data to be scattered + auto bytes = AllocateAligned<uint8_t>(max_bytes); + for (size_t i = 0; i < max_bytes; ++i) { + bytes[i] = static_cast<uint8_t>(Random32(&rng) & 0xFF); + } + const auto data = Load(d, reinterpret_cast<const T*>(bytes.get())); + + // Scatter into these regions, ensure vector results match scalar + auto expected = AllocateAligned<T>(range); + auto actual = AllocateAligned<T>(range); + + const Rebind<Offset, D> d_offsets; + auto offsets = AllocateAligned<Offset>(N); // or indices + + for (size_t rep = 0; rep < 100; ++rep) { + // Byte offsets + std::fill(expected.get(), expected.get() + range, T(0)); + std::fill(actual.get(), actual.get() + range, T(0)); + for (size_t i = 0; i < N; ++i) { + // Must be aligned + offsets[i] = static_cast<Offset>((Random32(&rng) % range) * sizeof(T)); + CopyBytes<sizeof(T)>( + bytes.get() + i * sizeof(T), + reinterpret_cast<uint8_t*>(expected.get()) + offsets[i]); + } + const auto voffsets = Load(d_offsets, offsets.get()); + ScatterOffset(data, d, actual.get(), voffsets); + if (!BytesEqual(expected.get(), actual.get(), max_bytes)) { + Print(d, "Data", data); + Print(d_offsets, "Offsets", voffsets); + HWY_ASSERT(false); + } + + // Indices + std::fill(expected.get(), expected.get() + range, T(0)); + std::fill(actual.get(), actual.get() + range, T(0)); + for (size_t i = 0; i < N; ++i) { + offsets[i] = static_cast<Offset>(Random32(&rng) % range); + CopyBytes<sizeof(T)>(bytes.get() + i * sizeof(T), + &expected[size_t(offsets[i])]); + } + const auto vindices = Load(d_offsets, offsets.get()); + ScatterIndex(data, d, actual.get(), vindices); + if (!BytesEqual(expected.get(), actual.get(), max_bytes)) { + Print(d, "Data", data); + Print(d_offsets, "Indices", vindices); + HWY_ASSERT(false); + } + } + } +}; + +HWY_NOINLINE void TestAllScatter() { + ForUIF3264(ForPartialVectors<TestScatter>()); +} + +struct TestGather { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Offset = MakeSigned<T>; + + const size_t N = Lanes(d); + const size_t range = 4 * N; // number of items to gather + const size_t max_bytes = range * sizeof(T); // upper bound on offset + + RandomState rng; + + // Data to be gathered from + auto bytes = AllocateAligned<uint8_t>(max_bytes); + for (size_t i = 0; i < max_bytes; ++i) { + bytes[i] = static_cast<uint8_t>(Random32(&rng) & 0xFF); + } + + auto expected = AllocateAligned<T>(N); + auto offsets = AllocateAligned<Offset>(N); + auto indices = AllocateAligned<Offset>(N); + + for (size_t rep = 0; rep < 100; ++rep) { + // Offsets + for (size_t i = 0; i < N; ++i) { + // Must be aligned + offsets[i] = static_cast<Offset>((Random32(&rng) % range) * sizeof(T)); + CopyBytes<sizeof(T)>(bytes.get() + offsets[i], &expected[i]); + } + + const Rebind<Offset, D> d_offset; + const T* base = reinterpret_cast<const T*>(bytes.get()); + auto actual = GatherOffset(d, base, Load(d_offset, offsets.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + + // Indices + for (size_t i = 0; i < N; ++i) { + indices[i] = + static_cast<Offset>(Random32(&rng) % (max_bytes / sizeof(T))); + CopyBytes<sizeof(T)>(base + indices[i], &expected[i]); + } + actual = GatherIndex(d, base, Load(d_offset, indices.get())); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual); + } + } +}; + +HWY_NOINLINE void TestAllGather() { + ForUIF3264(ForPartialVectors<TestGather>()); +} + +HWY_NOINLINE void TestAllCache() { + LoadFence(); + FlushStream(); + int test = 0; + Prefetch(&test); + FlushCacheline(&test); + Pause(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMemoryTest); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadStore); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllSafeCopyN); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadDup128); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStream); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllScatter); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllGather); +HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllCache); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/mul_test.cc b/third_party/highway/hwy/tests/mul_test.cc new file mode 100644 index 0000000000..5622983cee --- /dev/null +++ b/third_party/highway/hwy/tests/mul_test.cc @@ -0,0 +1,526 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/mul_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <size_t kBits> +constexpr uint64_t FirstBits() { + return (1ull << kBits) - 1; +} +template <> +constexpr uint64_t FirstBits<64>() { + return ~uint64_t{0}; +} + +struct TestUnsignedMul { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto vi = Iota(d, 1); + const auto vj = Iota(d, 3); + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + HWY_ASSERT_VEC_EQ(d, v0, Mul(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Mul(v1, v1)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(v1, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(vi, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((1 + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vi)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((1 + i) * (3 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vj)); + + const T max = LimitsMax<T>(); + const auto vmax = Set(d, max); + HWY_ASSERT_VEC_EQ(d, vmax, Mul(vmax, v1)); + HWY_ASSERT_VEC_EQ(d, vmax, Mul(v1, vmax)); + + constexpr uint64_t kMask = FirstBits<sizeof(T) * 8>(); + const T max2 = (static_cast<uint64_t>(max) * max) & kMask; + HWY_ASSERT_VEC_EQ(d, Set(d, max2), Mul(vmax, vmax)); + } +}; + +struct TestSignedMul { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, T(1)); + const auto vi = Iota(d, 1); + const auto vn = Iota(d, -T(N)); // no i8 supported, so no wraparound + HWY_ASSERT_VEC_EQ(d, v0, Mul(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v1, Mul(v1, v1)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(v1, vi)); + HWY_ASSERT_VEC_EQ(d, vi, Mul(vi, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((1 + i) * (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vi)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((-T(N) + T(i)) * T(1u + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vn, vi)); + HWY_ASSERT_VEC_EQ(d, expected.get(), Mul(vi, vn)); + } +}; + +struct TestMulOverflow { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto vMax = Set(d, LimitsMax<T>()); + HWY_ASSERT_VEC_EQ(d, Mul(vMax, vMax), Mul(vMax, vMax)); + } +}; + +struct TestDivOverflow { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto vZero = Set(d, T(0)); + const auto v1 = Set(d, T(1)); + HWY_ASSERT_VEC_EQ(d, Div(v1, vZero), Div(v1, vZero)); + } +}; + +HWY_NOINLINE void TestAllMul() { + const ForPartialVectors<TestUnsignedMul> test_unsigned; + // No u8. + test_unsigned(uint16_t()); + test_unsigned(uint32_t()); + test_unsigned(uint64_t()); + + const ForPartialVectors<TestSignedMul> test_signed; + // No i8. + test_signed(int16_t()); + test_signed(int32_t()); + test_signed(int64_t()); + + const ForPartialVectors<TestMulOverflow> test_mul_overflow; + test_mul_overflow(int16_t()); + test_mul_overflow(int32_t()); +#if HWY_HAVE_INTEGER64 + test_mul_overflow(int64_t()); +#endif + + const ForPartialVectors<TestDivOverflow> test_div_overflow; + test_div_overflow(float()); +#if HWY_HAVE_FLOAT64 + test_div_overflow(double()); +#endif +} + +struct TestMulHigh { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Wide = MakeWide<T>; + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned<T>(N); + auto expected_lanes = AllocateAligned<T>(N); + + const auto vi = Iota(d, 1); + // no i8 supported, so no wraparound + const auto vni = Iota(d, T(static_cast<T>(~N + 1))); + + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(v0, vi)); + HWY_ASSERT_VEC_EQ(d, v0, MulHigh(vi, v0)); + + // Large positive squared + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = T(LimitsMax<T>() >> i); + expected_lanes[i] = T((Wide(in_lanes[i]) * in_lanes[i]) >> 16); + } + auto v = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, v)); + + // Large positive * small positive + for (size_t i = 0; i < N; ++i) { + expected_lanes[i] = T((Wide(in_lanes[i]) * T(1u + i)) >> 16); + } + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, vi)); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(vi, v)); + + // Large positive * small negative + for (size_t i = 0; i < N; ++i) { + expected_lanes[i] = T((Wide(in_lanes[i]) * T(i - N)) >> 16); + } + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(v, vni)); + HWY_ASSERT_VEC_EQ(d, expected_lanes.get(), MulHigh(vni, v)); + } +}; + +HWY_NOINLINE void TestAllMulHigh() { + ForPartialVectors<TestMulHigh> test; + test(int16_t()); + test(uint16_t()); +} + +struct TestMulFixedPoint15 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, v0, MulFixedPoint15(v0, v0)); + HWY_ASSERT_VEC_EQ(d, v0, MulFixedPoint15(v0, v0)); + + const size_t N = Lanes(d); + auto in1 = AllocateAligned<T>(N); + auto in2 = AllocateAligned<T>(N); + auto expected = AllocateAligned<T>(N); + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(10000); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = static_cast<T>(Random64(&rng) & 0xFFFF); + in2[i] = static_cast<T>(Random64(&rng) & 0xFFFF); + } + + for (size_t i = 0; i < N; ++i) { + // There are three ways to compute the results. x86 and ARM are defined + // using 32-bit multiplication results: + const int arm = (2 * in1[i] * in2[i] + 0x8000) >> 16; + const int x86 = (((in1[i] * in2[i]) >> 14) + 1) >> 1; + // On other platforms, split the result into upper and lower 16 bits. + const auto v1 = Set(d, in1[i]); + const auto v2 = Set(d, in2[i]); + const int hi = GetLane(MulHigh(v1, v2)); + const int lo = GetLane(Mul(v1, v2)) & 0xFFFF; + const int split = 2 * hi + ((lo + 0x4000) >> 15); + expected[i] = static_cast<T>(arm); + if (in1[i] != -32768 || in2[i] != -32768) { + HWY_ASSERT_EQ(arm, x86); + HWY_ASSERT_EQ(arm, split); + } + } + + const auto a = Load(d, in1.get()); + const auto b = Load(d, in2.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulFixedPoint15(a, b)); + } + } +}; + +HWY_NOINLINE void TestAllMulFixedPoint15() { + ForPartialVectors<TestMulFixedPoint15>()(int16_t()); +} + +struct TestMulEven { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using Wide = MakeWide<T>; + const Repartition<Wide, D> d2; + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d2, Zero(d2), MulEven(v0, v0)); + + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned<T>(N); + auto expected = AllocateAligned<Wide>(Lanes(d2)); + for (size_t i = 0; i < N; i += 2) { + in_lanes[i + 0] = LimitsMax<T>() >> i; + if (N != 1) { + in_lanes[i + 1] = 1; // unused + } + expected[i / 2] = Wide(in_lanes[i + 0]) * in_lanes[i + 0]; + } + + const auto v = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(d2, expected.get(), MulEven(v, v)); + } +}; + +struct TestMulEvenOdd64 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const auto v0 = Zero(d); + HWY_ASSERT_VEC_EQ(d, Zero(d), MulEven(v0, v0)); + HWY_ASSERT_VEC_EQ(d, Zero(d), MulOdd(v0, v0)); + + const size_t N = Lanes(d); + if (N == 1) return; + + auto in1 = AllocateAligned<T>(N); + auto in2 = AllocateAligned<T>(N); + auto expected_even = AllocateAligned<T>(N); + auto expected_odd = AllocateAligned<T>(N); + + // Random inputs in each lane + RandomState rng; + for (size_t rep = 0; rep < AdjustedReps(1000); ++rep) { + for (size_t i = 0; i < N; ++i) { + in1[i] = Random64(&rng); + in2[i] = Random64(&rng); + } + + for (size_t i = 0; i < N; i += 2) { + expected_even[i] = Mul128(in1[i], in2[i], &expected_even[i + 1]); + expected_odd[i] = Mul128(in1[i + 1], in2[i + 1], &expected_odd[i + 1]); + } + + const auto a = Load(d, in1.get()); + const auto b = Load(d, in2.get()); + HWY_ASSERT_VEC_EQ(d, expected_even.get(), MulEven(a, b)); + HWY_ASSERT_VEC_EQ(d, expected_odd.get(), MulOdd(a, b)); + } +#else + (void)d; +#endif // HWY_TARGET != HWY_SCALAR + } +}; + +HWY_NOINLINE void TestAllMulEven() { + ForGEVectors<64, TestMulEven> test; + test(int32_t()); + test(uint32_t()); + + ForGEVectors<128, TestMulEvenOdd64>()(uint64_t()); +} + +#ifndef HWY_NATIVE_FMA +#error "Bug in set_macros-inl.h, did not set HWY_NATIVE_FMA" +#endif + +struct TestMulAdd { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto k0 = Zero(d); + const auto kNeg0 = Set(d, T(-0.0)); + const auto v1 = Iota(d, 1); + const auto v2 = Iota(d, 2); + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + HWY_ASSERT_VEC_EQ(d, k0, MulAdd(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, v2, MulAdd(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, MulAdd(v1, k0, v2)); + HWY_ASSERT_VEC_EQ(d, k0, NegMulAdd(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, v2, NegMulAdd(v1, k0, v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((i + 1) * (i + 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v1, v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(Neg(v2), v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v1, Neg(v2), k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((i + 2) * (i + 2) + (i + 1)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulAdd(v2, v2, v1)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(Neg(v2), v2, v1)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = + T(-T(i + 2u) * static_cast<T>(i + 2) + static_cast<T>(1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulAdd(v2, v2, v1)); + + HWY_ASSERT_VEC_EQ(d, k0, MulSub(k0, k0, k0)); + HWY_ASSERT_VEC_EQ(d, kNeg0, NegMulSub(k0, k0, k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = -T(i + 2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(k0, v1, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, k0, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(k0), v1, v2)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v1, Neg(k0), v2)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((i + 1) * (i + 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v1, v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v1, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(v1), v2, k0)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(v2, Neg(v1), k0)); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((i + 2) * (i + 2) - (1 + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), MulSub(v2, v2, v1)); + HWY_ASSERT_VEC_EQ(d, expected.get(), NegMulSub(Neg(v2), v2, v1)); + } +}; + +HWY_NOINLINE void TestAllMulAdd() { + ForFloatTypes(ForPartialVectors<TestMulAdd>()); +} + +struct TestReorderWidenMulAccumulate { + template <typename TN, class DN> + HWY_NOINLINE void operator()(TN /*unused*/, DN dn) { + using TW = MakeWide<TN>; + const RepartitionToWide<DN> dw; + const Half<DN> dnh; + using VW = Vec<decltype(dw)>; + using VN = Vec<decltype(dn)>; + const size_t NN = Lanes(dn); + + const VW f0 = Zero(dw); + const VW f1 = Set(dw, TW{1}); + const VN bf0 = Zero(dn); + // Cannot Set() bfloat16_t directly. + const VN bf1 = ReorderDemote2To(dn, f1, f1); + + // Any input zero => both outputs zero + VW sum1 = f0; + HWY_ASSERT_VEC_EQ(dw, f0, + ReorderWidenMulAccumulate(dw, bf0, bf0, f0, sum1)); + HWY_ASSERT_VEC_EQ(dw, f0, sum1); + HWY_ASSERT_VEC_EQ(dw, f0, + ReorderWidenMulAccumulate(dw, bf0, bf1, f0, sum1)); + HWY_ASSERT_VEC_EQ(dw, f0, sum1); + HWY_ASSERT_VEC_EQ(dw, f0, + ReorderWidenMulAccumulate(dw, bf1, bf0, f0, sum1)); + HWY_ASSERT_VEC_EQ(dw, f0, sum1); + + // delta[p] := 1, all others zero. For each p: Dot(delta, all-ones) == 1. + auto delta_w = AllocateAligned<TW>(NN); + for (size_t p = 0; p < NN; ++p) { + // Workaround for incorrect Clang wasm codegen: re-initialize the entire + // array rather than zero-initialize once and then toggle lane p. + for (size_t i = 0; i < NN; ++i) { + delta_w[i] = static_cast<TW>(i == p); + } + const VW delta0 = Load(dw, delta_w.get()); + const VW delta1 = Load(dw, delta_w.get() + NN / 2); + const VN delta = ReorderDemote2To(dn, delta0, delta1); + + { + sum1 = f0; + const VW sum0 = ReorderWidenMulAccumulate(dw, delta, bf1, f0, sum1); + HWY_ASSERT_EQ(TW{1}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + // Swapped arg order + { + sum1 = f0; + const VW sum0 = ReorderWidenMulAccumulate(dw, bf1, delta, f0, sum1); + HWY_ASSERT_EQ(TW{1}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + // Start with nonzero sum0 or sum1 + { + VW sum0 = PromoteTo(dw, LowerHalf(dnh, delta)); + sum1 = PromoteTo(dw, UpperHalf(dnh, delta)); + sum0 = ReorderWidenMulAccumulate(dw, delta, bf1, sum0, sum1); + HWY_ASSERT_EQ(TW{2}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + // Start with nonzero sum0 or sum1, and swap arg order + { + VW sum0 = PromoteTo(dw, LowerHalf(dnh, delta)); + sum1 = PromoteTo(dw, UpperHalf(dnh, delta)); + sum0 = ReorderWidenMulAccumulate(dw, bf1, delta, sum0, sum1); + HWY_ASSERT_EQ(TW{2}, GetLane(SumOfLanes(dw, Add(sum0, sum1)))); + } + } + } +}; + +HWY_NOINLINE void TestAllReorderWidenMulAccumulate() { + ForShrinkableVectors<TestReorderWidenMulAccumulate>()(bfloat16_t()); + ForShrinkableVectors<TestReorderWidenMulAccumulate>()(int16_t()); +} + +struct TestRearrangeToOddPlusEven { + template <typename TN, class DN> + HWY_NOINLINE void operator()(TN /*unused*/, DN dn) { + using TW = MakeWide<TN>; + const RebindToUnsigned<DN> du; + const RepartitionToWide<DN> dw; + const Half<DN> dnh; + const RebindToUnsigned<decltype(dnh)> duh; + using VW = Vec<decltype(dw)>; + using VN = Vec<decltype(dn)>; + const size_t NW = Lanes(dw); + + const VW up0 = Iota(dw, TW{1}); + const VW up1 = Iota(dw, static_cast<TW>(1 + NW)); + // We will compute i * (N-i) to avoid per-lane overflow. + const VW down0 = Reverse(dw, up1); + const VW down1 = Reverse(dw, up0); + + // Combine is not available for bf16, so cast to u16. + const auto a0 = BitCast(duh, DemoteTo(dnh, up0)); + const auto a1 = BitCast(duh, DemoteTo(dnh, up1)); + const VN a = BitCast(dn, Combine(du, a1, a0)); + const auto b0 = BitCast(duh, DemoteTo(dnh, down0)); + const auto b1 = BitCast(duh, DemoteTo(dnh, down1)); + const VN b = BitCast(dn, Combine(du, b1, b0)); + + const auto expected = AllocateAligned<TW>(NW); + for (size_t iw = 0; iw < NW; ++iw) { + const size_t in = iw * 2; // even, odd is +1 + const size_t a0 = 1 + in; + const size_t b0 = 1 + 2 * NW - a0; + const size_t a1 = a0 + 1; + const size_t b1 = b0 - 1; + expected[iw] = static_cast<TW>(a0 * b0 + a1 * b1); + } + + VW sum1 = Zero(dw); + const VW sum0 = ReorderWidenMulAccumulate(dw, a, b, Zero(dw), sum1); + const VW sum_odd_even = RearrangeToOddPlusEven(sum0, sum1); + HWY_ASSERT_VEC_EQ(dw, expected.get(), sum_odd_even); + } +}; + +HWY_NOINLINE void TestAllRearrangeToOddPlusEven() { + ForShrinkableVectors<TestRearrangeToOddPlusEven>()(bfloat16_t()); + ForShrinkableVectors<TestRearrangeToOddPlusEven>()(int16_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyMulTest); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMul); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulHigh); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulFixedPoint15); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulEven); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulAdd); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllReorderWidenMulAccumulate); +HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllRearrangeToOddPlusEven); + +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/reduction_test.cc b/third_party/highway/hwy/tests/reduction_test.cc new file mode 100644 index 0000000000..5cc051ef1c --- /dev/null +++ b/third_party/highway/hwy/tests/reduction_test.cc @@ -0,0 +1,261 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/reduction_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestSumOfLanes { + template <typename T, size_t N, int P, + hwy::EnableIf<!IsSigned<T>() || ((N & 1) != 0)>* = nullptr> + HWY_NOINLINE void SignedEvenLengthVectorTests(Simd<T, N, P>) { + // do nothing + } + template <typename T, size_t N, int P, + hwy::EnableIf<IsSigned<T>() && ((N & 1) == 0)>* = nullptr> + HWY_NOINLINE void SignedEvenLengthVectorTests(Simd<T, N, P> d) { + const T pairs = static_cast<T>(Lanes(d) / 2); + + // Lanes are the repeated sequence -2, 1, [...]; each pair sums to -1, + // so the eventual total is just -(N/2). + Vec<decltype(d)> v = + InterleaveLower(Set(d, static_cast<T>(-2)), Set(d, T{1})); + HWY_ASSERT_VEC_EQ(d, Set(d, static_cast<T>(-pairs)), SumOfLanes(d, v)); + + // Similar test with a positive result. + v = InterleaveLower(Set(d, static_cast<T>(-2)), Set(d, T{4})); + HWY_ASSERT_VEC_EQ(d, Set(d, static_cast<T>(pairs * 2)), SumOfLanes(d, v)); + } + + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned<T>(N); + + // Lane i = bit i, higher lanes 0 + double sum = 0.0; + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast<T>(1ull << i) : 0; + sum += static_cast<double>(in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, T(sum)), + SumOfLanes(d, Load(d, in_lanes.get()))); + + // Lane i = i (iota) to include upper lanes + sum = 0.0; + for (size_t i = 0; i < N; ++i) { + sum += static_cast<double>(i); + } + HWY_ASSERT_VEC_EQ(d, Set(d, T(sum)), SumOfLanes(d, Iota(d, 0))); + + // Run more tests only for signed types with even vector lengths. Some of + // this code may not otherwise compile, so put it in a templated function. + SignedEvenLengthVectorTests(d); + } +}; + +HWY_NOINLINE void TestAllSumOfLanes() { + ForUIF3264(ForPartialVectors<TestSumOfLanes>()); + ForUI16(ForPartialVectors<TestSumOfLanes>()); + +#if HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_SSE4 || HWY_TARGET == HWY_SSSE3 + ForUI8(ForGEVectors<64, TestSumOfLanes>()); +#endif +} + +struct TestMinOfLanes { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned<T>(N); + + // Lane i = bit i, higher lanes = 2 (not the minimum) + T min = HighestValue<T>(); + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast<T>(1ull << i) : 2; + min = HWY_MIN(min, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(d, Load(d, in_lanes.get()))); + + // Lane i = N - i to include upper lanes + min = HighestValue<T>(); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = static_cast<T>(N - i); // no 8-bit T so no wraparound + min = HWY_MIN(min, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(d, Load(d, in_lanes.get()))); + + // Bug #910: also check negative values + min = HighestValue<T>(); + const T input_copy[] = {static_cast<T>(-1), + static_cast<T>(-2), + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14}; + size_t i = 0; + for (; i < HWY_MIN(N, sizeof(input_copy) / sizeof(T)); ++i) { + in_lanes[i] = input_copy[i]; + min = HWY_MIN(min, input_copy[i]); + } + // Pad with neutral element to full vector (so we can load) + for (; i < N; ++i) { + in_lanes[i] = min; + } + HWY_ASSERT_VEC_EQ(d, Set(d, min), MinOfLanes(d, Load(d, in_lanes.get()))); + } +}; + +struct TestMaxOfLanes { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto in_lanes = AllocateAligned<T>(N); + + T max = LowestValue<T>(); + // Avoid setting sign bit and cap at double precision + constexpr size_t kBits = HWY_MIN(sizeof(T) * 8 - 1, 51); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = i < kBits ? static_cast<T>(1ull << i) : 0; + max = HWY_MAX(max, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(d, Load(d, in_lanes.get()))); + + // Lane i = i to include upper lanes + max = LowestValue<T>(); + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = static_cast<T>(i); // no 8-bit T so no wraparound + max = HWY_MAX(max, in_lanes[i]); + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(d, Load(d, in_lanes.get()))); + + // Bug #910: also check negative values + max = LowestValue<T>(); + const T input_copy[] = {static_cast<T>(-1), + static_cast<T>(-2), + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14}; + size_t i = 0; + for (; i < HWY_MIN(N, sizeof(input_copy) / sizeof(T)); ++i) { + in_lanes[i] = input_copy[i]; + max = HWY_MAX(max, in_lanes[i]); + } + // Pad with neutral element to full vector (so we can load) + for (; i < N; ++i) { + in_lanes[i] = max; + } + HWY_ASSERT_VEC_EQ(d, Set(d, max), MaxOfLanes(d, Load(d, in_lanes.get()))); + } +}; + +HWY_NOINLINE void TestAllMinMaxOfLanes() { + const ForPartialVectors<TestMinOfLanes> test_min; + const ForPartialVectors<TestMaxOfLanes> test_max; + ForUIF3264(test_min); + ForUIF3264(test_max); + ForUI16(test_min); + ForUI16(test_max); + +#if HWY_TARGET == HWY_NEON || HWY_TARGET == HWY_SSE4 || HWY_TARGET == HWY_SSSE3 + ForUI8(ForGEVectors<64, TestMinOfLanes>()); + ForUI8(ForGEVectors<64, TestMaxOfLanes>()); +#endif +} + +struct TestSumsOf8 { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const size_t N = Lanes(d); + if (N < 8) return; + const Repartition<uint64_t, D> du64; + + auto in_lanes = AllocateAligned<T>(N); + auto sum_lanes = AllocateAligned<uint64_t>(N / 8); + + for (size_t rep = 0; rep < 100; ++rep) { + for (size_t i = 0; i < N; ++i) { + in_lanes[i] = Random64(&rng) & 0xFF; + } + + for (size_t idx_sum = 0; idx_sum < N / 8; ++idx_sum) { + uint64_t sum = 0; + for (size_t i = 0; i < 8; ++i) { + sum += in_lanes[idx_sum * 8 + i]; + } + sum_lanes[idx_sum] = sum; + } + + const Vec<D> in = Load(d, in_lanes.get()); + HWY_ASSERT_VEC_EQ(du64, sum_lanes.get(), SumsOf8(in)); + } + } +}; + +HWY_NOINLINE void TestAllSumsOf8() { + ForGEVectors<64, TestSumsOf8>()(uint8_t()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyReductionTest); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumOfLanes); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes); +HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf8); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/reverse_test.cc b/third_party/highway/hwy/tests/reverse_test.cc new file mode 100644 index 0000000000..b1572c03fe --- /dev/null +++ b/third_party/highway/hwy/tests/reverse_test.cc @@ -0,0 +1,186 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/reverse_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestReverse { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned<D> du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned<T>(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned<T>(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[N - 1 - i]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse(d, v)); + } +}; + +struct TestReverse2 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned<D> du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned<T>(N); + if (N == 1) { + Store(v, d, expected.get()); + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse2(d, v)); + return; + } + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned<T>(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 1]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse2(d, v)); + } +}; + +struct TestReverse4 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned<D> du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned<T>(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned<T>(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 3]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse4(d, v)); + } +}; + +struct TestReverse8 { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned<D> du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned<T>(N); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned<T>(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = copy[i ^ 7]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Reverse8(d, v)); + } +}; + +HWY_NOINLINE void TestAllReverse() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF163264(ForPartialVectors<TestReverse>()); +} + +HWY_NOINLINE void TestAllReverse2() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<128, TestReverse2>()); + ForUIF32(ForGEVectors<64, TestReverse2>()); + ForUIF16(ForGEVectors<32, TestReverse2>()); + +#if HWY_TARGET == HWY_SSSE3 + // Implemented mainly for internal use. + ForUI8(ForPartialVectors<TestReverse2>()); +#endif +} + +HWY_NOINLINE void TestAllReverse4() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<256, TestReverse4>()); + ForUIF32(ForGEVectors<128, TestReverse4>()); + ForUIF16(ForGEVectors<64, TestReverse4>()); +} + +HWY_NOINLINE void TestAllReverse8() { + // 8-bit is not supported because Risc-V uses rgather of Lanes - Iota, + // which requires 16 bits. + ForUIF64(ForGEVectors<512, TestReverse8>()); + ForUIF32(ForGEVectors<256, TestReverse8>()); + ForUIF16(ForGEVectors<128, TestReverse8>()); +} + +struct TestReverseBlocks { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const RebindToUnsigned<D> du; // Iota does not support float16_t. + const auto v = BitCast(d, Iota(du, 1)); + auto expected = AllocateAligned<T>(N); + + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + const size_t num_blocks = N / kLanesPerBlock; + HWY_ASSERT(num_blocks != 0); + + // Can't set float16_t value directly, need to permute in memory. + auto copy = AllocateAligned<T>(N); + Store(v, d, copy.get()); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / kLanesPerBlock; + const size_t base = (num_blocks - 1 - idx_block) * kLanesPerBlock; + expected[i] = copy[base + (i % kLanesPerBlock)]; + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ReverseBlocks(d, v)); + } +}; + +HWY_NOINLINE void TestAllReverseBlocks() { + ForAllTypes(ForGEVectors<128, TestReverseBlocks>()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyReverseTest); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse2); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse4); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverse8); +HWY_EXPORT_AND_TEST_P(HwyReverseTest, TestAllReverseBlocks); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/shift_test.cc b/third_party/highway/hwy/tests/shift_test.cc new file mode 100644 index 0000000000..585eba761c --- /dev/null +++ b/third_party/highway/hwy/tests/shift_test.cc @@ -0,0 +1,428 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <limits> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/shift_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <bool kSigned> +struct TestLeftShifts { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestLeftShifts</*kSigned=*/false>()(t, d); + } + + using TI = MakeSigned<T>; + using TU = MakeUnsigned<T>; + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + // Values to shift + const auto values = Iota(d, static_cast<T>(kSigned ? -TI(N) : TI(0))); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftLeft<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftLeftSame(values, 0)); + + // 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(T(i) - T(N)) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(T(i) - T(N)) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeft<kMaxShift>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftLeftSame(values, kMaxShift)); + } +}; + +template <bool kSigned> +struct TestVariableLeftShifts { + template <typename T, class D> + HWY_NOINLINE void operator()(T t, D d) { + if (kSigned) { + // Also test positive values + TestVariableLeftShifts</*kSigned=*/false>()(t, d); + } + + using TI = MakeSigned<T>; + using TU = MakeUnsigned<T>; + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, kSigned ? -TI(N) : TI(0)); // value to shift + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shl(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, v1)); + + // Same: max + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + const T value = kSigned ? T(i) - T(N) : T(i); + expected[i] = T(TU(value) << (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(values, small_shifts)); + + // Variable: large + for (size_t i = 0; i < N; ++i) { + expected[i] = T(TU(1) << (kMaxShift - (i & kMaxShift))); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shl(v1, large_shifts)); + } +}; + +struct TestUnsignedRightShifts { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + const auto values = Iota(d, 0); + + const T kMax = LimitsMax<T>(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<kMaxShift>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, kMaxShift)); + } +}; + +struct TestRotateRight { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + constexpr size_t kBits = sizeof(T) * 8; + const auto mask_shift = Set(d, T{kBits}); + // Cover as many bit positions as possible to test shifting out + const auto values = Shl(Set(d, T{1}), And(Iota(d, 0), mask_shift)); + + // Rotate by 0 + HWY_ASSERT_VEC_EQ(d, values, RotateRight<0>(values)); + + // Rotate by 1 + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> 1) | (expected[i] << (kBits - 1)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight<1>(values)); + + // Rotate by half + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> (kBits / 2)) | (expected[i] << (kBits / 2)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight<kBits / 2>(values)); + + // Rotate by max + Store(values, d, expected.get()); + for (size_t i = 0; i < N; ++i) { + expected[i] = (expected[i] >> (kBits - 1)) | (expected[i] << 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), RotateRight<kBits - 1>(values)); + } +}; + +struct TestVariableUnsignedRightShifts { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + const auto v0 = Zero(d); + const auto v1 = Set(d, 1); + const auto values = Iota(d, 0); + + const T kMax = LimitsMax<T>(); + const auto max = Set(d, kMax); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + // Same: 0 + HWY_ASSERT_VEC_EQ(d, values, Shr(values, v0)); + + // Same: 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, v1)); + + // Same: max + HWY_ASSERT_VEC_EQ(d, v0, Shr(values, max_shift)); + + // Variable: small + for (size_t i = 0; i < N; ++i) { + expected[i] = T(i) >> (i & kMaxShift); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(values, small_shifts)); + + // Variable: Large + for (size_t i = 0; i < N; ++i) { + expected[i] = kMax >> (kMaxShift - (i & kMaxShift)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(max, large_shifts)); + } +}; + +template <int kAmount, typename T> +T RightShiftNegative(T val) { + // C++ shifts are implementation-defined for negative numbers, and we have + // seen divisions replaced with shifts, so resort to bit operations. + using TU = hwy::MakeUnsigned<T>; + TU bits; + CopySameSize(&val, &bits); + + const TU shifted = TU(bits >> kAmount); + + const TU all = TU(~TU(0)); + const size_t num_zero = sizeof(TU) * 8 - 1 - kAmount; + const TU sign_extended = static_cast<TU>((all << num_zero) & LimitsMax<TU>()); + + bits = shifted | sign_extended; + CopySameSize(&bits, &val); + return val; +} + +class TestSignedRightShifts { + public: + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + constexpr T kMin = LimitsMin<T>(); + constexpr T kMax = LimitsMax<T>(); + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto values = And(Iota(d, 0), Set(d, kMax)); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, values, ShiftRight<0>(values)); + HWY_ASSERT_VEC_EQ(d, values, ShiftRightSame(values, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(values)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(values, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight<kMaxShift>(values)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(values, kMaxShift)); + + // Even negative value + Test<0>(kMin, d, __LINE__); + Test<1>(kMin, d, __LINE__); + Test<2>(kMin, d, __LINE__); + Test<kMaxShift>(kMin, d, __LINE__); + + const T odd = static_cast<T>(kMin + 1); + Test<0>(odd, d, __LINE__); + Test<1>(odd, d, __LINE__); + Test<2>(odd, d, __LINE__); + Test<kMaxShift>(odd, d, __LINE__); + } + + private: + template <int kAmount, typename T, class D> + void Test(T val, D d, int line) { + const auto expected = Set(d, RightShiftNegative<kAmount>(val)); + const auto in = Set(d, val); + const char* file = __FILE__; + AssertVecEqual(d, expected, ShiftRight<kAmount>(in), file, line); + AssertVecEqual(d, expected, ShiftRightSame(in, kAmount), file, line); + } +}; + +struct TestVariableSignedRightShifts { + template <typename T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using TU = MakeUnsigned<T>; + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + + constexpr T kMin = LimitsMin<T>(); + constexpr T kMax = LimitsMax<T>(); + + constexpr size_t kMaxShift = (sizeof(T) * 8) - 1; + + // First test positive values, negative are checked below. + const auto v0 = Zero(d); + const auto positive = Iota(d, 0) & Set(d, kMax); + + // Shift by 0 + HWY_ASSERT_VEC_EQ(d, positive, ShiftRight<0>(positive)); + HWY_ASSERT_VEC_EQ(d, positive, ShiftRightSame(positive, 0)); + + // Shift by 1 + for (size_t i = 0; i < N; ++i) { + expected[i] = T(T(i & kMax) >> 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRight<1>(positive)); + HWY_ASSERT_VEC_EQ(d, expected.get(), ShiftRightSame(positive, 1)); + + // max + HWY_ASSERT_VEC_EQ(d, v0, ShiftRight<kMaxShift>(positive)); + HWY_ASSERT_VEC_EQ(d, v0, ShiftRightSame(positive, kMaxShift)); + + const auto max_shift = Set(d, kMaxShift); + const auto small_shifts = And(Iota(d, 0), max_shift); + const auto large_shifts = max_shift - small_shifts; + + const auto negative = Iota(d, kMin); + + // Test varying negative to shift + for (size_t i = 0; i < N; ++i) { + expected[i] = RightShiftNegative<1>(static_cast<T>(kMin + i)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(negative, Set(d, 1))); + + // Shift MSB right by small amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = i & kMaxShift; + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopySameSize(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), small_shifts)); + + // Shift MSB right by large amounts + for (size_t i = 0; i < N; ++i) { + const size_t amount = kMaxShift - (i & kMaxShift); + const TU shifted = ~((1ull << (kMaxShift - amount)) - 1); + CopySameSize(&shifted, &expected[i]); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), Shr(Set(d, kMin), large_shifts)); + } +}; + +HWY_NOINLINE void TestAllShifts() { + ForUnsignedTypes(ForPartialVectors<TestLeftShifts</*kSigned=*/false>>()); + ForSignedTypes(ForPartialVectors<TestLeftShifts</*kSigned=*/true>>()); + ForUnsignedTypes(ForPartialVectors<TestUnsignedRightShifts>()); + ForSignedTypes(ForPartialVectors<TestSignedRightShifts>()); +} + +HWY_NOINLINE void TestAllVariableShifts() { + const ForPartialVectors<TestLeftShifts</*kSigned=*/false>> shl_u; + const ForPartialVectors<TestLeftShifts</*kSigned=*/true>> shl_s; + const ForPartialVectors<TestUnsignedRightShifts> shr_u; + const ForPartialVectors<TestSignedRightShifts> shr_s; + + shl_u(uint16_t()); + shr_u(uint16_t()); + + shl_u(uint32_t()); + shr_u(uint32_t()); + + shl_s(int16_t()); + shr_s(int16_t()); + + shl_s(int32_t()); + shr_s(int32_t()); + +#if HWY_HAVE_INTEGER64 + shl_u(uint64_t()); + shr_u(uint64_t()); + + shl_s(int64_t()); + shr_s(int64_t()); +#endif +} + +HWY_NOINLINE void TestAllRotateRight() { + const ForPartialVectors<TestRotateRight> test; + test(uint32_t()); +#if HWY_HAVE_INTEGER64 + test(uint64_t()); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwyShiftTest); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllShifts); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllVariableShifts); +HWY_EXPORT_AND_TEST_P(HwyShiftTest, TestAllRotateRight); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/swizzle_test.cc b/third_party/highway/hwy/tests/swizzle_test.cc new file mode 100644 index 0000000000..f447f7a800 --- /dev/null +++ b/third_party/highway/hwy/tests/swizzle_test.cc @@ -0,0 +1,272 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <string.h> // memset + +#include "hwy/base.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/swizzle_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestGetLane { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + HWY_ASSERT_EQ(T(1), GetLane(v)); + } +}; + +HWY_NOINLINE void TestAllGetLane() { + ForAllTypes(ForPartialVectors<TestGetLane>()); +} + +struct TestExtractLane { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const auto v = Iota(d, T(1)); + for (size_t i = 0; i < Lanes(d); ++i) { + const T actual = ExtractLane(v, i); + HWY_ASSERT_EQ(static_cast<T>(i + 1), actual); + } + } +}; + +HWY_NOINLINE void TestAllExtractLane() { + ForAllTypes(ForPartialVectors<TestExtractLane>()); +} + +struct TestInsertLane { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + using V = Vec<D>; + const V v = Iota(d, T(1)); + const size_t N = Lanes(d); + auto lanes = AllocateAligned<T>(N); + Store(v, d, lanes.get()); + + for (size_t i = 0; i < Lanes(d); ++i) { + lanes[i] = T{0}; + const V actual = InsertLane(v, i, static_cast<T>(i + 1)); + HWY_ASSERT_VEC_EQ(d, v, actual); + Store(v, d, lanes.get()); // restore lane i + } + } +}; + +HWY_NOINLINE void TestAllInsertLane() { + ForAllTypes(ForPartialVectors<TestInsertLane>()); +} + +struct TestDupEven { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((static_cast<int>(i) & ~1) + 1); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), DupEven(Iota(d, 1))); + } +}; + +HWY_NOINLINE void TestAllDupEven() { + ForUIF3264(ForShrinkableVectors<TestDupEven>()); +} + +struct TestDupOdd { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>((static_cast<int>(i) & ~1) + 2); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), DupOdd(Iota(d, 1))); +#else + (void)d; +#endif + } +}; + +HWY_NOINLINE void TestAllDupOdd() { + ForUIF3264(ForShrinkableVectors<TestDupOdd>()); +} + +struct TestOddEven { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto even = Iota(d, 1); + const auto odd = Iota(d, static_cast<T>(1 + N)); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>(1 + i + ((i & 1) ? N : 0)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), OddEven(odd, even)); + } +}; + +HWY_NOINLINE void TestAllOddEven() { + ForAllTypes(ForShrinkableVectors<TestOddEven>()); +} + +struct TestOddEvenBlocks { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + const auto even = Iota(d, 1); + const auto odd = Iota(d, static_cast<T>(1 + N)); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / (16 / sizeof(T)); + expected[i] = static_cast<T>(1 + i + ((idx_block & 1) ? N : 0)); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), OddEvenBlocks(odd, even)); + } +}; + +HWY_NOINLINE void TestAllOddEvenBlocks() { + ForAllTypes(ForGEVectors<128, TestOddEvenBlocks>()); +} + +struct TestSwapAdjacentBlocks { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const size_t N = Lanes(d); + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (N < 2 * kLanesPerBlock) return; + const auto vi = Iota(d, 1); + auto expected = AllocateAligned<T>(N); + for (size_t i = 0; i < N; ++i) { + const size_t idx_block = i / kLanesPerBlock; + const size_t base = (idx_block ^ 1) * kLanesPerBlock; + const size_t mod = i % kLanesPerBlock; + expected[i] = static_cast<T>(1 + base + mod); + } + HWY_ASSERT_VEC_EQ(d, expected.get(), SwapAdjacentBlocks(vi)); + } +}; + +HWY_NOINLINE void TestAllSwapAdjacentBlocks() { + ForAllTypes(ForGEVectors<128, TestSwapAdjacentBlocks>()); +} + +struct TestTableLookupLanes { + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + const RebindToSigned<D> di; + using TI = TFromD<decltype(di)>; +#if HWY_TARGET != HWY_SCALAR + const size_t N = Lanes(d); + auto idx = AllocateAligned<TI>(N); + memset(idx.get(), 0, N * sizeof(TI)); + auto expected = AllocateAligned<T>(N); + const auto v = Iota(d, 1); + + if (N <= 8) { // Test all permutations + for (size_t i0 = 0; i0 < N; ++i0) { + idx[0] = static_cast<TI>(i0); + + for (size_t i1 = 0; i1 < N; ++i1) { + if (N >= 2) idx[1] = static_cast<TI>(i1); + for (size_t i2 = 0; i2 < N; ++i2) { + if (N >= 4) idx[2] = static_cast<TI>(i2); + for (size_t i3 = 0; i3 < N; ++i3) { + if (N >= 4) idx[3] = static_cast<TI>(i3); + + for (size_t i = 0; i < N; ++i) { + expected[i] = static_cast<T>(idx[i] + 1); // == v[idx[i]] + } + + const auto opaque1 = IndicesFromVec(d, Load(di, idx.get())); + const auto actual1 = TableLookupLanes(v, opaque1); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual1); + + const auto opaque2 = SetTableIndices(d, idx.get()); + const auto actual2 = TableLookupLanes(v, opaque2); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual2); + } + } + } + } + } else { + // Too many permutations to test exhaustively; choose one with repeated + // and cross-block indices and ensure indices do not exceed #lanes. + // For larger vectors, upper lanes will be zero. + HWY_ALIGN TI idx_source[16] = {1, 3, 2, 2, 8, 1, 7, 6, + 15, 14, 14, 15, 4, 9, 8, 5}; + for (size_t i = 0; i < N; ++i) { + idx[i] = (i < 16) ? idx_source[i] : 0; + // Avoid undefined results / asan error for scalar by capping indices. + if (idx[i] >= static_cast<TI>(N)) { + idx[i] = static_cast<TI>(N - 1); + } + expected[i] = static_cast<T>(idx[i] + 1); // == v[idx[i]] + } + + const auto opaque1 = IndicesFromVec(d, Load(di, idx.get())); + const auto actual1 = TableLookupLanes(v, opaque1); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual1); + + const auto opaque2 = SetTableIndices(d, idx.get()); + const auto actual2 = TableLookupLanes(v, opaque2); + HWY_ASSERT_VEC_EQ(d, expected.get(), actual2); + } +#else + const TI index = 0; + const auto v = Set(d, 1); + const auto opaque1 = SetTableIndices(d, &index); + HWY_ASSERT_VEC_EQ(d, v, TableLookupLanes(v, opaque1)); + const auto opaque2 = IndicesFromVec(d, Zero(di)); + HWY_ASSERT_VEC_EQ(d, v, TableLookupLanes(v, opaque2)); +#endif + } +}; + +HWY_NOINLINE void TestAllTableLookupLanes() { + ForUIF3264(ForPartialVectors<TestTableLookupLanes>()); +} + + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(HwySwizzleTest); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllGetLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllExtractLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllInsertLane); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllDupEven); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllDupOdd); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEven); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllOddEvenBlocks); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllSwapAdjacentBlocks); +HWY_EXPORT_AND_TEST_P(HwySwizzleTest, TestAllTableLookupLanes); +} // namespace hwy + +#endif diff --git a/third_party/highway/hwy/tests/test_util-inl.h b/third_party/highway/hwy/tests/test_util-inl.h new file mode 100644 index 0000000000..972b3361e0 --- /dev/null +++ b/third_party/highway/hwy/tests/test_util-inl.h @@ -0,0 +1,665 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-specific helper functions for use by *_test.cc. + +#include <stdint.h> + +#include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util.h" + +// After test_util (also includes highway.h) +#include "hwy/print-inl.h" + +// Per-target include guard +#if defined(HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#undef HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#else +#define HIGHWAY_HWY_TESTS_TEST_UTIL_INL_H_ +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Compare expected vector to vector. +// HWY_INLINE works around a Clang SVE compiler bug where all but the first +// 128 bits (the NEON register) of actual are zero. +template <class D, typename T = TFromD<D>, class V = Vec<D>> +HWY_INLINE void AssertVecEqual(D d, const T* expected, VecArg<V> actual, + const char* filename, const int line) { + const size_t N = Lanes(d); + auto actual_lanes = AllocateAligned<T>(N); + Store(actual, d, actual_lanes.get()); + + const auto info = hwy::detail::MakeTypeInfo<T>(); + const char* target_name = hwy::TargetName(HWY_TARGET); + hwy::detail::AssertArrayEqual(info, expected, actual_lanes.get(), N, + target_name, filename, line); +} + +// Compare expected lanes to vector. +// HWY_INLINE works around a Clang SVE compiler bug where all but the first +// 128 bits (the NEON register) of actual are zero. +template <class D, typename T = TFromD<D>, class V = Vec<D>> +HWY_INLINE void AssertVecEqual(D d, VecArg<V> expected, VecArg<V> actual, + const char* filename, int line) { + auto expected_lanes = AllocateAligned<T>(Lanes(d)); + Store(expected, d, expected_lanes.get()); + AssertVecEqual(d, expected_lanes.get(), actual, filename, line); +} + +// Only checks the valid mask elements (those whose index < Lanes(d)). +template <class D> +HWY_NOINLINE void AssertMaskEqual(D d, VecArg<Mask<D>> a, VecArg<Mask<D>> b, + const char* filename, int line) { + // lvalues prevented MSAN failure in farm_sve. + const Vec<D> va = VecFromMask(d, a); + const Vec<D> vb = VecFromMask(d, b); + AssertVecEqual(d, va, vb, filename, line); + + const char* target_name = hwy::TargetName(HWY_TARGET); + AssertEqual(CountTrue(d, a), CountTrue(d, b), target_name, filename, line); + AssertEqual(AllTrue(d, a), AllTrue(d, b), target_name, filename, line); + AssertEqual(AllFalse(d, a), AllFalse(d, b), target_name, filename, line); + + const size_t N = Lanes(d); +#if HWY_TARGET == HWY_SCALAR + const Rebind<uint8_t, D> d8; +#else + const Repartition<uint8_t, D> d8; +#endif + const size_t N8 = Lanes(d8); + auto bits_a = AllocateAligned<uint8_t>(HWY_MAX(size_t{8}, N8)); + auto bits_b = AllocateAligned<uint8_t>(size_t{HWY_MAX(8, N8)}); + memset(bits_a.get(), 0, N8); + memset(bits_b.get(), 0, N8); + const size_t num_bytes_a = StoreMaskBits(d, a, bits_a.get()); + const size_t num_bytes_b = StoreMaskBits(d, b, bits_b.get()); + AssertEqual(num_bytes_a, num_bytes_b, target_name, filename, line); + size_t i = 0; + // First check whole bytes (if that many elements are still valid) + for (; i < N / 8; ++i) { + if (bits_a[i] != bits_b[i]) { + fprintf(stderr, "Mismatch in byte %d: %d != %d\n", static_cast<int>(i), + bits_a[i], bits_b[i]); + Print(d8, "expect", Load(d8, bits_a.get()), 0, N8); + Print(d8, "actual", Load(d8, bits_b.get()), 0, N8); + hwy::Abort(filename, line, "Masks not equal"); + } + } + // Then the valid bit(s) in the last byte. + const size_t remainder = N % 8; + if (remainder != 0) { + const int mask = (1 << remainder) - 1; + const int valid_a = bits_a[i] & mask; + const int valid_b = bits_b[i] & mask; + if (valid_a != valid_b) { + fprintf(stderr, "Mismatch in last byte %d: %d != %d\n", + static_cast<int>(i), valid_a, valid_b); + Print(d8, "expect", Load(d8, bits_a.get()), 0, N8); + Print(d8, "actual", Load(d8, bits_b.get()), 0, N8); + hwy::Abort(filename, line, "Masks not equal"); + } + } +} + +// Only sets valid elements (those whose index < Lanes(d)). This helps catch +// tests that are not masking off the (undefined) upper mask elements. +// +// TODO(janwas): with HWY_NOINLINE GCC zeros the upper half of AVX2 masks. +template <class D> +HWY_INLINE Mask<D> MaskTrue(const D d) { + return FirstN(d, Lanes(d)); +} + +template <class D> +HWY_INLINE Mask<D> MaskFalse(const D d) { + const auto zero = Zero(RebindToSigned<D>()); + return RebindMask(d, Lt(zero, zero)); +} + +#ifndef HWY_ASSERT_EQ + +#define HWY_ASSERT_EQ(expected, actual) \ + hwy::AssertEqual(expected, actual, hwy::TargetName(HWY_TARGET), __FILE__, \ + __LINE__) + +#define HWY_ASSERT_ARRAY_EQ(expected, actual, count) \ + hwy::AssertArrayEqual(expected, actual, count, hwy::TargetName(HWY_TARGET), \ + __FILE__, __LINE__) + +#define HWY_ASSERT_STRING_EQ(expected, actual) \ + hwy::AssertStringEqual(expected, actual, hwy::TargetName(HWY_TARGET), \ + __FILE__, __LINE__) + +#define HWY_ASSERT_VEC_EQ(d, expected, actual) \ + AssertVecEqual(d, expected, actual, __FILE__, __LINE__) + +#define HWY_ASSERT_MASK_EQ(d, expected, actual) \ + AssertMaskEqual(d, expected, actual, __FILE__, __LINE__) + +#endif // HWY_ASSERT_EQ + +namespace detail { + +// Helpers for instantiating tests with combinations of lane types / counts. + +// Calls Test for each CappedTag<T, N> where N is in [kMinLanes, kMul * kMinArg] +// and the resulting Lanes() is in [min_lanes, max_lanes]. The upper bound +// is required to ensure capped vectors remain extendable. Implemented by +// recursively halving kMul until it is zero. +template <typename T, size_t kMul, size_t kMinArg, class Test> +struct ForeachCappedR { + static void Do(size_t min_lanes, size_t max_lanes) { + const CappedTag<T, kMul * kMinArg> d; + + // If we already don't have enough lanes, stop. + const size_t lanes = Lanes(d); + if (lanes < min_lanes) return; + + if (lanes <= max_lanes) { + Test()(T(), d); + } + ForeachCappedR<T, kMul / 2, kMinArg, Test>::Do(min_lanes, max_lanes); + } +}; + +// Base case to stop the recursion. +template <typename T, size_t kMinArg, class Test> +struct ForeachCappedR<T, 0, kMinArg, Test> { + static void Do(size_t, size_t) {} +}; + +#if HWY_HAVE_SCALABLE + +template <typename T> +constexpr int MinPow2() { + // Highway follows RVV LMUL in that the smallest fraction is 1/8th (encoded + // as kPow2 == -3). The fraction also must not result in zero lanes for the + // smallest possible vector size, which is 128 bits even on RISC-V (with the + // application processor profile). + return HWY_MAX(-3, -static_cast<int>(CeilLog2(16 / sizeof(T)))); +} + +// Iterates kPow2 upward through +3. +template <typename T, int kPow2, int kAddPow2, class Test> +struct ForeachShiftR { + static void Do(size_t min_lanes) { + const ScalableTag<T, kPow2 + kAddPow2> d; + + // Precondition: [kPow2, 3] + kAddPow2 is a valid fraction of the minimum + // vector size, so we always have enough lanes, except ForGEVectors. + if (Lanes(d) >= min_lanes) { + Test()(T(), d); + } else { + fprintf(stderr, "%d lanes < %d: T=%d pow=%d\n", + static_cast<int>(Lanes(d)), static_cast<int>(min_lanes), + static_cast<int>(sizeof(T)), kPow2 + kAddPow2); + HWY_ASSERT(min_lanes != 1); + } + + ForeachShiftR<T, kPow2 + 1, kAddPow2, Test>::Do(min_lanes); + } +}; + +// Base case to stop the recursion. +template <typename T, int kAddPow2, class Test> +struct ForeachShiftR<T, 4, kAddPow2, Test> { + static void Do(size_t) {} +}; +#else +// ForeachCappedR already handled all possible sizes. +#endif // HWY_HAVE_SCALABLE + +} // namespace detail + +// These 'adapters' call a test for all possible N or kPow2 subject to +// constraints such as "vectors must be extendable" or "vectors >= 128 bits". +// They may be called directly, or via For*Types. Note that for an adapter C, +// `C<Test>(T())` does not call the test - the correct invocation is +// `C<Test>()(T())`, or preferably `ForAllTypes(C<Test>())`. We check at runtime +// that operator() is called to prevent such bugs. Note that this is not +// thread-safe, but that is fine because C are typically local variables. + +// Calls Test for all power of two N in [1, Lanes(d) >> kPow2]. This is for +// ops that widen their input, e.g. Combine (not supported by HWY_SCALAR). +template <class Test, int kPow2 = 1> +class ForExtendableVectors { + mutable bool called_ = false; + + public: + ~ForExtendableVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMaxCapped = HWY_LANES(T); + // Skip CappedTag that are already full vectors. + const size_t max_lanes = Lanes(ScalableTag<T>()) >> kPow2; + (void)kMaxCapped; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + // not supported +#else + detail::ForeachCappedR<T, (kMaxCapped >> kPow2), 1, Test>::Do(1, max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2, 3 - kPow2]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2, -kPow2, Test>::Do(1); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2, 0 - kPow2]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2 + 3, -kPow2 - 3, + Test>::Do(1); +#endif +#endif // HWY_SCALAR + } +}; + +// Calls Test for all power of two N in [1 << kPow2, Lanes(d)]. This is for ops +// that narrow their input, e.g. UpperHalf. +template <class Test, int kPow2 = 1> +class ForShrinkableVectors { + mutable bool called_ = false; + + public: + ~ForShrinkableVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + // not supported +#else + detail::ForeachCappedR<T, (kMaxCapped >> kPow2), kMinLanes, Test>::Do( + kMinLanes, max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2, 0, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2 + 3, -3, Test>::Do( + kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// Calls Test for all supported power of two vectors of at least kMinBits. +// Examples: AES or 64x64 require 128 bits, casts may require 64 bits. +template <size_t kMinBits, class Test> +class ForGEVectors { + mutable bool called_ = false; + + public: + ~ForGEVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMaxCapped = HWY_LANES(T); + constexpr size_t kMinLanes = kMinBits / 8 / sizeof(T); + // An upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + (void)kMinLanes; // not supported +#else + detail::ForeachCappedR<T, HWY_LANES(T) / kMinLanes, kMinLanes, Test>::Do( + kMinLanes, max_lanes); +#if HWY_TARGET == HWY_RVV + // Can be 0 (handled below) if kMinBits > 64. + constexpr size_t kRatio = 128 / kMinBits; + constexpr int kMinPow2 = + kRatio == 0 ? 0 : -static_cast<int>(CeilLog2(kRatio)); + // For each [kMinPow2, 3]; counter is [kMinPow2, 3]. + detail::ForeachShiftR<T, kMinPow2, 0, Test>::Do(kMinLanes); +#elif HWY_HAVE_SCALABLE + // Can be 0 (handled below) if kMinBits > 128. + constexpr size_t kRatio = 128 / kMinBits; + constexpr int kMinPow2 = + kRatio == 0 ? 0 : -static_cast<int>(CeilLog2(kRatio)); + // For each [kMinPow2, 0]; counter is [kMinPow2 + 3, 3]. + detail::ForeachShiftR<T, kMinPow2 + 3, -3, Test>::Do(kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +template <class Test> +using ForGE128Vectors = ForGEVectors<128, Test>; + +// Calls Test for all N that can be promoted (not the same as Extendable because +// HWY_SCALAR has one lane). Also used for ZipLower, but not ZipUpper. +template <class Test, int kPow2 = 1> +class ForPromoteVectors { + mutable bool called_ = false; + + public: + ~ForPromoteVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kFactor = size_t{1} << kPow2; + static_assert(kFactor >= 2 && kFactor * sizeof(T) <= sizeof(uint64_t), ""); + constexpr size_t kMaxCapped = HWY_LANES(T); + constexpr size_t kMinLanes = kFactor; + // Skip CappedTag that are already full vectors. + const size_t max_lanes = Lanes(ScalableTag<T>()) >> kPow2; + (void)kMaxCapped; + (void)kMinLanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + detail::ForeachCappedR<T, 1, 1, Test>::Do(1, 1); +#else + // TODO(janwas): call Extendable if kMinLanes check not required? + detail::ForeachCappedR<T, (kMaxCapped >> kPow2), 1, Test>::Do(kMinLanes, + max_lanes); +#if HWY_TARGET == HWY_RVV + // For each [MinPow2, 3 - kPow2]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2, -kPow2, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2, 0 - kPow2]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2 + 3, -kPow2 - 3, + Test>::Do(kMinLanes); +#endif +#endif // HWY_SCALAR + } +}; + +// Calls Test for all N than can be demoted (not the same as Shrinkable because +// HWY_SCALAR has one lane). +template <class Test, int kPow2 = 1> +class ForDemoteVectors { + mutable bool called_ = false; + + public: + ~ForDemoteVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T /*unused*/) const { + called_ = true; + constexpr size_t kMinLanes = size_t{1} << kPow2; + constexpr size_t kMaxCapped = HWY_LANES(T); + // For shrinking, an upper limit is unnecessary. + constexpr size_t max_lanes = kMaxCapped; + + (void)kMinLanes; + (void)max_lanes; + (void)max_lanes; +#if HWY_TARGET == HWY_SCALAR + detail::ForeachCappedR<T, 1, 1, Test>::Do(1, 1); +#else + detail::ForeachCappedR<T, (kMaxCapped >> kPow2), kMinLanes, Test>::Do( + kMinLanes, max_lanes); + +// TODO(janwas): call Extendable if kMinLanes check not required? +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2, 0, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2 + 3, -3, Test>::Do( + kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// For LowerHalf/Quarter. +template <class Test, int kPow2 = 1> +class ForHalfVectors { + mutable bool called_ = false; + + public: + ~ForHalfVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T /*unused*/) const { + called_ = true; +#if HWY_TARGET == HWY_SCALAR + detail::ForeachCappedR<T, 1, 1, Test>::Do(1, 1); +#else + constexpr size_t kMinLanes = size_t{1} << kPow2; + // For shrinking, an upper limit is unnecessary. + constexpr size_t kMaxCapped = HWY_LANES(T); + detail::ForeachCappedR<T, (kMaxCapped >> kPow2), kMinLanes, Test>::Do( + kMinLanes, kMaxCapped); + +// TODO(janwas): call Extendable if kMinLanes check not required? +#if HWY_TARGET == HWY_RVV + // For each [MinPow2 + kPow2, 3]; counter is [MinPow2 + kPow2, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2, 0, Test>::Do( + kMinLanes); +#elif HWY_HAVE_SCALABLE + // For each [MinPow2 + kPow2, 0]; counter is [MinPow2 + kPow2 + 3, 3]. + detail::ForeachShiftR<T, detail::MinPow2<T>() + kPow2 + 3, -3, Test>::Do( + kMinLanes); +#endif +#endif // HWY_TARGET == HWY_SCALAR + } +}; + +// Calls Test for all power of two N in [1, Lanes(d)]. This is the default +// for ops that do not narrow nor widen their input, nor require 128 bits. +template <class Test> +class ForPartialVectors { + mutable bool called_ = false; + + public: + ~ForPartialVectors() { + if (!called_) { + HWY_ABORT("Test is incorrect, ensure operator() is called"); + } + } + + template <typename T> + void operator()(T t) const { + called_ = true; +#if HWY_TARGET == HWY_SCALAR + (void)t; + detail::ForeachCappedR<T, 1, 1, Test>::Do(1, 1); +#else + ForExtendableVectors<Test, 0>()(t); +#endif + } +}; + +// Type lists to shorten call sites: + +template <class Func> +void ForSignedTypes(const Func& func) { + func(int8_t()); + func(int16_t()); + func(int32_t()); +#if HWY_HAVE_INTEGER64 + func(int64_t()); +#endif +} + +template <class Func> +void ForUnsignedTypes(const Func& func) { + func(uint8_t()); + func(uint16_t()); + func(uint32_t()); +#if HWY_HAVE_INTEGER64 + func(uint64_t()); +#endif +} + +template <class Func> +void ForIntegerTypes(const Func& func) { + ForSignedTypes(func); + ForUnsignedTypes(func); +} + +template <class Func> +void ForFloatTypes(const Func& func) { + func(float()); +#if HWY_HAVE_FLOAT64 + func(double()); +#endif +} + +template <class Func> +void ForAllTypes(const Func& func) { + ForIntegerTypes(func); + ForFloatTypes(func); +} + +template <class Func> +void ForUI8(const Func& func) { + func(uint8_t()); + func(int8_t()); +} + +template <class Func> +void ForUI16(const Func& func) { + func(uint16_t()); + func(int16_t()); +} + +template <class Func> +void ForUIF16(const Func& func) { + ForUI16(func); +#if HWY_HAVE_FLOAT16 + func(float16_t()); +#endif +} + +template <class Func> +void ForUI32(const Func& func) { + func(uint32_t()); + func(int32_t()); +} + +template <class Func> +void ForUIF32(const Func& func) { + ForUI32(func); + func(float()); +} + +template <class Func> +void ForUI64(const Func& func) { +#if HWY_HAVE_INTEGER64 + func(uint64_t()); + func(int64_t()); +#endif +} + +template <class Func> +void ForUIF64(const Func& func) { + ForUI64(func); +#if HWY_HAVE_FLOAT64 + func(double()); +#endif +} + +template <class Func> +void ForUI3264(const Func& func) { + ForUI32(func); + ForUI64(func); +} + +template <class Func> +void ForUIF3264(const Func& func) { + ForUIF32(func); + ForUIF64(func); +} + +template <class Func> +void ForUI163264(const Func& func) { + ForUI16(func); + ForUI3264(func); +} + +template <class Func> +void ForUIF163264(const Func& func) { + ForUIF16(func); + ForUIF3264(func); +} + +// For tests that involve loops, adjust the trip count so that emulated tests +// finish quickly (but always at least 2 iterations to ensure some diversity). +constexpr size_t AdjustedReps(size_t max_reps) { +#if HWY_ARCH_RVV + return HWY_MAX(max_reps / 32, 2); +#elif HWY_IS_DEBUG_BUILD + return HWY_MAX(max_reps / 8, 2); +#elif HWY_ARCH_ARM + return HWY_MAX(max_reps / 4, 2); +#else + return HWY_MAX(max_reps, 2); +#endif +} + +// Same as above, but the loop trip count will be 1 << max_pow2. +constexpr size_t AdjustedLog2Reps(size_t max_pow2) { + // If "negative" (unsigned wraparound), use original. +#if HWY_ARCH_RVV + return HWY_MIN(max_pow2 - 4, max_pow2); +#elif HWY_IS_DEBUG_BUILD + return HWY_MIN(max_pow2 - 1, max_pow2); +#elif HWY_ARCH_ARM + return HWY_MIN(max_pow2 - 1, max_pow2); +#else + return max_pow2; +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/third_party/highway/hwy/tests/test_util.cc b/third_party/highway/hwy/tests/test_util.cc new file mode 100644 index 0000000000..a0796b15f9 --- /dev/null +++ b/third_party/highway/hwy/tests/test_util.cc @@ -0,0 +1,117 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/tests/test_util.h" + +#include <stddef.h> +#include <stdio.h> + +#include <cmath> + +#include "hwy/base.h" +#include "hwy/print.h" + +namespace hwy { + +HWY_TEST_DLLEXPORT bool BytesEqual(const void* p1, const void* p2, + const size_t size, size_t* pos) { + const uint8_t* bytes1 = reinterpret_cast<const uint8_t*>(p1); + const uint8_t* bytes2 = reinterpret_cast<const uint8_t*>(p2); + for (size_t i = 0; i < size; ++i) { + if (bytes1[i] != bytes2[i]) { + if (pos != nullptr) { + *pos = i; + } + return false; + } + } + return true; +} + +void AssertStringEqual(const char* expected, const char* actual, + const char* target_name, const char* filename, + int line) { + while (*expected == *actual++) { + if (*expected++ == '\0') return; + } + + Abort(filename, line, "%s string mismatch: expected '%s', got '%s'.\n", + target_name, expected, actual); +} + +namespace detail { + +HWY_TEST_DLLEXPORT bool IsEqual(const TypeInfo& info, const void* expected_ptr, + const void* actual_ptr) { + if (!info.is_float) { + return BytesEqual(expected_ptr, actual_ptr, info.sizeof_t); + } + + if (info.sizeof_t == 4) { + float expected, actual; + CopyBytes<4>(expected_ptr, &expected); + CopyBytes<4>(actual_ptr, &actual); + return ComputeUlpDelta(expected, actual) <= 1; + } else if (info.sizeof_t == 8) { + double expected, actual; + CopyBytes<8>(expected_ptr, &expected); + CopyBytes<8>(actual_ptr, &actual); + return ComputeUlpDelta(expected, actual) <= 1; + } else { + HWY_ABORT("Unexpected float size %d\n", static_cast<int>(info.sizeof_t)); + return false; + } +} + +HWY_TEST_DLLEXPORT HWY_NORETURN void PrintMismatchAndAbort( + const TypeInfo& info, const void* expected_ptr, const void* actual_ptr, + const char* target_name, const char* filename, int line, size_t lane, + size_t num_lanes) { + char type_name[100]; + TypeName(info, 1, type_name); + char expected_str[100]; + ToString(info, expected_ptr, expected_str); + char actual_str[100]; + ToString(info, actual_ptr, actual_str); + Abort(filename, line, + "%s, %sx%d lane %d mismatch: expected '%s', got '%s'.\n", target_name, + type_name, static_cast<int>(num_lanes), static_cast<int>(lane), + expected_str, actual_str); +} + +HWY_TEST_DLLEXPORT void AssertArrayEqual(const TypeInfo& info, + const void* expected_void, + const void* actual_void, size_t N, + const char* target_name, + const char* filename, int line) { + const uint8_t* expected_array = + reinterpret_cast<const uint8_t*>(expected_void); + const uint8_t* actual_array = reinterpret_cast<const uint8_t*>(actual_void); + for (size_t i = 0; i < N; ++i) { + const void* expected_ptr = expected_array + i * info.sizeof_t; + const void* actual_ptr = actual_array + i * info.sizeof_t; + if (!IsEqual(info, expected_ptr, actual_ptr)) { + fprintf(stderr, "\n\n"); + PrintArray(info, "expect", expected_array, N, i); + PrintArray(info, "actual", actual_array, N, i); + + PrintMismatchAndAbort(info, expected_ptr, actual_ptr, target_name, + filename, line, i, N); + } + } +} + +} // namespace detail +} // namespace hwy diff --git a/third_party/highway/hwy/tests/test_util.h b/third_party/highway/hwy/tests/test_util.h new file mode 100644 index 0000000000..558d1bcfba --- /dev/null +++ b/third_party/highway/hwy/tests/test_util.h @@ -0,0 +1,173 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_TESTS_TEST_UTIL_H_ +#define HWY_TESTS_TEST_UTIL_H_ + +// Target-independent helper functions for use by *_test.cc. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <cmath> // std::isnan +#include <string> + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/highway.h" +#include "hwy/highway_export.h" +#include "hwy/print.h" + +namespace hwy { + +// The maximum vector size used in tests when defining test data. DEPRECATED. +constexpr size_t kTestMaxVectorSize = 64; + +// 64-bit random generator (Xorshift128+). Much smaller state than std::mt19937, +// which triggers a compiler bug. +class RandomState { + public: + explicit RandomState(const uint64_t seed = 0x123456789ull) { + s0_ = SplitMix64(seed + 0x9E3779B97F4A7C15ull); + s1_ = SplitMix64(s0_); + } + + HWY_INLINE uint64_t operator()() { + uint64_t s1 = s0_; + const uint64_t s0 = s1_; + const uint64_t bits = s1 + s0; + s0_ = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s1_ = s1; + return bits; + } + + private: + static uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + uint64_t s0_; + uint64_t s1_; +}; + +static HWY_INLINE uint32_t Random32(RandomState* rng) { + return static_cast<uint32_t>((*rng)()); +} + +static HWY_INLINE uint64_t Random64(RandomState* rng) { return (*rng)(); } + +// Prevents the compiler from eliding the computations that led to "output". +// Works by indicating to the compiler that "output" is being read and modified. +// The +r constraint avoids unnecessary writes to memory, but only works for +// built-in types. +template <class T> +inline void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC + (void)output; +#else // HWY_COMPILER_MSVC + asm volatile("" : "+r"(output) : : "memory"); +#endif // HWY_COMPILER_MSVC +} + +HWY_TEST_DLLEXPORT bool BytesEqual(const void* p1, const void* p2, + const size_t size, size_t* pos = nullptr); + +void AssertStringEqual(const char* expected, const char* actual, + const char* target_name, const char* filename, int line); + +namespace detail { + +template <typename T, typename TU = MakeUnsigned<T>> +TU ComputeUlpDelta(const T expected, const T actual) { + // Handle -0 == 0 and infinities. + if (expected == actual) return 0; + + // Consider "equal" if both are NaN, so we can verify an expected NaN. + // Needs a special case because there are many possible NaN representations. + if (std::isnan(expected) && std::isnan(actual)) return 0; + + // Compute the difference in units of last place. We do not need to check for + // differing signs; they will result in large differences, which is fine. + TU ux, uy; + CopySameSize(&expected, &ux); + CopySameSize(&actual, &uy); + + // Avoid unsigned->signed cast: 2's complement is only guaranteed by C++20. + const TU ulp = HWY_MAX(ux, uy) - HWY_MIN(ux, uy); + return ulp; +} + +HWY_TEST_DLLEXPORT bool IsEqual(const TypeInfo& info, const void* expected_ptr, + const void* actual_ptr); + +HWY_TEST_DLLEXPORT HWY_NORETURN void PrintMismatchAndAbort( + const TypeInfo& info, const void* expected_ptr, const void* actual_ptr, + const char* target_name, const char* filename, int line, size_t lane = 0, + size_t num_lanes = 1); + +HWY_TEST_DLLEXPORT void AssertArrayEqual(const TypeInfo& info, + const void* expected_void, + const void* actual_void, size_t N, + const char* target_name, + const char* filename, int line); + +} // namespace detail + +// Returns a name for the vector/part/scalar. The type prefix is u/i/f for +// unsigned/signed/floating point, followed by the number of bits per lane; +// then 'x' followed by the number of lanes. Example: u8x16. This is useful for +// understanding which instantiation of a generic test failed. +template <typename T> +std::string TypeName(T /*unused*/, size_t N) { + char string100[100]; + detail::TypeName(detail::MakeTypeInfo<T>(), N, string100); + return string100; +} + +// Compare non-vector, non-string T. +template <typename T> +HWY_INLINE bool IsEqual(const T expected, const T actual) { + const auto info = detail::MakeTypeInfo<T>(); + return detail::IsEqual(info, &expected, &actual); +} + +template <typename T> +HWY_INLINE void AssertEqual(const T expected, const T actual, + const char* target_name, const char* filename, + int line, size_t lane = 0) { + const auto info = detail::MakeTypeInfo<T>(); + if (!detail::IsEqual(info, &expected, &actual)) { + detail::PrintMismatchAndAbort(info, &expected, &actual, target_name, + filename, line, lane); + } +} + +template <typename T> +HWY_INLINE void AssertArrayEqual(const T* expected, const T* actual, + size_t count, const char* target_name, + const char* filename, int line) { + const auto info = hwy::detail::MakeTypeInfo<T>(); + detail::AssertArrayEqual(info, expected, actual, count, target_name, filename, + line); +} + +} // namespace hwy + +#endif // HWY_TESTS_TEST_UTIL_H_ diff --git a/third_party/highway/hwy/tests/test_util_test.cc b/third_party/highway/hwy/tests/test_util_test.cc new file mode 100644 index 0000000000..1911467c34 --- /dev/null +++ b/third_party/highway/hwy/tests/test_util_test.cc @@ -0,0 +1,107 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stddef.h> +#include <stdint.h> + +#include <string> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "tests/test_util_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct TestName { + template <class T, class D> + HWY_NOINLINE void operator()(T t, D d) { + char num[10]; + std::string expected = IsFloat<T>() ? "f" : (IsSigned<T>() ? "i" : "u"); + snprintf(num, sizeof(num), "%u" , static_cast<unsigned>(sizeof(T) * 8)); + expected += num; + + const size_t N = Lanes(d); + if (N != 1) { + expected += 'x'; + snprintf(num, sizeof(num), "%u", static_cast<unsigned>(N)); + expected += num; + } + const std::string actual = TypeName(t, N); + if (expected != actual) { + HWY_ABORT("%s mismatch: expected '%s', got '%s'.\n", + hwy::TargetName(HWY_TARGET), expected.c_str(), actual.c_str()); + } + } +}; + +HWY_NOINLINE void TestAllName() { ForAllTypes(ForPartialVectors<TestName>()); } + +struct TestEqualInteger { + template <class T> + HWY_NOINLINE void operator()(T /*t*/) const { + HWY_ASSERT_EQ(T(0), T(0)); + HWY_ASSERT_EQ(T(1), T(1)); + HWY_ASSERT_EQ(T(-1), T(-1)); + HWY_ASSERT_EQ(LimitsMin<T>(), LimitsMin<T>()); + + HWY_ASSERT(!IsEqual(T(0), T(1))); + HWY_ASSERT(!IsEqual(T(1), T(0))); + HWY_ASSERT(!IsEqual(T(1), T(-1))); + HWY_ASSERT(!IsEqual(T(-1), T(1))); + HWY_ASSERT(!IsEqual(LimitsMin<T>(), LimitsMax<T>())); + HWY_ASSERT(!IsEqual(LimitsMax<T>(), LimitsMin<T>())); + } +}; + +struct TestEqualFloat { + template <class T> + HWY_NOINLINE void operator()(T /*t*/) const { + HWY_ASSERT(IsEqual(T(0), T(0))); + HWY_ASSERT(IsEqual(T(1), T(1))); + HWY_ASSERT(IsEqual(T(-1), T(-1))); + HWY_ASSERT(IsEqual(MantissaEnd<T>(), MantissaEnd<T>())); + + HWY_ASSERT(!IsEqual(T(0), T(1))); + HWY_ASSERT(!IsEqual(T(1), T(0))); + HWY_ASSERT(!IsEqual(T(1), T(-1))); + HWY_ASSERT(!IsEqual(T(-1), T(1))); + HWY_ASSERT(!IsEqual(LowestValue<T>(), HighestValue<T>())); + HWY_ASSERT(!IsEqual(HighestValue<T>(), LowestValue<T>())); + } +}; + +HWY_NOINLINE void TestAllEqual() { + ForIntegerTypes(TestEqualInteger()); + ForFloatTypes(TestEqualFloat()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(TestUtilTest); +HWY_EXPORT_AND_TEST_P(TestUtilTest, TestAllName); +HWY_EXPORT_AND_TEST_P(TestUtilTest, TestAllEqual); +} // namespace hwy + +#endif |