summaryrefslogtreecommitdiffstats
path: root/third_party/intgemm/test/add127_test.cc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
commit43a97878ce14b72f0981164f87f2e35e14151312 (patch)
tree620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/intgemm/test/add127_test.cc
parentInitial commit. (diff)
downloadfirefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz
firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/intgemm/test/add127_test.cc')
-rw-r--r--third_party/intgemm/test/add127_test.cc492
1 files changed, 492 insertions, 0 deletions
diff --git a/third_party/intgemm/test/add127_test.cc b/third_party/intgemm/test/add127_test.cc
new file mode 100644
index 0000000000..c31732c56e
--- /dev/null
+++ b/third_party/intgemm/test/add127_test.cc
@@ -0,0 +1,492 @@
+#include "test.h"
+
+namespace intgemm {
+namespace {
+
+void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols) {
+ for (Index r = 0; r<rows; r++) {
+ for (Index c = 0; c<cols; c++) {
+ int a = int(output_old[rows*c + r]);
+ int b = int(output_new[rows*c + r]);
+ INFO("Inaccurate at row: " << r << " column " << c << ' '
+ << a << ' ' << b);
+ CHECK(a+127 == b);
+ }
+ }
+}
+
+template <class Routine> void TestPrepareA(Index rows, Index cols) {
+ std::mt19937 gen;
+ // Go somewhat out of range too.
+ std::uniform_real_distribution<float> dist(-2, 2);
+ // Create array.
+ AlignedVector<float> inputA(rows * cols);
+ for (auto& it : inputA) {
+ it = dist(gen);
+ }
+ AlignedVector<int8_t> oldA(rows * cols);
+ AlignedVector<uint8_t> newA(rows * cols);
+ float quant_mult = 64; //From example
+ Routine::PrepareA(inputA.begin(), oldA.begin(), quant_mult, rows, cols);
+ Routine::PrepareA(inputA.begin(), newA.begin(), quant_mult, rows, cols);
+ CompareAs(oldA.begin(), newA.begin(), rows, cols);
+}
+
+template <class Routine> void TestPrepareBias(Index rows, Index cols) {
+ std::mt19937 gen;
+ // Go somewhat out of range too.
+ std::uniform_real_distribution<float> dist(-30.0, 30.0);
+ // Create array.
+ AlignedVector<float> inputB(rows * cols);
+ for (auto& it : inputB) {
+ it = dist(gen);
+ }
+
+ float alpha = 25;
+ float quant_mult = 127/alpha;
+
+ AlignedVector<int8_t> B_prep(inputB.size());
+ AlignedVector<int8_t> B_quant(inputB.size());
+ Routine::PrepareB(inputB.begin(), B_prep.begin(), quant_mult, rows, cols);
+ Routine::Quantize(inputB.begin(), B_quant.begin(), quant_mult, static_cast<intgemm::Index>(inputB.size()));
+
+
+ AlignedVector<float> inputBias(cols);
+ AlignedVector<float> goldBias(cols);
+
+ for (auto& it : goldBias) {
+ it = dist(gen);
+ }
+ int i = 0;
+ for (auto& it : inputBias) {
+ it = goldBias[i];
+ i++;
+ }
+
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f);
+
+ Routine::PrepareBias(B_prep.begin(), rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin()));
+
+ int A_rows = 1;
+ AlignedVector<int8_t> A_prep2(A_rows*rows);
+ for (auto& it : A_prep2) {
+ it =1;
+ }
+ //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin()));
+ //CompareEps(goldBias.begin(), inputBias.begin(), cols, 0.0001f);
+ AlignedVector<float> slowint_C(cols);
+ references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult_forprep + goldBias[info.col_idx];
+ });
+ CompareEps(slowint_C.begin(), inputBias.begin(), cols, 0.0001f);
+}
+
+template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127.0f / alpha;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ /*REFERENCE MULTIPLICATION
+ *
+ *
+ */
+ AlignedVector<int8_t> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Taking the original A_preparation which means A would be int8_t
+ AlignedVector<int8_t> A_prep2(A.size());
+ Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width);
+ references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult + bias[info.col_idx];
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
+ return static_cast<float>(sum) + bias[info.col_idx];
+ });
+
+ /*ACTUAL MULTIPLICATION
+ *
+ */
+ float unquant_mult_forprep = (-1.0f)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ //Routine::PrepareBias(B.begin(), bias.begin(), alpha, width, B_cols);
+ Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
+template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = 0;
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127.0f / alpha;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> A_prep_old(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ /*
+ * Reference non shift multiplication instead of slowint
+ */
+ AlignedVector<float> slowint_C(test_C.size());
+ Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin()));
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
+ return static_cast<float>(sum) + bias[info.col_idx];
+ });
+
+ /*
+ * Multiply8 shift multiplication
+ */
+ float unquant_mult_forprep = (-1.0f)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
+template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = 0;
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127.0f / alpha;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> A_prep_old(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ /*
+ * Reference float multiplication
+ */
+ AlignedVector<int8_t> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Taking the original A_preparation which means A would be int8_t
+ // references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ // return sum * unquant_mult + bias[info.col_idx];
+ // });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
+ return static_cast<float>(sum) + bias[info.col_idx];
+ });
+ /*
+ * Multiply8 shift multiplication
+ */
+ //First prepare SlowInteger Bias:
+ AlignedVector<int8_t> A_prep2(1*width);
+ for (auto& it : A_prep2) {
+ it = 1;
+ }
+ AlignedVector<float> ShiftedBias(B_cols);
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ references::Multiply(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), 1, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult_forprep + bias[info.col_idx];
+ });
+
+
+ //Now prepare Fast integer Bias
+ Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ // Reference INT VERSION HERE with ADD127
+ // Taking the original A_preparation which means A would be int8_t
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult + ShiftedBias[info.col_idx];
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
+
+// Bias
+TEST_CASE("PrepareBias SSSE3", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestPrepareBias<SSSE3::Kernels8>(256,256);
+ TestPrepareBias<SSSE3::Kernels8>(2048,256);
+ TestPrepareBias<SSSE3::Kernels8>(512,512);
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+TEST_CASE("PrepareBias AVX2", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestPrepareBias<AVX2::Kernels8>(256,256);
+ TestPrepareBias<AVX2::Kernels8>(2048,256);
+ TestPrepareBias<AVX2::Kernels8>(512,512);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+TEST_CASE("PrepareBias AVX512F", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestPrepareBias<AVX512BW::Kernels8>(256,256);
+ TestPrepareBias<AVX512BW::Kernels8>(2048,256);
+ TestPrepareBias<AVX512BW::Kernels8>(512,512);
+}
+#endif
+
+//A
+TEST_CASE("PrepareA SSSE3", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestPrepareA<SSSE3::Kernels8>(64,64);
+ TestPrepareA<SSSE3::Kernels8>(256,256);
+ TestPrepareA<SSSE3::Kernels8>(512,512);
+ TestPrepareA<SSSE3::Kernels8>(2048,256);
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+TEST_CASE("PrepareA AVX2", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestPrepareA<AVX2::Kernels8>(64,64);
+ TestPrepareA<AVX2::Kernels8>(256,256);
+ TestPrepareA<AVX2::Kernels8>(512,512);
+ TestPrepareA<AVX2::Kernels8>(2048,256);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+TEST_CASE("PrepareA AVX512F", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestPrepareA<AVX512BW::Kernels8>(64,64);
+ TestPrepareA<AVX512BW::Kernels8>(256,256);
+ TestPrepareA<AVX512BW::Kernels8>(512,512);
+ TestPrepareA<AVX512BW::Kernels8>(2048,256);
+}
+#endif
+
+// Multiply
+
+TEST_CASE ("Multiply SSSE3 8bit Shift with bias", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyBiasNew<SSSE3::Kernels8>(1, 64, 8, 0.11f, 0.1f, 0.06f, 0.05f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(8, 256, 256, 0.45f, 0.54f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(8, 2048, 256, 1.7f, 1.7f, 0.46f, 0.43f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(320, 256, 256, 0.56f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f);
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+TEST_CASE ("Multiply AVX2 8bit Shift with bias", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasNew<AVX2::Kernels8>(1, 64, 8, 0.11f, 0.11f, 0.06f, 0.05f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(8, 256, 256, 0.49f, 0.54f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(8, 2048, 256, 1.57f, 1.66f, 0.46f, 0.46f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(320, 256, 256, 0.49f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ TEST_CASE ("Multiply AVX512VNNI 8bit Shift with bias", "[Add127]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
+ }
+#endif
+
+//Multiply old vs new
+TEST_CASE ("Multiply SSSE3 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(1, 64, 8, 0.00001f, 0.1f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(8, 2048, 256, 17.9f, 1.7f, 0.46f, 4.2f); //Big difference here because the non-shift version is very bad
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(320, 256, 256, 1.2f, 0.64f, 0.16f, 0.006f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(472, 256, 256, 1.1f, 0.62f, 0.17f, 0.006f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(248, 256, 256, 0.9f, 0.64f, 0.16f, 0.007f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(200, 256, 256, 1, 0.74f, 0.17f, 0.006f);
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+TEST_CASE ("Multiply AVX2 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(1, 64, 8, 0.00001f, 0.11f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(8, 2048, 256, 9.4f, 1.66f, 0.46f, 1.67f); //Big difference here because the non-shift version is very bad
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+TEST_CASE ("Multiply AVX512F 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(8, 2048, 256, 3.51f, 0.61f, 0.17f, 0.3f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(1, 64, 8, 0.00001f, 0.05f, 0.03f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(8, 256, 256, 0.00001f, 0.22f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(320, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(472, 256, 256, 0.00001f, 0.33f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(248, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(200, 256, 256, 0.00001f, 0.28f, 0.06f, 0.00001f);
+ }
+#endif
+
+//Multiply Shift vs int shift implementation
+TEST_CASE ("Multiply SSSE3 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyShiftInt<SSSE3::Kernels8>(1, 64, 8, 0.0001f, 0.1f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(8, 2048, 256, 0.0001f, 1.7f, 0.46f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+TEST_CASE ("Multiply AVX2 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyShiftInt<AVX2::Kernels8>(1, 64, 8, 0.0001f, 0.11f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(8, 2048, 256, 0.0001f, 1.66f, 0.46f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+TEST_CASE ("Multiply AVX512F 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f);
+}
+#endif
+
+} // namespace
+} // namespace intgemm