diff options
Diffstat (limited to 'third_party/gemmology/gemmology.h')
-rw-r--r-- | third_party/gemmology/gemmology.h | 95 |
1 files changed, 63 insertions, 32 deletions
diff --git a/third_party/gemmology/gemmology.h b/third_party/gemmology/gemmology.h index d774c53388..eb5ebed3b4 100644 --- a/third_party/gemmology/gemmology.h +++ b/third_party/gemmology/gemmology.h @@ -198,6 +198,17 @@ PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, return _mm256_add_epi32(rev, blended); } +template <class Arch> +inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::avx2>) { + auto pack01 = _mm256_hadd_epi32(sum0, sum1); + auto pack23 = _mm256_hadd_epi32(sum2, sum3); + return _mm256_hadd_epi32(pack01, pack23); +} + #ifdef __AVXVNNI__ template <class Arch> @@ -245,6 +256,17 @@ madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, xsimd::kernel::requires_arch<xsimd::ssse3>) { return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x)); } + +template <class Arch> +inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::ssse3>) { + auto pack01 = _mm_hadd_epi32(sum0, sum1); + auto pack23 = _mm_hadd_epi32(sum2, sum3); + return _mm_hadd_epi32(pack01, pack23); +} #endif #ifdef __SSE2__ @@ -524,7 +546,8 @@ xsimd::batch<int8_t, Arch> deinterleave(xsimd::batch<int16_t, Arch> first, xsimd::batch<int16_t, Arch> second, xsimd::kernel::requires_arch<xsimd::neon64>) { - return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); + + return vqmovn_high_s16(vqmovn_s16(first), second); } template <class Arch> @@ -532,27 +555,18 @@ xsimd::batch<int16_t, Arch> deinterleave(xsimd::batch<int32_t, Arch> first, xsimd::batch<int32_t, Arch> second, xsimd::kernel::requires_arch<xsimd::neon64>) { - return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); + return vqmovn_high_s32(vqmovn_s32(first), second); } +#ifdef __ARM_FEATURE_MATMUL_INT8 template <class Arch> inline xsimd::batch<int32_t, Arch> -madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y, - xsimd::kernel::requires_arch<xsimd::neon64>) { - int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); - return vmlal_high_s16(low, x, y); -} - -template <class Arch> -inline xsimd::batch<int16_t, Arch> -madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, - xsimd::kernel::requires_arch<xsimd::neon64>) { - - int16x8_t tl = vmull_s8(vreinterpret_s8_u8(vget_low_u8(x)), - vget_low_s8(y)); - int16x8_t th = vmull_high_s8(vreinterpretq_s8_u8(x), y); - return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th)); +maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, + xsimd::batch<int32_t, Arch> z, + xsimd::kernel::requires_arch<xsimd::i8mm<xsimd::neon64>>) { + return vusdotq_s32(z, x, y); } +#endif template <class Arch> inline xsimd::batch<int32_t, Arch> @@ -564,15 +578,17 @@ maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), vmovl_s8(vget_high_s8(y))); return vpadalq_s16(vpadalq_s16(z, tl), th); - //TODO: investigate using vdotq_s32 } template <class Arch> -inline xsimd::batch<int16_t, Arch> -madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, - xsimd::kernel::requires_arch<xsimd::neon64>) { - int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); - return vmlal_high_s8(low, x, y); +inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::neon64>) { + auto pack01 = vpaddq_s32(sum0, sum1); + auto pack23 = vpaddq_s32(sum2, sum3); + return vpaddq_s32(pack01, pack23); } #endif @@ -644,20 +660,35 @@ inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, return kernel::PermuteSummer(pack0123, pack4567, Arch{}); } + +namespace kernel { + + template <class Arch> + inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, + xsimd::batch<int32_t, Arch> sum1, + xsimd::batch<int32_t, Arch> sum2, + xsimd::batch<int32_t, Arch> sum3, + xsimd::kernel::requires_arch<xsimd::generic>) { + + std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{}); + auto pack01 = sum0 + sum1; + std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{}); + auto pack23 = sum2 + sum3; + + auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01), + xsimd::bitwise_cast<int64_t>(pack23), + Arch{}); + return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) + + xsimd::bitwise_cast<int32_t>(std::get<1>(packed)); + } +} + template <class Arch> inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, xsimd::batch<int32_t, Arch> sum1, xsimd::batch<int32_t, Arch> sum2, xsimd::batch<int32_t, Arch> sum3) { - std::tie(sum0, sum1) = interleave(sum0, sum1); - auto pack01 = sum0 + sum1; - std::tie(sum2, sum3) = interleave(sum2, sum3); - auto pack23 = sum2 + sum3; - - auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01), - xsimd::bitwise_cast<int64_t>(pack23)); - return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) + - xsimd::bitwise_cast<int32_t>(std::get<1>(packed)); + return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{}); } template <class Arch> |