summaryrefslogtreecommitdiffstats
path: root/third_party/gemmology/gemmology.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/gemmology/gemmology.h')
-rw-r--r--third_party/gemmology/gemmology.h95
1 files changed, 63 insertions, 32 deletions
diff --git a/third_party/gemmology/gemmology.h b/third_party/gemmology/gemmology.h
index d774c53388..eb5ebed3b4 100644
--- a/third_party/gemmology/gemmology.h
+++ b/third_party/gemmology/gemmology.h
@@ -198,6 +198,17 @@ PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
return _mm256_add_epi32(rev, blended);
}
+template <class Arch>
+inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
+ xsimd::batch<int32_t, Arch> sum1,
+ xsimd::batch<int32_t, Arch> sum2,
+ xsimd::batch<int32_t, Arch> sum3,
+ xsimd::kernel::requires_arch<xsimd::avx2>) {
+ auto pack01 = _mm256_hadd_epi32(sum0, sum1);
+ auto pack23 = _mm256_hadd_epi32(sum2, sum3);
+ return _mm256_hadd_epi32(pack01, pack23);
+}
+
#ifdef __AVXVNNI__
template <class Arch>
@@ -245,6 +256,17 @@ madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
xsimd::kernel::requires_arch<xsimd::ssse3>) {
return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x));
}
+
+template <class Arch>
+inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
+ xsimd::batch<int32_t, Arch> sum1,
+ xsimd::batch<int32_t, Arch> sum2,
+ xsimd::batch<int32_t, Arch> sum3,
+ xsimd::kernel::requires_arch<xsimd::ssse3>) {
+ auto pack01 = _mm_hadd_epi32(sum0, sum1);
+ auto pack23 = _mm_hadd_epi32(sum2, sum3);
+ return _mm_hadd_epi32(pack01, pack23);
+}
#endif
#ifdef __SSE2__
@@ -524,7 +546,8 @@ xsimd::batch<int8_t, Arch>
deinterleave(xsimd::batch<int16_t, Arch> first,
xsimd::batch<int16_t, Arch> second,
xsimd::kernel::requires_arch<xsimd::neon64>) {
- return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second));
+
+ return vqmovn_high_s16(vqmovn_s16(first), second);
}
template <class Arch>
@@ -532,27 +555,18 @@ xsimd::batch<int16_t, Arch>
deinterleave(xsimd::batch<int32_t, Arch> first,
xsimd::batch<int32_t, Arch> second,
xsimd::kernel::requires_arch<xsimd::neon64>) {
- return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second));
+ return vqmovn_high_s32(vqmovn_s32(first), second);
}
+#ifdef __ARM_FEATURE_MATMUL_INT8
template <class Arch>
inline xsimd::batch<int32_t, Arch>
-madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
- xsimd::kernel::requires_arch<xsimd::neon64>) {
- int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
- return vmlal_high_s16(low, x, y);
-}
-
-template <class Arch>
-inline xsimd::batch<int16_t, Arch>
-madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
- xsimd::kernel::requires_arch<xsimd::neon64>) {
-
- int16x8_t tl = vmull_s8(vreinterpret_s8_u8(vget_low_u8(x)),
- vget_low_s8(y));
- int16x8_t th = vmull_high_s8(vreinterpretq_s8_u8(x), y);
- return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th));
+maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
+ xsimd::batch<int32_t, Arch> z,
+ xsimd::kernel::requires_arch<xsimd::i8mm<xsimd::neon64>>) {
+ return vusdotq_s32(z, x, y);
}
+#endif
template <class Arch>
inline xsimd::batch<int32_t, Arch>
@@ -564,15 +578,17 @@ maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))),
vmovl_s8(vget_high_s8(y)));
return vpadalq_s16(vpadalq_s16(z, tl), th);
- //TODO: investigate using vdotq_s32
}
template <class Arch>
-inline xsimd::batch<int16_t, Arch>
-madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
- xsimd::kernel::requires_arch<xsimd::neon64>) {
- int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y));
- return vmlal_high_s8(low, x, y);
+inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
+ xsimd::batch<int32_t, Arch> sum1,
+ xsimd::batch<int32_t, Arch> sum2,
+ xsimd::batch<int32_t, Arch> sum3,
+ xsimd::kernel::requires_arch<xsimd::neon64>) {
+ auto pack01 = vpaddq_s32(sum0, sum1);
+ auto pack23 = vpaddq_s32(sum2, sum3);
+ return vpaddq_s32(pack01, pack23);
}
#endif
@@ -644,20 +660,35 @@ inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
return kernel::PermuteSummer(pack0123, pack4567, Arch{});
}
+
+namespace kernel {
+
+ template <class Arch>
+ inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
+ xsimd::batch<int32_t, Arch> sum1,
+ xsimd::batch<int32_t, Arch> sum2,
+ xsimd::batch<int32_t, Arch> sum3,
+ xsimd::kernel::requires_arch<xsimd::generic>) {
+
+ std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{});
+ auto pack01 = sum0 + sum1;
+ std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{});
+ auto pack23 = sum2 + sum3;
+
+ auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01),
+ xsimd::bitwise_cast<int64_t>(pack23),
+ Arch{});
+ return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) +
+ xsimd::bitwise_cast<int32_t>(std::get<1>(packed));
+ }
+}
+
template <class Arch>
inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
xsimd::batch<int32_t, Arch> sum1,
xsimd::batch<int32_t, Arch> sum2,
xsimd::batch<int32_t, Arch> sum3) {
- std::tie(sum0, sum1) = interleave(sum0, sum1);
- auto pack01 = sum0 + sum1;
- std::tie(sum2, sum3) = interleave(sum2, sum3);
- auto pack23 = sum2 + sum3;
-
- auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01),
- xsimd::bitwise_cast<int64_t>(pack23));
- return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) +
- xsimd::bitwise_cast<int32_t>(std::get<1>(packed));
+ return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{});
}
template <class Arch>