#ifndef GEMMOLOGY_H #define GEMMOLOGY_H #include "gemmology_fwd.h" #include #include #include #include namespace gemmology { namespace { // // Arch specific implementation of various elementary operations // namespace kernel { #ifdef __AVX512BW__ template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm512_unpacklo_epi8(first, second), _mm512_unpackhi_epi8(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm512_unpacklo_epi16(first, second), _mm512_unpackhi_epi16(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm512_unpacklo_epi32(first, second), _mm512_unpackhi_epi32(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm512_unpacklo_epi64(first, second), _mm512_unpackhi_epi64(first, second)}; } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return _mm512_packs_epi16(first, second); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return _mm512_packs_epi32(first, second); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm512_madd_epi16(x, y); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm512_maddubs_epi16(x, y); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm512_madd_epi16(x, y); } template inline xsimd::batch PermuteSummer(xsimd::batch pack0123, xsimd::batch pack4567, xsimd::kernel::requires_arch) { // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567, // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567] __m512i mix0 = _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6)); // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567, // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567] __m512i mix1 = _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2)); __m512i added = _mm512_add_epi32(mix0, mix1); // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7. // Fold register over itself. return _mm256_add_epi32(_mm512_castsi512_si256(added), _mm512_extracti64x4_epi64(added, 1)); } #endif #ifdef __AVX2__ template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm256_unpacklo_epi8(first, second), _mm256_unpackhi_epi8(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm256_unpacklo_epi16(first, second), _mm256_unpackhi_epi16(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm256_unpacklo_epi32(first, second), _mm256_unpackhi_epi32(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm256_unpacklo_epi64(first, second), _mm256_unpackhi_epi64(first, second)}; } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return _mm256_packs_epi16(first, second); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return _mm256_packs_epi32(first, second); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm256_madd_epi16(x, y); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm256_maddubs_epi16(x, y); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x)); } template inline xsimd::batch PermuteSummer(xsimd::batch pack0123, xsimd::batch pack4567, xsimd::kernel::requires_arch) { // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21); // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0); return _mm256_add_epi32(rev, blended); } #endif #ifdef __SSSE3__ template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm_maddubs_epi16(x, y); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x)); } #endif #ifdef __SSE2__ template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm_unpacklo_epi8(first, second), _mm_unpackhi_epi8(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm_unpacklo_epi16(first, second), _mm_unpackhi_epi16(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm_unpacklo_epi32(first, second), _mm_unpackhi_epi32(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {_mm_unpacklo_epi64(first, second), _mm_unpackhi_epi64(first, second)}; } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return _mm_packs_epi16(first, second); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return _mm_packs_epi32(first, second); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { return _mm_madd_epi16(x, y); } template inline xsimd::batch madd(xsimd::batch a, xsimd::batch b, xsimd::kernel::requires_arch) { // Adapted from // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 // a = 0x00 0x01 0xFE 0x04 ... // b = 0x00 0x02 0x80 0x84 ... // To extend signed 8-bit value, MSB has to be set to 0xFF __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); // sign_mask_b = 0x00 0x00 0xFF 0xFF ... // Unpack positives with 0x00, negatives with 0xFF __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128()); __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128()); __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); // Here - valid 16-bit signed integers corresponding to the 8-bit input // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); // Now go back from 32-bit values to 16-bit values & signed saturate return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); } template inline xsimd::batch madd(xsimd::batch a, xsimd::batch b, xsimd::kernel::requires_arch) { // adapted // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 // a = 0x00 0x01 0xFE 0x04 ... // b = 0x00 0x02 0x80 0x84 ... // To extend signed 8-bit value, MSB has to be set to 0xFF __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128()); __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); // sign_mask_a = 0x00 0x00 0xFF 0x00 ... // sign_mask_b = 0x00 0x00 0xFF 0xFF ... // Unpack positives with 0x00, negatives with 0xFF __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a); __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a); __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); // Here - valid 16-bit signed integers corresponding to the 8-bit input // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); // Now go back from 32-bit values to 16-bit values & signed saturate return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); } template inline std::tuple, xsimd::batch> PermuteSummer(xsimd::batch pack0123, xsimd::batch pack4567, xsimd::kernel::requires_arch) { return {pack0123, pack4567}; } #endif #if __ARM_ARCH >= 7 template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { int8x8_t first_lo = vget_low_s8(first); int8x8_t second_lo = vget_low_s8(second); int8x8x2_t result_lo = vzip_s8(first_lo, second_lo); int8x8_t first_hi = vget_high_s8(first); int8x8_t second_hi = vget_high_s8(second); int8x8x2_t result_hi = vzip_s8(first_hi, second_hi); return {vcombine_s8(result_lo.val[0], result_lo.val[1]), vcombine_s8(result_hi.val[0], result_hi.val[1])}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { int16x4_t first_lo = vget_low_s16(first); int16x4_t second_lo = vget_low_s16(second); int16x4x2_t result_lo = vzip_s16(first_lo, second_lo); int16x4_t first_hi = vget_high_s16(first); int16x4_t second_hi = vget_high_s16(second); int16x4x2_t result_hi = vzip_s16(first_hi, second_hi); return {vcombine_s16(result_lo.val[0], result_lo.val[1]), vcombine_s16(result_hi.val[0], result_hi.val[1])}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { int32x2_t first_lo = vget_low_s32(first); int32x2_t second_lo = vget_low_s32(second); int32x2x2_t result_lo = vzip_s32(first_lo, second_lo); int32x2_t first_hi = vget_high_s32(first); int32x2_t second_hi = vget_high_s32(second); int32x2x2_t result_hi = vzip_s32(first_hi, second_hi); return {vcombine_s32(result_lo.val[0], result_lo.val[1]), vcombine_s32(result_hi.val[0], result_hi.val[1])}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { int64x1_t first_lo = vget_low_s64(first); int64x1_t second_lo = vget_low_s64(second); int64x1_t first_hi = vget_high_s64(first); int64x1_t second_hi = vget_high_s64(second); return {vcombine_s64(first_lo, second_lo), vcombine_s64(first_hi, second_hi)}; } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y)); int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low)); int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high)); return vcombine_s32(low_sum, high_sum); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { // This would be much simpler if x86 would choose to zero extend OR sign // extend, not both. This could probably be optimized better. // Zero extend x int16x8_t x_odd = vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8)); int16x8_t x_even = vreinterpretq_s16_u16( vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00))); // Sign extend by shifting left then shifting right. int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8); int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8); // multiply int16x8_t prod1 = vmulq_s16(x_even, y_even); int16x8_t prod2 = vmulq_s16(x_odd, y_odd); // saturated add return vqaddq_s16(prod1, prod2); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y)); int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low)); int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high)); return vcombine_s16(low_sum, high_sum); } template inline std::tuple, xsimd::batch> PermuteSummer(xsimd::batch pack0123, xsimd::batch pack4567, xsimd::kernel::requires_arch) { return {pack0123, pack4567}; } #endif #ifdef __aarch64__ template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {vzip1q_s8(first, second), vzip2q_s8(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {vzip1q_s16(first, second), vzip2q_s16(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {vzip1q_s32(first, second), vzip2q_s32(first, second)}; } template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return {vzip1q_s64(first, second), vzip2q_s64(first, second)}; } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second, xsimd::kernel::requires_arch) { return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); int32x4_t high = vmull_high_s16(x, y); return vpaddq_s32(low, high); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), vmovl_s8(vget_low_s8(y))); int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), vmovl_s8(vget_high_s8(y))); return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th)); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y, xsimd::kernel::requires_arch) { int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); int16x8_t high = vmull_high_s8(x, y); return vpaddq_s16(low, high); } #endif } // namespace kernel // // Generic dispatcher for interleave, deinterleave madd and PermuteSummer // template std::tuple, xsimd::batch> interleave(xsimd::batch first, xsimd::batch second) { return kernel::interleave(first, second, Arch{}); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second) { return kernel::deinterleave(first, second, Arch{}); } template xsimd::batch deinterleave(xsimd::batch first, xsimd::batch second) { return kernel::deinterleave(first, second, Arch{}); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y) { return kernel::madd(x, y, Arch{}); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y) { return kernel::madd(x, y, Arch{}); } template inline xsimd::batch madd(xsimd::batch x, xsimd::batch y) { return kernel::madd(x, y, Arch{}); } template inline auto PermuteSummer(xsimd::batch pack0123, xsimd::batch pack4567) -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) { return kernel::PermuteSummer(pack0123, pack4567, Arch{}); } template inline xsimd::batch Pack0123(xsimd::batch sum0, xsimd::batch sum1, xsimd::batch sum2, xsimd::batch sum3) { std::tie(sum0, sum1) = interleave(sum0, sum1); auto pack01 = sum0 + sum1; std::tie(sum2, sum3) = interleave(sum2, sum3); auto pack23 = sum2 + sum3; auto packed = interleave(xsimd::bitwise_cast(pack01), xsimd::bitwise_cast(pack23)); return xsimd::bitwise_cast(std::get<0>(packed)) + xsimd::bitwise_cast(std::get<1>(packed)); } template static inline xsimd::batch quantize(xsimd::batch input, xsimd::batch quant_mult) { return xsimd::nearbyint_as_int(input * quant_mult); } template inline xsimd::batch QuantizerGrab(const float *input, xsimd::batch quant_mult_reg) { return quantize(xsimd::batch::load_unaligned(input), quant_mult_reg); } #ifdef __AVX512BW__ inline __m512 Concat(const __m256 first, const __m256 second) { // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway. return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1); } // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be // controlled independently. /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set * INTGEMM_AVX512BW */ inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, const __m512 quant_mult_reg) { __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1)); appended = _mm512_mul_ps(appended, quant_mult_reg); return _mm512_cvtps_epi32(appended); } #else template inline xsimd::batch QuantizerGrabHalves(const float *input0, const float *input1, xsimd::batch quant_mult_reg); #endif /* Read 8 floats at a time from input0, input1, input2, and input3. Quantize * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate * the result into one register and return it. */ class QuantizeTile8 { template struct Tiler { static constexpr uint32_t get(std::size_t i, std::size_t n) { size_t factor = xsimd::batch::size / 4; return (i % factor) * 4 + i / factor; } }; public: template static inline xsimd::batch Consecutive(xsimd::batch quant_mult, const float *input) { return Tile(quant_mult, input + 0 * xsimd::batch::size, input + 1 * xsimd::batch::size, input + 2 * xsimd::batch::size, input + 3 * xsimd::batch::size); } template static inline xsimd::batch ConsecutiveU(xsimd::batch quant_mult, const float *input) { return TileU(quant_mult, input + 0 * xsimd::batch::size, input + 1 * xsimd::batch::size, input + 2 * xsimd::batch::size, input + 3 * xsimd::batch::size); } template static inline xsimd::batch ConsecutiveWithWrapping(xsimd::batch quant_mult, const float *input, size_t cols_left, size_t cols, size_t row_step) { using batchf32 = xsimd::batch; const float *inputs[4]; for (size_t i = 0; i < std::size(inputs); ++i) { while (cols_left < batchf32::size) { input += cols * (row_step - 1); cols_left += cols; } inputs[i] = input; input += batchf32::size; cols_left -= batchf32::size; } return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]); } template static inline xsimd::batch ForReshape(xsimd::batch quant_mult, const float *input, size_t cols) { using batchf32 = xsimd::batch; using batch8 = xsimd::batch; using batch16 = xsimd::batch; using batch32 = xsimd::batch; using ubatch32 = xsimd::batch; // Put higher rows in the second half of the register. These will jumble // around in the same way then conveniently land in the right place. if constexpr (batchf32::size == 16) { const batch8 neg127(-127); // In reverse order: grabbing the first 32-bit values from each 128-bit // register, then the second 32-bit values, etc. Grab 4 registers at a // time in 32-bit format. batch32 g0 = QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult); batch32 g1 = QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult); batch32 g2 = QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult); batch32 g3 = QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult); // Pack 32-bit to 16-bit. batch16 packed0 = deinterleave(g0, g1); batch16 packed1 = deinterleave(g2, g3); // Pack 16-bit to 8-bit. batch8 packed = deinterleave(packed0, packed1); // Ban -128. packed = xsimd::max(packed, neg127); return xsimd::bitwise_cast( xsimd::swizzle(xsimd::bitwise_cast(packed), xsimd::make_batch_constant>())); } else if constexpr (batchf32::size == 8) return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, input + 18 * cols); else if constexpr (batchf32::size == 4) // Skip a row. return Tile(quant_mult, input, input + 4, input + 2 * cols, input + 2 * cols + 4); else return {}; } template static inline xsimd::batch Tile(xsimd::batch quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) { using batch8 = xsimd::batch; using batch16 = xsimd::batch; using batch32 = xsimd::batch; using ubatch32 = xsimd::batch; const batch8 neg127(-127); // Grab 4 registers at a time in 32-bit format. batch32 g0 = QuantizerGrab(input0, quant_mult); batch32 g1 = QuantizerGrab(input1, quant_mult); batch32 g2 = QuantizerGrab(input2, quant_mult); batch32 g3 = QuantizerGrab(input3, quant_mult); // Pack 32-bit to 16-bit. batch16 packed0 = deinterleave(g0, g1); batch16 packed1 = deinterleave(g2, g3); // Pack 16-bit to 8-bit. batch8 packed = deinterleave(packed0, packed1); // Ban -128. packed = xsimd::max(packed, neg127); if constexpr (batch32::size == 4) return packed; // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 // Technically this could be removed so long as the rows are bigger than 16 // and the values are only used for GEMM. return xsimd::bitwise_cast( xsimd::swizzle(xsimd::bitwise_cast(packed), xsimd::make_batch_constant>())); } private: // A version that produces uint8_ts template static inline xsimd::batch TileU(xsimd::batch quant_mult, const float *input0, const float *input1, const float *input2, const float *input3) { using batch8 = xsimd::batch; using batch16 = xsimd::batch; using batch32 = xsimd::batch; using ubatch32 = xsimd::batch; const batch8 neg127 = -127; const batch8 pos127 = +127; // Grab 4 registers at a time in 32-bit format. batch32 g0 = QuantizerGrab(input0, quant_mult); batch32 g1 = QuantizerGrab(input1, quant_mult); batch32 g2 = QuantizerGrab(input2, quant_mult); batch32 g3 = QuantizerGrab(input3, quant_mult); // Pack 32-bit to 16-bit. batch16 packed0 = deinterleave(g0, g1); batch16 packed1 = deinterleave(g2, g3); // Pack 16-bit to 8-bit. batch8 packed = deinterleave(packed0, packed1); // Ban -128. packed = xsimd::max(packed, neg127); // Could be removed if we use +128 packed = packed + pos127; if (batch32::size == 4) return xsimd::bitwise_cast(packed); // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 // Technically this could be removed so long as the rows are bigger than 16 // and the values are only used for GEMM. return xsimd::bitwise_cast( xsimd::swizzle(xsimd::bitwise_cast(packed), xsimd::make_batch_constant>())); } }; template inline void Transpose16InLane( xsimd::batch &r0, xsimd::batch &r1, xsimd::batch &r2, xsimd::batch &r3, xsimd::batch &r4, xsimd::batch &r5, xsimd::batch &r6, xsimd::batch &r7) { /* r0: columns 0 1 2 3 4 5 6 7 from row 0 r1: columns 0 1 2 3 4 5 6 7 from row 1*/ auto r0_16 = xsimd::bitwise_cast(r0); auto r1_16 = xsimd::bitwise_cast(r1); auto r2_16 = xsimd::bitwise_cast(r2); auto r3_16 = xsimd::bitwise_cast(r3); auto r4_16 = xsimd::bitwise_cast(r4); auto r5_16 = xsimd::bitwise_cast(r5); auto r6_16 = xsimd::bitwise_cast(r6); auto r7_16 = xsimd::bitwise_cast(r7); std::tie(r0_16, r1_16) = interleave(r0_16, r1_16); std::tie(r2_16, r3_16) = interleave(r2_16, r3_16); std::tie(r4_16, r5_16) = interleave(r4_16, r5_16); std::tie(r6_16, r7_16) = interleave(r6_16, r7_16); /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1 r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1 r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3 r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3 r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5 r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5 r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7 r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ auto r0_32 = xsimd::bitwise_cast(r0_16); auto r2_32 = xsimd::bitwise_cast(r2_16); auto r1_32 = xsimd::bitwise_cast(r1_16); auto r3_32 = xsimd::bitwise_cast(r3_16); auto r4_32 = xsimd::bitwise_cast(r4_16); auto r6_32 = xsimd::bitwise_cast(r6_16); auto r5_32 = xsimd::bitwise_cast(r5_16); auto r7_32 = xsimd::bitwise_cast(r7_16); std::tie(r0_32, r2_32) = interleave(r0_32, r2_32); std::tie(r1_32, r3_32) = interleave(r1_32, r3_32); std::tie(r4_32, r6_32) = interleave(r4_32, r6_32); std::tie(r5_32, r7_32) = interleave(r5_32, r7_32); /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3 r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3 r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3 r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3 r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7 r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7 r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7 r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ auto r0_64 = xsimd::bitwise_cast(r0_32); auto r2_64 = xsimd::bitwise_cast(r2_32); auto r1_64 = xsimd::bitwise_cast(r1_32); auto r3_64 = xsimd::bitwise_cast(r3_32); auto r4_64 = xsimd::bitwise_cast(r4_32); auto r6_64 = xsimd::bitwise_cast(r6_32); auto r5_64 = xsimd::bitwise_cast(r5_32); auto r7_64 = xsimd::bitwise_cast(r7_32); std::tie(r0_64, r4_64) = interleave(r0_64, r4_64); std::tie(r1_64, r5_64) = interleave(r1_64, r5_64); std::tie(r2_64, r6_64) = interleave(r2_64, r6_64); std::tie(r3_64, r7_64) = interleave(r3_64, r7_64); r0 = xsimd::bitwise_cast(r0_64); r1 = xsimd::bitwise_cast(r1_64); r2 = xsimd::bitwise_cast(r2_64); r3 = xsimd::bitwise_cast(r3_64); r4 = xsimd::bitwise_cast(r4_64); r5 = xsimd::bitwise_cast(r5_64); r6 = xsimd::bitwise_cast(r6_64); r7 = xsimd::bitwise_cast(r7_64); /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7 r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7 r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7 r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7 r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7 r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ /* Empirically gcc is able to remove these movs and just rename the outputs of * Interleave64. */ std::swap(r1, r4); std::swap(r3, r6); } template void SelectColumnsOfB(const xsimd::batch *input, xsimd::batch *output, size_t rows_bytes /* number of bytes in a row */, const IntegerTy *cols_begin, const IntegerTy *cols_end) { using batch8 = xsimd::batch; /* Do columns for multiples of 8.*/ size_t register_rows = rows_bytes / batch8::size; const batch8 *starts[8]; for (; cols_begin != cols_end; cols_begin += 8) { for (size_t k = 0; k < 8; ++k) { starts[k] = input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; } for (size_t r = 0; r < register_rows; ++r) { for (size_t k = 0; k < 8; ++k) { *(output++) = *starts[k]; starts[k] += 8; } } } } } // namespace namespace callbacks { template xsimd::batch Unquantize::operator()(xsimd::batch total, size_t, size_t, size_t) { return xsimd::batch_cast(total) * unquant_mult; } template std::tuple, xsimd::batch> Unquantize::operator()( std::tuple, xsimd::batch> total, size_t, size_t, size_t) { return std::make_tuple( xsimd::batch_cast(std::get<0>(total)) * unquant_mult, xsimd::batch_cast(std::get<1>(total)) * unquant_mult); } template xsimd::batch AddBias::operator()(xsimd::batch total, size_t, size_t col_idx, size_t) { return total + xsimd::batch::load_aligned(bias_addr + col_idx); } template std::tuple, xsimd::batch> AddBias::operator()( std::tuple, xsimd::batch> total, size_t, size_t col_idx, size_t) { return std::make_tuple( std::get<0>(total) + xsimd::batch::load_aligned(bias_addr + col_idx + 0), std::get<1>(total) + xsimd::batch::load_aligned(bias_addr + col_idx + xsimd::batch::size)); } template void Write::operator()(xsimd::batch result, size_t row_idx, size_t col_idx, size_t col_size) { result.store_aligned(output_addr + row_idx * col_size + col_idx); } template void Write::operator()(xsimd::batch result, size_t row_idx, size_t col_idx, size_t col_size) { xsimd::bitwise_cast(result).store_aligned( output_addr + row_idx * col_size + col_idx); } template void Write::operator()( std::tuple, xsimd::batch> result, size_t row_idx, size_t col_idx, size_t col_size) { std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx + 0); std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx + xsimd::batch::size); } template void Write::operator()( std::tuple, xsimd::batch> result, size_t row_idx, size_t col_idx, size_t col_size) { xsimd::bitwise_cast(std::get<0>(result)) .store_aligned(output_addr + row_idx * col_size + col_idx + 0); xsimd::bitwise_cast(std::get<1>(result)) .store_aligned(output_addr + row_idx * col_size + col_idx + xsimd::batch::size); } template void UnquantizeAndWrite::operator()(T const &total, size_t row_idx, size_t col_idx, size_t col_size) { auto unquantized = unquantize(total, row_idx, col_idx, col_size); write(unquantized, row_idx, col_idx, col_size); } template void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx, size_t col_idx, size_t col_size) { auto unquantized = unquantize(total, row_idx, col_idx, col_size); auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size); write(bias_added, row_idx, col_idx, col_size); } } // namespace callbacks template void Engine::QuantizeU(const float *input, uint8_t *output, float quant_mult, size_t size) { using batch8 = xsimd::batch; xsimd::batch q(quant_mult); const float *end = input + size; for (; input != end; input += batch8::size, output += batch8::size) { auto tile = QuantizeTile8::ConsecutiveU(q, input); tile.store_aligned(output); } } template void Engine::Quantize(const float *const input, int8_t *const output, float quant_mult, size_t size) { using batch8 = xsimd::batch; const std::size_t kBatch = batch8::size; const std::size_t fast_end = size & ~(kBatch - 1); xsimd::batch q(quant_mult); for (std::size_t i = 0; i < fast_end; i += kBatch) { auto tile = QuantizeTile8::Consecutive(q, input + i); tile.store_aligned(output + i); } std::size_t overhang = size & (kBatch - 1); if (!overhang) return; /* Each does size(xsimd::batch) / 32 == kBatch / 4 floats at a * time. If we're allowed to read one of them, then we can read the whole * register. */ const float *inputs[4]; std::size_t i; for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) { inputs[i] = &input[fast_end + i * (kBatch / 4)]; } /* These will be clipped off. */ for (; i < 4; ++i) { inputs[i] = &input[fast_end]; } auto result = QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]); std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang); } template template void Engine::SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, const IntegerTy *cols_begin, const IntegerTy *cols_end) { using batch8 = xsimd::batch; SelectColumnsOfB(reinterpret_cast(input), reinterpret_cast(output), rows, cols_begin, cols_end); } template void Engine::PrepareBTransposed(const float *input, int8_t *output, float quant_mult, size_t cols, size_t rows) { using batch8 = xsimd::batch; const size_t RegisterElemsInt = batch8::size; const size_t kColStride = 8; xsimd::batch q(quant_mult); auto *output_it = reinterpret_cast(output); size_t r = 0; size_t c = 0; while (r < rows) { for (size_t ri = 0; ri < 8; ++ri) *output_it++ = QuantizeTile8::ConsecutiveWithWrapping( q, input + (r + ri) * cols + c, cols - c, cols, 8); c += RegisterElemsInt; while (c >= cols) { r += kColStride; c -= cols; } } } template void Engine::PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, size_t cols, size_t rows) { using batch8 = xsimd::batch; const size_t RegisterElems = batch8::size; const size_t kColStride = 8; auto *output_it = reinterpret_cast(output); for (size_t r = 0; r < rows; r += kColStride) for (size_t c = 0; c < cols; c += RegisterElems) for (size_t ri = 0; ri < 8; ++ri) *output_it++ = *reinterpret_cast(input + (r + ri) * cols + c); } template void Engine::PrepareB(const float *input, int8_t *output_shadow, float quant_mult, size_t rows, size_t cols) { using batch8 = xsimd::batch; xsimd::batch q(quant_mult); /* Currently all multipliers have a stride of 8 columns.*/ const size_t kColStride = 8; auto *output = reinterpret_cast(output_shadow); for (size_t c = 0; c < cols; c += kColStride) { for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) { output[0] = QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols); output[1] = QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols); output[2] = QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols); output[3] = QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols); output[4] = QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols); output[5] = QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols); output[6] = QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols); output[7] = QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols); std::tie(output[0], output[1]) = interleave(xsimd::bitwise_cast(output[0]), xsimd::bitwise_cast(output[1])); std::tie(output[2], output[3]) = interleave(xsimd::bitwise_cast(output[2]), xsimd::bitwise_cast(output[3])); std::tie(output[4], output[5]) = interleave(xsimd::bitwise_cast(output[4]), xsimd::bitwise_cast(output[5])); std::tie(output[6], output[7]) = interleave(xsimd::bitwise_cast(output[6]), xsimd::bitwise_cast(output[7])); Transpose16InLane(output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7]); } } } template void Engine::PrepareA(const float *input, int8_t *output, float quant_mult, size_t rows, size_t cols) { Quantize(input, output, quant_mult, rows * cols); } template void Engine::Shift::PrepareA(const float *input, uint8_t *output, float quant_mult, size_t rows, size_t cols) { QuantizeU(input, output, quant_mult, rows * cols); } template template void Engine::Shift::Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, size_t width, size_t B_cols, Callback callback) { using batch8 = xsimd::batch; using ubatch8 = xsimd::batch; using batch16 = xsimd::batch; using batch32 = xsimd::batch; const size_t simd_width = width / batch8::size; for (size_t B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { const auto *B0_col = reinterpret_cast(B) + simd_width * B0_colidx; /* Process one row of A at a time. Doesn't seem to be faster to do multiple * rows of A at once.*/ for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { const auto *A_row = reinterpret_cast(A + A_rowidx * width); /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. Iterate over shared (inner) dimension.*/ size_t k = 0; ubatch8 a = *(A_row + k); batch16 sum0 = madd(a, *(B0_col + k * 8)); batch16 sum1 = madd(a, *(B0_col + k * 8 + 1)); batch16 sum2 = madd(a, *(B0_col + k * 8 + 2)); batch16 sum3 = madd(a, *(B0_col + k * 8 + 3)); batch16 sum4 = madd(a, *(B0_col + k * 8 + 4)); batch16 sum5 = madd(a, *(B0_col + k * 8 + 5)); batch16 sum6 = madd(a, *(B0_col + k * 8 + 6)); batch16 sum7 = madd(a, *(B0_col + k * 8 + 7)); /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is * declared here.*/ batch16 ones(1); batch32 isum0 = madd(sum0, ones); batch32 isum1 = madd(sum1, ones); batch32 isum2 = madd(sum2, ones); batch32 isum3 = madd(sum3, ones); batch32 isum4 = madd(sum4, ones); batch32 isum5 = madd(sum5, ones); batch32 isum6 = madd(sum6, ones); batch32 isum7 = madd(sum7, ones); for (k = 1; k < simd_width; ++k) { a = *(A_row + k); /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ batch16 mult0 = madd(a, *(B0_col + k * 8)); batch16 mult1 = madd(a, *(B0_col + k * 8 + 1)); batch16 mult2 = madd(a, *(B0_col + k * 8 + 2)); batch16 mult3 = madd(a, *(B0_col + k * 8 + 3)); batch16 mult4 = madd(a, *(B0_col + k * 8 + 4)); batch16 mult5 = madd(a, *(B0_col + k * 8 + 5)); batch16 mult6 = madd(a, *(B0_col + k * 8 + 6)); batch16 mult7 = madd(a, *(B0_col + k * 8 + 7)); /* Upcast to 32-bit and horizontally add.*/ batch32 imult0 = madd(mult0, ones); batch32 imult1 = madd(mult1, ones); batch32 imult2 = madd(mult2, ones); batch32 imult3 = madd(mult3, ones); batch32 imult4 = madd(mult4, ones); batch32 imult5 = madd(mult5, ones); batch32 imult6 = madd(mult6, ones); batch32 imult7 = madd(mult7, ones); /*Add in 32bit*/ isum0 += imult0; isum1 += imult1; isum2 += imult2; isum3 += imult3; isum4 += imult4; isum5 += imult5; isum6 += imult6; isum7 += imult7; } /* Reduce sums within 128-bit lanes.*/ auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); /*The specific implementation may need to reduce further.*/ auto total = PermuteSummer(pack0123, pack4567); callback(total, A_rowidx, B0_colidx, B_cols); } } } template template void Engine::Shift::PrepareBias(const int8_t *B, size_t width, size_t B_cols, Callback C) { using batch8 = xsimd::batch; using batch16 = xsimd::batch; const size_t simd_width = width / batch8::size; xsimd::batch a(1); for (size_t j = 0; j < B_cols; j += 8) { /*Process one row of A at a time. Doesn't seem to be faster to do multiple * rows of A at once.*/ const int8_t *B_j = B + j * width; /* Rather than initializing as zeros and adding, just initialize the * first.*/ /* These will be packed 16-bit integers containing sums for each column of * B multiplied by the row of A.*/ auto sum0 = madd(a, batch8::load_aligned(&B_j[0 * batch8::size])); auto sum1 = madd(a, batch8::load_aligned(&B_j[1 * batch8::size])); auto sum2 = madd(a, batch8::load_aligned(&B_j[2 * batch8::size])); auto sum3 = madd(a, batch8::load_aligned(&B_j[3 * batch8::size])); auto sum4 = madd(a, batch8::load_aligned(&B_j[4 * batch8::size])); auto sum5 = madd(a, batch8::load_aligned(&B_j[5 * batch8::size])); auto sum6 = madd(a, batch8::load_aligned(&B_j[6 * batch8::size])); auto sum7 = madd(a, batch8::load_aligned(&B_j[7 * batch8::size])); B_j += 8 * batch8::size; /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is * declared here.*/ batch16 ones(1); auto isum0 = madd(sum0, ones); auto isum1 = madd(sum1, ones); auto isum2 = madd(sum2, ones); auto isum3 = madd(sum3, ones); auto isum4 = madd(sum4, ones); auto isum5 = madd(sum5, ones); auto isum6 = madd(sum6, ones); auto isum7 = madd(sum7, ones); for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) { isum0 += madd(madd(a, batch8::load_aligned(&B_j[0 * batch8::size])), ones); isum1 += madd(madd(a, batch8::load_aligned(&B_j[1 * batch8::size])), ones); isum2 += madd(madd(a, batch8::load_aligned(&B_j[2 * batch8::size])), ones); isum3 += madd(madd(a, batch8::load_aligned(&B_j[3 * batch8::size])), ones); isum4 += madd(madd(a, batch8::load_aligned(&B_j[4 * batch8::size])), ones); isum5 += madd(madd(a, batch8::load_aligned(&B_j[5 * batch8::size])), ones); isum6 += madd(madd(a, batch8::load_aligned(&B_j[6 * batch8::size])), ones); isum7 += madd(madd(a, batch8::load_aligned(&B_j[7 * batch8::size])), ones); } auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); auto total = PermuteSummer(pack0123, pack4567); C(total, 0, j, B_cols); } } } // namespace gemmology #endif