From 26a029d407be480d791972afb5975cf62c9360a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 02:47:55 +0200 Subject: Adding upstream version 124.0.1. Signed-off-by: Daniel Baumann --- third_party/gemmology/LICENSE | 22 + third_party/gemmology/gemmology.h | 1326 ++++++++++++++++++++ third_party/gemmology/gemmology_fwd.h | 218 ++++ .../gemmology/kernels/GemmologyEngineAVX2.cpp | 19 + .../gemmology/kernels/GemmologyEngineAVX512BW.cpp | 19 + .../kernels/GemmologyEngineAVX512VNNI.cpp | 19 + .../gemmology/kernels/GemmologyEngineAVXVNNI.cpp | 19 + .../gemmology/kernels/GemmologyEngineNeon64.cpp | 19 + .../gemmology/kernels/GemmologyEngineSSE2.cpp | 19 + .../gemmology/kernels/GemmologyEngineSSSE3.cpp | 19 + third_party/gemmology/moz.yaml | 29 + 11 files changed, 1728 insertions(+) create mode 100644 third_party/gemmology/LICENSE create mode 100644 third_party/gemmology/gemmology.h create mode 100644 third_party/gemmology/gemmology_fwd.h create mode 100644 third_party/gemmology/kernels/GemmologyEngineAVX2.cpp create mode 100644 third_party/gemmology/kernels/GemmologyEngineAVX512BW.cpp create mode 100644 third_party/gemmology/kernels/GemmologyEngineAVX512VNNI.cpp create mode 100644 third_party/gemmology/kernels/GemmologyEngineAVXVNNI.cpp create mode 100644 third_party/gemmology/kernels/GemmologyEngineNeon64.cpp create mode 100644 third_party/gemmology/kernels/GemmologyEngineSSE2.cpp create mode 100644 third_party/gemmology/kernels/GemmologyEngineSSSE3.cpp create mode 100644 third_party/gemmology/moz.yaml (limited to 'third_party/gemmology') diff --git a/third_party/gemmology/LICENSE b/third_party/gemmology/LICENSE new file mode 100644 index 0000000000..fe8b644629 --- /dev/null +++ b/third_party/gemmology/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2023 Serge Guelton + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +-- + +The original 8-bit code came from: +MIT License + +Copyright (c) 2017--2019 University of Edinburgh, Nikolay Bogoychev, Mateusz Chudyk, Kenneth Heafield, and Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/third_party/gemmology/gemmology.h b/third_party/gemmology/gemmology.h new file mode 100644 index 0000000000..d774c53388 --- /dev/null +++ b/third_party/gemmology/gemmology.h @@ -0,0 +1,1326 @@ +#ifndef GEMMOLOGY_H +#define GEMMOLOGY_H + +#include "gemmology_fwd.h" + +#include +#include +#include + +#include + +namespace gemmology { + +namespace { + +// +// Arch specific implementation of various elementary operations +// + +namespace kernel { + +#ifdef __AVX512BW__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi8(first, second), + _mm512_unpackhi_epi8(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi16(first, second), + _mm512_unpackhi_epi16(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi32(first, second), + _mm512_unpackhi_epi32(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi64(first, second), + _mm512_unpackhi_epi64(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm512_packs_epi16(first, second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm512_packs_epi32(first, second); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm512_madd_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm512_maddubs_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm512_madd_epi16(x, y); +} + +template +inline xsimd::batch +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567, + // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567] + __m512i mix0 = + _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6)); + // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567, + // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567] + __m512i mix1 = + _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2)); + __m512i added = _mm512_add_epi32(mix0, mix1); + // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7. + // Fold register over itself. + return _mm256_add_epi32(_mm512_castsi512_si256(added), + _mm512_extracti64x4_epi64(added, 1)); +} +#endif + +#ifdef __AVX2__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi8(first, second), + _mm256_unpackhi_epi8(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi16(first, second), + _mm256_unpackhi_epi16(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi32(first, second), + _mm256_unpackhi_epi32(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi64(first, second), + _mm256_unpackhi_epi64(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm256_packs_epi16(first, second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm256_packs_epi32(first, second); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm256_madd_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm256_maddubs_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x)); +} + +template +inline xsimd::batch +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f + __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21); + // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s + __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0); + return _mm256_add_epi32(rev, blended); +} + +#ifdef __AVXVNNI__ + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch) { + return _mm256_dpbusd_avx_epi32(z, x, y); +} +#endif + +#ifdef __AVX512VNNI__ + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return _mm512_dpbusd_epi32(z, x, y); +} + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return _mm512_dpbusd_epi32(z, x, y); +} +#endif + +#endif + +#ifdef __SSSE3__ + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm_maddubs_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x)); +} +#endif + +#ifdef __SSE2__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm_packs_epi16(first, second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm_packs_epi32(first, second); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm_madd_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch a, xsimd::batch b, + xsimd::kernel::requires_arch) { + // Adapted from + // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 + // a = 0x00 0x01 0xFE 0x04 ... + // b = 0x00 0x02 0x80 0x84 ... + + // To extend signed 8-bit value, MSB has to be set to 0xFF + __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); + + // sign_mask_b = 0x00 0x00 0xFF 0xFF ... + + // Unpack positives with 0x00, negatives with 0xFF + __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128()); + __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128()); + __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); + __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); + + // Here - valid 16-bit signed integers corresponding to the 8-bit input + // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... + + // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts + __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); + __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); + + // Now go back from 32-bit values to 16-bit values & signed saturate + return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); +} + +template +inline xsimd::batch +madd(xsimd::batch a, xsimd::batch b, + xsimd::kernel::requires_arch) { + // adapted + // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 + // a = 0x00 0x01 0xFE 0x04 ... + // b = 0x00 0x02 0x80 0x84 ... + + // To extend signed 8-bit value, MSB has to be set to 0xFF + __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128()); + __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); + + // sign_mask_a = 0x00 0x00 0xFF 0x00 ... + // sign_mask_b = 0x00 0x00 0xFF 0xFF ... + + // Unpack positives with 0x00, negatives with 0xFF + __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a); + __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a); + __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); + __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); + + // Here - valid 16-bit signed integers corresponding to the 8-bit input + // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... + + // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts + __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); + __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); + + // Now go back from 32-bit values to 16-bit values & signed saturate + return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); +} + +template +inline std::tuple, xsimd::batch> +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + return {pack0123, pack4567}; +} + +#endif + +#if __ARM_ARCH >= 7 +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + + return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + + int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); + int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y)); + + int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low)); + int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high)); + + return vcombine_s32(low_sum, high_sum); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + + // This would be much simpler if x86 would choose to zero extend OR sign + // extend, not both. This could probably be optimized better. + + // Zero extend x + int16x8_t x_odd = + vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8)); + int16x8_t x_even = vreinterpretq_s16_u16( + vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00))); + + // Sign extend by shifting left then shifting right. + int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8); + int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8); + + // multiply + int16x8_t prod1 = vmulq_s16(x_even, y_even); + int16x8_t prod2 = vmulq_s16(x_odd, y_odd); + + // saturated add + return vqaddq_s16(prod1, prod2); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); + int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y)); + + int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low)); + int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high)); + + return vcombine_s16(low_sum, high_sum); +} + +template +inline std::tuple, xsimd::batch> +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + return {pack0123, pack4567}; +} +#endif + +#ifdef __aarch64__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s8(first, second), vzip2q_s8(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s16(first, second), vzip2q_s16(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s32(first, second), vzip2q_s32(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s64(first, second), vzip2q_s64(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); + return vmlal_high_s16(low, x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + + int16x8_t tl = vmull_s8(vreinterpret_s8_u8(vget_low_u8(x)), + vget_low_s8(y)); + int16x8_t th = vmull_high_s8(vreinterpretq_s8_u8(x), y); + return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th)); +} + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch) { + int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), + vmovl_s8(vget_low_s8(y))); + int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), + vmovl_s8(vget_high_s8(y))); + return vpadalq_s16(vpadalq_s16(z, tl), th); + //TODO: investigate using vdotq_s32 +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); + return vmlal_high_s8(low, x, y); +} + +#endif + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch) { + return z + madd(xsimd::batch(1), madd(x, y, Arch{}), Arch{}); +} + +} // namespace kernel + +// +// Generic dispatcher for interleave, deinterleave madd and PermuteSummer +// + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second) { + return kernel::interleave(first, second, Arch{}); +} + +template +xsimd::batch deinterleave(xsimd::batch first, + xsimd::batch second) { + return kernel::deinterleave(first, second, Arch{}); +} +template +xsimd::batch deinterleave(xsimd::batch first, + xsimd::batch second) { + return kernel::deinterleave(first, second, Arch{}); +} + +template +inline xsimd::batch madd(xsimd::batch x, + xsimd::batch y) { + return kernel::madd(x, y, Arch{}); +} +template +inline xsimd::batch madd(xsimd::batch x, + xsimd::batch y) { + return kernel::madd(x, y, Arch{}); +} +template +inline xsimd::batch madd(xsimd::batch x, + xsimd::batch y) { + return kernel::madd(x, y, Arch{}); +} +template +inline xsimd::batch maddw(xsimd::batch x, + xsimd::batch y, + xsimd::batch z + ) { + return kernel::maddw(x, y, z, Arch{}); +} +template +inline xsimd::batch maddw(xsimd::batch x, + xsimd::batch y + ) { + return maddw(x, y, xsimd::batch((int32_t)0)); +} + +template +inline auto PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567) + -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) { + return kernel::PermuteSummer(pack0123, pack4567, Arch{}); +} + +template +inline xsimd::batch Pack0123(xsimd::batch sum0, + xsimd::batch sum1, + xsimd::batch sum2, + xsimd::batch sum3) { + std::tie(sum0, sum1) = interleave(sum0, sum1); + auto pack01 = sum0 + sum1; + std::tie(sum2, sum3) = interleave(sum2, sum3); + auto pack23 = sum2 + sum3; + + auto packed = interleave(xsimd::bitwise_cast(pack01), + xsimd::bitwise_cast(pack23)); + return xsimd::bitwise_cast(std::get<0>(packed)) + + xsimd::bitwise_cast(std::get<1>(packed)); +} + +template +static inline xsimd::batch +quantize(xsimd::batch input, + xsimd::batch quant_mult) { + return xsimd::nearbyint_as_int(input * quant_mult); +} + +template +inline xsimd::batch +QuantizerGrab(const float *input, xsimd::batch quant_mult_reg) { + return quantize(xsimd::batch::load_unaligned(input), + quant_mult_reg); +} + +#ifdef __AVX512BW__ +inline __m512 Concat(const __m256 first, const __m256 second) { + // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway. + return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1); +} + +// Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be +// controlled independently. +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set + * INTGEMM_AVX512BW */ +inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, + const __m512 quant_mult_reg) { + __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1)); + appended = _mm512_mul_ps(appended, quant_mult_reg); + return _mm512_cvtps_epi32(appended); +} +#else +template +inline xsimd::batch +QuantizerGrabHalves(const float *input0, const float *input1, + xsimd::batch quant_mult_reg); +#endif + +/* Read 8 floats at a time from input0, input1, input2, and input3. Quantize + * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate + * the result into one register and return it. + */ +class QuantizeTile8 { + template struct Tiler { + static constexpr uint32_t get(std::size_t i, std::size_t n) { + size_t factor = xsimd::batch::size / 4; + return (i % factor) * 4 + i / factor; + } + }; + +public: + template + static inline xsimd::batch + Consecutive(xsimd::batch quant_mult, const float *input) { + return Tile(quant_mult, input + 0 * xsimd::batch::size, + input + 1 * xsimd::batch::size, + input + 2 * xsimd::batch::size, + input + 3 * xsimd::batch::size); + } + + template + static inline xsimd::batch + ConsecutiveU(xsimd::batch quant_mult, const float *input) { + return TileU(quant_mult, input + 0 * xsimd::batch::size, + input + 1 * xsimd::batch::size, + input + 2 * xsimd::batch::size, + input + 3 * xsimd::batch::size); + } + + template + static inline xsimd::batch + ConsecutiveWithWrapping(xsimd::batch quant_mult, + const float *input, size_t cols_left, size_t cols, + size_t row_step) { + using batchf32 = xsimd::batch; + const float *inputs[4]; + for (size_t i = 0; i < std::size(inputs); ++i) { + while (cols_left < batchf32::size) { + input += cols * (row_step - 1); + cols_left += cols; + } + inputs[i] = input; + input += batchf32::size; + cols_left -= batchf32::size; + } + return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]); + } + + template + static inline xsimd::batch + ForReshape(xsimd::batch quant_mult, const float *input, + size_t cols) { + using batchf32 = xsimd::batch; + using batch8 = xsimd::batch; + using batch16 = xsimd::batch; + using batch32 = xsimd::batch; + using ubatch32 = xsimd::batch; + + // Put higher rows in the second half of the register. These will jumble + // around in the same way then conveniently land in the right place. + if constexpr (batchf32::size == 16) { + const batch8 neg127(-127); + // In reverse order: grabbing the first 32-bit values from each 128-bit + // register, then the second 32-bit values, etc. Grab 4 registers at a + // time in 32-bit format. + batch32 g0 = + QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult); + batch32 g1 = + QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult); + batch32 g2 = + QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult); + batch32 g3 = + QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult); + + // Pack 32-bit to 16-bit. + batch16 packed0 = deinterleave(g0, g1); + batch16 packed1 = deinterleave(g2, g3); + // Pack 16-bit to 8-bit. + batch8 packed = deinterleave(packed0, packed1); + // Ban -128. + packed = xsimd::max(packed, neg127); + + return xsimd::bitwise_cast( + xsimd::swizzle(xsimd::bitwise_cast(packed), + xsimd::make_batch_constant>())); + } else if constexpr (batchf32::size == 8) + return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, + input + 18 * cols); + else if constexpr (batchf32::size == 4) + // Skip a row. + return Tile(quant_mult, input, input + 4, input + 2 * cols, + input + 2 * cols + 4); + else + return {}; + } + + template + static inline xsimd::batch + Tile(xsimd::batch quant_mult, const float *input0, + const float *input1, const float *input2, const float *input3) { + using batch8 = xsimd::batch; + using batch16 = xsimd::batch; + using batch32 = xsimd::batch; + using ubatch32 = xsimd::batch; + + const batch8 neg127(-127); + // Grab 4 registers at a time in 32-bit format. + batch32 g0 = QuantizerGrab(input0, quant_mult); + batch32 g1 = QuantizerGrab(input1, quant_mult); + batch32 g2 = QuantizerGrab(input2, quant_mult); + batch32 g3 = QuantizerGrab(input3, quant_mult); + // Pack 32-bit to 16-bit. + batch16 packed0 = deinterleave(g0, g1); + batch16 packed1 = deinterleave(g2, g3); + // Pack 16-bit to 8-bit. + batch8 packed = deinterleave(packed0, packed1); + // Ban -128. + packed = xsimd::max(packed, neg127); + + if constexpr (batch32::size == 4) + return packed; + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 + // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return xsimd::bitwise_cast( + xsimd::swizzle(xsimd::bitwise_cast(packed), + xsimd::make_batch_constant>())); + } + +private: + // A version that produces uint8_ts + template + static inline xsimd::batch + TileU(xsimd::batch quant_mult, const float *input0, + const float *input1, const float *input2, const float *input3) { + using batch8 = xsimd::batch; + using batch16 = xsimd::batch; + using batch32 = xsimd::batch; + using ubatch32 = xsimd::batch; + + const batch8 neg127 = -127; + const batch8 pos127 = +127; + // Grab 4 registers at a time in 32-bit format. + batch32 g0 = QuantizerGrab(input0, quant_mult); + batch32 g1 = QuantizerGrab(input1, quant_mult); + batch32 g2 = QuantizerGrab(input2, quant_mult); + batch32 g3 = QuantizerGrab(input3, quant_mult); + // Pack 32-bit to 16-bit. + batch16 packed0 = deinterleave(g0, g1); + batch16 packed1 = deinterleave(g2, g3); + // Pack 16-bit to 8-bit. + batch8 packed = deinterleave(packed0, packed1); + // Ban -128. + packed = xsimd::max(packed, neg127); // Could be removed if we use +128 + packed = packed + pos127; + if (batch32::size == 4) + return xsimd::bitwise_cast(packed); + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 + // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return xsimd::bitwise_cast( + xsimd::swizzle(xsimd::bitwise_cast(packed), + xsimd::make_batch_constant>())); + } +}; + +template +inline void Transpose16InLane( + xsimd::batch &r0, xsimd::batch &r1, + xsimd::batch &r2, xsimd::batch &r3, + xsimd::batch &r4, xsimd::batch &r5, + xsimd::batch &r6, xsimd::batch &r7) { + /* r0: columns 0 1 2 3 4 5 6 7 from row 0 + r1: columns 0 1 2 3 4 5 6 7 from row 1*/ + auto r0_16 = xsimd::bitwise_cast(r0); + auto r1_16 = xsimd::bitwise_cast(r1); + auto r2_16 = xsimd::bitwise_cast(r2); + auto r3_16 = xsimd::bitwise_cast(r3); + auto r4_16 = xsimd::bitwise_cast(r4); + auto r5_16 = xsimd::bitwise_cast(r5); + auto r6_16 = xsimd::bitwise_cast(r6); + auto r7_16 = xsimd::bitwise_cast(r7); + + std::tie(r0_16, r1_16) = interleave(r0_16, r1_16); + std::tie(r2_16, r3_16) = interleave(r2_16, r3_16); + std::tie(r4_16, r5_16) = interleave(r4_16, r5_16); + std::tie(r6_16, r7_16) = interleave(r6_16, r7_16); + /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1 + r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1 + r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3 + r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3 + r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5 + r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5 + r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7 + r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ + auto r0_32 = xsimd::bitwise_cast(r0_16); + auto r2_32 = xsimd::bitwise_cast(r2_16); + auto r1_32 = xsimd::bitwise_cast(r1_16); + auto r3_32 = xsimd::bitwise_cast(r3_16); + auto r4_32 = xsimd::bitwise_cast(r4_16); + auto r6_32 = xsimd::bitwise_cast(r6_16); + auto r5_32 = xsimd::bitwise_cast(r5_16); + auto r7_32 = xsimd::bitwise_cast(r7_16); + + std::tie(r0_32, r2_32) = interleave(r0_32, r2_32); + std::tie(r1_32, r3_32) = interleave(r1_32, r3_32); + std::tie(r4_32, r6_32) = interleave(r4_32, r6_32); + std::tie(r5_32, r7_32) = interleave(r5_32, r7_32); + /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3 + r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3 + r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3 + r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3 + r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7 + r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7 + r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7 + r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ + + auto r0_64 = xsimd::bitwise_cast(r0_32); + auto r2_64 = xsimd::bitwise_cast(r2_32); + auto r1_64 = xsimd::bitwise_cast(r1_32); + auto r3_64 = xsimd::bitwise_cast(r3_32); + auto r4_64 = xsimd::bitwise_cast(r4_32); + auto r6_64 = xsimd::bitwise_cast(r6_32); + auto r5_64 = xsimd::bitwise_cast(r5_32); + auto r7_64 = xsimd::bitwise_cast(r7_32); + + std::tie(r0_64, r4_64) = interleave(r0_64, r4_64); + std::tie(r1_64, r5_64) = interleave(r1_64, r5_64); + std::tie(r2_64, r6_64) = interleave(r2_64, r6_64); + std::tie(r3_64, r7_64) = interleave(r3_64, r7_64); + + r0 = xsimd::bitwise_cast(r0_64); + r1 = xsimd::bitwise_cast(r1_64); + r2 = xsimd::bitwise_cast(r2_64); + r3 = xsimd::bitwise_cast(r3_64); + r4 = xsimd::bitwise_cast(r4_64); + r5 = xsimd::bitwise_cast(r5_64); + r6 = xsimd::bitwise_cast(r6_64); + r7 = xsimd::bitwise_cast(r7_64); + /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7 + r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7 + r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7 + r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7 + r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7 + r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ + /* Empirically gcc is able to remove these movs and just rename the outputs of + * Interleave64. */ + std::swap(r1, r4); + std::swap(r3, r6); +} + +template +void SelectColumnsOfB(const xsimd::batch *input, + xsimd::batch *output, + size_t rows_bytes /* number of bytes in a row */, + const IntegerTy *cols_begin, const IntegerTy *cols_end) { + using batch8 = xsimd::batch; + /* Do columns for multiples of 8.*/ + size_t register_rows = rows_bytes / batch8::size; + const batch8 *starts[8]; + for (; cols_begin != cols_end; cols_begin += 8) { + for (size_t k = 0; k < 8; ++k) { + starts[k] = + input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; + } + for (size_t r = 0; r < register_rows; ++r) { + for (size_t k = 0; k < 8; ++k) { + *(output++) = *starts[k]; + starts[k] += 8; + } + } + } +} + +} // namespace + +namespace callbacks { +template +xsimd::batch Unquantize::operator()(xsimd::batch total, size_t, size_t, + size_t) { + return xsimd::batch_cast(total) * unquant_mult; +} + +template +std::tuple, xsimd::batch> Unquantize::operator()( + std::tuple, xsimd::batch> total, + size_t, size_t, size_t) { + return std::make_tuple( + xsimd::batch_cast(std::get<0>(total)) * unquant_mult, + xsimd::batch_cast(std::get<1>(total)) * unquant_mult); +} + +template +xsimd::batch AddBias::operator()(xsimd::batch total, size_t, + size_t col_idx, size_t) { + return total + xsimd::batch::load_aligned(bias_addr + col_idx); +} + +template +std::tuple, xsimd::batch> +AddBias::operator()( + std::tuple, xsimd::batch> total, + size_t, size_t col_idx, size_t) { + return std::make_tuple( + std::get<0>(total) + xsimd::batch::load_aligned(bias_addr + col_idx + 0), + std::get<1>(total) + + xsimd::batch::load_aligned(bias_addr + col_idx + + xsimd::batch::size)); +} + +template +void Write::operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size) { + result.store_aligned(output_addr + row_idx * col_size + col_idx); +} + +template +void Write::operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size) { + xsimd::bitwise_cast(result).store_aligned( + output_addr + row_idx * col_size + col_idx); +} + +template +void Write::operator()( + std::tuple, xsimd::batch> result, + size_t row_idx, size_t col_idx, size_t col_size) { + std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx + + 0); + std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx + + xsimd::batch::size); +} + +template +void Write::operator()( + std::tuple, xsimd::batch> result, + size_t row_idx, size_t col_idx, size_t col_size) { + xsimd::bitwise_cast(std::get<0>(result)) + .store_aligned(output_addr + row_idx * col_size + col_idx + 0); + xsimd::bitwise_cast(std::get<1>(result)) + .store_aligned(output_addr + row_idx * col_size + col_idx + + xsimd::batch::size); +} + +template +void UnquantizeAndWrite::operator()(T const &total, size_t row_idx, + size_t col_idx, size_t col_size) { + auto unquantized = unquantize(total, row_idx, col_idx, col_size); + write(unquantized, row_idx, col_idx, col_size); +} + +template +void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx, + size_t col_idx, size_t col_size) { + auto unquantized = unquantize(total, row_idx, col_idx, col_size); + auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size); + write(bias_added, row_idx, col_idx, col_size); +} +} // namespace callbacks + +template +void Engine::QuantizeU(const float *input, uint8_t *output, + float quant_mult, size_t size) { + using batch8 = xsimd::batch; + + xsimd::batch q(quant_mult); + const float *end = input + size; + for (; input != end; input += batch8::size, output += batch8::size) { + auto tile = QuantizeTile8::ConsecutiveU(q, input); + tile.store_aligned(output); + } +} + +template +void Engine::Quantize(const float *const input, int8_t *const output, + float quant_mult, size_t size) { + using batch8 = xsimd::batch; + + const std::size_t kBatch = batch8::size; + const std::size_t fast_end = size & ~(kBatch - 1); + + xsimd::batch q(quant_mult); + for (std::size_t i = 0; i < fast_end; i += kBatch) { + auto tile = QuantizeTile8::Consecutive(q, input + i); + tile.store_aligned(output + i); + } + + std::size_t overhang = size & (kBatch - 1); + if (!overhang) + return; + /* Each does size(xsimd::batch) / 32 == kBatch / 4 floats at a + * time. If we're allowed to read one of them, then we can read the whole + * register. + */ + const float *inputs[4]; + std::size_t i; + for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) { + inputs[i] = &input[fast_end + i * (kBatch / 4)]; + } + /* These will be clipped off. */ + for (; i < 4; ++i) { + inputs[i] = &input[fast_end]; + } + auto result = + QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]); + std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang); +} + +template +template +void Engine::SelectColumnsB(const int8_t *input, int8_t *output, + size_t rows, const IntegerTy *cols_begin, + const IntegerTy *cols_end) { + using batch8 = xsimd::batch; + SelectColumnsOfB(reinterpret_cast(input), + reinterpret_cast(output), rows, cols_begin, + cols_end); +} + +template +void Engine::PrepareBTransposed(const float *input, int8_t *output, + float quant_mult, size_t cols, + size_t rows) { + using batch8 = xsimd::batch; + const size_t RegisterElemsInt = batch8::size; + const size_t kColStride = 8; + + xsimd::batch q(quant_mult); + auto *output_it = reinterpret_cast(output); + size_t r = 0; + size_t c = 0; + while (r < rows) { + for (size_t ri = 0; ri < 8; ++ri) + *output_it++ = QuantizeTile8::ConsecutiveWithWrapping( + q, input + (r + ri) * cols + c, cols - c, cols, 8); + c += RegisterElemsInt; + while (c >= cols) { + r += kColStride; + c -= cols; + } + } +} + +template +void Engine::PrepareBQuantizedTransposed(const int8_t *input, + int8_t *output, size_t cols, + size_t rows) { + using batch8 = xsimd::batch; + const size_t RegisterElems = batch8::size; + const size_t kColStride = 8; + + auto *output_it = reinterpret_cast(output); + for (size_t r = 0; r < rows; r += kColStride) + for (size_t c = 0; c < cols; c += RegisterElems) + for (size_t ri = 0; ri < 8; ++ri) + *output_it++ = + *reinterpret_cast(input + (r + ri) * cols + c); +} + +template +void Engine::PrepareB(const float *input, int8_t *output_shadow, + float quant_mult, size_t rows, size_t cols) { + using batch8 = xsimd::batch; + + xsimd::batch q(quant_mult); + /* Currently all multipliers have a stride of 8 columns.*/ + const size_t kColStride = 8; + auto *output = reinterpret_cast(output_shadow); + for (size_t c = 0; c < cols; c += kColStride) { + for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) { + output[0] = + QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols); + output[1] = + QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols); + output[2] = + QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols); + output[3] = + QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols); + output[4] = + QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols); + output[5] = + QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols); + output[6] = + QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols); + output[7] = + QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols); + std::tie(output[0], output[1]) = + interleave(xsimd::bitwise_cast(output[0]), + xsimd::bitwise_cast(output[1])); + std::tie(output[2], output[3]) = + interleave(xsimd::bitwise_cast(output[2]), + xsimd::bitwise_cast(output[3])); + std::tie(output[4], output[5]) = + interleave(xsimd::bitwise_cast(output[4]), + xsimd::bitwise_cast(output[5])); + std::tie(output[6], output[7]) = + interleave(xsimd::bitwise_cast(output[6]), + xsimd::bitwise_cast(output[7])); + Transpose16InLane(output[0], output[1], output[2], output[3], output[4], + output[5], output[6], output[7]); + } + } +} + +template +void Engine::PrepareA(const float *input, int8_t *output, + float quant_mult, size_t rows, size_t cols) { + Quantize(input, output, quant_mult, rows * cols); +} + +template +void Engine::Shift::PrepareA(const float *input, uint8_t *output, + float quant_mult, size_t rows, size_t cols) { + QuantizeU(input, output, quant_mult, rows * cols); +} + +template +template +void Engine::Shift::Multiply(const uint8_t *A, const int8_t *B, + size_t A_rows, size_t width, size_t B_cols, + Callback callback) { + + using batch8 = xsimd::batch; + using ubatch8 = xsimd::batch; + using batch32 = xsimd::batch; + + const size_t simd_width = width / batch8::size; + for (size_t B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const auto *B0_col = + reinterpret_cast(B) + simd_width * B0_colidx; + /* Process one row of A at a time. Doesn't seem to be faster to do multiple + * rows of A at once.*/ + for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + const auto *A_row = + reinterpret_cast(A + A_rowidx * width); + /* These will be packed 16-bit integers containing sums for each row of B + multiplied by the row of A. Iterate over shared (inner) dimension.*/ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is + * declared here.*/ + size_t k = 0; + ubatch8 a = *(A_row + k); + batch32 isum0 = maddw(a, *(B0_col + k * 8)); + batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1)); + batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2)); + batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3)); + batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4)); + batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5)); + batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6)); + batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7)); + for (k = 1; k < simd_width; ++k) { + a = *(A_row + k); + /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ + /* Upcast to 32-bit and horizontally add.*/ + isum0 = maddw(a, *(B0_col + k * 8 + 0), isum0); + isum1 = maddw(a, *(B0_col + k * 8 + 1), isum1); + isum2 = maddw(a, *(B0_col + k * 8 + 2), isum2); + isum3 = maddw(a, *(B0_col + k * 8 + 3), isum3); + isum4 = maddw(a, *(B0_col + k * 8 + 4), isum4); + isum5 = maddw(a, *(B0_col + k * 8 + 5), isum5); + isum6 = maddw(a, *(B0_col + k * 8 + 6), isum6); + isum7 = maddw(a, *(B0_col + k * 8 + 7), isum7); + } + /* Reduce sums within 128-bit lanes.*/ + auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); + auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); + /*The specific implementation may need to reduce further.*/ + auto total = PermuteSummer(pack0123, pack4567); + callback(total, A_rowidx, B0_colidx, B_cols); + } + } +} + +template +template +void Engine::Shift::PrepareBias(const int8_t *B, size_t width, + size_t B_cols, Callback C) { + using batch8 = xsimd::batch; + const size_t simd_width = width / batch8::size; + xsimd::batch a(1); + for (size_t j = 0; j < B_cols; j += 8) { + /*Process one row of A at a time. Doesn't seem to be faster to do multiple + * rows of A at once.*/ + const int8_t *B_j = B + j * width; + + /* Rather than initializing as zeros and adding, just initialize the + * first.*/ + /* These will be packed 16-bit integers containing sums for each column of + * B multiplied by the row of A.*/ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is + * declared here.*/ + auto isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size])); + auto isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size])); + auto isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size])); + auto isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size])); + auto isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size])); + auto isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size])); + auto isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size])); + auto isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size])); + + B_j += 8 * batch8::size; + + for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) { + isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]), isum0); + isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]), isum1); + isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]), isum2); + isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]), isum3); + isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]), isum4); + isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]), isum5); + isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]), isum6); + isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]), isum7); + } + + auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); + auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); + + auto total = PermuteSummer(pack0123, pack4567); + C(total, 0, j, B_cols); + } +} + +} // namespace gemmology + +#endif diff --git a/third_party/gemmology/gemmology_fwd.h b/third_party/gemmology/gemmology_fwd.h new file mode 100644 index 0000000000..83e3719b4f --- /dev/null +++ b/third_party/gemmology/gemmology_fwd.h @@ -0,0 +1,218 @@ +/*************************************************************** + * _ * + * | | * + * __ _ ___ _ __ ___ _ __ ___ ___ | | ___ __ _ _ _ * + * / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | | * + * | (_| | __/ | | | | | | | | | | (_) | | (_) | (_| | |_| | * + * \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, | * + * __/ | __/ | __/ | * + * |___/ |___/ |___/ * + * * + * version 0.1 * + ***************************************************************/ + +#ifndef GEMMOLOGY_FWD_H +#define GEMMOLOGY_FWD_H + +#include +#include +#include +#include + +namespace gemmology { + +namespace callbacks { + +struct Unquantize { + float unquant_mult; + template + xsimd::batch operator()(xsimd::batch total, size_t, size_t, size_t); + template + std::tuple, xsimd::batch> operator()( + std::tuple, xsimd::batch> + total, + size_t, size_t, size_t); +}; + +struct AddBias { + const float *bias_addr; + template + xsimd::batch operator()(xsimd::batch total, size_t, size_t col_idx, + size_t); + template + std::tuple, xsimd::batch> + operator()( + std::tuple, xsimd::batch> total, + size_t, size_t col_idx, size_t); +}; + +struct Write { + float *output_addr; + + Write(float *o) : output_addr(o) {} + + template + void operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size); + template + void operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size); + + template + void operator()( + std::tuple, xsimd::batch> result, + size_t row_idx, size_t col_idx, size_t col_size); + + template + void operator()( + std::tuple, xsimd::batch> + result, + size_t row_idx, size_t col_idx, size_t col_size); +}; + +struct UnquantizeAndWrite { + + Unquantize unquantize; + Write write; + + UnquantizeAndWrite(float factor, float *output) + : unquantize{factor}, write{output} {} + + template + void operator()(T const &total, size_t row_idx, size_t col_idx, + size_t col_size); +}; + +struct UnquantizeAndAddBiasAndWrite { + + Unquantize unquantize; + AddBias add_bias; + Write write; + + UnquantizeAndAddBiasAndWrite(float factor, const float *bias, float *output) + : unquantize{factor}, add_bias{bias}, write{output} {} + + template + void operator()(T const &total, size_t row_idx, size_t col_idx, + size_t col_size); +}; + +} // namespace callbacks + +// +// Arch-specific implementation of each routine +// +template struct Engine { + + static void QuantizeU(const float *input, uint8_t *output, float quant_mult, + size_t size); + + static void Quantize(const float *const input, int8_t *const output, + float quant_mult, size_t size); + + template + static void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, + const IntegerTy *cols_begin, + const IntegerTy *cols_end); + + static void PrepareBTransposed(const float *input, int8_t *output, + float quant_mult, size_t cols, size_t rows); + + static void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, + size_t cols, size_t rows); + + static void PrepareB(const float *input, int8_t *output_shadow, + float quant_mult, size_t rows, size_t cols); + + static void PrepareA(const float *input, int8_t *output, float quant_mult, + size_t rows, size_t cols); + + struct Shift { + + static void PrepareA(const float *input, uint8_t *output, float quant_mult, + size_t rows, size_t cols); + + template + static void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, + size_t width, size_t B_cols, Callback callback); + + template + static void PrepareBias(const int8_t *B, size_t width, size_t B_cols, + Callback C); + }; +}; + +// +// Top-level wrappers that mostly match intgemm API +// + +template +inline void QuantizeU(const float *input, uint8_t *output, float quant_mult, + size_t size) { + return Engine::QuantizeU(input, output, quant_mult, size); +} + +template +inline void Quantize(const float *const input, int8_t *const output, + float quant_mult, size_t size) { + return Engine::Quantize(input, output, quant_mult, size); +} + +template +inline void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, + const IntegerTy *cols_begin, + const IntegerTy *cols_end) { + return Engine::SelectColumnsB(input, output, rows, cols_begin, + cols_end); +} + +template +inline void PrepareBTransposed(const float *input, int8_t *output, + float quant_mult, size_t cols, size_t rows) { + return Engine::PrepareBTransposed(input, output, quant_mult, cols, + rows); +} + +template +inline void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, + size_t cols, size_t rows) { + return Engine::PrepareBQuantizedTransposed(input, output, cols, rows); +} + +template +inline void PrepareB(const float *input, int8_t *output_shadow, + float quant_mult, size_t rows, size_t cols) { + return Engine::PrepareB(input, output_shadow, quant_mult, rows, cols); +} + +template +inline void PrepareA(const float *input, int8_t *output, float quant_mult, + size_t rows, size_t cols) { + return Engine::PrepareA(input, output, quant_mult, rows, cols); +} + +namespace Shift { + +template +inline void PrepareA(const float *input, uint8_t *output, float quant_mult, + size_t rows, size_t cols) { + return Engine::Shift::PrepareA(input, output, quant_mult, rows, cols); +} + +template +inline void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, + size_t width, size_t B_cols, Callback C) { + return Engine::Shift::Multiply(A, B, A_rows, width, B_cols, C); +} + +template +inline void PrepareBias(const int8_t *B, size_t width, size_t B_cols, + Callback C) { + return Engine::Shift::PrepareBias(B, width, B_cols, C); +} + +} // namespace Shift + +} // namespace gemmology + +#endif diff --git a/third_party/gemmology/kernels/GemmologyEngineAVX2.cpp b/third_party/gemmology/kernels/GemmologyEngineAVX2.cpp new file mode 100644 index 0000000000..2bb55d4a1a --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineAVX2.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine; +template void Engine::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/kernels/GemmologyEngineAVX512BW.cpp b/third_party/gemmology/kernels/GemmologyEngineAVX512BW.cpp new file mode 100644 index 0000000000..3cb1d35017 --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineAVX512BW.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine; +template void Engine::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/kernels/GemmologyEngineAVX512VNNI.cpp b/third_party/gemmology/kernels/GemmologyEngineAVX512VNNI.cpp new file mode 100644 index 0000000000..80425fafd4 --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineAVX512VNNI.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine>; +template void Engine>::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine>::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine>::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/kernels/GemmologyEngineAVXVNNI.cpp b/third_party/gemmology/kernels/GemmologyEngineAVXVNNI.cpp new file mode 100644 index 0000000000..c0a057346b --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineAVXVNNI.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine; +template void Engine::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/kernels/GemmologyEngineNeon64.cpp b/third_party/gemmology/kernels/GemmologyEngineNeon64.cpp new file mode 100644 index 0000000000..63801f8ceb --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineNeon64.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine; +template void Engine::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/kernels/GemmologyEngineSSE2.cpp b/third_party/gemmology/kernels/GemmologyEngineSSE2.cpp new file mode 100644 index 0000000000..134f9e0e92 --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineSSE2.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine; +template void Engine::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/kernels/GemmologyEngineSSSE3.cpp b/third_party/gemmology/kernels/GemmologyEngineSSSE3.cpp new file mode 100644 index 0000000000..9b6a6e1bff --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineSSSE3.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +namespace gemmology { +template struct Engine; +template void Engine::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/moz.yaml b/third_party/gemmology/moz.yaml new file mode 100644 index 0000000000..d9f9472da7 --- /dev/null +++ b/third_party/gemmology/moz.yaml @@ -0,0 +1,29 @@ +schema: 1 + +bugzilla: + product: Core + component: "JavaScript: WebAssembly" + +origin: + name: gemmology + description: small integer matrix multiply + + url: https://github.com/mozilla/gemmology + + release: ec535e87d0ab9d1457ff6d2af247cc8113e74694 (2024-02-05T09:05:20Z). + revision: ec535e87d0ab9d1457ff6d2af247cc8113e74694 + + license: MIT + +vendoring: + url: https://github.com/mozilla/gemmology + source-hosting: github + tracking: commit + + exclude: + - ".*" + - "*.rst" + - test + + keep: + - kernels/*.cpp -- cgit v1.2.3