diff options
Diffstat (limited to 'third_party/highway/hwy/contrib/sort/algo-inl.h')
-rw-r--r-- | third_party/highway/hwy/contrib/sort/algo-inl.h | 513 |
1 files changed, 513 insertions, 0 deletions
diff --git a/third_party/highway/hwy/contrib/sort/algo-inl.h b/third_party/highway/hwy/contrib/sort/algo-inl.h new file mode 100644 index 0000000000..1ebbbd5745 --- /dev/null +++ b/third_party/highway/hwy/contrib/sort/algo-inl.h @@ -0,0 +1,513 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +#include <stdint.h> +#include <string.h> // memcpy + +#include <algorithm> // std::sort, std::min, std::max +#include <functional> // std::less, std::greater +#include <thread> // NOLINT +#include <vector> + +#include "hwy/base.h" +#include "hwy/contrib/sort/vqsort.h" + +// Third-party algorithms +#define HAVE_AVX2SORT 0 +#define HAVE_IPS4O 0 +// When enabling, consider changing max_threads (required for Table 1a) +#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) +#define HAVE_PDQSORT 0 +#define HAVE_SORT512 0 +#define HAVE_VXSORT 0 + +#if HAVE_AVX2SORT +HWY_PUSH_ATTRIBUTES("avx2,avx") +#include "avx2sort.h" //NOLINT +HWY_POP_ATTRIBUTES +#endif +#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O +#include "third_party/ips4o/include/ips4o.hpp" +#include "third_party/ips4o/include/ips4o/thread_pool.hpp" +#endif +#if HAVE_PDQSORT +#include "third_party/boost/allowed/sort/sort.hpp" +#endif +#if HAVE_SORT512 +#include "sort512.h" //NOLINT +#endif + +// vxsort is difficult to compile for multiple targets because it also uses +// .cpp files, and we'd also have to #undef its include guards. Instead, compile +// only for AVX2 or AVX3 depending on this macro. +#define VXSORT_AVX3 1 +#if HAVE_VXSORT +// inlined from vxsort_targets_enable_avx512 (must close before end of header) +#ifdef __GNUC__ +#ifdef __clang__ +#if VXSORT_AVX3 +#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \ + apply_to = any(function)) +#else +#pragma clang attribute push(__attribute__((target("avx2"))), \ + apply_to = any(function)) +#endif // VXSORT_AVX3 + +#else +#pragma GCC push_options +#if VXSORT_AVX3 +#pragma GCC target("avx512f,avx512dq") +#else +#pragma GCC target("avx2") +#endif // VXSORT_AVX3 +#endif +#endif + +#if VXSORT_AVX3 +#include "vxsort/machine_traits.avx512.h" +#else +#include "vxsort/machine_traits.avx2.h" +#endif // VXSORT_AVX3 +#include "vxsort/vxsort.h" +#ifdef __GNUC__ +#ifdef __clang__ +#pragma clang attribute pop +#else +#pragma GCC pop_options +#endif +#endif +#endif // HAVE_VXSORT + +namespace hwy { + +enum class Dist { kUniform8, kUniform16, kUniform32 }; + +static inline std::vector<Dist> AllDist() { + return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32}; +} + +static inline const char* DistName(Dist dist) { + switch (dist) { + case Dist::kUniform8: + return "uniform8"; + case Dist::kUniform16: + return "uniform16"; + case Dist::kUniform32: + return "uniform32"; + } + return "unreachable"; +} + +template <typename T> +class InputStats { + public: + void Notify(T value) { + min_ = std::min(min_, value); + max_ = std::max(max_, value); + // Converting to integer would truncate floats, multiplying to save digits + // risks overflow especially when casting, so instead take the sum of the + // bit representations as the checksum. + uint64_t bits = 0; + static_assert(sizeof(T) <= 8, "Expected a built-in type"); + CopyBytes<sizeof(T)>(&value, &bits); // not same size + sum_ += bits; + count_ += 1; + } + + bool operator==(const InputStats& other) const { + if (count_ != other.count_) { + HWY_ABORT("count %d vs %d\n", static_cast<int>(count_), + static_cast<int>(other.count_)); + } + + if (min_ != other.min_ || max_ != other.max_) { + HWY_ABORT("minmax %f/%f vs %f/%f\n", static_cast<double>(min_), + static_cast<double>(max_), static_cast<double>(other.min_), + static_cast<double>(other.max_)); + } + + // Sum helps detect duplicated/lost values + if (sum_ != other.sum_) { + HWY_ABORT("Sum mismatch %g %g; min %g max %g\n", + static_cast<double>(sum_), static_cast<double>(other.sum_), + static_cast<double>(min_), static_cast<double>(max_)); + } + + return true; + } + + private: + T min_ = hwy::HighestValue<T>(); + T max_ = hwy::LowestValue<T>(); + uint64_t sum_ = 0; + size_t count_ = 0; +}; + +enum class Algo { +#if HAVE_AVX2SORT + kSEA, +#endif +#if HAVE_IPS4O + kIPS4O, +#endif +#if HAVE_PARALLEL_IPS4O + kParallelIPS4O, +#endif +#if HAVE_PDQSORT + kPDQ, +#endif +#if HAVE_SORT512 + kSort512, +#endif +#if HAVE_VXSORT + kVXSort, +#endif + kStd, + kVQSort, + kHeap, +}; + +static inline const char* AlgoName(Algo algo) { + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return "sea"; +#endif +#if HAVE_IPS4O + case Algo::kIPS4O: + return "ips4o"; +#endif +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + return "par_ips4o"; +#endif +#if HAVE_PDQSORT + case Algo::kPDQ: + return "pdq"; +#endif +#if HAVE_SORT512 + case Algo::kSort512: + return "sort512"; +#endif +#if HAVE_VXSORT + case Algo::kVXSort: + return "vxsort"; +#endif + case Algo::kStd: + return "std"; + case Algo::kVQSort: + return "vq"; + case Algo::kHeap: + return "heap"; + } + return "unreachable"; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#endif + +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +class Xorshift128Plus { + static HWY_INLINE uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + public: + // Generates two vectors of 64-bit seeds via SplitMix64 and stores into + // `seeds`. Generating these afresh in each ChoosePivot is too expensive. + template <class DU64> + static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) { + seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < 2 * Lanes(du64); ++i) { + seeds[i] = SplitMix64(seeds[i - 1]); + } + } + + // Need to pass in the state because vector cannot be class members. + template <class VU64> + static VU64 RandomBits(VU64& state0, VU64& state1) { + VU64 s1 = state0; + VU64 s0 = state1; + const VU64 bits = Add(s1, s0); + state0 = s0; + s1 = Xor(s1, ShiftLeft<23>(s1)); + state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + return bits; + } +}; + +template <class D, class VU64, HWY_IF_NOT_FLOAT_D(D)> +Vec<D> RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) { + const VU64 bits = Xorshift128Plus::RandomBits(s0, s1); + return BitCast(d, And(bits, mask)); +} + +// It is important to avoid denormals, which are flushed to zero by SIMD but not +// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. +template <class DF, class VU64, HWY_IF_FLOAT_D(DF)> +Vec<DF> RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) { + using TF = TFromD<DF>; + const RebindToUnsigned<decltype(df)> du; + using VU = Vec<decltype(du)>; + + const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask); + +#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types + using TU = MakeUnsigned<TF>; + const VU bits = Set(du, static_cast<TU>(GetLane(bits64) & LimitsMax<TU>())); +#else + const VU bits = BitCast(du, bits64); +#endif + // Avoid NaN/denormal by only generating values in [1, 2), i.e. random + // mantissas with the exponent taken from the representation of 1.0. + const VU k1 = BitCast(du, Set(df, TF{1.0})); + const VU mantissa_mask = Set(du, MantissaMask<TF>()); + const VU representation = OrAnd(k1, bits, mantissa_mask); + return BitCast(df, representation); +} + +template <class DU64> +Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { + switch (sizeof_t) { + case 2: + return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull + : 0xFFFFFFFFFFFFFFFFull); + case 4: + return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull + : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull + : 0xFFFFFFFFFFFFFFFFull); + case 8: + return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull + : (dist == Dist::kUniform16) ? 0x000000000000FFFFull + : 0x00000000FFFFFFFFull); + default: + HWY_ABORT("Logic error"); + return Zero(du64); + } +} + +template <typename T> +InputStats<T> GenerateInput(const Dist dist, T* v, size_t num) { + SortTag<uint64_t> du64; + using VU64 = Vec<decltype(du64)>; + const size_t N64 = Lanes(du64); + auto seeds = hwy::AllocateAligned<uint64_t>(2 * N64); + Xorshift128Plus::GenerateSeeds(du64, seeds.get()); + VU64 s0 = Load(du64, seeds.get()); + VU64 s1 = Load(du64, seeds.get() + N64); + +#if HWY_TARGET == HWY_SCALAR + const Sisd<T> d; +#else + const Repartition<T, decltype(du64)> d; +#endif + using V = Vec<decltype(d)>; + const size_t N = Lanes(d); + const VU64 mask = MaskForDist(du64, dist, sizeof(T)); + auto buf = hwy::AllocateAligned<T>(N); + + size_t i = 0; + for (; i + N <= num; i += N) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, v + i); + } + if (i < num) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, buf.get()); + memcpy(v + i, buf.get(), (num - i) * sizeof(T)); + } + + InputStats<T> input_stats; + for (size_t i = 0; i < num; ++i) { + input_stats.Notify(v[i]); + } + return input_stats; +} + +struct ThreadLocal { + Sorter sorter; +}; + +struct SharedState { +#if HAVE_PARALLEL_IPS4O + const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a + ips4o::StdThreadPool pool{static_cast<int>( + HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))}; +#endif + std::vector<ThreadLocal> tls{1}; +}; + +// Bridge from keys (passed to Run) to lanes as expected by HeapSort. For +// non-128-bit keys they are the same: +template <class Order, typename KeyType, HWY_IF_NOT_LANE_SIZE(KeyType, 16)> +void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) { + using detail::TraitsLane; + using detail::SharedTraits; + if (Order().IsAscending()) { + const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st; + return detail::HeapSort(st, keys, num_keys); + } else { + const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st; + return detail::HeapSort(st, keys, num_keys); + } +} + +#if VQSORT_ENABLED +template <class Order> +void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) { + using detail::SharedTraits; + using detail::Traits128; + uint64_t* lanes = reinterpret_cast<uint64_t*>(keys); + const size_t num_lanes = num_keys * 2; + if (Order().IsAscending()) { + const SharedTraits<Traits128<detail::OrderAscending128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } else { + const SharedTraits<Traits128<detail::OrderDescending128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } +} + +template <class Order> +void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) { + using detail::SharedTraits; + using detail::Traits128; + uint64_t* lanes = reinterpret_cast<uint64_t*>(keys); + const size_t num_lanes = num_keys * 2; + if (Order().IsAscending()) { + const SharedTraits<Traits128<detail::OrderAscendingKV128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } else { + const SharedTraits<Traits128<detail::OrderDescendingKV128>> st; + return detail::HeapSort(st, lanes, num_lanes); + } +} +#endif // VQSORT_ENABLED + +template <class Order, typename KeyType> +void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num, + SharedState& shared, size_t thread) { + const std::less<KeyType> less; + const std::greater<KeyType> greater; + + switch (algo) { +#if HAVE_AVX2SORT + case Algo::kSEA: + return avx2::quicksort(inout, static_cast<int>(num)); +#endif + +#if HAVE_IPS4O + case Algo::kIPS4O: + if (Order().IsAscending()) { + return ips4o::sort(inout, inout + num, less); + } else { + return ips4o::sort(inout, inout + num, greater); + } +#endif + +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + if (Order().IsAscending()) { + return ips4o::parallel::sort(inout, inout + num, less, shared.pool); + } else { + return ips4o::parallel::sort(inout, inout + num, greater, shared.pool); + } +#endif + +#if HAVE_SORT512 + case Algo::kSort512: + HWY_ABORT("not supported"); + // return Sort512::Sort(inout, num); +#endif + +#if HAVE_PDQSORT + case Algo::kPDQ: + if (Order().IsAscending()) { + return boost::sort::pdqsort_branchless(inout, inout + num, less); + } else { + return boost::sort::pdqsort_branchless(inout, inout + num, greater); + } +#endif + +#if HAVE_VXSORT + case Algo::kVXSort: { +#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2) + fprintf(stderr, "Do not call for target %s\n", + hwy::TargetName(HWY_TARGET)); + return; +#else +#if VXSORT_AVX3 + vxsort::vxsort<KeyType, vxsort::AVX512> vx; +#else + vxsort::vxsort<KeyType, vxsort::AVX2> vx; +#endif + if (Order().IsAscending()) { + return vx.sort(inout, inout + num - 1); + } else { + fprintf(stderr, "Skipping VX - does not support descending order\n"); + return; + } +#endif // enabled for this target + } +#endif // HAVE_VXSORT + + case Algo::kStd: + if (Order().IsAscending()) { + return std::sort(inout, inout + num, less); + } else { + return std::sort(inout, inout + num, greater); + } + + case Algo::kVQSort: + return shared.tls[thread].sorter(inout, num, Order()); + + case Algo::kHeap: + return CallHeapSort<Order>(inout, num); + + default: + HWY_ABORT("Not implemented"); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE |