summaryrefslogtreecommitdiffstats
path: root/third_party/gemmology/gemmology.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/gemmology/gemmology.h')
-rw-r--r--third_party/gemmology/gemmology.h1335
1 files changed, 1335 insertions, 0 deletions
diff --git a/third_party/gemmology/gemmology.h b/third_party/gemmology/gemmology.h
new file mode 100644
index 0000000000..86eb0b36be
--- /dev/null
+++ b/third_party/gemmology/gemmology.h
@@ -0,0 +1,1335 @@
+#ifndef GEMMOLOGY_H
+#define GEMMOLOGY_H
+
+#include "gemmology_fwd.h"
+
+#include <cstdint>
+#include <cstring>
+#include <tuple>
+
+#include <xsimd/xsimd.hpp>
+
+namespace gemmology {
+
+namespace {
+
+//
+// Arch specific implementation of various elementary operations
+//
+
+namespace kernel {
+
+#ifdef __AVX512BW__
+template <class Arch>
+std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
+interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return {_mm512_unpacklo_epi8(first, second),
+ _mm512_unpackhi_epi8(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
+interleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return {_mm512_unpacklo_epi16(first, second),
+ _mm512_unpackhi_epi16(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+interleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return {_mm512_unpacklo_epi32(first, second),
+ _mm512_unpackhi_epi32(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
+interleave(xsimd::batch<int64_t, Arch> first,
+ xsimd::batch<int64_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return {_mm512_unpacklo_epi64(first, second),
+ _mm512_unpackhi_epi64(first, second)};
+}
+
+template <class Arch>
+xsimd::batch<int8_t, Arch>
+deinterleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return _mm512_packs_epi16(first, second);
+}
+
+template <class Arch>
+xsimd::batch<int16_t, Arch>
+deinterleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return _mm512_packs_epi32(first, second);
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return _mm512_madd_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return _mm512_maddubs_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ return _mm512_madd_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, xsimd::avx2>
+PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
+ xsimd::batch<int32_t, Arch> pack4567,
+ xsimd::kernel::requires_arch<xsimd::avx512bw>) {
+ // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567,
+ // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567]
+ __m512i mix0 =
+ _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6));
+ // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567,
+ // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567]
+ __m512i mix1 =
+ _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2));
+ __m512i added = _mm512_add_epi32(mix0, mix1);
+ // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7.
+ // Fold register over itself.
+ return _mm256_add_epi32(_mm512_castsi512_si256(added),
+ _mm512_extracti64x4_epi64(added, 1));
+}
+#endif
+
+#ifdef __AVX2__
+template <class Arch>
+std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
+interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return {_mm256_unpacklo_epi8(first, second),
+ _mm256_unpackhi_epi8(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
+interleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return {_mm256_unpacklo_epi16(first, second),
+ _mm256_unpackhi_epi16(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+interleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return {_mm256_unpacklo_epi32(first, second),
+ _mm256_unpackhi_epi32(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
+interleave(xsimd::batch<int64_t, Arch> first,
+ xsimd::batch<int64_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return {_mm256_unpacklo_epi64(first, second),
+ _mm256_unpackhi_epi64(first, second)};
+}
+
+template <class Arch>
+xsimd::batch<int8_t, Arch>
+deinterleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return _mm256_packs_epi16(first, second);
+}
+
+template <class Arch>
+xsimd::batch<int16_t, Arch>
+deinterleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return _mm256_packs_epi32(first, second);
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return _mm256_madd_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return _mm256_maddubs_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x));
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
+ xsimd::batch<int32_t, Arch> pack4567,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
+ __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21);
+ // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s
+ __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0);
+ return _mm256_add_epi32(rev, blended);
+}
+#endif
+
+#ifdef __SSSE3__
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::ssse3>) {
+ return _mm_maddubs_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::ssse3>) {
+ return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x));
+}
+#endif
+
+#ifdef __SSE2__
+template <class Arch>
+std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
+interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return {_mm_unpacklo_epi8(first, second), _mm_unpackhi_epi8(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
+interleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return {_mm_unpacklo_epi16(first, second), _mm_unpackhi_epi16(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+interleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return {_mm_unpacklo_epi32(first, second), _mm_unpackhi_epi32(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
+interleave(xsimd::batch<int64_t, Arch> first,
+ xsimd::batch<int64_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return {_mm_unpacklo_epi64(first, second), _mm_unpackhi_epi64(first, second)};
+}
+
+template <class Arch>
+xsimd::batch<int8_t, Arch>
+deinterleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return _mm_packs_epi16(first, second);
+}
+
+template <class Arch>
+xsimd::batch<int16_t, Arch>
+deinterleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return _mm_packs_epi32(first, second);
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return _mm_madd_epi16(x, y);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<uint8_t, Arch> a, xsimd::batch<int8_t, Arch> b,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ // Adapted from
+ // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
+ // a = 0x00 0x01 0xFE 0x04 ...
+ // b = 0x00 0x02 0x80 0x84 ...
+
+ // To extend signed 8-bit value, MSB has to be set to 0xFF
+ __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128());
+
+ // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
+
+ // Unpack positives with 0x00, negatives with 0xFF
+ __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128());
+ __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128());
+ __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b);
+ __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b);
+
+ // Here - valid 16-bit signed integers corresponding to the 8-bit input
+ // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
+
+ // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
+ __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l);
+ __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h);
+
+ // Now go back from 32-bit values to 16-bit values & signed saturate
+ return _mm_packs_epi32(madd_epi32_l, madd_epi32_h);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<int8_t, Arch> a, xsimd::batch<int8_t, Arch> b,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ // adapted
+ // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
+ // a = 0x00 0x01 0xFE 0x04 ...
+ // b = 0x00 0x02 0x80 0x84 ...
+
+ // To extend signed 8-bit value, MSB has to be set to 0xFF
+ __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128());
+ __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128());
+
+ // sign_mask_a = 0x00 0x00 0xFF 0x00 ...
+ // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
+
+ // Unpack positives with 0x00, negatives with 0xFF
+ __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a);
+ __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a);
+ __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b);
+ __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b);
+
+ // Here - valid 16-bit signed integers corresponding to the 8-bit input
+ // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
+
+ // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
+ __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l);
+ __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h);
+
+ // Now go back from 32-bit values to 16-bit values & signed saturate
+ return _mm_packs_epi32(madd_epi32_l, madd_epi32_h);
+}
+
+template <class Arch>
+inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
+ xsimd::batch<int32_t, Arch> pack4567,
+ xsimd::kernel::requires_arch<xsimd::sse2>) {
+ return {pack0123, pack4567};
+}
+
+#endif
+
+#if __ARM_ARCH >= 7
+template <class Arch>
+std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
+interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ int8x8_t first_lo = vget_low_s8(first);
+ int8x8_t second_lo = vget_low_s8(second);
+ int8x8x2_t result_lo = vzip_s8(first_lo, second_lo);
+ int8x8_t first_hi = vget_high_s8(first);
+ int8x8_t second_hi = vget_high_s8(second);
+ int8x8x2_t result_hi = vzip_s8(first_hi, second_hi);
+ return {vcombine_s8(result_lo.val[0], result_lo.val[1]),
+ vcombine_s8(result_hi.val[0], result_hi.val[1])};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
+interleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ int16x4_t first_lo = vget_low_s16(first);
+ int16x4_t second_lo = vget_low_s16(second);
+ int16x4x2_t result_lo = vzip_s16(first_lo, second_lo);
+ int16x4_t first_hi = vget_high_s16(first);
+ int16x4_t second_hi = vget_high_s16(second);
+ int16x4x2_t result_hi = vzip_s16(first_hi, second_hi);
+ return {vcombine_s16(result_lo.val[0], result_lo.val[1]),
+ vcombine_s16(result_hi.val[0], result_hi.val[1])};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+interleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ int32x2_t first_lo = vget_low_s32(first);
+ int32x2_t second_lo = vget_low_s32(second);
+ int32x2x2_t result_lo = vzip_s32(first_lo, second_lo);
+ int32x2_t first_hi = vget_high_s32(first);
+ int32x2_t second_hi = vget_high_s32(second);
+ int32x2x2_t result_hi = vzip_s32(first_hi, second_hi);
+ return {vcombine_s32(result_lo.val[0], result_lo.val[1]),
+ vcombine_s32(result_hi.val[0], result_hi.val[1])};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
+interleave(xsimd::batch<int64_t, Arch> first,
+ xsimd::batch<int64_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ int64x1_t first_lo = vget_low_s64(first);
+ int64x1_t second_lo = vget_low_s64(second);
+ int64x1_t first_hi = vget_high_s64(first);
+ int64x1_t second_hi = vget_high_s64(second);
+ return {vcombine_s64(first_lo, second_lo), vcombine_s64(first_hi, second_hi)};
+}
+
+template <class Arch>
+xsimd::batch<int8_t, Arch>
+deinterleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+
+ return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second));
+}
+
+template <class Arch>
+xsimd::batch<int16_t, Arch>
+deinterleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second));
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+
+ int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
+ int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
+
+ int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low));
+ int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high));
+
+ return vcombine_s32(low_sum, high_sum);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+
+ // This would be much simpler if x86 would choose to zero extend OR sign
+ // extend, not both. This could probably be optimized better.
+
+ // Zero extend x
+ int16x8_t x_odd =
+ vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8));
+ int16x8_t x_even = vreinterpretq_s16_u16(
+ vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00)));
+
+ // Sign extend by shifting left then shifting right.
+ int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8);
+ int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8);
+
+ // multiply
+ int16x8_t prod1 = vmulq_s16(x_even, y_even);
+ int16x8_t prod2 = vmulq_s16(x_odd, y_odd);
+
+ // saturated add
+ return vqaddq_s16(prod1, prod2);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y));
+ int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y));
+
+ int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low));
+ int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high));
+
+ return vcombine_s16(low_sum, high_sum);
+}
+
+template <class Arch>
+inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
+ xsimd::batch<int32_t, Arch> pack4567,
+ xsimd::kernel::requires_arch<xsimd::neon>) {
+ return {pack0123, pack4567};
+}
+#endif
+
+#ifdef __aarch64__
+template <class Arch>
+std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
+interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ return {vzip1q_s8(first, second), vzip2q_s8(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
+interleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ return {vzip1q_s16(first, second), vzip2q_s16(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
+interleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ return {vzip1q_s32(first, second), vzip2q_s32(first, second)};
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
+interleave(xsimd::batch<int64_t, Arch> first,
+ xsimd::batch<int64_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ return {vzip1q_s64(first, second), vzip2q_s64(first, second)};
+}
+
+template <class Arch>
+xsimd::batch<int8_t, Arch>
+deinterleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second));
+}
+
+template <class Arch>
+xsimd::batch<int16_t, Arch>
+deinterleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second));
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
+ int32x4_t high = vmull_high_s16(x, y);
+ return vpaddq_s32(low, high);
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+
+ int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))),
+ vmovl_s8(vget_low_s8(y)));
+ int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))),
+ vmovl_s8(vget_high_s8(y)));
+ return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th));
+}
+
+template <class Arch>
+inline xsimd::batch<int16_t, Arch>
+madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y));
+ int16x8_t high = vmull_high_s8(x, y);
+ return vpaddq_s16(low, high);
+}
+
+#endif
+
+} // namespace kernel
+
+//
+// Generic dispatcher for interleave, deinterleave madd and PermuteSummer
+//
+
+template <class T, class Arch>
+std::tuple<xsimd::batch<T, Arch>, xsimd::batch<T, Arch>>
+interleave(xsimd::batch<T, Arch> first, xsimd::batch<T, Arch> second) {
+ return kernel::interleave(first, second, Arch{});
+}
+
+template <class Arch>
+xsimd::batch<int8_t, Arch> deinterleave(xsimd::batch<int16_t, Arch> first,
+ xsimd::batch<int16_t, Arch> second) {
+ return kernel::deinterleave(first, second, Arch{});
+}
+template <class Arch>
+xsimd::batch<int16_t, Arch> deinterleave(xsimd::batch<int32_t, Arch> first,
+ xsimd::batch<int32_t, Arch> second) {
+ return kernel::deinterleave(first, second, Arch{});
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch> madd(xsimd::batch<int16_t, Arch> x,
+ xsimd::batch<int16_t, Arch> y) {
+ return kernel::madd(x, y, Arch{});
+}
+template <class Arch>
+inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<int8_t, Arch> x,
+ xsimd::batch<int8_t, Arch> y) {
+ return kernel::madd(x, y, Arch{});
+}
+template <class Arch>
+inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<uint8_t, Arch> x,
+ xsimd::batch<int8_t, Arch> y) {
+ return kernel::madd(x, y, Arch{});
+}
+
+template <class Arch>
+inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
+ xsimd::batch<int32_t, Arch> pack4567)
+ -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) {
+ return kernel::PermuteSummer(pack0123, pack4567, Arch{});
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
+ xsimd::batch<int32_t, Arch> sum1,
+ xsimd::batch<int32_t, Arch> sum2,
+ xsimd::batch<int32_t, Arch> sum3) {
+ std::tie(sum0, sum1) = interleave(sum0, sum1);
+ auto pack01 = sum0 + sum1;
+ std::tie(sum2, sum3) = interleave(sum2, sum3);
+ auto pack23 = sum2 + sum3;
+
+ auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01),
+ xsimd::bitwise_cast<int64_t>(pack23));
+ return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) +
+ xsimd::bitwise_cast<int32_t>(std::get<1>(packed));
+}
+
+template <class Arch>
+static inline xsimd::batch<int32_t, Arch>
+quantize(xsimd::batch<float, Arch> input,
+ xsimd::batch<float, Arch> quant_mult) {
+ return xsimd::nearbyint_as_int(input * quant_mult);
+}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+QuantizerGrab(const float *input, xsimd::batch<float, Arch> quant_mult_reg) {
+ return quantize(xsimd::batch<float, Arch>::load_unaligned(input),
+ quant_mult_reg);
+}
+
+#ifdef __AVX512BW__
+inline __m512 Concat(const __m256 first, const __m256 second) {
+ // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway.
+ return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1);
+}
+
+// Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be
+// controlled independently.
+/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set
+ * INTGEMM_AVX512BW */
+inline __m512i QuantizerGrabHalves(const float *input0, const float *input1,
+ const __m512 quant_mult_reg) {
+ __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1));
+ appended = _mm512_mul_ps(appended, quant_mult_reg);
+ return _mm512_cvtps_epi32(appended);
+}
+#else
+template <class Arch>
+inline xsimd::batch<int32_t, Arch>
+QuantizerGrabHalves(const float *input0, const float *input1,
+ xsimd::batch<float, Arch> quant_mult_reg);
+#endif
+
+/* Read 8 floats at a time from input0, input1, input2, and input3. Quantize
+ * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate
+ * the result into one register and return it.
+ */
+class QuantizeTile8 {
+ template <class Arch> struct Tiler {
+ static constexpr uint32_t get(std::size_t i, std::size_t n) {
+ size_t factor = xsimd::batch<float, Arch>::size / 4;
+ return (i % factor) * 4 + i / factor;
+ }
+ };
+
+public:
+ template <class Arch>
+ static inline xsimd::batch<int8_t, Arch>
+ Consecutive(xsimd::batch<float, Arch> quant_mult, const float *input) {
+ return Tile(quant_mult, input + 0 * xsimd::batch<float, Arch>::size,
+ input + 1 * xsimd::batch<float, Arch>::size,
+ input + 2 * xsimd::batch<float, Arch>::size,
+ input + 3 * xsimd::batch<float, Arch>::size);
+ }
+
+ template <class Arch>
+ static inline xsimd::batch<uint8_t, Arch>
+ ConsecutiveU(xsimd::batch<float, Arch> quant_mult, const float *input) {
+ return TileU(quant_mult, input + 0 * xsimd::batch<float, Arch>::size,
+ input + 1 * xsimd::batch<float, Arch>::size,
+ input + 2 * xsimd::batch<float, Arch>::size,
+ input + 3 * xsimd::batch<float, Arch>::size);
+ }
+
+ template <class Arch>
+ static inline xsimd::batch<int8_t, Arch>
+ ConsecutiveWithWrapping(xsimd::batch<float, Arch> quant_mult,
+ const float *input, size_t cols_left, size_t cols,
+ size_t row_step) {
+ using batchf32 = xsimd::batch<float, Arch>;
+ const float *inputs[4];
+ for (size_t i = 0; i < std::size(inputs); ++i) {
+ while (cols_left < batchf32::size) {
+ input += cols * (row_step - 1);
+ cols_left += cols;
+ }
+ inputs[i] = input;
+ input += batchf32::size;
+ cols_left -= batchf32::size;
+ }
+ return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]);
+ }
+
+ template <class Arch>
+ static inline xsimd::batch<int8_t, Arch>
+ ForReshape(xsimd::batch<float, Arch> quant_mult, const float *input,
+ size_t cols) {
+ using batchf32 = xsimd::batch<float, Arch>;
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ using batch16 = xsimd::batch<int16_t, Arch>;
+ using batch32 = xsimd::batch<int32_t, Arch>;
+ using ubatch32 = xsimd::batch<uint32_t, Arch>;
+
+ // Put higher rows in the second half of the register. These will jumble
+ // around in the same way then conveniently land in the right place.
+ if constexpr (batchf32::size == 16) {
+ const batch8 neg127(-127);
+ // In reverse order: grabbing the first 32-bit values from each 128-bit
+ // register, then the second 32-bit values, etc. Grab 4 registers at a
+ // time in 32-bit format.
+ batch32 g0 =
+ QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult);
+ batch32 g1 =
+ QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult);
+ batch32 g2 =
+ QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult);
+ batch32 g3 =
+ QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult);
+
+ // Pack 32-bit to 16-bit.
+ batch16 packed0 = deinterleave(g0, g1);
+ batch16 packed1 = deinterleave(g2, g3);
+ // Pack 16-bit to 8-bit.
+ batch8 packed = deinterleave(packed0, packed1);
+ // Ban -128.
+ packed = xsimd::max(packed, neg127);
+
+ return xsimd::bitwise_cast<int8_t>(
+ xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
+ xsimd::make_batch_constant<ubatch32, Tiler<Arch>>()));
+ } else if constexpr (batchf32::size == 8)
+ return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols,
+ input + 18 * cols);
+ else if constexpr (batchf32::size == 4)
+ // Skip a row.
+ return Tile(quant_mult, input, input + 4, input + 2 * cols,
+ input + 2 * cols + 4);
+ else
+ return {};
+ }
+
+ template <class Arch>
+ static inline xsimd::batch<int8_t, Arch>
+ Tile(xsimd::batch<float, Arch> quant_mult, const float *input0,
+ const float *input1, const float *input2, const float *input3) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ using batch16 = xsimd::batch<int16_t, Arch>;
+ using batch32 = xsimd::batch<int32_t, Arch>;
+ using ubatch32 = xsimd::batch<uint32_t, Arch>;
+
+ const batch8 neg127(-127);
+ // Grab 4 registers at a time in 32-bit format.
+ batch32 g0 = QuantizerGrab(input0, quant_mult);
+ batch32 g1 = QuantizerGrab(input1, quant_mult);
+ batch32 g2 = QuantizerGrab(input2, quant_mult);
+ batch32 g3 = QuantizerGrab(input3, quant_mult);
+ // Pack 32-bit to 16-bit.
+ batch16 packed0 = deinterleave(g0, g1);
+ batch16 packed1 = deinterleave(g2, g3);
+ // Pack 16-bit to 8-bit.
+ batch8 packed = deinterleave(packed0, packed1);
+ // Ban -128.
+ packed = xsimd::max(packed, neg127);
+
+ if constexpr (batch32::size == 4)
+ return packed;
+ // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
+ // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
+ // Technically this could be removed so long as the rows are bigger than 16
+ // and the values are only used for GEMM.
+ return xsimd::bitwise_cast<int8_t>(
+ xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
+ xsimd::make_batch_constant<ubatch32, Tiler<Arch>>()));
+ }
+
+private:
+ // A version that produces uint8_ts
+ template <class Arch>
+ static inline xsimd::batch<uint8_t, Arch>
+ TileU(xsimd::batch<float, Arch> quant_mult, const float *input0,
+ const float *input1, const float *input2, const float *input3) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ using batch16 = xsimd::batch<int16_t, Arch>;
+ using batch32 = xsimd::batch<int32_t, Arch>;
+ using ubatch32 = xsimd::batch<uint32_t, Arch>;
+
+ const batch8 neg127 = -127;
+ const batch8 pos127 = +127;
+ // Grab 4 registers at a time in 32-bit format.
+ batch32 g0 = QuantizerGrab(input0, quant_mult);
+ batch32 g1 = QuantizerGrab(input1, quant_mult);
+ batch32 g2 = QuantizerGrab(input2, quant_mult);
+ batch32 g3 = QuantizerGrab(input3, quant_mult);
+ // Pack 32-bit to 16-bit.
+ batch16 packed0 = deinterleave(g0, g1);
+ batch16 packed1 = deinterleave(g2, g3);
+ // Pack 16-bit to 8-bit.
+ batch8 packed = deinterleave(packed0, packed1);
+ // Ban -128.
+ packed = xsimd::max(packed, neg127); // Could be removed if we use +128
+ packed = packed + pos127;
+ if (batch32::size == 4)
+ return xsimd::bitwise_cast<uint8_t>(packed);
+ // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
+ // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
+ // Technically this could be removed so long as the rows are bigger than 16
+ // and the values are only used for GEMM.
+ return xsimd::bitwise_cast<uint8_t>(
+ xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
+ xsimd::make_batch_constant<ubatch32, Tiler<Arch>>()));
+ }
+};
+
+template <class Arch>
+inline void Transpose16InLane(
+ xsimd::batch<int8_t, Arch> &r0, xsimd::batch<int8_t, Arch> &r1,
+ xsimd::batch<int8_t, Arch> &r2, xsimd::batch<int8_t, Arch> &r3,
+ xsimd::batch<int8_t, Arch> &r4, xsimd::batch<int8_t, Arch> &r5,
+ xsimd::batch<int8_t, Arch> &r6, xsimd::batch<int8_t, Arch> &r7) {
+ /* r0: columns 0 1 2 3 4 5 6 7 from row 0
+ r1: columns 0 1 2 3 4 5 6 7 from row 1*/
+ auto r0_16 = xsimd::bitwise_cast<int16_t>(r0);
+ auto r1_16 = xsimd::bitwise_cast<int16_t>(r1);
+ auto r2_16 = xsimd::bitwise_cast<int16_t>(r2);
+ auto r3_16 = xsimd::bitwise_cast<int16_t>(r3);
+ auto r4_16 = xsimd::bitwise_cast<int16_t>(r4);
+ auto r5_16 = xsimd::bitwise_cast<int16_t>(r5);
+ auto r6_16 = xsimd::bitwise_cast<int16_t>(r6);
+ auto r7_16 = xsimd::bitwise_cast<int16_t>(r7);
+
+ std::tie(r0_16, r1_16) = interleave(r0_16, r1_16);
+ std::tie(r2_16, r3_16) = interleave(r2_16, r3_16);
+ std::tie(r4_16, r5_16) = interleave(r4_16, r5_16);
+ std::tie(r6_16, r7_16) = interleave(r6_16, r7_16);
+ /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
+ r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
+ r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
+ r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
+ r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
+ r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
+ r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
+ r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/
+ auto r0_32 = xsimd::bitwise_cast<int32_t>(r0_16);
+ auto r2_32 = xsimd::bitwise_cast<int32_t>(r2_16);
+ auto r1_32 = xsimd::bitwise_cast<int32_t>(r1_16);
+ auto r3_32 = xsimd::bitwise_cast<int32_t>(r3_16);
+ auto r4_32 = xsimd::bitwise_cast<int32_t>(r4_16);
+ auto r6_32 = xsimd::bitwise_cast<int32_t>(r6_16);
+ auto r5_32 = xsimd::bitwise_cast<int32_t>(r5_16);
+ auto r7_32 = xsimd::bitwise_cast<int32_t>(r7_16);
+
+ std::tie(r0_32, r2_32) = interleave(r0_32, r2_32);
+ std::tie(r1_32, r3_32) = interleave(r1_32, r3_32);
+ std::tie(r4_32, r6_32) = interleave(r4_32, r6_32);
+ std::tie(r5_32, r7_32) = interleave(r5_32, r7_32);
+ /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
+ r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
+ r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
+ r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
+ r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
+ r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
+ r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
+ r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/
+
+ auto r0_64 = xsimd::bitwise_cast<int64_t>(r0_32);
+ auto r2_64 = xsimd::bitwise_cast<int64_t>(r2_32);
+ auto r1_64 = xsimd::bitwise_cast<int64_t>(r1_32);
+ auto r3_64 = xsimd::bitwise_cast<int64_t>(r3_32);
+ auto r4_64 = xsimd::bitwise_cast<int64_t>(r4_32);
+ auto r6_64 = xsimd::bitwise_cast<int64_t>(r6_32);
+ auto r5_64 = xsimd::bitwise_cast<int64_t>(r5_32);
+ auto r7_64 = xsimd::bitwise_cast<int64_t>(r7_32);
+
+ std::tie(r0_64, r4_64) = interleave(r0_64, r4_64);
+ std::tie(r1_64, r5_64) = interleave(r1_64, r5_64);
+ std::tie(r2_64, r6_64) = interleave(r2_64, r6_64);
+ std::tie(r3_64, r7_64) = interleave(r3_64, r7_64);
+
+ r0 = xsimd::bitwise_cast<int8_t>(r0_64);
+ r1 = xsimd::bitwise_cast<int8_t>(r1_64);
+ r2 = xsimd::bitwise_cast<int8_t>(r2_64);
+ r3 = xsimd::bitwise_cast<int8_t>(r3_64);
+ r4 = xsimd::bitwise_cast<int8_t>(r4_64);
+ r5 = xsimd::bitwise_cast<int8_t>(r5_64);
+ r6 = xsimd::bitwise_cast<int8_t>(r6_64);
+ r7 = xsimd::bitwise_cast<int8_t>(r7_64);
+ /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
+ r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
+ r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
+ r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
+ r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
+ r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/
+ /* Empirically gcc is able to remove these movs and just rename the outputs of
+ * Interleave64. */
+ std::swap(r1, r4);
+ std::swap(r3, r6);
+}
+
+template <class Arch, typename IntegerTy>
+void SelectColumnsOfB(const xsimd::batch<int8_t, Arch> *input,
+ xsimd::batch<int8_t, Arch> *output,
+ size_t rows_bytes /* number of bytes in a row */,
+ const IntegerTy *cols_begin, const IntegerTy *cols_end) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ /* Do columns for multiples of 8.*/
+ size_t register_rows = rows_bytes / batch8::size;
+ const batch8 *starts[8];
+ for (; cols_begin != cols_end; cols_begin += 8) {
+ for (size_t k = 0; k < 8; ++k) {
+ starts[k] =
+ input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows;
+ }
+ for (size_t r = 0; r < register_rows; ++r) {
+ for (size_t k = 0; k < 8; ++k) {
+ *(output++) = *starts[k];
+ starts[k] += 8;
+ }
+ }
+ }
+}
+
+} // namespace
+
+namespace callbacks {
+template <class Arch>
+xsimd::batch<float, Arch> Unquantize::operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t,
+ size_t) {
+ return xsimd::batch_cast<float>(total) * unquant_mult;
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> Unquantize::operator()(
+ std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> total,
+ size_t, size_t, size_t) {
+ return std::make_tuple(
+ xsimd::batch_cast<float>(std::get<0>(total)) * unquant_mult,
+ xsimd::batch_cast<float>(std::get<1>(total)) * unquant_mult);
+}
+
+template <class Arch>
+xsimd::batch<float, Arch> AddBias::operator()(xsimd::batch<float, Arch> total, size_t,
+ size_t col_idx, size_t) {
+ return total + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx);
+}
+
+template <class Arch>
+std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>>
+AddBias::operator()(
+ std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total,
+ size_t, size_t col_idx, size_t) {
+ return std::make_tuple(
+ std::get<0>(total) + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx + 0),
+ std::get<1>(total) +
+ xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx +
+ xsimd::batch<float, Arch>::size));
+}
+
+template <class Arch>
+void Write::operator()(xsimd::batch<float, Arch> result, size_t row_idx,
+ size_t col_idx, size_t col_size) {
+ result.store_aligned(output_addr + row_idx * col_size + col_idx);
+}
+
+template <class Arch>
+void Write::operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx,
+ size_t col_idx, size_t col_size) {
+ xsimd::bitwise_cast<float>(result).store_aligned(
+ output_addr + row_idx * col_size + col_idx);
+}
+
+template <class Arch>
+void Write::operator()(
+ std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result,
+ size_t row_idx, size_t col_idx, size_t col_size) {
+ std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx +
+ 0);
+ std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx +
+ xsimd::batch<float, Arch>::size);
+}
+
+template <class Arch>
+void Write::operator()(
+ std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> result,
+ size_t row_idx, size_t col_idx, size_t col_size) {
+ xsimd::bitwise_cast<float>(std::get<0>(result))
+ .store_aligned(output_addr + row_idx * col_size + col_idx + 0);
+ xsimd::bitwise_cast<float>(std::get<1>(result))
+ .store_aligned(output_addr + row_idx * col_size + col_idx +
+ xsimd::batch<int32_t, Arch>::size);
+}
+
+template <class T>
+void UnquantizeAndWrite::operator()(T const &total, size_t row_idx,
+ size_t col_idx, size_t col_size) {
+ auto unquantized = unquantize(total, row_idx, col_idx, col_size);
+ write(unquantized, row_idx, col_idx, col_size);
+}
+
+template <class T>
+void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx,
+ size_t col_idx, size_t col_size) {
+ auto unquantized = unquantize(total, row_idx, col_idx, col_size);
+ auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size);
+ write(bias_added, row_idx, col_idx, col_size);
+}
+} // namespace callbacks
+
+template <class Arch>
+void Engine<Arch>::QuantizeU(const float *input, uint8_t *output,
+ float quant_mult, size_t size) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+
+ xsimd::batch<float, Arch> q(quant_mult);
+ const float *end = input + size;
+ for (; input != end; input += batch8::size, output += batch8::size) {
+ auto tile = QuantizeTile8::ConsecutiveU(q, input);
+ tile.store_aligned(output);
+ }
+}
+
+template <class Arch>
+void Engine<Arch>::Quantize(const float *const input, int8_t *const output,
+ float quant_mult, size_t size) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+
+ const std::size_t kBatch = batch8::size;
+ const std::size_t fast_end = size & ~(kBatch - 1);
+
+ xsimd::batch<float, Arch> q(quant_mult);
+ for (std::size_t i = 0; i < fast_end; i += kBatch) {
+ auto tile = QuantizeTile8::Consecutive(q, input + i);
+ tile.store_aligned(output + i);
+ }
+
+ std::size_t overhang = size & (kBatch - 1);
+ if (!overhang)
+ return;
+ /* Each does size(xsimd::batch<int8_t, Arch>) / 32 == kBatch / 4 floats at a
+ * time. If we're allowed to read one of them, then we can read the whole
+ * register.
+ */
+ const float *inputs[4];
+ std::size_t i;
+ for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) {
+ inputs[i] = &input[fast_end + i * (kBatch / 4)];
+ }
+ /* These will be clipped off. */
+ for (; i < 4; ++i) {
+ inputs[i] = &input[fast_end];
+ }
+ auto result =
+ QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]);
+ std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang);
+}
+
+template <class Arch>
+template <typename IntegerTy>
+void Engine<Arch>::SelectColumnsB(const int8_t *input, int8_t *output,
+ size_t rows, const IntegerTy *cols_begin,
+ const IntegerTy *cols_end) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ SelectColumnsOfB(reinterpret_cast<const batch8 *>(input),
+ reinterpret_cast<batch8 *>(output), rows, cols_begin,
+ cols_end);
+}
+
+template <class Arch>
+void Engine<Arch>::PrepareBTransposed(const float *input, int8_t *output,
+ float quant_mult, size_t cols,
+ size_t rows) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ const size_t RegisterElemsInt = batch8::size;
+ const size_t kColStride = 8;
+
+ xsimd::batch<float, Arch> q(quant_mult);
+ auto *output_it = reinterpret_cast<batch8 *>(output);
+ size_t r = 0;
+ size_t c = 0;
+ while (r < rows) {
+ for (size_t ri = 0; ri < 8; ++ri)
+ *output_it++ = QuantizeTile8::ConsecutiveWithWrapping(
+ q, input + (r + ri) * cols + c, cols - c, cols, 8);
+ c += RegisterElemsInt;
+ while (c >= cols) {
+ r += kColStride;
+ c -= cols;
+ }
+ }
+}
+
+template <class Arch>
+void Engine<Arch>::PrepareBQuantizedTransposed(const int8_t *input,
+ int8_t *output, size_t cols,
+ size_t rows) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ const size_t RegisterElems = batch8::size;
+ const size_t kColStride = 8;
+
+ auto *output_it = reinterpret_cast<batch8 *>(output);
+ for (size_t r = 0; r < rows; r += kColStride)
+ for (size_t c = 0; c < cols; c += RegisterElems)
+ for (size_t ri = 0; ri < 8; ++ri)
+ *output_it++ =
+ *reinterpret_cast<const batch8 *>(input + (r + ri) * cols + c);
+}
+
+template <class Arch>
+void Engine<Arch>::PrepareB(const float *input, int8_t *output_shadow,
+ float quant_mult, size_t rows, size_t cols) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+
+ xsimd::batch<float, Arch> q(quant_mult);
+ /* Currently all multipliers have a stride of 8 columns.*/
+ const size_t kColStride = 8;
+ auto *output = reinterpret_cast<batch8 *>(output_shadow);
+ for (size_t c = 0; c < cols; c += kColStride) {
+ for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) {
+ output[0] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols);
+ output[1] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols);
+ output[2] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols);
+ output[3] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols);
+ output[4] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols);
+ output[5] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols);
+ output[6] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols);
+ output[7] =
+ QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols);
+ std::tie(output[0], output[1]) =
+ interleave(xsimd::bitwise_cast<int8_t>(output[0]),
+ xsimd::bitwise_cast<int8_t>(output[1]));
+ std::tie(output[2], output[3]) =
+ interleave(xsimd::bitwise_cast<int8_t>(output[2]),
+ xsimd::bitwise_cast<int8_t>(output[3]));
+ std::tie(output[4], output[5]) =
+ interleave(xsimd::bitwise_cast<int8_t>(output[4]),
+ xsimd::bitwise_cast<int8_t>(output[5]));
+ std::tie(output[6], output[7]) =
+ interleave(xsimd::bitwise_cast<int8_t>(output[6]),
+ xsimd::bitwise_cast<int8_t>(output[7]));
+ Transpose16InLane(output[0], output[1], output[2], output[3], output[4],
+ output[5], output[6], output[7]);
+ }
+ }
+}
+
+template <class Arch>
+void Engine<Arch>::PrepareA(const float *input, int8_t *output,
+ float quant_mult, size_t rows, size_t cols) {
+ Quantize(input, output, quant_mult, rows * cols);
+}
+
+template <class Arch>
+void Engine<Arch>::Shift::PrepareA(const float *input, uint8_t *output,
+ float quant_mult, size_t rows, size_t cols) {
+ QuantizeU(input, output, quant_mult, rows * cols);
+}
+
+template <class Arch>
+template <class Callback>
+void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B,
+ size_t A_rows, size_t width, size_t B_cols,
+ Callback callback) {
+
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ using ubatch8 = xsimd::batch<uint8_t, Arch>;
+ using batch16 = xsimd::batch<int16_t, Arch>;
+ using batch32 = xsimd::batch<int32_t, Arch>;
+
+ const size_t simd_width = width / batch8::size;
+ for (size_t B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const auto *B0_col =
+ reinterpret_cast<const batch8 *>(B) + simd_width * B0_colidx;
+ /* Process one row of A at a time. Doesn't seem to be faster to do multiple
+ * rows of A at once.*/
+ for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
+ const auto *A_row =
+ reinterpret_cast<const ubatch8 *>(A + A_rowidx * width);
+ /* These will be packed 16-bit integers containing sums for each row of B
+ multiplied by the row of A. Iterate over shared (inner) dimension.*/
+ size_t k = 0;
+ ubatch8 a = *(A_row + k);
+ batch16 sum0 = madd(a, *(B0_col + k * 8));
+ batch16 sum1 = madd(a, *(B0_col + k * 8 + 1));
+ batch16 sum2 = madd(a, *(B0_col + k * 8 + 2));
+ batch16 sum3 = madd(a, *(B0_col + k * 8 + 3));
+ batch16 sum4 = madd(a, *(B0_col + k * 8 + 4));
+ batch16 sum5 = madd(a, *(B0_col + k * 8 + 5));
+ batch16 sum6 = madd(a, *(B0_col + k * 8 + 6));
+ batch16 sum7 = madd(a, *(B0_col + k * 8 + 7));
+ /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
+ * declared here.*/
+ batch16 ones(1);
+ batch32 isum0 = madd(sum0, ones);
+ batch32 isum1 = madd(sum1, ones);
+ batch32 isum2 = madd(sum2, ones);
+ batch32 isum3 = madd(sum3, ones);
+ batch32 isum4 = madd(sum4, ones);
+ batch32 isum5 = madd(sum5, ones);
+ batch32 isum6 = madd(sum6, ones);
+ batch32 isum7 = madd(sum7, ones);
+ for (k = 1; k < simd_width; ++k) {
+ a = *(A_row + k);
+ /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
+ batch16 mult0 = madd(a, *(B0_col + k * 8));
+ batch16 mult1 = madd(a, *(B0_col + k * 8 + 1));
+ batch16 mult2 = madd(a, *(B0_col + k * 8 + 2));
+ batch16 mult3 = madd(a, *(B0_col + k * 8 + 3));
+ batch16 mult4 = madd(a, *(B0_col + k * 8 + 4));
+ batch16 mult5 = madd(a, *(B0_col + k * 8 + 5));
+ batch16 mult6 = madd(a, *(B0_col + k * 8 + 6));
+ batch16 mult7 = madd(a, *(B0_col + k * 8 + 7));
+ /* Upcast to 32-bit and horizontally add.*/
+ batch32 imult0 = madd(mult0, ones);
+ batch32 imult1 = madd(mult1, ones);
+ batch32 imult2 = madd(mult2, ones);
+ batch32 imult3 = madd(mult3, ones);
+ batch32 imult4 = madd(mult4, ones);
+ batch32 imult5 = madd(mult5, ones);
+ batch32 imult6 = madd(mult6, ones);
+ batch32 imult7 = madd(mult7, ones);
+ /*Add in 32bit*/
+ isum0 += imult0;
+ isum1 += imult1;
+ isum2 += imult2;
+ isum3 += imult3;
+ isum4 += imult4;
+ isum5 += imult5;
+ isum6 += imult6;
+ isum7 += imult7;
+ }
+ /* Reduce sums within 128-bit lanes.*/
+ auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
+ auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
+ /*The specific implementation may need to reduce further.*/
+ auto total = PermuteSummer(pack0123, pack4567);
+ callback(total, A_rowidx, B0_colidx, B_cols);
+ }
+ }
+}
+
+template <class Arch>
+template <class Callback>
+void Engine<Arch>::Shift::PrepareBias(const int8_t *B, size_t width,
+ size_t B_cols, Callback C) {
+ using batch8 = xsimd::batch<int8_t, Arch>;
+ using batch16 = xsimd::batch<int16_t, Arch>;
+ const size_t simd_width = width / batch8::size;
+ xsimd::batch<uint8_t, Arch> a(1);
+ for (size_t j = 0; j < B_cols; j += 8) {
+ /*Process one row of A at a time. Doesn't seem to be faster to do multiple
+ * rows of A at once.*/
+ const int8_t *B_j = B + j * width;
+
+ /* Rather than initializing as zeros and adding, just initialize the
+ * first.*/
+ /* These will be packed 16-bit integers containing sums for each column of
+ * B multiplied by the row of A.*/
+ auto sum0 = madd(a, batch8::load_aligned(&B_j[0 * batch8::size]));
+ auto sum1 = madd(a, batch8::load_aligned(&B_j[1 * batch8::size]));
+ auto sum2 = madd(a, batch8::load_aligned(&B_j[2 * batch8::size]));
+ auto sum3 = madd(a, batch8::load_aligned(&B_j[3 * batch8::size]));
+ auto sum4 = madd(a, batch8::load_aligned(&B_j[4 * batch8::size]));
+ auto sum5 = madd(a, batch8::load_aligned(&B_j[5 * batch8::size]));
+ auto sum6 = madd(a, batch8::load_aligned(&B_j[6 * batch8::size]));
+ auto sum7 = madd(a, batch8::load_aligned(&B_j[7 * batch8::size]));
+
+ B_j += 8 * batch8::size;
+
+ /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
+ * declared here.*/
+ batch16 ones(1);
+ auto isum0 = madd(sum0, ones);
+ auto isum1 = madd(sum1, ones);
+ auto isum2 = madd(sum2, ones);
+ auto isum3 = madd(sum3, ones);
+ auto isum4 = madd(sum4, ones);
+ auto isum5 = madd(sum5, ones);
+ auto isum6 = madd(sum6, ones);
+ auto isum7 = madd(sum7, ones);
+
+ for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) {
+ isum0 +=
+ madd(madd(a, batch8::load_aligned(&B_j[0 * batch8::size])), ones);
+ isum1 +=
+ madd(madd(a, batch8::load_aligned(&B_j[1 * batch8::size])), ones);
+ isum2 +=
+ madd(madd(a, batch8::load_aligned(&B_j[2 * batch8::size])), ones);
+ isum3 +=
+ madd(madd(a, batch8::load_aligned(&B_j[3 * batch8::size])), ones);
+ isum4 +=
+ madd(madd(a, batch8::load_aligned(&B_j[4 * batch8::size])), ones);
+ isum5 +=
+ madd(madd(a, batch8::load_aligned(&B_j[5 * batch8::size])), ones);
+ isum6 +=
+ madd(madd(a, batch8::load_aligned(&B_j[6 * batch8::size])), ones);
+ isum7 +=
+ madd(madd(a, batch8::load_aligned(&B_j[7 * batch8::size])), ones);
+ }
+
+ auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
+ auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
+
+ auto total = PermuteSummer(pack0123, pack4567);
+ C(total, 0, j, B_cols);
+ }
+}
+
+} // namespace gemmology
+
+#endif