summaryrefslogtreecommitdiffstats
path: root/third_party/highway/hwy/contrib/dot
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/highway/hwy/contrib/dot')
-rw-r--r--third_party/highway/hwy/contrib/dot/dot-inl.h252
-rw-r--r--third_party/highway/hwy/contrib/dot/dot_test.cc167
2 files changed, 419 insertions, 0 deletions
diff --git a/third_party/highway/hwy/contrib/dot/dot-inl.h b/third_party/highway/hwy/contrib/dot/dot-inl.h
new file mode 100644
index 0000000000..e04636f1b8
--- /dev/null
+++ b/third_party/highway/hwy/contrib/dot/dot-inl.h
@@ -0,0 +1,252 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Include guard (still compiled once per target)
+#include <cmath>
+
+#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
+ defined(HWY_TARGET_TOGGLE)
+#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
+#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
+#else
+#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
+#endif
+
+#include "hwy/highway.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+
+struct Dot {
+ // Specify zero or more of these, ORed together, as the kAssumptions template
+ // argument to Compute. Each one may improve performance or reduce code size,
+ // at the cost of additional requirements on the arguments.
+ enum Assumptions {
+ // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T).
+ kAtLeastOneVector = 1,
+ // num_elements is divisible by N (a power of two, so this can be used if
+ // the problem size is known to be a power of two >= HWY_MAX_BYTES /
+ // sizeof(T)).
+ kMultipleOfVector = 2,
+ // RoundUpTo(num_elements, N) elements are accessible; their value does not
+ // matter (will be treated as if they were zero).
+ kPaddedToVector = 4,
+ };
+
+ // Returns sum{pa[i] * pb[i]} for float or double inputs. Aligning the
+ // pointers to a multiple of N elements is helpful but not required.
+ template <int kAssumptions, class D, typename T = TFromD<D>,
+ HWY_IF_NOT_LANE_SIZE_D(D, 2)>
+ static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
+ const T* const HWY_RESTRICT pb,
+ const size_t num_elements) {
+ static_assert(IsFloat<T>(), "MulAdd requires float type");
+ using V = decltype(Zero(d));
+
+ const size_t N = Lanes(d);
+ size_t i = 0;
+
+ constexpr bool kIsAtLeastOneVector =
+ (kAssumptions & kAtLeastOneVector) != 0;
+ constexpr bool kIsMultipleOfVector =
+ (kAssumptions & kMultipleOfVector) != 0;
+ constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
+
+ // Won't be able to do a full vector load without padding => scalar loop.
+ if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
+ HWY_UNLIKELY(num_elements < N)) {
+ // Only 2x unroll to avoid excessive code size.
+ T sum0 = T(0);
+ T sum1 = T(0);
+ for (; i + 2 <= num_elements; i += 2) {
+ sum0 += pa[i + 0] * pb[i + 0];
+ sum1 += pa[i + 1] * pb[i + 1];
+ }
+ if (i < num_elements) {
+ sum1 += pa[i] * pb[i];
+ }
+ return sum0 + sum1;
+ }
+
+ // Compiler doesn't make independent sum* accumulators, so unroll manually.
+ // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
+ // for unaligned inputs (each unaligned pointer halves the throughput
+ // because it occupies both L1 load ports for a cycle). We cannot have
+ // arrays of vectors on RVV/SVE, so always unroll 4x.
+ V sum0 = Zero(d);
+ V sum1 = Zero(d);
+ V sum2 = Zero(d);
+ V sum3 = Zero(d);
+
+ // Main loop: unrolled
+ for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
+ const auto a0 = LoadU(d, pa + i);
+ const auto b0 = LoadU(d, pb + i);
+ i += N;
+ sum0 = MulAdd(a0, b0, sum0);
+ const auto a1 = LoadU(d, pa + i);
+ const auto b1 = LoadU(d, pb + i);
+ i += N;
+ sum1 = MulAdd(a1, b1, sum1);
+ const auto a2 = LoadU(d, pa + i);
+ const auto b2 = LoadU(d, pb + i);
+ i += N;
+ sum2 = MulAdd(a2, b2, sum2);
+ const auto a3 = LoadU(d, pa + i);
+ const auto b3 = LoadU(d, pb + i);
+ i += N;
+ sum3 = MulAdd(a3, b3, sum3);
+ }
+
+ // Up to 3 iterations of whole vectors
+ for (; i + N <= num_elements; i += N) {
+ const auto a = LoadU(d, pa + i);
+ const auto b = LoadU(d, pb + i);
+ sum0 = MulAdd(a, b, sum0);
+ }
+
+ if (!kIsMultipleOfVector) {
+ const size_t remaining = num_elements - i;
+ if (remaining != 0) {
+ if (kIsPaddedToVector) {
+ const auto mask = FirstN(d, remaining);
+ const auto a = LoadU(d, pa + i);
+ const auto b = LoadU(d, pb + i);
+ sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
+ } else {
+ // Unaligned load such that the last element is in the highest lane -
+ // ensures we do not touch any elements outside the valid range.
+ // If we get here, then num_elements >= N.
+ HWY_DASSERT(i >= N);
+ i += remaining - N;
+ const auto skip = FirstN(d, N - remaining);
+ const auto a = LoadU(d, pa + i); // always unaligned
+ const auto b = LoadU(d, pb + i);
+ sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
+ }
+ }
+ } // kMultipleOfVector
+
+ // Reduction tree: sum of all accumulators by pairs, then across lanes.
+ sum0 = Add(sum0, sum1);
+ sum2 = Add(sum2, sum3);
+ sum0 = Add(sum0, sum2);
+ return GetLane(SumOfLanes(d, sum0));
+ }
+
+ // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a
+ // multiple of N elements is helpful but not required.
+ template <int kAssumptions, class D>
+ static HWY_INLINE float Compute(const D d,
+ const bfloat16_t* const HWY_RESTRICT pa,
+ const bfloat16_t* const HWY_RESTRICT pb,
+ const size_t num_elements) {
+ const RebindToUnsigned<D> du16;
+ const Repartition<float, D> df32;
+
+ using V = decltype(Zero(df32));
+ const size_t N = Lanes(d);
+ size_t i = 0;
+
+ constexpr bool kIsAtLeastOneVector =
+ (kAssumptions & kAtLeastOneVector) != 0;
+ constexpr bool kIsMultipleOfVector =
+ (kAssumptions & kMultipleOfVector) != 0;
+ constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
+
+ // Won't be able to do a full vector load without padding => scalar loop.
+ if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
+ HWY_UNLIKELY(num_elements < N)) {
+ float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
+ float sum1 = 0.0f; // this unlikely(?) case.
+ for (; i + 2 <= num_elements; i += 2) {
+ sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
+ sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
+ }
+ if (i < num_elements) {
+ sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
+ }
+ return sum0 + sum1;
+ }
+
+ // See comment in the other Compute() overload. Unroll 2x, but we need
+ // twice as many sums for ReorderWidenMulAccumulate.
+ V sum0 = Zero(df32);
+ V sum1 = Zero(df32);
+ V sum2 = Zero(df32);
+ V sum3 = Zero(df32);
+
+ // Main loop: unrolled
+ for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
+ const auto a0 = LoadU(d, pa + i);
+ const auto b0 = LoadU(d, pb + i);
+ i += N;
+ sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
+ const auto a1 = LoadU(d, pa + i);
+ const auto b1 = LoadU(d, pb + i);
+ i += N;
+ sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
+ }
+
+ // Possibly one more iteration of whole vectors
+ if (i + N <= num_elements) {
+ const auto a0 = LoadU(d, pa + i);
+ const auto b0 = LoadU(d, pb + i);
+ i += N;
+ sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
+ }
+
+ if (!kIsMultipleOfVector) {
+ const size_t remaining = num_elements - i;
+ if (remaining != 0) {
+ if (kIsPaddedToVector) {
+ const auto mask = FirstN(du16, remaining);
+ const auto va = LoadU(d, pa + i);
+ const auto vb = LoadU(d, pb + i);
+ const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
+ const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
+ sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
+
+ } else {
+ // Unaligned load such that the last element is in the highest lane -
+ // ensures we do not touch any elements outside the valid range.
+ // If we get here, then num_elements >= N.
+ HWY_DASSERT(i >= N);
+ i += remaining - N;
+ const auto skip = FirstN(du16, N - remaining);
+ const auto va = LoadU(d, pa + i); // always unaligned
+ const auto vb = LoadU(d, pb + i);
+ const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
+ const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
+ sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
+ }
+ }
+ } // kMultipleOfVector
+
+ // Reduction tree: sum of all accumulators by pairs, then across lanes.
+ sum0 = Add(sum0, sum1);
+ sum2 = Add(sum2, sum3);
+ sum0 = Add(sum0, sum2);
+ return GetLane(SumOfLanes(df32, sum0));
+ }
+};
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
diff --git a/third_party/highway/hwy/contrib/dot/dot_test.cc b/third_party/highway/hwy/contrib/dot/dot_test.cc
new file mode 100644
index 0000000000..12d7ab270d
--- /dev/null
+++ b/third_party/highway/hwy/contrib/dot/dot_test.cc
@@ -0,0 +1,167 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "hwy/aligned_allocator.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc"
+#include "hwy/foreach_target.h" // IWYU pragma: keep
+
+#include "hwy/contrib/dot/dot-inl.h"
+#include "hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+
+template <typename T>
+HWY_NOINLINE T SimpleDot(const T* pa, const T* pb, size_t num) {
+ double sum = 0.0;
+ for (size_t i = 0; i < num; ++i) {
+ sum += pa[i] * pb[i];
+ }
+ return static_cast<T>(sum);
+}
+
+HWY_NOINLINE float SimpleDot(const bfloat16_t* pa, const bfloat16_t* pb,
+ size_t num) {
+ float sum = 0.0f;
+ for (size_t i = 0; i < num; ++i) {
+ sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
+ }
+ return sum;
+}
+
+template <typename T>
+void SetValue(const float value, T* HWY_RESTRICT ptr) {
+ *ptr = static_cast<T>(value);
+}
+void SetValue(const float value, bfloat16_t* HWY_RESTRICT ptr) {
+ *ptr = BF16FromF32(value);
+}
+
+class TestDot {
+ // Computes/verifies one dot product.
+ template <int kAssumptions, class D>
+ void Test(D d, size_t num, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ const size_t N = Lanes(d);
+ const auto random_t = [&rng]() {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ return static_cast<float>(bits - 512) * (1.0f / 64);
+ };
+
+ const size_t padded =
+ (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num;
+ AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded);
+ AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + padded);
+ T* a = pa.get() + misalign_a;
+ T* b = pb.get() + misalign_b;
+ size_t i = 0;
+ for (; i < num; ++i) {
+ SetValue(random_t(), a + i);
+ SetValue(random_t(), b + i);
+ }
+ // Fill padding with NaN - the values are not used, but avoids MSAN errors.
+ for (; i < padded; ++i) {
+ ScalableTag<float> df1;
+ SetValue(GetLane(NaN(df1)), a + i);
+ SetValue(GetLane(NaN(df1)), b + i);
+ }
+
+ const auto expected = SimpleDot(a, b, num);
+ const auto actual = Dot::Compute<kAssumptions>(d, a, b, num);
+ const auto max = static_cast<decltype(actual)>(8 * 8 * num);
+ HWY_ASSERT(-max <= actual && actual <= max);
+ HWY_ASSERT(expected - 1E-4 <= actual && actual <= expected + 1E-4);
+ }
+
+ // Runs tests with various alignments.
+ template <int kAssumptions, class D>
+ void ForeachMisalign(D d, size_t num, RandomState& rng) {
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+ for (size_t ma : misalignments) {
+ for (size_t mb : misalignments) {
+ Test<kAssumptions>(d, num, ma, mb, rng);
+ }
+ }
+ }
+
+ // Runs tests with various lengths compatible with the given assumptions.
+ template <int kAssumptions, class D>
+ void ForeachCount(D d, RandomState& rng) {
+ const size_t N = Lanes(d);
+ const size_t counts[] = {1,
+ 3,
+ 7,
+ 16,
+ HWY_MAX(N / 2, 1),
+ HWY_MAX(2 * N / 3, 1),
+ N,
+ N + 1,
+ 4 * N / 3,
+ 3 * N,
+ 8 * N,
+ 8 * N + 2};
+ for (size_t num : counts) {
+ if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue;
+ if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue;
+ ForeachMisalign<kAssumptions>(d, num, rng);
+ }
+ }
+
+ public:
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ // All 8 combinations of the three length-related flags:
+ ForeachCount<0>(d, rng);
+ ForeachCount<Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kMultipleOfVector>(d, rng);
+ ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector |
+ Dot::kAtLeastOneVector>(d, rng);
+ }
+};
+
+void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); }
+void TestAllDotBF16() { ForShrinkableVectors<TestDot>()(bfloat16_t()); }
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+
+namespace hwy {
+HWY_BEFORE_TEST(DotTest);
+HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot);
+HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16);
+} // namespace hwy
+
+#endif