diff options
Diffstat (limited to 'third_party/intgemm/test/multiply_test.cc')
-rw-r--r-- | third_party/intgemm/test/multiply_test.cc | 761 |
1 files changed, 761 insertions, 0 deletions
diff --git a/third_party/intgemm/test/multiply_test.cc b/third_party/intgemm/test/multiply_test.cc new file mode 100644 index 0000000000..f72758fe19 --- /dev/null +++ b/third_party/intgemm/test/multiply_test.cc @@ -0,0 +1,761 @@ +#include "test.h" +#include "../intgemm/aligned.h" +#include "../intgemm/callbacks.h" +#include "../intgemm/interleave.h" +#include "../intgemm/intgemm.h" +#include "../intgemm/multiply.h" +#include "../intgemm/stats.h" + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <iomanip> +#include <iostream> +#include <memory> +#include <numeric> +#include <random> + +namespace intgemm { + +#ifndef __INTEL_COMPILER +INTGEMM_SSE2 +#endif +TEST_CASE("Transpose 16", "[transpose]") { + if (kCPU < CPUType::SSE2) return; + const unsigned N = 8; + AlignedVector<int16_t> input(N * N); + std::iota(input.begin(), input.end(), static_cast<int16_t>(0)); + + AlignedVector<int16_t> ref(N * N); + references::Transpose(input.begin(), ref.begin(), N, N); + + // Overwrite input. + __m128i *t = input.as<__m128i>(); + Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]); + + for (std::size_t i = 0; i < input.size(); ++i) { + CHECK_MESSAGE(ref[i] == input[i], "16-bit transpose failure at: " << i << ": " << ref[i] << " != " << input[i]); + } +} + +#ifndef __INTEL_COMPILER +INTGEMM_SSSE3 +#endif +TEST_CASE("Transpose 8", "[transpose]") { + if (kCPU < CPUType::SSSE3) return; + const unsigned N = 16; + AlignedVector<int8_t> input(N * N); + std::iota(input.begin(), input.end(), static_cast<int8_t>(0)); + + AlignedVector<int8_t> ref(input.size()); + references::Transpose(input.begin(), ref.begin(), N, N); + + // Overwrite input. + __m128i *t = input.as<__m128i>(); + Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]); + + for (std::size_t i = 0; i < input.size(); ++i) { + CHECK_MESSAGE(ref[i] == input[i], "8-bit transpose failure at " << i << ": " << (int16_t)ref[i] << " != " << (int16_t)input[i]); + } +} + +template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-129.0, 129.0); + // Create array. + AlignedVector<float> input(rows * cols); + for (auto& it : input) { + it = dist(gen); + } + + using Integer = typename Routine::Integer; + // Call Prepare + AlignedVector<Integer> test(input.size()); + Routine::PrepareB(input.begin(), test.begin(), 1, rows, cols); + + // Compute reference output. + AlignedVector<Integer> quantized(input.size()); + Routine::Quantize(input.begin(), quantized.begin(), 1, static_cast<Index>(input.size())); + AlignedVector<Integer> reference(input.size()); + // Note this won't work for Int8/Int16 generic routines because tile sizes vary. + references::Rearragement(quantized.begin(), reference.begin(), Routine::kBTileRow, Routine::kBTileCol, rows, cols); + CHECK_MESSAGE(memcmp(reference.begin(), test.begin(), test.size() * sizeof(Integer)) == 0, Routine::kName << " Mismatch:\n" << + "Quantized Input" << '\n' << PrintMatrix(quantized.begin(), rows, cols) << "Reference" << '\n' << + PrintMatrix(reference.begin(), rows, cols) << "Routine" << '\n' << PrintMatrix(test.begin(), rows, cols)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("Prepare AVX512", "[prepare]") { + if (kCPU < CPUType::AVX512BW) return; + TestPrepare<AVX512BW::Kernels8>(64, 8); + TestPrepare<AVX512BW::Kernels8>(256, 32); + TestPrepare<AVX512BW::Kernels16>(64, 8); + TestPrepare<AVX512BW::Kernels16>(256, 32); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("Prepare AVX2", "[prepare]") { + if (kCPU < CPUType::AVX2) return; + TestPrepare<AVX2::Kernels8>(64, 32); + TestPrepare<AVX2::Kernels16>(64, 32); +} +#endif + +TEST_CASE("Prepare SSSE3", "[prepare]") { + if (kCPU < CPUType::SSSE3) return; + TestPrepare<SSSE3::Kernels8>(16, 8); + TestPrepare<SSSE3::Kernels8>(32, 16); + TestPrepare<SSSE3::Kernels8>(32, 32); +} + +TEST_CASE("Prepare SSE2", "[prepare]") { + if (kCPU < CPUType::SSE2) return; + TestPrepare<SSE2::Kernels16>(8, 8); + TestPrepare<SSE2::Kernels16>(32, 32); +} + +template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 16) { + std::mt19937 gen; + // Go somewhat out of range too. + std::uniform_real_distribution<float> dist(-129.0, 129.0); + AlignedVector<float> input(rows * cols); + for (auto& it : input) { + it = dist(gen); + } + using Integer = typename Routine::Integer; + AlignedVector<Integer> prepared(input.size()); + Routine::PrepareB(input.begin(), prepared.begin(), 1, rows, cols); + + const int kSelectCols = 24; + Index select_cols[kSelectCols]; + std::uniform_int_distribution<Index> col_dist(0, cols - 1); + for (auto& it : select_cols) { + it = col_dist(gen); + } + + AlignedVector<Integer> test(rows * kSelectCols); + Routine::SelectColumnsB(prepared.begin(), test.begin(), rows, select_cols, select_cols + kSelectCols); + + // Select columns manually in float space. + AlignedVector<float> selected(rows * kSelectCols); + for (Index r = 0; r < rows; ++r) { + for (int c = 0; c < kSelectCols; ++c) { + assert(c + r * kSelectCols < rows * kSelectCols); + selected[c + r * kSelectCols] = input[select_cols[c] + r * cols]; + } + } + AlignedVector<Integer> ref(rows * kSelectCols); + Routine::PrepareB(selected.begin(), ref.begin(), 1, rows, kSelectCols); + CHECK_MESSAGE(memcmp(ref.begin(), test.begin(), sizeof(Integer) * rows * kSelectCols) == 0, "Reference:\n" << + PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols)); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("SelectColumnsB AVX512", "[select]") { + if (kCPU < CPUType::AVX512BW) return; + TestSelectColumnsB<AVX512BW::Kernels8>(); + TestSelectColumnsB<AVX512BW::Kernels16>(256, 256); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("SelectColumnsB AVX2", "[select]") { + if (kCPU < CPUType::AVX2) return; + TestSelectColumnsB<AVX2::Kernels8>(256, 256); + TestSelectColumnsB<AVX2::Kernels16>(256, 256); +} +#endif + +TEST_CASE("SelectColumnsB SSSE3", "[select]") { + if (kCPU < CPUType::SSSE3) return; + TestSelectColumnsB<SSSE3::Kernels8>(); + TestSelectColumnsB<SSSE3::Kernels8>(256, 256); +} + +TEST_CASE("SelectColumnsB SSE2", "[select]") { + if (kCPU < CPUType::SSE2) return; + TestSelectColumnsB<SSE2::Kernels16>(); + TestSelectColumnsB<SSE2::Kernels16>(256, 256); +} + +template <class Register> void TestMax() { + Register r = set1_ps<Register>(-2.0); + for (std::size_t i = 0; i < sizeof(Register) / sizeof(float); ++i) { + Register c = r; + reinterpret_cast<float*>(&c)[i] = -1.0; + CHECK_MESSAGE((MaxFloat32(c) == -1.0), "MaxFloat32 produced " << MaxFloat32(c)); + } +} + +TEST_CASE("Max", "[max]") { + TestMax<__m128>(); +} + +void CompareMaxAbs(const float *begin, const float *end, float test, std::size_t offset) { + float largest = std::fabs(*std::max_element(begin, end)); + float smallest = std::fabs(*std::min_element(begin, end)); + largest = std::max(largest, smallest); + CHECK_MESSAGE(largest == test, "Error: " << largest << " versus " << test << " in length " << (end - begin) << " offset " << offset); +} + +template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute() { + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-8.0, 8.0); + const std::size_t kLengthMax = 65; + AlignedVector<float> test(kLengthMax); + for (std::size_t len = 1; len < kLengthMax; ++len) { + for (std::size_t t = 0; t < len; ++t) { + // Fill with [-8, 8). + for (auto& it : test) { + it = dist(gen); + } + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + test[t] = -32.0; + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + test[t] = 32.0; + CompareMaxAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + } + } +} + +TEST_CASE("MaxAbsolute SSE2", "[max]") { + if (kCPU < CPUType::SSE2) return; + TestMaxAbsolute<SSE2::MaxAbsolute>(); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("MaxAbsolute AVX2", "[max]") { + if (kCPU < CPUType::AVX2) return; + TestMaxAbsolute<AVX2::MaxAbsolute>(); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("MaxAbsolute AVX512BW", "[max]") { + if (kCPU < CPUType::AVX512BW) return; + TestMaxAbsolute<AVX512BW::MaxAbsolute>(); +} +#endif + +// Based on https://arxiv.org/abs/1705.01991 + +// Copyright (c) 2017 Microsoft Corporation + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// Compute A*B slowly in floats. + +template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); + // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( + // callbacks::Unquantize(unquant_mult), + // callbacks::Write<float>(test_C.begin()) + // )); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) { + return sum * unquant_mult; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) { + return static_cast<float>(sum); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin())); + // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( + // callbacks::Unquantize(unquant_mult), + // callbacks::Write<float>(test_C.begin()) + // )); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) { + float ret = std::max(0.0f, sum * unquant_mult); + return ret; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) { + return static_cast<float>(std::max(0.0,sum)); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +//Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols +//Require different number of arguments. I don't think the refactoring is worth it. +template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols, + float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + bias[info.col_idx]; + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return static_cast<float>(sum) + bias[info.col_idx]; + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols, + float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) { + using Integer = typename Routine::Integer; + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = dist(gen); + } + + float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64; + float unquant_mult = 1.0f / (quant_mult*quant_mult); + + AlignedVector<Integer> A_prep(A.size()); + AlignedVector<Integer> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin())); + + AlignedVector<Integer> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size())); + AlignedVector<float> slowint_C(test_C.size()); + // Assuming A is just quantization here. + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]); + }); + + AlignedVector<float> float_C(test_C.size()); + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) { + return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]); + }); + + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { + if (kCPU < CPUType::SSE2) return; + TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiply<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBias<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSE2) return; + TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiply<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiply<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiply<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiply<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + +TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + +TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyBias<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyBias<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyBias<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyBias<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + +TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f); + TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f); +} + + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiply<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiply<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyBias<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f); + TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f); + TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f); +} + +TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiply<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBias<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} + +TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f); + TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW + TEST_CASE ("Multiply AVX512 8bit", "[multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiply<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiply<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiply<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiply<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBias<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyBias<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBias<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiply<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiply<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiply<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiply<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBias<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyBias<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBias<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + + TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f); + TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f); + } + #endif + + TEST_CASE ("Multiply AVX512 16bit", "[multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiply<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiply<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + + TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyBias<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } + + TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f); + TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f); + } +#endif + +} // namespace intgemm |