diff options
Diffstat (limited to 'third_party/highway/hwy/contrib/dot')
-rw-r--r-- | third_party/highway/hwy/contrib/dot/dot-inl.h | 252 | ||||
-rw-r--r-- | third_party/highway/hwy/contrib/dot/dot_test.cc | 167 |
2 files changed, 419 insertions, 0 deletions
diff --git a/third_party/highway/hwy/contrib/dot/dot-inl.h b/third_party/highway/hwy/contrib/dot/dot-inl.h new file mode 100644 index 0000000000..e04636f1b8 --- /dev/null +++ b/third_party/highway/hwy/contrib/dot/dot-inl.h @@ -0,0 +1,252 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#include <cmath> + +#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +struct Dot { + // Specify zero or more of these, ORed together, as the kAssumptions template + // argument to Compute. Each one may improve performance or reduce code size, + // at the cost of additional requirements on the arguments. + enum Assumptions { + // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T). + kAtLeastOneVector = 1, + // num_elements is divisible by N (a power of two, so this can be used if + // the problem size is known to be a power of two >= HWY_MAX_BYTES / + // sizeof(T)). + kMultipleOfVector = 2, + // RoundUpTo(num_elements, N) elements are accessible; their value does not + // matter (will be treated as if they were zero). + kPaddedToVector = 4, + }; + + // Returns sum{pa[i] * pb[i]} for float or double inputs. Aligning the + // pointers to a multiple of N elements is helpful but not required. + template <int kAssumptions, class D, typename T = TFromD<D>, + HWY_IF_NOT_LANE_SIZE_D(D, 2)> + static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa, + const T* const HWY_RESTRICT pb, + const size_t num_elements) { + static_assert(IsFloat<T>(), "MulAdd requires float type"); + using V = decltype(Zero(d)); + + const size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + // Only 2x unroll to avoid excessive code size. + T sum0 = T(0); + T sum1 = T(0); + for (; i + 2 <= num_elements; i += 2) { + sum0 += pa[i + 0] * pb[i + 0]; + sum1 += pa[i + 1] * pb[i + 1]; + } + if (i < num_elements) { + sum1 += pa[i] * pb[i]; + } + return sum0 + sum1; + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + + // Main loop: unrolled + for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = LoadU(d, pa + i); + const auto b2 = LoadU(d, pb + i); + i += N; + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = LoadU(d, pa + i); + const auto b3 = LoadU(d, pb + i); + i += N; + sum3 = MulAdd(a3, b3, sum3); + } + + // Up to 3 iterations of whole vectors + for (; i + N <= num_elements; i += N) { + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(d, remaining); + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(d, N - remaining); + const auto a = LoadU(d, pa + i); // always unaligned + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return GetLane(SumOfLanes(d, sum0)); + } + + // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a + // multiple of N elements is helpful but not required. + template <int kAssumptions, class D> + static HWY_INLINE float Compute(const D d, + const bfloat16_t* const HWY_RESTRICT pa, + const bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned<D> du16; + const Repartition<float, D> df32; + + using V = decltype(Zero(df32)); + const size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. + float sum1 = 0.0f; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]); + sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]); + } + if (i < num_elements) { + sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return GetLane(SumOfLanes(df32, sum0)); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ diff --git a/third_party/highway/hwy/contrib/dot/dot_test.cc b/third_party/highway/hwy/contrib/dot/dot_test.cc new file mode 100644 index 0000000000..12d7ab270d --- /dev/null +++ b/third_party/highway/hwy/contrib/dot/dot_test.cc @@ -0,0 +1,167 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> + +#include "hwy/aligned_allocator.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/contrib/dot/dot-inl.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template <typename T> +HWY_NOINLINE T SimpleDot(const T* pa, const T* pb, size_t num) { + double sum = 0.0; + for (size_t i = 0; i < num; ++i) { + sum += pa[i] * pb[i]; + } + return static_cast<T>(sum); +} + +HWY_NOINLINE float SimpleDot(const bfloat16_t* pa, const bfloat16_t* pb, + size_t num) { + float sum = 0.0f; + for (size_t i = 0; i < num; ++i) { + sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum; +} + +template <typename T> +void SetValue(const float value, T* HWY_RESTRICT ptr) { + *ptr = static_cast<T>(value); +} +void SetValue(const float value, bfloat16_t* HWY_RESTRICT ptr) { + *ptr = BF16FromF32(value); +} + +class TestDot { + // Computes/verifies one dot product. + template <int kAssumptions, class D> + void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, + RandomState& rng) { + using T = TFromD<D>; + const size_t N = Lanes(d); + const auto random_t = [&rng]() { + const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023; + return static_cast<float>(bits - 512) * (1.0f / 64); + }; + + const size_t padded = + (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; + AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded); + AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + padded); + T* a = pa.get() + misalign_a; + T* b = pb.get() + misalign_b; + size_t i = 0; + for (; i < num; ++i) { + SetValue(random_t(), a + i); + SetValue(random_t(), b + i); + } + // Fill padding with NaN - the values are not used, but avoids MSAN errors. + for (; i < padded; ++i) { + ScalableTag<float> df1; + SetValue(GetLane(NaN(df1)), a + i); + SetValue(GetLane(NaN(df1)), b + i); + } + + const auto expected = SimpleDot(a, b, num); + const auto actual = Dot::Compute<kAssumptions>(d, a, b, num); + const auto max = static_cast<decltype(actual)>(8 * 8 * num); + HWY_ASSERT(-max <= actual && actual <= max); + HWY_ASSERT(expected - 1E-4 <= actual && actual <= expected + 1E-4); + } + + // Runs tests with various alignments. + template <int kAssumptions, class D> + void ForeachMisalign(D d, size_t num, RandomState& rng) { + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test<kAssumptions>(d, num, ma, mb, rng); + } + } + } + + // Runs tests with various lengths compatible with the given assumptions. + template <int kAssumptions, class D> + void ForeachCount(D d, RandomState& rng) { + const size_t N = Lanes(d); + const size_t counts[] = {1, + 3, + 7, + 16, + HWY_MAX(N / 2, 1), + HWY_MAX(2 * N / 3, 1), + N, + N + 1, + 4 * N / 3, + 3 * N, + 8 * N, + 8 * N + 2}; + for (size_t num : counts) { + if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; + if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; + ForeachMisalign<kAssumptions>(d, num, rng); + } + } + + public: + template <class T, class D> + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + // All 8 combinations of the three length-related flags: + ForeachCount<0>(d, rng); + ForeachCount<Dot::kAtLeastOneVector>(d, rng); + ForeachCount<Dot::kMultipleOfVector>(d, rng); + ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng); + ForeachCount<Dot::kPaddedToVector>(d, rng); + ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng); + ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng); + ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector | + Dot::kAtLeastOneVector>(d, rng); + } +}; + +void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); } +void TestAllDotBF16() { ForShrinkableVectors<TestDot>()(bfloat16_t()); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +HWY_BEFORE_TEST(DotTest); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16); +} // namespace hwy + +#endif |