diff options
Diffstat (limited to 'third_party/gemmology')
-rw-r--r-- | third_party/gemmology/gemmology.h | 95 | ||||
-rw-r--r-- | third_party/gemmology/kernels/GemmologyEngineNeon64I8mm.cpp | 19 | ||||
-rw-r--r-- | third_party/gemmology/moz.yaml | 4 |
3 files changed, 84 insertions, 34 deletions
diff --git a/third_party/gemmology/gemmology.h b/third_party/gemmology/gemmology.h index d774c53388..eb5ebed3b4 100644 --- a/third_party/gemmology/gemmology.h +++ b/third_party/gemmology/gemmology.h @@ -198,6 +198,17 @@ PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, return _mm256_add_epi32(rev, blended); } +template <class Arch> +inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::avx2>) { + auto pack01 = _mm256_hadd_epi32(sum0, sum1); + auto pack23 = _mm256_hadd_epi32(sum2, sum3); + return _mm256_hadd_epi32(pack01, pack23); +} + #ifdef __AVXVNNI__ template <class Arch> @@ -245,6 +256,17 @@ madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, xsimd::kernel::requires_arch<xsimd::ssse3>) { return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x)); } + +template <class Arch> +inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::ssse3>) { + auto pack01 = _mm_hadd_epi32(sum0, sum1); + auto pack23 = _mm_hadd_epi32(sum2, sum3); + return _mm_hadd_epi32(pack01, pack23); +} #endif #ifdef __SSE2__ @@ -524,7 +546,8 @@ xsimd::batch<int8_t, Arch> deinterleave(xsimd::batch<int16_t, Arch> first, xsimd::batch<int16_t, Arch> second, xsimd::kernel::requires_arch<xsimd::neon64>) { - return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); + + return vqmovn_high_s16(vqmovn_s16(first), second); } template <class Arch> @@ -532,27 +555,18 @@ xsimd::batch<int16_t, Arch> deinterleave(xsimd::batch<int32_t, Arch> first, xsimd::batch<int32_t, Arch> second, xsimd::kernel::requires_arch<xsimd::neon64>) { - return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); + return vqmovn_high_s32(vqmovn_s32(first), second); } +#ifdef __ARM_FEATURE_MATMUL_INT8 template <class Arch> inline xsimd::batch<int32_t, Arch> -madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y, - xsimd::kernel::requires_arch<xsimd::neon64>) { - int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); - return vmlal_high_s16(low, x, y); -} - -template <class Arch> -inline xsimd::batch<int16_t, Arch> -madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, - xsimd::kernel::requires_arch<xsimd::neon64>) { - - int16x8_t tl = vmull_s8(vreinterpret_s8_u8(vget_low_u8(x)), - vget_low_s8(y)); - int16x8_t th = vmull_high_s8(vreinterpretq_s8_u8(x), y); - return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th)); +maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, + xsimd::batch<int32_t, Arch> z, + xsimd::kernel::requires_arch<xsimd::i8mm<xsimd::neon64>>) { + return vusdotq_s32(z, x, y); } +#endif template <class Arch> inline xsimd::batch<int32_t, Arch> @@ -564,15 +578,17 @@ maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), vmovl_s8(vget_high_s8(y))); return vpadalq_s16(vpadalq_s16(z, tl), th); - //TODO: investigate using vdotq_s32 } template <class Arch> -inline xsimd::batch<int16_t, Arch> -madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, - xsimd::kernel::requires_arch<xsimd::neon64>) { - int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); - return vmlal_high_s8(low, x, y); +inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::neon64>) { + auto pack01 = vpaddq_s32(sum0, sum1); + auto pack23 = vpaddq_s32(sum2, sum3); + return vpaddq_s32(pack01, pack23); } #endif @@ -644,20 +660,35 @@ inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, return kernel::PermuteSummer(pack0123, pack4567, Arch{}); } + +namespace kernel { + + template <class Arch> + inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::generic>) { + + std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{}); + auto pack01 = sum0 + sum1; + std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{}); + auto pack23 = sum2 + sum3; + + auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01), + xsimd::bitwise_cast<int64_t>(pack23), + Arch{}); + return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) + + xsimd::bitwise_cast<int32_t>(std::get<1>(packed)); + } +} + template <class Arch> inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, xsimd::batch<int32_t, Arch> sum1, xsimd::batch<int32_t, Arch> sum2, xsimd::batch<int32_t, Arch> sum3) { - std::tie(sum0, sum1) = interleave(sum0, sum1); - auto pack01 = sum0 + sum1; - std::tie(sum2, sum3) = interleave(sum2, sum3); - auto pack23 = sum2 + sum3; - - auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01), - xsimd::bitwise_cast<int64_t>(pack23)); - return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) + - xsimd::bitwise_cast<int32_t>(std::get<1>(packed)); + return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{}); } template <class Arch> diff --git a/third_party/gemmology/kernels/GemmologyEngineNeon64I8mm.cpp b/third_party/gemmology/kernels/GemmologyEngineNeon64I8mm.cpp new file mode 100644 index 0000000000..d8259e750f --- /dev/null +++ b/third_party/gemmology/kernels/GemmologyEngineNeon64I8mm.cpp @@ -0,0 +1,19 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* this source code form is subject to the terms of the mozilla public + * license, v. 2.0. if a copy of the mpl was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <gemmology.h> + +namespace gemmology { +template struct Engine<xsimd::i8mm<xsimd::neon64>>; +template void Engine<xsimd::i8mm<xsimd::neon64>>::SelectColumnsB(int8_t const*, int8_t*, + size_t, uint32_t const*, + uint32_t const*); +template void Engine<xsimd::i8mm<xsimd::neon64>>::Shift::Multiply( + uint8_t const*, int8_t const*, size_t, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +template void Engine<xsimd::i8mm<xsimd::neon64>>::Shift::PrepareBias( + int8_t const*, size_t, size_t, + gemmology::callbacks::UnquantizeAndAddBiasAndWrite); +} // namespace gemmology diff --git a/third_party/gemmology/moz.yaml b/third_party/gemmology/moz.yaml index d9f9472da7..749227e2ee 100644 --- a/third_party/gemmology/moz.yaml +++ b/third_party/gemmology/moz.yaml @@ -10,8 +10,8 @@ origin: url: https://github.com/mozilla/gemmology - release: ec535e87d0ab9d1457ff6d2af247cc8113e74694 (2024-02-05T09:05:20Z). - revision: ec535e87d0ab9d1457ff6d2af247cc8113e74694 + release: dbcd029c3bc6e183355ea597216d379677ff9b19 (2024-02-20T12:36:14Z). + revision: dbcd029c3bc6e183355ea597216d379677ff9b19 license: MIT |