summaryrefslogtreecommitdiffstats
path: root/third_party/intgemm/test/kernels/relu_test.cc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
commit43a97878ce14b72f0981164f87f2e35e14151312 (patch)
tree620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/intgemm/test/kernels/relu_test.cc
parentInitial commit. (diff)
downloadfirefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz
firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/intgemm/test/kernels/relu_test.cc')
-rw-r--r--third_party/intgemm/test/kernels/relu_test.cc65
1 files changed, 65 insertions, 0 deletions
diff --git a/third_party/intgemm/test/kernels/relu_test.cc b/third_party/intgemm/test/kernels/relu_test.cc
new file mode 100644
index 0000000000..8fd30ae25a
--- /dev/null
+++ b/third_party/intgemm/test/kernels/relu_test.cc
@@ -0,0 +1,65 @@
+#include "../test.h"
+#include "../../intgemm/aligned.h"
+#include "../../intgemm/kernels.h"
+
+#include <cstdint>
+#include <numeric>
+
+namespace intgemm {
+
+template <CPUType CPUType_, typename ElemType_>
+void kernel_relu_test() {
+ if (kCPU < CPUType_)
+ return;
+
+ using vec_t = vector_t<CPUType_, ElemType_>;
+ constexpr int VECTOR_LENGTH = sizeof(vec_t) / sizeof(ElemType_);
+
+ AlignedVector<ElemType_> input(VECTOR_LENGTH);
+ AlignedVector<ElemType_> output(VECTOR_LENGTH);
+
+ std::iota(input.begin(), input.end(), static_cast<ElemType_>(-VECTOR_LENGTH / 2));
+
+ *output.template as<vec_t>() = kernels::relu<ElemType_>(*input.template as<vec_t>());
+ for (std::size_t i = 0; i < output.size(); ++i)
+ CHECK(output[i] == (input[i] < 0 ? 0 : input[i]));
+}
+
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int8_t>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int16_t>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, int>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, float>();
+template INTGEMM_SSE2 void kernel_relu_test<CPUType::SSE2, double>();
+KERNEL_TEST_CASE("relu/int8 SSE2") { return kernel_relu_test<CPUType::SSE2, int8_t>(); }
+KERNEL_TEST_CASE("relu/int16 SSE2") { return kernel_relu_test<CPUType::SSE2, int16_t>(); }
+KERNEL_TEST_CASE("relu/int SSE2") { return kernel_relu_test<CPUType::SSE2, int>(); }
+KERNEL_TEST_CASE("relu/float SSE2") { return kernel_relu_test<CPUType::SSE2, float>(); }
+KERNEL_TEST_CASE("relu/double SSE2") { return kernel_relu_test<CPUType::SSE2, double>(); }
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int8_t>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int16_t>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, float>();
+template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, double>();
+KERNEL_TEST_CASE("relu/int8 AVX2") { return kernel_relu_test<CPUType::AVX2, int8_t>(); }
+KERNEL_TEST_CASE("relu/int16 AVX2") { return kernel_relu_test<CPUType::AVX2, int16_t>(); }
+KERNEL_TEST_CASE("relu/int AVX2") { return kernel_relu_test<CPUType::AVX2, int>(); }
+KERNEL_TEST_CASE("relu/float AVX2") { return kernel_relu_test<CPUType::AVX2, float>(); }
+KERNEL_TEST_CASE("relu/double AVX2") { return kernel_relu_test<CPUType::AVX2, double>(); }
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int8_t>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int16_t>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, float>();
+template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, double>();
+KERNEL_TEST_CASE("relu/int8 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int8_t>(); }
+KERNEL_TEST_CASE("relu/int16 AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int16_t>(); }
+KERNEL_TEST_CASE("relu/int AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, int>(); }
+KERNEL_TEST_CASE("relu/float AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, float>(); }
+KERNEL_TEST_CASE("relu/double AVX512BW") { return kernel_relu_test<CPUType::AVX512BW, double>(); }
+#endif
+
+}