From 26a029d407be480d791972afb5975cf62c9360a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 19 Apr 2024 02:47:55 +0200 Subject: Adding upstream version 124.0.1. Signed-off-by: Daniel Baumann --- third_party/aom/aom_dsp/arm/aom_convolve8_neon.c | 349 +++ .../aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c | 460 +++ .../aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c | 408 +++ .../aom/aom_dsp/arm/aom_convolve_copy_neon.c | 154 + third_party/aom/aom_dsp/arm/avg_neon.c | 309 ++ third_party/aom/aom_dsp/arm/avg_pred_neon.c | 221 ++ third_party/aom/aom_dsp/arm/avg_sve.c | 62 + third_party/aom/aom_dsp/arm/blend_a64_mask_neon.c | 492 ++++ third_party/aom/aom_dsp/arm/blend_neon.h | 125 + third_party/aom/aom_dsp/arm/blk_sse_sum_neon.c | 124 + third_party/aom/aom_dsp/arm/blk_sse_sum_sve.c | 106 + third_party/aom/aom_dsp/arm/dist_wtd_avg_neon.h | 65 + third_party/aom/aom_dsp/arm/dot_sve.h | 42 + third_party/aom/aom_dsp/arm/fwd_txfm_neon.c | 304 ++ third_party/aom/aom_dsp/arm/hadamard_neon.c | 325 ++ third_party/aom/aom_dsp/arm/highbd_avg_neon.c | 125 + third_party/aom/aom_dsp/arm/highbd_avg_pred_neon.c | 190 ++ .../aom/aom_dsp/arm/highbd_blend_a64_hmask_neon.c | 97 + .../aom/aom_dsp/arm/highbd_blend_a64_mask_neon.c | 473 +++ .../aom/aom_dsp/arm/highbd_blend_a64_vmask_neon.c | 105 + .../aom/aom_dsp/arm/highbd_convolve8_neon.c | 363 +++ third_party/aom/aom_dsp/arm/highbd_hadamard_neon.c | 213 ++ .../aom/aom_dsp/arm/highbd_intrapred_neon.c | 2730 +++++++++++++++++ .../aom/aom_dsp/arm/highbd_loopfilter_neon.c | 1265 ++++++++ .../aom/aom_dsp/arm/highbd_masked_sad_neon.c | 354 +++ third_party/aom/aom_dsp/arm/highbd_obmc_sad_neon.c | 211 ++ .../aom/aom_dsp/arm/highbd_obmc_variance_neon.c | 369 +++ third_party/aom/aom_dsp/arm/highbd_quantize_neon.c | 431 +++ third_party/aom/aom_dsp/arm/highbd_sad_neon.c | 509 ++++ third_party/aom/aom_dsp/arm/highbd_sadxd_neon.c | 617 ++++ third_party/aom/aom_dsp/arm/highbd_sse_neon.c | 284 ++ third_party/aom/aom_dsp/arm/highbd_sse_sve.c | 215 ++ .../aom/aom_dsp/arm/highbd_subpel_variance_neon.c | 1497 ++++++++++ third_party/aom/aom_dsp/arm/highbd_variance_neon.c | 502 ++++ .../aom/aom_dsp/arm/highbd_variance_neon_dotprod.c | 92 + third_party/aom/aom_dsp/arm/highbd_variance_sve.c | 430 +++ third_party/aom/aom_dsp/arm/intrapred_neon.c | 3110 ++++++++++++++++++++ third_party/aom/aom_dsp/arm/loopfilter_neon.c | 1045 +++++++ third_party/aom/aom_dsp/arm/masked_sad4d_neon.c | 562 ++++ third_party/aom/aom_dsp/arm/masked_sad_neon.c | 244 ++ third_party/aom/aom_dsp/arm/mem_neon.h | 1253 ++++++++ third_party/aom/aom_dsp/arm/obmc_sad_neon.c | 250 ++ third_party/aom/aom_dsp/arm/obmc_variance_neon.c | 290 ++ third_party/aom/aom_dsp/arm/reinterpret_neon.h | 33 + third_party/aom/aom_dsp/arm/sad_neon.c | 873 ++++++ third_party/aom/aom_dsp/arm/sad_neon_dotprod.c | 530 ++++ third_party/aom/aom_dsp/arm/sadxd_neon.c | 514 ++++ third_party/aom/aom_dsp/arm/sadxd_neon_dotprod.c | 289 ++ third_party/aom/aom_dsp/arm/sse_neon.c | 210 ++ third_party/aom/aom_dsp/arm/sse_neon_dotprod.c | 223 ++ third_party/aom/aom_dsp/arm/subpel_variance_neon.c | 1103 +++++++ third_party/aom/aom_dsp/arm/subtract_neon.c | 166 ++ third_party/aom/aom_dsp/arm/sum_neon.h | 311 ++ third_party/aom/aom_dsp/arm/sum_squares_neon.c | 574 ++++ .../aom/aom_dsp/arm/sum_squares_neon_dotprod.c | 154 + third_party/aom/aom_dsp/arm/sum_squares_sve.c | 402 +++ third_party/aom/aom_dsp/arm/transpose_neon.h | 1263 ++++++++ third_party/aom/aom_dsp/arm/variance_neon.c | 470 +++ .../aom/aom_dsp/arm/variance_neon_dotprod.c | 314 ++ 59 files changed, 28801 insertions(+) create mode 100644 third_party/aom/aom_dsp/arm/aom_convolve8_neon.c create mode 100644 third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c create mode 100644 third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c create mode 100644 third_party/aom/aom_dsp/arm/aom_convolve_copy_neon.c create mode 100644 third_party/aom/aom_dsp/arm/avg_neon.c create mode 100644 third_party/aom/aom_dsp/arm/avg_pred_neon.c create mode 100644 third_party/aom/aom_dsp/arm/avg_sve.c create mode 100644 third_party/aom/aom_dsp/arm/blend_a64_mask_neon.c create mode 100644 third_party/aom/aom_dsp/arm/blend_neon.h create mode 100644 third_party/aom/aom_dsp/arm/blk_sse_sum_neon.c create mode 100644 third_party/aom/aom_dsp/arm/blk_sse_sum_sve.c create mode 100644 third_party/aom/aom_dsp/arm/dist_wtd_avg_neon.h create mode 100644 third_party/aom/aom_dsp/arm/dot_sve.h create mode 100644 third_party/aom/aom_dsp/arm/fwd_txfm_neon.c create mode 100644 third_party/aom/aom_dsp/arm/hadamard_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_avg_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_avg_pred_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_blend_a64_hmask_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_blend_a64_mask_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_blend_a64_vmask_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_convolve8_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_hadamard_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_intrapred_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_loopfilter_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_masked_sad_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_obmc_sad_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_obmc_variance_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_quantize_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_sad_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_sadxd_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_sse_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_sse_sve.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_subpel_variance_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_variance_neon.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_variance_neon_dotprod.c create mode 100644 third_party/aom/aom_dsp/arm/highbd_variance_sve.c create mode 100644 third_party/aom/aom_dsp/arm/intrapred_neon.c create mode 100644 third_party/aom/aom_dsp/arm/loopfilter_neon.c create mode 100644 third_party/aom/aom_dsp/arm/masked_sad4d_neon.c create mode 100644 third_party/aom/aom_dsp/arm/masked_sad_neon.c create mode 100644 third_party/aom/aom_dsp/arm/mem_neon.h create mode 100644 third_party/aom/aom_dsp/arm/obmc_sad_neon.c create mode 100644 third_party/aom/aom_dsp/arm/obmc_variance_neon.c create mode 100644 third_party/aom/aom_dsp/arm/reinterpret_neon.h create mode 100644 third_party/aom/aom_dsp/arm/sad_neon.c create mode 100644 third_party/aom/aom_dsp/arm/sad_neon_dotprod.c create mode 100644 third_party/aom/aom_dsp/arm/sadxd_neon.c create mode 100644 third_party/aom/aom_dsp/arm/sadxd_neon_dotprod.c create mode 100644 third_party/aom/aom_dsp/arm/sse_neon.c create mode 100644 third_party/aom/aom_dsp/arm/sse_neon_dotprod.c create mode 100644 third_party/aom/aom_dsp/arm/subpel_variance_neon.c create mode 100644 third_party/aom/aom_dsp/arm/subtract_neon.c create mode 100644 third_party/aom/aom_dsp/arm/sum_neon.h create mode 100644 third_party/aom/aom_dsp/arm/sum_squares_neon.c create mode 100644 third_party/aom/aom_dsp/arm/sum_squares_neon_dotprod.c create mode 100644 third_party/aom/aom_dsp/arm/sum_squares_sve.c create mode 100644 third_party/aom/aom_dsp/arm/transpose_neon.h create mode 100644 third_party/aom/aom_dsp/arm/variance_neon.c create mode 100644 third_party/aom/aom_dsp/arm/variance_neon_dotprod.c (limited to 'third_party/aom/aom_dsp/arm') diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c new file mode 100644 index 0000000000..7441108b01 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2014 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_ports/mem.h" + +static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1, + const int16x4_t s2, const int16x4_t s3, + const int16x4_t s4, const int16x4_t s5, + const int16x4_t s6, const int16x4_t s7, + const int16x8_t filter) { + const int16x4_t filter_lo = vget_low_s16(filter); + const int16x4_t filter_hi = vget_high_s16(filter); + int16x4_t sum; + + sum = vmul_lane_s16(s0, filter_lo, 0); + sum = vmla_lane_s16(sum, s1, filter_lo, 1); + sum = vmla_lane_s16(sum, s2, filter_lo, 2); + sum = vmla_lane_s16(sum, s5, filter_hi, 1); + sum = vmla_lane_s16(sum, s6, filter_hi, 2); + sum = vmla_lane_s16(sum, s7, filter_hi, 3); + sum = vqadd_s16(sum, vmul_lane_s16(s3, filter_lo, 3)); + sum = vqadd_s16(sum, vmul_lane_s16(s4, filter_hi, 0)); + return sum; +} + +static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1, + const int16x8_t s2, const int16x8_t s3, + const int16x8_t s4, const int16x8_t s5, + const int16x8_t s6, const int16x8_t s7, + const int16x8_t filter) { + const int16x4_t filter_lo = vget_low_s16(filter); + const int16x4_t filter_hi = vget_high_s16(filter); + int16x8_t sum; + + sum = vmulq_lane_s16(s0, filter_lo, 0); + sum = vmlaq_lane_s16(sum, s1, filter_lo, 1); + sum = vmlaq_lane_s16(sum, s2, filter_lo, 2); + sum = vmlaq_lane_s16(sum, s5, filter_hi, 1); + sum = vmlaq_lane_s16(sum, s6, filter_hi, 2); + sum = vmlaq_lane_s16(sum, s7, filter_hi, 3); + sum = vqaddq_s16(sum, vmulq_lane_s16(s3, filter_lo, 3)); + sum = vqaddq_s16(sum, vmulq_lane_s16(s4, filter_hi, 0)); + return vqrshrun_n_s16(sum, FILTER_BITS); +} + +void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, int w, + int h) { + const int16x8_t filter = vld1q_s16(filter_x); + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)x_step_q4; + (void)filter_y; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1); + + if (h == 4) { + uint8x8_t t0, t1, t2, t3, d01, d23; + int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3; + + load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3); + transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3); + s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + + src += 7; + + do { + load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3); + transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3); + s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + + d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter); + d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter); + d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter); + d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + transpose_elems_inplace_u8_4x4(&d01, &d23); + + store_u8x4_strided_x2(dst + 0 * dst_stride, 2 * dst_stride, d01); + store_u8x4_strided_x2(dst + 1 * dst_stride, 2 * dst_stride, d23); + + s0 = s4; + s1 = s5; + s2 = s6; + s3 = s7; + s4 = s8; + s5 = s9; + s6 = s10; + src += 4; + dst += 4; + w -= 4; + } while (w != 0); + } else { + uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7, d0, d1, d2, d3; + int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; + + if (w == 4) { + do { + load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); + transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); + s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); + s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); + s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); + s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); + s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); + s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); + s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); + + load_u8_8x8(src + 7, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, + &t7); + transpose_elems_u8_4x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2, + &t3); + s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); + s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); + s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); + s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); + + d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); + d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); + d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); + d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); + + transpose_elems_inplace_u8_8x4(&d0, &d1, &d2, &d3); + + store_u8x4_strided_x2(dst + 0 * dst_stride, 4 * dst_stride, d0); + store_u8x4_strided_x2(dst + 1 * dst_stride, 4 * dst_stride, d1); + store_u8x4_strided_x2(dst + 2 * dst_stride, 4 * dst_stride, d2); + store_u8x4_strided_x2(dst + 3 * dst_stride, 4 * dst_stride, d3); + + src += 8 * src_stride; + dst += 8 * dst_stride; + h -= 8; + } while (h > 0); + } else { + uint8x8_t d4, d5, d6, d7; + int16x8_t s11, s12, s13, s14; + int width; + const uint8_t *s; + uint8_t *d; + + do { + load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); + transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); + s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); + s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); + s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); + s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); + s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); + s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); + s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); + + width = w; + s = src + 7; + d = dst; + + do { + load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); + transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, + &t7); + s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); + s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); + s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); + s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); + s11 = vreinterpretq_s16_u16(vmovl_u8(t4)); + s12 = vreinterpretq_s16_u16(vmovl_u8(t5)); + s13 = vreinterpretq_s16_u16(vmovl_u8(t6)); + s14 = vreinterpretq_s16_u16(vmovl_u8(t7)); + + d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); + d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); + d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); + d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); + d4 = convolve8_8(s4, s5, s6, s7, s8, s9, s10, s11, filter); + d5 = convolve8_8(s5, s6, s7, s8, s9, s10, s11, s12, filter); + d6 = convolve8_8(s6, s7, s8, s9, s10, s11, s12, s13, filter); + d7 = convolve8_8(s7, s8, s9, s10, s11, s12, s13, s14, filter); + + transpose_elems_inplace_u8_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, + &d7); + + store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7); + + s0 = s8; + s1 = s9; + s2 = s10; + s3 = s11; + s4 = s12; + s5 = s13; + s6 = s14; + s += 8; + d += 8; + width -= 8; + } while (width != 0); + src += 8 * src_stride; + dst += 8 * dst_stride; + h -= 8; + } while (h > 0); + } + } +} + +void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, int w, + int h) { + const int16x8_t filter = vld1q_s16(filter_y); + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)filter_x; + (void)x_step_q4; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1) * src_stride; + + if (w == 4) { + uint8x8_t t0, t1, t2, t3, t4, t5, t6, d01, d23; + int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3; + + load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); + s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + s4 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t4))); + s5 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t5))); + s6 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t6))); + + src += 7 * src_stride; + + do { + load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3); + s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + + d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter); + d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter); + d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter); + d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); + store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); + + s0 = s4; + s1 = s5; + s2 = s6; + s3 = s7; + s4 = s8; + s5 = s9; + s6 = s10; + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h != 0); + } else { + uint8x8_t t0, t1, t2, t3, t4, t5, t6, d0, d1, d2, d3; + int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; + int height; + const uint8_t *s; + uint8_t *d; + + do { + load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); + s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); + s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); + s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); + s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); + s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); + s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); + s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); + + height = h; + s = src + 7 * src_stride; + d = dst; + + do { + load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3); + s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); + s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); + s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); + s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); + + d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); + d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); + d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); + d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + s0 = s4; + s1 = s5; + s2 = s6; + s3 = s7; + s4 = s8; + s5 = s9; + s6 = s10; + s += 4 * src_stride; + d += 4 * dst_stride; + height -= 4; + } while (height != 0); + src += 8; + dst += 8; + w -= 8; + } while (w != 0); + } +} diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c new file mode 100644 index 0000000000..ac0a6efd00 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c @@ -0,0 +1,460 @@ +/* + * Copyright (c) 2014 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_ports/mem.h" + +DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = { + 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, + 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, + 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 +}; + +DECLARE_ALIGNED(16, static const uint8_t, dot_prod_tran_concat_tbl[32]) = { + 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, + 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 +}; + +DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = { + /* Shift left and insert new last column in transposed 4x4 block. */ + 1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28, + /* Shift left and insert two new columns in transposed 4x4 block. */ + 2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29, + /* Shift left and insert three new columns in transposed 4x4 block. */ + 3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30 +}; + +static INLINE int16x4_t convolve8_4_sdot(uint8x16_t samples, + const int8x8_t filter, + const int32x4_t correction, + const uint8x16_t range_limit, + const uint8x16x2_t permute_tbl) { + int8x16_t clamped_samples, permuted_samples[2]; + int32x4_t sum; + + /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ + clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit)); + + /* Permute samples ready for dot product. */ + /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ + permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]); + /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ + permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]); + + /* Accumulate dot product into 'correction' to account for range clamp. */ + sum = vdotq_lane_s32(correction, permuted_samples[0], filter, 0); + sum = vdotq_lane_s32(sum, permuted_samples[1], filter, 1); + + /* Further narrowing and packing is performed by the caller. */ + return vqmovn_s32(sum); +} + +static INLINE uint8x8_t convolve8_8_sdot(uint8x16_t samples, + const int8x8_t filter, + const int32x4_t correction, + const uint8x16_t range_limit, + const uint8x16x3_t permute_tbl) { + int8x16_t clamped_samples, permuted_samples[3]; + int32x4_t sum0, sum1; + int16x8_t sum; + + /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ + clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit)); + + /* Permute samples ready for dot product. */ + /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ + permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]); + /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ + permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]); + /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */ + permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]); + + /* Accumulate dot product into 'correction' to account for range clamp. */ + /* First 4 output values. */ + sum0 = vdotq_lane_s32(correction, permuted_samples[0], filter, 0); + sum0 = vdotq_lane_s32(sum0, permuted_samples[1], filter, 1); + /* Second 4 output values. */ + sum1 = vdotq_lane_s32(correction, permuted_samples[1], filter, 0); + sum1 = vdotq_lane_s32(sum1, permuted_samples[2], filter, 1); + + /* Narrow and re-pack. */ + sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); + return vqrshrun_n_s16(sum, FILTER_BITS); +} + +void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, + int w, int h) { + const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x)); + const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_x), 128); + const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp)); + const uint8x16_t range_limit = vdupq_n_u8(128); + uint8x16_t s0, s1, s2, s3; + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)x_step_q4; + (void)filter_y; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1); + + if (w == 4) { + const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl); + do { + int16x4_t t0, t1, t2, t3; + uint8x8_t d01, d23; + + load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); + + t0 = convolve8_4_sdot(s0, filter, correction, range_limit, perm_tbl); + t1 = convolve8_4_sdot(s1, filter, correction, range_limit, perm_tbl); + t2 = convolve8_4_sdot(s2, filter, correction, range_limit, perm_tbl); + t3 = convolve8_4_sdot(s3, filter, correction, range_limit, perm_tbl); + d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS); + + store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); + store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); + + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 0); + } else { + const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); + const uint8_t *s; + uint8_t *d; + int width; + uint8x8_t d0, d1, d2, d3; + + do { + width = w; + s = src; + d = dst; + do { + load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); + + d0 = convolve8_8_sdot(s0, filter, correction, range_limit, perm_tbl); + d1 = convolve8_8_sdot(s1, filter, correction, range_limit, perm_tbl); + d2 = convolve8_8_sdot(s2, filter, correction, range_limit, perm_tbl); + d3 = convolve8_8_sdot(s3, filter, correction, range_limit, perm_tbl); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + s += 8; + d += 8; + width -= 8; + } while (width != 0); + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 0); + } +} + +static INLINE void transpose_concat_4x4(int8x8_t a0, int8x8_t a1, int8x8_t a2, + int8x8_t a3, int8x16_t *b, + const uint8x16_t permute_tbl) { + /* Transpose 8-bit elements and concatenate result rows as follows: + * a0: 00, 01, 02, 03, XX, XX, XX, XX + * a1: 10, 11, 12, 13, XX, XX, XX, XX + * a2: 20, 21, 22, 23, XX, XX, XX, XX + * a3: 30, 31, 32, 33, XX, XX, XX, XX + * + * b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + * + * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it + * as an argument is preferable to loading it directly from memory as this + * inline helper is called many times from the same parent function. + */ + + int8x16x2_t samples = { { vcombine_s8(a0, a1), vcombine_s8(a2, a3) } }; + *b = vqtbl2q_s8(samples, permute_tbl); +} + +static INLINE void transpose_concat_8x4(int8x8_t a0, int8x8_t a1, int8x8_t a2, + int8x8_t a3, int8x16_t *b0, + int8x16_t *b1, + const uint8x16x2_t permute_tbl) { + /* Transpose 8-bit elements and concatenate result rows as follows: + * a0: 00, 01, 02, 03, 04, 05, 06, 07 + * a1: 10, 11, 12, 13, 14, 15, 16, 17 + * a2: 20, 21, 22, 23, 24, 25, 26, 27 + * a3: 30, 31, 32, 33, 34, 35, 36, 37 + * + * b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + * b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37 + * + * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it + * as an argument is preferable to loading it directly from memory as this + * inline helper is called many times from the same parent function. + */ + + int8x16x2_t samples = { { vcombine_s8(a0, a1), vcombine_s8(a2, a3) } }; + *b0 = vqtbl2q_s8(samples, permute_tbl.val[0]); + *b1 = vqtbl2q_s8(samples, permute_tbl.val[1]); +} + +static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo, + const int8x16_t samples_hi, + const int32x4_t correction, + const int8x8_t filter) { + /* Sample range-clamping and permutation are performed by the caller. */ + int32x4_t sum; + + /* Accumulate dot product into 'correction' to account for range clamp. */ + sum = vdotq_lane_s32(correction, samples_lo, filter, 0); + sum = vdotq_lane_s32(sum, samples_hi, filter, 1); + + /* Further narrowing and packing is performed by the caller. */ + return vqmovn_s32(sum); +} + +static INLINE uint8x8_t convolve8_8_sdot_partial(const int8x16_t samples0_lo, + const int8x16_t samples0_hi, + const int8x16_t samples1_lo, + const int8x16_t samples1_hi, + const int32x4_t correction, + const int8x8_t filter) { + /* Sample range-clamping and permutation are performed by the caller. */ + int32x4_t sum0, sum1; + int16x8_t sum; + + /* Accumulate dot product into 'correction' to account for range clamp. */ + /* First 4 output values. */ + sum0 = vdotq_lane_s32(correction, samples0_lo, filter, 0); + sum0 = vdotq_lane_s32(sum0, samples0_hi, filter, 1); + /* Second 4 output values. */ + sum1 = vdotq_lane_s32(correction, samples1_lo, filter, 0); + sum1 = vdotq_lane_s32(sum1, samples1_hi, filter, 1); + + /* Narrow and re-pack. */ + sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); + return vqrshrun_n_s16(sum, FILTER_BITS); +} + +void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, + int w, int h) { + const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y)); + const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_y), 128); + const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp)); + const uint8x8_t range_limit = vdup_n_u8(128); + const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl); + uint8x8_t t0, t1, t2, t3, t4, t5, t6; + int8x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; + int8x16x2_t samples_LUT; + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)filter_x; + (void)x_step_q4; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1) * src_stride; + + if (w == 4) { + const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl); + int8x16_t s0123, s1234, s2345, s3456, s4567, s5678, s6789, s78910; + int16x4_t d0, d1, d2, d3; + uint8x8_t d01, d23; + + load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); + src += 7 * src_stride; + + /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ + s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit)); + s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit)); + s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit)); + s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit)); + s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit)); + s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit)); + s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit)); + s7 = vdup_n_s8(0); + s8 = vdup_n_s8(0); + s9 = vdup_n_s8(0); + + /* This operation combines a conventional transpose and the sample permute + * (see horizontal case) required before computing the dot product. + */ + transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl); + transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl); + transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl); + transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl); + transpose_concat_4x4(s4, s5, s6, s7, &s4567, tran_concat_tbl); + transpose_concat_4x4(s5, s6, s7, s8, &s5678, tran_concat_tbl); + transpose_concat_4x4(s6, s7, s8, s9, &s6789, tran_concat_tbl); + + do { + uint8x8_t t7, t8, t9, t10; + + load_u8_8x4(src, src_stride, &t7, &t8, &t9, &t10); + + s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit)); + s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit)); + s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit)); + s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit)); + + transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl); + + /* Merge new data into block from previous iteration. */ + samples_LUT.val[0] = s3456; + samples_LUT.val[1] = s78910; + s4567 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]); + s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]); + s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]); + + d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filter); + d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filter); + d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filter); + d3 = convolve8_4_sdot_partial(s3456, s78910, correction, filter); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); + store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); + + /* Prepare block for next iteration - re-using as much as possible. */ + /* Shuffle everything up four rows. */ + s0123 = s4567; + s1234 = s5678; + s2345 = s6789; + s3456 = s78910; + + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h != 0); + } else { + const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl); + int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi, + s3456_lo, s3456_hi, s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, + s6789_hi, s78910_lo, s78910_hi; + uint8x8_t d0, d1, d2, d3; + const uint8_t *s; + uint8_t *d; + int height; + + do { + height = h; + s = src; + d = dst; + + load_u8_8x7(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); + s += 7 * src_stride; + + /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ + s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit)); + s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit)); + s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit)); + s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit)); + s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit)); + s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit)); + s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit)); + s7 = vdup_n_s8(0); + s8 = vdup_n_s8(0); + s9 = vdup_n_s8(0); + + /* This operation combines a conventional transpose and the sample permute + * (see horizontal case) required before computing the dot product. + */ + transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi, + tran_concat_tbl); + transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi, + tran_concat_tbl); + transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi, + tran_concat_tbl); + transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi, + tran_concat_tbl); + transpose_concat_8x4(s4, s5, s6, s7, &s4567_lo, &s4567_hi, + tran_concat_tbl); + transpose_concat_8x4(s5, s6, s7, s8, &s5678_lo, &s5678_hi, + tran_concat_tbl); + transpose_concat_8x4(s6, s7, s8, s9, &s6789_lo, &s6789_hi, + tran_concat_tbl); + + do { + uint8x8_t t7, t8, t9, t10; + + load_u8_8x4(s, src_stride, &t7, &t8, &t9, &t10); + + s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit)); + s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit)); + s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit)); + s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit)); + + transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi, + tran_concat_tbl); + + /* Merge new data into block from previous iteration. */ + samples_LUT.val[0] = s3456_lo; + samples_LUT.val[1] = s78910_lo; + s4567_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]); + s5678_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]); + s6789_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]); + + samples_LUT.val[0] = s3456_hi; + samples_LUT.val[1] = s78910_hi; + s4567_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]); + s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]); + s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]); + + d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi, + correction, filter); + d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi, + correction, filter); + d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi, + correction, filter); + d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi, + correction, filter); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + /* Prepare block for next iteration - re-using as much as possible. */ + /* Shuffle everything up four rows. */ + s0123_lo = s4567_lo; + s0123_hi = s4567_hi; + s1234_lo = s5678_lo; + s1234_hi = s5678_hi; + s2345_lo = s6789_lo; + s2345_hi = s6789_hi; + s3456_lo = s78910_lo; + s3456_hi = s78910_hi; + + s += 4 * src_stride; + d += 4 * dst_stride; + height -= 4; + } while (height != 0); + src += 8; + dst += 8; + w -= 8; + } while (w != 0); + } +} diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c new file mode 100644 index 0000000000..c314c0a192 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c @@ -0,0 +1,408 @@ +/* + * Copyright (c) 2014 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_ports/mem.h" + +DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = { + 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, + 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, + 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 +}; + +DECLARE_ALIGNED(16, static const uint8_t, dot_prod_tran_concat_tbl[32]) = { + 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, + 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 +}; + +DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = { + /* Shift left and insert new last column in transposed 4x4 block. */ + 1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28, + /* Shift left and insert two new columns in transposed 4x4 block. */ + 2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29, + /* Shift left and insert three new columns in transposed 4x4 block. */ + 3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30 +}; + +static INLINE int16x4_t convolve8_4_usdot(const uint8x16_t samples, + const int8x8_t filter, + const uint8x16x2_t permute_tbl) { + uint8x16_t permuted_samples[2]; + int32x4_t sum; + + /* Permute samples ready for dot product. */ + /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ + permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]); + /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ + permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]); + + sum = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filter, 0); + sum = vusdotq_lane_s32(sum, permuted_samples[1], filter, 1); + + /* Further narrowing and packing is performed by the caller. */ + return vqmovn_s32(sum); +} + +static INLINE uint8x8_t convolve8_8_usdot(const uint8x16_t samples, + const int8x8_t filter, + const uint8x16x3_t permute_tbl) { + uint8x16_t permuted_samples[3]; + int32x4_t sum0, sum1; + int16x8_t sum; + + /* Permute samples ready for dot product. */ + /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ + permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]); + /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ + permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]); + /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */ + permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]); + + /* First 4 output values. */ + sum0 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filter, 0); + sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filter, 1); + /* Second 4 output values. */ + sum1 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filter, 0); + sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filter, 1); + + /* Narrow and re-pack. */ + sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); + return vqrshrun_n_s16(sum, FILTER_BITS); +} + +void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, + int w, int h) { + const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x)); + uint8x16_t s0, s1, s2, s3; + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)x_step_q4; + (void)filter_y; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1); + + if (w == 4) { + const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl); + do { + int16x4_t t0, t1, t2, t3; + uint8x8_t d01, d23; + + load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); + + t0 = convolve8_4_usdot(s0, filter, perm_tbl); + t1 = convolve8_4_usdot(s1, filter, perm_tbl); + t2 = convolve8_4_usdot(s2, filter, perm_tbl); + t3 = convolve8_4_usdot(s3, filter, perm_tbl); + d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS); + + store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); + store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); + + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 0); + } else { + const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); + const uint8_t *s; + uint8_t *d; + int width; + uint8x8_t d0, d1, d2, d3; + + do { + width = w; + s = src; + d = dst; + do { + load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); + + d0 = convolve8_8_usdot(s0, filter, perm_tbl); + d1 = convolve8_8_usdot(s1, filter, perm_tbl); + d2 = convolve8_8_usdot(s2, filter, perm_tbl); + d3 = convolve8_8_usdot(s3, filter, perm_tbl); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + s += 8; + d += 8; + width -= 8; + } while (width != 0); + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 0); + } +} + +static INLINE void transpose_concat_4x4(uint8x8_t a0, uint8x8_t a1, + uint8x8_t a2, uint8x8_t a3, + uint8x16_t *b, + const uint8x16_t permute_tbl) { + /* Transpose 8-bit elements and concatenate result rows as follows: + * a0: 00, 01, 02, 03, XX, XX, XX, XX + * a1: 10, 11, 12, 13, XX, XX, XX, XX + * a2: 20, 21, 22, 23, XX, XX, XX, XX + * a3: 30, 31, 32, 33, XX, XX, XX, XX + * + * b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + * + * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it + * as an argument is preferable to loading it directly from memory as this + * inline helper is called many times from the same parent function. + */ + + uint8x16x2_t samples = { { vcombine_u8(a0, a1), vcombine_u8(a2, a3) } }; + *b = vqtbl2q_u8(samples, permute_tbl); +} + +static INLINE void transpose_concat_8x4(uint8x8_t a0, uint8x8_t a1, + uint8x8_t a2, uint8x8_t a3, + uint8x16_t *b0, uint8x16_t *b1, + const uint8x16x2_t permute_tbl) { + /* Transpose 8-bit elements and concatenate result rows as follows: + * a0: 00, 01, 02, 03, 04, 05, 06, 07 + * a1: 10, 11, 12, 13, 14, 15, 16, 17 + * a2: 20, 21, 22, 23, 24, 25, 26, 27 + * a3: 30, 31, 32, 33, 34, 35, 36, 37 + * + * b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + * b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37 + * + * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it + * as an argument is preferable to loading it directly from memory as this + * inline helper is called many times from the same parent function. + */ + + uint8x16x2_t samples = { { vcombine_u8(a0, a1), vcombine_u8(a2, a3) } }; + *b0 = vqtbl2q_u8(samples, permute_tbl.val[0]); + *b1 = vqtbl2q_u8(samples, permute_tbl.val[1]); +} + +static INLINE int16x4_t convolve8_4_usdot_partial(const uint8x16_t samples_lo, + const uint8x16_t samples_hi, + const int8x8_t filter) { + /* Sample permutation is performed by the caller. */ + int32x4_t sum; + + sum = vusdotq_lane_s32(vdupq_n_s32(0), samples_lo, filter, 0); + sum = vusdotq_lane_s32(sum, samples_hi, filter, 1); + + /* Further narrowing and packing is performed by the caller. */ + return vqmovn_s32(sum); +} + +static INLINE uint8x8_t convolve8_8_usdot_partial(const uint8x16_t samples0_lo, + const uint8x16_t samples0_hi, + const uint8x16_t samples1_lo, + const uint8x16_t samples1_hi, + const int8x8_t filter) { + /* Sample permutation is performed by the caller. */ + int32x4_t sum0, sum1; + int16x8_t sum; + + /* First 4 output values. */ + sum0 = vusdotq_lane_s32(vdupq_n_s32(0), samples0_lo, filter, 0); + sum0 = vusdotq_lane_s32(sum0, samples0_hi, filter, 1); + /* Second 4 output values. */ + sum1 = vusdotq_lane_s32(vdupq_n_s32(0), samples1_lo, filter, 0); + sum1 = vusdotq_lane_s32(sum1, samples1_hi, filter, 1); + + /* Narrow and re-pack. */ + sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); + return vqrshrun_n_s16(sum, FILTER_BITS); +} + +void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, int w, + int h) { + const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y)); + const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl); + uint8x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; + uint8x16x2_t samples_LUT; + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)filter_x; + (void)x_step_q4; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1) * src_stride; + + if (w == 4) { + const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl); + uint8x16_t s0123, s1234, s2345, s3456, s4567, s5678, s6789, s78910; + int16x4_t d0, d1, d2, d3; + uint8x8_t d01, d23; + + load_u8_8x7(src, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); + src += 7 * src_stride; + + s7 = vdup_n_u8(0); + s8 = vdup_n_u8(0); + s9 = vdup_n_u8(0); + + /* This operation combines a conventional transpose and the sample permute + * (see horizontal case) required before computing the dot product. + */ + transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl); + transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl); + transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl); + transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl); + transpose_concat_4x4(s4, s5, s6, s7, &s4567, tran_concat_tbl); + transpose_concat_4x4(s5, s6, s7, s8, &s5678, tran_concat_tbl); + transpose_concat_4x4(s6, s7, s8, s9, &s6789, tran_concat_tbl); + + do { + load_u8_8x4(src, src_stride, &s7, &s8, &s9, &s10); + + transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl); + + /* Merge new data into block from previous iteration. */ + samples_LUT.val[0] = s3456; + samples_LUT.val[1] = s78910; + s4567 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]); + s5678 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]); + s6789 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]); + + d0 = convolve8_4_usdot_partial(s0123, s4567, filter); + d1 = convolve8_4_usdot_partial(s1234, s5678, filter); + d2 = convolve8_4_usdot_partial(s2345, s6789, filter); + d3 = convolve8_4_usdot_partial(s3456, s78910, filter); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); + store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); + + /* Prepare block for next iteration - re-using as much as possible. */ + /* Shuffle everything up four rows. */ + s0123 = s4567; + s1234 = s5678; + s2345 = s6789; + s3456 = s78910; + + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h != 0); + } else { + const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl); + uint8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi, + s3456_lo, s3456_hi, s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, + s6789_hi, s78910_lo, s78910_hi; + uint8x8_t d0, d1, d2, d3; + const uint8_t *s; + uint8_t *d; + int height; + + do { + height = h; + s = src; + d = dst; + + load_u8_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); + s += 7 * src_stride; + + s7 = vdup_n_u8(0); + s8 = vdup_n_u8(0); + s9 = vdup_n_u8(0); + + /* This operation combines a conventional transpose and the sample permute + * (see horizontal case) required before computing the dot product. + */ + transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi, + tran_concat_tbl); + transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi, + tran_concat_tbl); + transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi, + tran_concat_tbl); + transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi, + tran_concat_tbl); + transpose_concat_8x4(s4, s5, s6, s7, &s4567_lo, &s4567_hi, + tran_concat_tbl); + transpose_concat_8x4(s5, s6, s7, s8, &s5678_lo, &s5678_hi, + tran_concat_tbl); + transpose_concat_8x4(s6, s7, s8, s9, &s6789_lo, &s6789_hi, + tran_concat_tbl); + + do { + load_u8_8x4(s, src_stride, &s7, &s8, &s9, &s10); + + transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi, + tran_concat_tbl); + + /* Merge new data into block from previous iteration. */ + samples_LUT.val[0] = s3456_lo; + samples_LUT.val[1] = s78910_lo; + s4567_lo = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]); + s5678_lo = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]); + s6789_lo = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]); + + samples_LUT.val[0] = s3456_hi; + samples_LUT.val[1] = s78910_hi; + s4567_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]); + s5678_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]); + s6789_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]); + + d0 = convolve8_8_usdot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi, + filter); + d1 = convolve8_8_usdot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi, + filter); + d2 = convolve8_8_usdot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi, + filter); + d3 = convolve8_8_usdot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi, + filter); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + /* Prepare block for next iteration - re-using as much as possible. */ + /* Shuffle everything up four rows. */ + s0123_lo = s4567_lo; + s0123_hi = s4567_hi; + s1234_lo = s5678_lo; + s1234_hi = s5678_hi; + s2345_lo = s6789_lo; + s2345_hi = s6789_hi; + s3456_lo = s78910_lo; + s3456_hi = s78910_hi; + + s += 4 * src_stride; + d += 4 * dst_stride; + height -= 4; + } while (height != 0); + src += 8; + dst += 8; + w -= 8; + } while (w != 0); + } +} diff --git a/third_party/aom/aom_dsp/arm/aom_convolve_copy_neon.c b/third_party/aom/aom_dsp/arm/aom_convolve_copy_neon.c new file mode 100644 index 0000000000..325d6f29ff --- /dev/null +++ b/third_party/aom/aom_dsp/arm/aom_convolve_copy_neon.c @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include + +#include "config/aom_dsp_rtcd.h" + +void aom_convolve_copy_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, int w, int h) { + const uint8_t *src1; + uint8_t *dst1; + int y; + + if (!(w & 0x0F)) { + for (y = 0; y < h; ++y) { + src1 = src; + dst1 = dst; + for (int x = 0; x < (w >> 4); ++x) { + vst1q_u8(dst1, vld1q_u8(src1)); + src1 += 16; + dst1 += 16; + } + src += src_stride; + dst += dst_stride; + } + } else if (!(w & 0x07)) { + for (y = 0; y < h; ++y) { + vst1_u8(dst, vld1_u8(src)); + src += src_stride; + dst += dst_stride; + } + } else if (!(w & 0x03)) { + for (y = 0; y < h; ++y) { + memcpy(dst, src, sizeof(uint32_t)); + src += src_stride; + dst += dst_stride; + } + } else if (!(w & 0x01)) { + for (y = 0; y < h; ++y) { + memcpy(dst, src, sizeof(uint16_t)); + src += src_stride; + dst += dst_stride; + } + } +} + +#if CONFIG_AV1_HIGHBITDEPTH +void aom_highbd_convolve_copy_neon(const uint16_t *src, ptrdiff_t src_stride, + uint16_t *dst, ptrdiff_t dst_stride, int w, + int h) { + if (w < 8) { // copy4 + uint16x4_t s0, s1; + do { + s0 = vld1_u16(src); + src += src_stride; + s1 = vld1_u16(src); + src += src_stride; + + vst1_u16(dst, s0); + dst += dst_stride; + vst1_u16(dst, s1); + dst += dst_stride; + h -= 2; + } while (h != 0); + } else if (w == 8) { // copy8 + uint16x8_t s0, s1; + do { + s0 = vld1q_u16(src); + src += src_stride; + s1 = vld1q_u16(src); + src += src_stride; + + vst1q_u16(dst, s0); + dst += dst_stride; + vst1q_u16(dst, s1); + dst += dst_stride; + h -= 2; + } while (h != 0); + } else if (w < 32) { // copy16 + uint16x8_t s0, s1, s2, s3; + do { + s0 = vld1q_u16(src); + s1 = vld1q_u16(src + 8); + src += src_stride; + s2 = vld1q_u16(src); + s3 = vld1q_u16(src + 8); + src += src_stride; + + vst1q_u16(dst, s0); + vst1q_u16(dst + 8, s1); + dst += dst_stride; + vst1q_u16(dst, s2); + vst1q_u16(dst + 8, s3); + dst += dst_stride; + h -= 2; + } while (h != 0); + } else if (w == 32) { // copy32 + uint16x8_t s0, s1, s2, s3; + do { + s0 = vld1q_u16(src); + s1 = vld1q_u16(src + 8); + s2 = vld1q_u16(src + 16); + s3 = vld1q_u16(src + 24); + src += src_stride; + + vst1q_u16(dst, s0); + vst1q_u16(dst + 8, s1); + vst1q_u16(dst + 16, s2); + vst1q_u16(dst + 24, s3); + dst += dst_stride; + } while (--h != 0); + } else { // copy64 + uint16x8_t s0, s1, s2, s3, s4, s5, s6, s7; + do { + const uint16_t *s = src; + uint16_t *d = dst; + int width = w; + do { + s0 = vld1q_u16(s); + s1 = vld1q_u16(s + 8); + s2 = vld1q_u16(s + 16); + s3 = vld1q_u16(s + 24); + s4 = vld1q_u16(s + 32); + s5 = vld1q_u16(s + 40); + s6 = vld1q_u16(s + 48); + s7 = vld1q_u16(s + 56); + + vst1q_u16(d, s0); + vst1q_u16(d + 8, s1); + vst1q_u16(d + 16, s2); + vst1q_u16(d + 24, s3); + vst1q_u16(d + 32, s4); + vst1q_u16(d + 40, s5); + vst1q_u16(d + 48, s6); + vst1q_u16(d + 56, s7); + s += 64; + d += 64; + width -= 64; + } while (width > 0); + src += src_stride; + dst += dst_stride; + } while (--h != 0); + } +} + +#endif // CONFIG_AV1_HIGHBITDEPTH diff --git a/third_party/aom/aom_dsp/arm/avg_neon.c b/third_party/aom/aom_dsp/arm/avg_neon.c new file mode 100644 index 0000000000..2e79b2ef69 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/avg_neon.c @@ -0,0 +1,309 @@ +/* + * Copyright (c) 2019, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_ports/mem.h" + +unsigned int aom_avg_4x4_neon(const uint8_t *p, int stride) { + const uint8x8_t s0 = load_unaligned_u8(p, stride); + const uint8x8_t s1 = load_unaligned_u8(p + 2 * stride, stride); + + const uint32_t sum = horizontal_add_u16x8(vaddl_u8(s0, s1)); + return (sum + (1 << 3)) >> 4; +} + +unsigned int aom_avg_8x8_neon(const uint8_t *p, int stride) { + uint8x8_t s0 = vld1_u8(p); + p += stride; + uint8x8_t s1 = vld1_u8(p); + p += stride; + uint16x8_t acc = vaddl_u8(s0, s1); + + int i = 0; + do { + const uint8x8_t si = vld1_u8(p); + p += stride; + acc = vaddw_u8(acc, si); + } while (++i < 6); + + const uint32_t sum = horizontal_add_u16x8(acc); + return (sum + (1 << 5)) >> 6; +} + +void aom_avg_8x8_quad_neon(const uint8_t *s, int p, int x16_idx, int y16_idx, + int *avg) { + avg[0] = aom_avg_8x8_neon(s + y16_idx * p + x16_idx, p); + avg[1] = aom_avg_8x8_neon(s + y16_idx * p + (x16_idx + 8), p); + avg[2] = aom_avg_8x8_neon(s + (y16_idx + 8) * p + x16_idx, p); + avg[3] = aom_avg_8x8_neon(s + (y16_idx + 8) * p + (x16_idx + 8), p); +} + +int aom_satd_lp_neon(const int16_t *coeff, int length) { + int16x8_t s0 = vld1q_s16(coeff); + int16x8_t s1 = vld1q_s16(coeff + 8); + + int16x8_t abs0 = vabsq_s16(s0); + int16x8_t abs1 = vabsq_s16(s1); + + int32x4_t acc0 = vpaddlq_s16(abs0); + int32x4_t acc1 = vpaddlq_s16(abs1); + + length -= 16; + coeff += 16; + + while (length != 0) { + s0 = vld1q_s16(coeff); + s1 = vld1q_s16(coeff + 8); + + abs0 = vabsq_s16(s0); + abs1 = vabsq_s16(s1); + + acc0 = vpadalq_s16(acc0, abs0); + acc1 = vpadalq_s16(acc1, abs1); + + length -= 16; + coeff += 16; + } + + int32x4_t accum = vaddq_s32(acc0, acc1); + return horizontal_add_s32x4(accum); +} + +void aom_int_pro_row_neon(int16_t *hbuf, const uint8_t *ref, + const int ref_stride, const int width, + const int height, int norm_factor) { + assert(width % 16 == 0); + assert(height % 4 == 0); + + const int16x8_t neg_norm_factor = vdupq_n_s16(-norm_factor); + uint16x8_t sum_lo[2], sum_hi[2]; + + int w = 0; + do { + const uint8_t *r = ref + w; + uint8x16_t r0 = vld1q_u8(r + 0 * ref_stride); + uint8x16_t r1 = vld1q_u8(r + 1 * ref_stride); + uint8x16_t r2 = vld1q_u8(r + 2 * ref_stride); + uint8x16_t r3 = vld1q_u8(r + 3 * ref_stride); + + sum_lo[0] = vaddl_u8(vget_low_u8(r0), vget_low_u8(r1)); + sum_hi[0] = vaddl_u8(vget_high_u8(r0), vget_high_u8(r1)); + sum_lo[1] = vaddl_u8(vget_low_u8(r2), vget_low_u8(r3)); + sum_hi[1] = vaddl_u8(vget_high_u8(r2), vget_high_u8(r3)); + + r += 4 * ref_stride; + + for (int h = height - 4; h != 0; h -= 4) { + r0 = vld1q_u8(r + 0 * ref_stride); + r1 = vld1q_u8(r + 1 * ref_stride); + r2 = vld1q_u8(r + 2 * ref_stride); + r3 = vld1q_u8(r + 3 * ref_stride); + + uint16x8_t tmp0_lo = vaddl_u8(vget_low_u8(r0), vget_low_u8(r1)); + uint16x8_t tmp0_hi = vaddl_u8(vget_high_u8(r0), vget_high_u8(r1)); + uint16x8_t tmp1_lo = vaddl_u8(vget_low_u8(r2), vget_low_u8(r3)); + uint16x8_t tmp1_hi = vaddl_u8(vget_high_u8(r2), vget_high_u8(r3)); + + sum_lo[0] = vaddq_u16(sum_lo[0], tmp0_lo); + sum_hi[0] = vaddq_u16(sum_hi[0], tmp0_hi); + sum_lo[1] = vaddq_u16(sum_lo[1], tmp1_lo); + sum_hi[1] = vaddq_u16(sum_hi[1], tmp1_hi); + + r += 4 * ref_stride; + } + + sum_lo[0] = vaddq_u16(sum_lo[0], sum_lo[1]); + sum_hi[0] = vaddq_u16(sum_hi[0], sum_hi[1]); + + const int16x8_t avg0 = + vshlq_s16(vreinterpretq_s16_u16(sum_lo[0]), neg_norm_factor); + const int16x8_t avg1 = + vshlq_s16(vreinterpretq_s16_u16(sum_hi[0]), neg_norm_factor); + + vst1q_s16(hbuf + w, avg0); + vst1q_s16(hbuf + w + 8, avg1); + w += 16; + } while (w < width); +} + +void aom_int_pro_col_neon(int16_t *vbuf, const uint8_t *ref, + const int ref_stride, const int width, + const int height, int norm_factor) { + assert(width % 16 == 0); + assert(height % 4 == 0); + + const int16x4_t neg_norm_factor = vdup_n_s16(-norm_factor); + uint16x8_t sum[4]; + + int h = 0; + do { + sum[0] = vpaddlq_u8(vld1q_u8(ref + 0 * ref_stride)); + sum[1] = vpaddlq_u8(vld1q_u8(ref + 1 * ref_stride)); + sum[2] = vpaddlq_u8(vld1q_u8(ref + 2 * ref_stride)); + sum[3] = vpaddlq_u8(vld1q_u8(ref + 3 * ref_stride)); + + for (int w = 16; w < width; w += 16) { + sum[0] = vpadalq_u8(sum[0], vld1q_u8(ref + 0 * ref_stride + w)); + sum[1] = vpadalq_u8(sum[1], vld1q_u8(ref + 1 * ref_stride + w)); + sum[2] = vpadalq_u8(sum[2], vld1q_u8(ref + 2 * ref_stride + w)); + sum[3] = vpadalq_u8(sum[3], vld1q_u8(ref + 3 * ref_stride + w)); + } + + uint16x4_t sum_4d = vmovn_u32(horizontal_add_4d_u16x8(sum)); + int16x4_t avg = vshl_s16(vreinterpret_s16_u16(sum_4d), neg_norm_factor); + vst1_s16(vbuf + h, avg); + + ref += 4 * ref_stride; + h += 4; + } while (h < height); +} + +// coeff: 20 bits, dynamic range [-524287, 524287]. +// length: value range {16, 32, 64, 128, 256, 512, 1024}. +int aom_satd_neon(const tran_low_t *coeff, int length) { + const int32x4_t zero = vdupq_n_s32(0); + + int32x4_t s0 = vld1q_s32(&coeff[0]); + int32x4_t s1 = vld1q_s32(&coeff[4]); + int32x4_t s2 = vld1q_s32(&coeff[8]); + int32x4_t s3 = vld1q_s32(&coeff[12]); + + int32x4_t accum0 = vabsq_s32(s0); + int32x4_t accum1 = vabsq_s32(s2); + accum0 = vabaq_s32(accum0, s1, zero); + accum1 = vabaq_s32(accum1, s3, zero); + + length -= 16; + coeff += 16; + + while (length != 0) { + s0 = vld1q_s32(&coeff[0]); + s1 = vld1q_s32(&coeff[4]); + s2 = vld1q_s32(&coeff[8]); + s3 = vld1q_s32(&coeff[12]); + + accum0 = vabaq_s32(accum0, s0, zero); + accum1 = vabaq_s32(accum1, s1, zero); + accum0 = vabaq_s32(accum0, s2, zero); + accum1 = vabaq_s32(accum1, s3, zero); + + length -= 16; + coeff += 16; + } + + // satd: 30 bits, dynamic range [-524287 * 1024, 524287 * 1024] + return horizontal_add_s32x4(vaddq_s32(accum0, accum1)); +} + +int aom_vector_var_neon(const int16_t *ref, const int16_t *src, int bwl) { + assert(bwl >= 2 && bwl <= 5); + int width = 4 << bwl; + + int16x8_t r = vld1q_s16(ref); + int16x8_t s = vld1q_s16(src); + + // diff: dynamic range [-510, 510] 10 (signed) bits. + int16x8_t diff = vsubq_s16(r, s); + // v_mean: dynamic range 16 * diff -> [-8160, 8160], 14 (signed) bits. + int16x8_t v_mean = diff; + // v_sse: dynamic range 2 * 16 * diff^2 -> [0, 8,323,200], 24 (signed) bits. + int32x4_t v_sse[2]; + v_sse[0] = vmull_s16(vget_low_s16(diff), vget_low_s16(diff)); + v_sse[1] = vmull_s16(vget_high_s16(diff), vget_high_s16(diff)); + + ref += 8; + src += 8; + width -= 8; + + do { + r = vld1q_s16(ref); + s = vld1q_s16(src); + + diff = vsubq_s16(r, s); + v_mean = vaddq_s16(v_mean, diff); + + v_sse[0] = vmlal_s16(v_sse[0], vget_low_s16(diff), vget_low_s16(diff)); + v_sse[1] = vmlal_s16(v_sse[1], vget_high_s16(diff), vget_high_s16(diff)); + + ref += 8; + src += 8; + width -= 8; + } while (width != 0); + + // Dynamic range [0, 65280], 16 (unsigned) bits. + const uint32_t mean_abs = abs(horizontal_add_s16x8(v_mean)); + const int32_t sse = horizontal_add_s32x4(vaddq_s32(v_sse[0], v_sse[1])); + + // (mean_abs * mean_abs): dynamic range 32 (unsigned) bits. + return sse - ((mean_abs * mean_abs) >> (bwl + 2)); +} + +void aom_minmax_8x8_neon(const uint8_t *a, int a_stride, const uint8_t *b, + int b_stride, int *min, int *max) { + // Load and concatenate. + const uint8x16_t a01 = load_u8_8x2(a + 0 * a_stride, a_stride); + const uint8x16_t a23 = load_u8_8x2(a + 2 * a_stride, a_stride); + const uint8x16_t a45 = load_u8_8x2(a + 4 * a_stride, a_stride); + const uint8x16_t a67 = load_u8_8x2(a + 6 * a_stride, a_stride); + + const uint8x16_t b01 = load_u8_8x2(b + 0 * b_stride, b_stride); + const uint8x16_t b23 = load_u8_8x2(b + 2 * b_stride, b_stride); + const uint8x16_t b45 = load_u8_8x2(b + 4 * b_stride, b_stride); + const uint8x16_t b67 = load_u8_8x2(b + 6 * b_stride, b_stride); + + // Absolute difference. + const uint8x16_t ab01_diff = vabdq_u8(a01, b01); + const uint8x16_t ab23_diff = vabdq_u8(a23, b23); + const uint8x16_t ab45_diff = vabdq_u8(a45, b45); + const uint8x16_t ab67_diff = vabdq_u8(a67, b67); + + // Max values between the Q vectors. + const uint8x16_t ab0123_max = vmaxq_u8(ab01_diff, ab23_diff); + const uint8x16_t ab4567_max = vmaxq_u8(ab45_diff, ab67_diff); + const uint8x16_t ab0123_min = vminq_u8(ab01_diff, ab23_diff); + const uint8x16_t ab4567_min = vminq_u8(ab45_diff, ab67_diff); + + const uint8x16_t ab07_max = vmaxq_u8(ab0123_max, ab4567_max); + const uint8x16_t ab07_min = vminq_u8(ab0123_min, ab4567_min); + +#if AOM_ARCH_AARCH64 + *min = *max = 0; // Clear high bits + *((uint8_t *)max) = vmaxvq_u8(ab07_max); + *((uint8_t *)min) = vminvq_u8(ab07_min); +#else + // Split into 64-bit vectors and execute pairwise min/max. + uint8x8_t ab_max = vmax_u8(vget_high_u8(ab07_max), vget_low_u8(ab07_max)); + uint8x8_t ab_min = vmin_u8(vget_high_u8(ab07_min), vget_low_u8(ab07_min)); + + // Enough runs of vpmax/min propagate the max/min values to every position. + ab_max = vpmax_u8(ab_max, ab_max); + ab_min = vpmin_u8(ab_min, ab_min); + + ab_max = vpmax_u8(ab_max, ab_max); + ab_min = vpmin_u8(ab_min, ab_min); + + ab_max = vpmax_u8(ab_max, ab_max); + ab_min = vpmin_u8(ab_min, ab_min); + + *min = *max = 0; // Clear high bits + // Store directly to avoid costly neon->gpr transfer. + vst1_lane_u8((uint8_t *)max, ab_max, 0); + vst1_lane_u8((uint8_t *)min, ab_min, 0); +#endif +} diff --git a/third_party/aom/aom_dsp/arm/avg_pred_neon.c b/third_party/aom/aom_dsp/arm/avg_pred_neon.c new file mode 100644 index 0000000000..b17f7fca7f --- /dev/null +++ b/third_party/aom/aom_dsp/arm/avg_pred_neon.c @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/dist_wtd_avg_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/blend.h" + +void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width, + int height, const uint8_t *ref, int ref_stride) { + if (width > 8) { + do { + const uint8_t *pred_ptr = pred; + const uint8_t *ref_ptr = ref; + uint8_t *comp_pred_ptr = comp_pred; + int w = width; + + do { + const uint8x16_t p = vld1q_u8(pred_ptr); + const uint8x16_t r = vld1q_u8(ref_ptr); + const uint8x16_t avg = vrhaddq_u8(p, r); + + vst1q_u8(comp_pred_ptr, avg); + + ref_ptr += 16; + pred_ptr += 16; + comp_pred_ptr += 16; + w -= 16; + } while (w != 0); + + ref += ref_stride; + pred += width; + comp_pred += width; + } while (--height != 0); + } else if (width == 8) { + int h = height / 2; + + do { + const uint8x16_t p = vld1q_u8(pred); + const uint8x16_t r = load_u8_8x2(ref, ref_stride); + const uint8x16_t avg = vrhaddq_u8(p, r); + + vst1q_u8(comp_pred, avg); + + ref += 2 * ref_stride; + pred += 16; + comp_pred += 16; + } while (--h != 0); + } else { + int h = height / 4; + assert(width == 4); + + do { + const uint8x16_t p = vld1q_u8(pred); + const uint8x16_t r = load_unaligned_u8q(ref, ref_stride); + const uint8x16_t avg = vrhaddq_u8(p, r); + + vst1q_u8(comp_pred, avg); + + ref += 4 * ref_stride; + pred += 16; + comp_pred += 16; + } while (--h != 0); + } +} + +void aom_dist_wtd_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, + int width, int height, const uint8_t *ref, + int ref_stride, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + + if (width > 8) { + do { + const uint8_t *pred_ptr = pred; + const uint8_t *ref_ptr = ref; + uint8_t *comp_pred_ptr = comp_pred; + int w = width; + + do { + const uint8x16_t p = vld1q_u8(pred_ptr); + const uint8x16_t r = vld1q_u8(ref_ptr); + + const uint8x16_t wtd_avg = + dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset); + + vst1q_u8(comp_pred_ptr, wtd_avg); + + ref_ptr += 16; + pred_ptr += 16; + comp_pred_ptr += 16; + w -= 16; + } while (w != 0); + + ref += ref_stride; + pred += width; + comp_pred += width; + } while (--height != 0); + } else if (width == 8) { + int h = height / 2; + + do { + const uint8x16_t p = vld1q_u8(pred); + const uint8x16_t r = load_u8_8x2(ref, ref_stride); + + const uint8x16_t wtd_avg = + dist_wtd_avg_u8x16(r, p, fwd_offset, bck_offset); + + vst1q_u8(comp_pred, wtd_avg); + + ref += 2 * ref_stride; + pred += 16; + comp_pred += 16; + } while (--h != 0); + } else { + int h = height / 2; + assert(width == 4); + + do { + const uint8x8_t p = vld1_u8(pred); + const uint8x8_t r = load_unaligned_u8_4x2(ref, ref_stride); + + const uint8x8_t wtd_avg = dist_wtd_avg_u8x8(r, p, vget_low_u8(fwd_offset), + vget_low_u8(bck_offset)); + + vst1_u8(comp_pred, wtd_avg); + + ref += 2 * ref_stride; + pred += 8; + comp_pred += 8; + } while (--h != 0); + } +} + +void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width, + int height, const uint8_t *ref, int ref_stride, + const uint8_t *mask, int mask_stride, + int invert_mask) { + const uint8_t *src0 = invert_mask ? pred : ref; + const uint8_t *src1 = invert_mask ? ref : pred; + const int src_stride0 = invert_mask ? width : ref_stride; + const int src_stride1 = invert_mask ? ref_stride : width; + + if (width > 8) { + do { + const uint8_t *src0_ptr = src0; + const uint8_t *src1_ptr = src1; + const uint8_t *mask_ptr = mask; + uint8_t *comp_pred_ptr = comp_pred; + int w = width; + + do { + const uint8x16_t s0 = vld1q_u8(src0_ptr); + const uint8x16_t s1 = vld1q_u8(src1_ptr); + const uint8x16_t m0 = vld1q_u8(mask_ptr); + + uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, s0, s1); + + vst1q_u8(comp_pred_ptr, blend_u8); + + src0_ptr += 16; + src1_ptr += 16; + mask_ptr += 16; + comp_pred_ptr += 16; + w -= 16; + } while (w != 0); + + src0 += src_stride0; + src1 += src_stride1; + mask += mask_stride; + comp_pred += width; + } while (--height != 0); + } else if (width == 8) { + do { + const uint8x8_t s0 = vld1_u8(src0); + const uint8x8_t s1 = vld1_u8(src1); + const uint8x8_t m0 = vld1_u8(mask); + + uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1); + + vst1_u8(comp_pred, blend_u8); + + src0 += src_stride0; + src1 += src_stride1; + mask += mask_stride; + comp_pred += 8; + } while (--height != 0); + } else { + int h = height / 2; + assert(width == 4); + + do { + const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0); + const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1); + const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride); + + uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1); + + vst1_u8(comp_pred, blend_u8); + + src0 += 2 * src_stride0; + src1 += 2 * src_stride1; + mask += 2 * mask_stride; + comp_pred += 8; + } while (--h != 0); + } +} diff --git a/third_party/aom/aom_dsp/arm/avg_sve.c b/third_party/aom/aom_dsp/arm/avg_sve.c new file mode 100644 index 0000000000..bbf5a9447c --- /dev/null +++ b/third_party/aom/aom_dsp/arm/avg_sve.c @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "aom/aom_integer.h" +#include "aom_dsp/arm/dot_sve.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_ports/mem.h" + +int aom_vector_var_sve(const int16_t *ref, const int16_t *src, int bwl) { + assert(bwl >= 2 && bwl <= 5); + int width = 4 << bwl; + + int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + int16x8_t v_mean[2] = { vdupq_n_s16(0), vdupq_n_s16(0) }; + + do { + int16x8_t r0 = vld1q_s16(ref); + int16x8_t s0 = vld1q_s16(src); + + // diff: dynamic range [-510, 510] 10 (signed) bits. + int16x8_t diff0 = vsubq_s16(r0, s0); + // v_mean: dynamic range 16 * diff -> [-8160, 8160], 14 (signed) bits. + v_mean[0] = vaddq_s16(v_mean[0], diff0); + + // v_sse: dynamic range 2 * 16 * diff^2 -> [0, 8,323,200], 24 (signed) bits. + sse_s64[0] = aom_sdotq_s16(sse_s64[0], diff0, diff0); + + int16x8_t r1 = vld1q_s16(ref + 8); + int16x8_t s1 = vld1q_s16(src + 8); + + // diff: dynamic range [-510, 510] 10 (signed) bits. + int16x8_t diff1 = vsubq_s16(r1, s1); + // v_mean: dynamic range 16 * diff -> [-8160, 8160], 14 (signed) bits. + v_mean[1] = vaddq_s16(v_mean[1], diff1); + + // v_sse: dynamic range 2 * 16 * diff^2 -> [0, 8,323,200], 24 (signed) bits. + sse_s64[1] = aom_sdotq_s16(sse_s64[1], diff1, diff1); + + ref += 16; + src += 16; + width -= 16; + } while (width != 0); + + // Dynamic range [0, 65280], 16 (unsigned) bits. + const uint32_t mean_abs = abs(vaddlvq_s16(vaddq_s16(v_mean[0], v_mean[1]))); + const int64_t sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[1])); + + // (mean_abs * mean_abs): dynamic range 32 (unsigned) bits. + return (int)(sse - ((mean_abs * mean_abs) >> (bwl + 2))); +} diff --git a/third_party/aom/aom_dsp/arm/blend_a64_mask_neon.c b/third_party/aom/aom_dsp/arm/blend_a64_mask_neon.c new file mode 100644 index 0000000000..1bc3b80310 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/blend_a64_mask_neon.c @@ -0,0 +1,492 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/blend.h" + +uint8x8_t alpha_blend_a64_d16_u16x8(uint16x8_t m, uint16x8_t a, uint16x8_t b, + uint16x8_t round_offset) { + const uint16x8_t m_inv = vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); + + uint32x4_t blend_u32_lo = vmull_u16(vget_low_u16(m), vget_low_u16(a)); + uint32x4_t blend_u32_hi = vmull_u16(vget_high_u16(m), vget_high_u16(a)); + + blend_u32_lo = vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b)); + blend_u32_hi = + vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b)); + + uint16x4_t blend_u16_lo = vshrn_n_u32(blend_u32_lo, AOM_BLEND_A64_ROUND_BITS); + uint16x4_t blend_u16_hi = vshrn_n_u32(blend_u32_hi, AOM_BLEND_A64_ROUND_BITS); + + uint16x8_t res = vcombine_u16(blend_u16_lo, blend_u16_hi); + + res = vqsubq_u16(res, round_offset); + + return vqrshrn_n_u16(res, + 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS); +} + +void aom_lowbd_blend_a64_d16_mask_neon( + uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0, + uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, + const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh, + ConvolveParams *conv_params) { + (void)conv_params; + + const int bd = 8; + const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS; + const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + + (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); + const uint16x8_t offset_vec = vdupq_n_u16(round_offset); + + assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride)); + assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride)); + + assert(h >= 4); + assert(w >= 4); + assert(IS_POWER_OF_TWO(h)); + assert(IS_POWER_OF_TWO(w)); + + if (subw == 0 && subh == 0) { + if (w >= 8) { + do { + int i = 0; + do { + uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint8x8_t blend = alpha_blend_a64_d16_u16x8(m0, s0, s1, offset_vec); + + vst1_u8(dst + i, blend); + i += 8; + } while (i < w); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint8x8_t blend = alpha_blend_a64_d16_u16x8(m0, s0, s1, offset_vec); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 2 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else if (subw == 1 && subh == 1) { + if (w >= 8) { + do { + int i = 0; + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i); + uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8); + uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t m_avg = + vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); + + uint8x8_t blend = + alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); + + vst1_u8(dst + i, blend); + i += 8; + } while (i < w); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); + uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); + uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 4 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else if (subw == 1 && subh == 0) { + if (w >= 8) { + do { + int i = 0; + do { + uint8x8_t m0 = vld1_u8(mask + 2 * i); + uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); + uint8x8_t blend = + alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); + + vst1_u8(dst + i, blend); + i += 8; + } while (i < w); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); + uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 2 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else { + if (w >= 8) { + do { + int i = 0; + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); + uint8x8_t blend = + alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); + + vst1_u8(dst + i, blend); + i += 8; + } while (i < w); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0_2 = + load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); + uint8x8_t m1_3 = + load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); + uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 4 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } +} + +void aom_blend_a64_mask_neon(uint8_t *dst, uint32_t dst_stride, + const uint8_t *src0, uint32_t src0_stride, + const uint8_t *src1, uint32_t src1_stride, + const uint8_t *mask, uint32_t mask_stride, int w, + int h, int subw, int subh) { + assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); + assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); + + assert(h >= 1); + assert(w >= 1); + assert(IS_POWER_OF_TWO(h)); + assert(IS_POWER_OF_TWO(w)); + + if ((subw | subh) == 0) { + if (w > 8) { + do { + int i = 0; + do { + uint8x16_t m0 = vld1q_u8(mask + i); + uint8x16_t s0 = vld1q_u8(src0 + i); + uint8x16_t s1 = vld1q_u8(src1 + i); + + uint8x16_t blend = alpha_blend_a64_u8x16(m0, s0, s1); + + vst1q_u8(dst + i, blend); + i += 16; + } while (i < w); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else if (w == 8) { + do { + uint8x8_t m0 = vld1_u8(mask); + uint8x8_t s0 = vld1_u8(src0); + uint8x8_t s1 = vld1_u8(src1); + + uint8x8_t blend = alpha_blend_a64_u8x8(m0, s0, s1); + + vst1_u8(dst, blend); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = load_unaligned_u8_4x2(mask, mask_stride); + uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); + uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); + + uint8x8_t blend = alpha_blend_a64_u8x8(m0, s0, s1); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 2 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else if ((subw & subh) == 1) { + if (w > 8) { + do { + int i = 0; + do { + uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i); + uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i); + uint8x16_t m2 = vld1q_u8(mask + 0 * mask_stride + 2 * i + 16); + uint8x16_t m3 = vld1q_u8(mask + 1 * mask_stride + 2 * i + 16); + uint8x16_t s0 = vld1q_u8(src0 + i); + uint8x16_t s1 = vld1q_u8(src1 + i); + + uint8x16_t m_avg = avg_blend_pairwise_u8x16_4(m0, m1, m2, m3); + uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1); + + vst1q_u8(dst + i, blend); + + i += 16; + } while (i < w); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else if (w == 8) { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 8); + uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 8); + uint8x8_t s0 = vld1_u8(src0); + uint8x8_t s1 = vld1_u8(src1); + + uint8x8_t m_avg = avg_blend_pairwise_u8x8_4(m0, m1, m2, m3); + uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); + + vst1_u8(dst, blend); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); + uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); + uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); + uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); + + uint8x8_t m_avg = avg_blend_pairwise_u8x8_4(m0, m1, m2, m3); + uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 4 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else if (subw == 1 && subh == 0) { + if (w > 8) { + do { + int i = 0; + + do { + uint8x16_t m0 = vld1q_u8(mask + 2 * i); + uint8x16_t m1 = vld1q_u8(mask + 2 * i + 16); + uint8x16_t s0 = vld1q_u8(src0 + i); + uint8x16_t s1 = vld1q_u8(src1 + i); + + uint8x16_t m_avg = avg_blend_pairwise_u8x16(m0, m1); + uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1); + + vst1q_u8(dst + i, blend); + + i += 16; + } while (i < w); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else if (w == 8) { + do { + uint8x8_t m0 = vld1_u8(mask); + uint8x8_t m1 = vld1_u8(mask + 8); + uint8x8_t s0 = vld1_u8(src0); + uint8x8_t s1 = vld1_u8(src1); + + uint8x8_t m_avg = avg_blend_pairwise_u8x8(m0, m1); + uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); + + vst1_u8(dst, blend); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); + uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); + + uint8x8_t m_avg = avg_blend_pairwise_u8x8(m0, m1); + uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 2 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else { + if (w > 8) { + do { + int i = 0; + do { + uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + i); + uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + i); + uint8x16_t s0 = vld1q_u8(src0 + i); + uint8x16_t s1 = vld1q_u8(src1 + i); + + uint8x16_t m_avg = avg_blend_u8x16(m0, m1); + uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1); + + vst1q_u8(dst + i, blend); + + i += 16; + } while (i < w); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else if (w == 8) { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint8x8_t s0 = vld1_u8(src0); + uint8x8_t s1 = vld1_u8(src1); + + uint8x8_t m_avg = avg_blend_u8x8(m0, m1); + uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); + + vst1_u8(dst, blend); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0_2 = + load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); + uint8x8_t m1_3 = + load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); + uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride); + uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride); + + uint8x8_t m_avg = avg_blend_u8x8(m0_2, m1_3); + uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1); + + store_u8x4_strided_x2(dst, dst_stride, blend); + + mask += 4 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } +} diff --git a/third_party/aom/aom_dsp/arm/blend_neon.h b/third_party/aom/aom_dsp/arm/blend_neon.h new file mode 100644 index 0000000000..c8a03224e4 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/blend_neon.h @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_AOM_DSP_ARM_BLEND_NEON_H_ +#define AOM_AOM_DSP_ARM_BLEND_NEON_H_ + +#include + +#include "aom_dsp/blend.h" + +static INLINE uint8x16_t alpha_blend_a64_u8x16(uint8x16_t m, uint8x16_t a, + uint8x16_t b) { + const uint8x16_t m_inv = vsubq_u8(vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA), m); + + uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(m), vget_low_u8(a)); + uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(m), vget_high_u8(a)); + + blend_u16_lo = vmlal_u8(blend_u16_lo, vget_low_u8(m_inv), vget_low_u8(b)); + blend_u16_hi = vmlal_u8(blend_u16_hi, vget_high_u8(m_inv), vget_high_u8(b)); + + uint8x8_t blend_u8_lo = vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS); + uint8x8_t blend_u8_hi = vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS); + + return vcombine_u8(blend_u8_lo, blend_u8_hi); +} + +static INLINE uint8x8_t alpha_blend_a64_u8x8(uint8x8_t m, uint8x8_t a, + uint8x8_t b) { + const uint8x8_t m_inv = vsub_u8(vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA), m); + + uint16x8_t blend_u16 = vmull_u8(m, a); + + blend_u16 = vmlal_u8(blend_u16, m_inv, b); + + return vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS); +} + +#if CONFIG_AV1_HIGHBITDEPTH +static INLINE uint16x8_t alpha_blend_a64_u16x8(uint16x8_t m, uint16x8_t a, + uint16x8_t b) { + uint16x8_t m_inv = vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); + + uint32x4_t blend_u32_lo = vmull_u16(vget_low_u16(a), vget_low_u16(m)); + uint32x4_t blend_u32_hi = vmull_u16(vget_high_u16(a), vget_high_u16(m)); + + blend_u32_lo = vmlal_u16(blend_u32_lo, vget_low_u16(b), vget_low_u16(m_inv)); + blend_u32_hi = + vmlal_u16(blend_u32_hi, vget_high_u16(b), vget_high_u16(m_inv)); + + uint16x4_t blend_u16_lo = + vrshrn_n_u32(blend_u32_lo, AOM_BLEND_A64_ROUND_BITS); + uint16x4_t blend_u16_hi = + vrshrn_n_u32(blend_u32_hi, AOM_BLEND_A64_ROUND_BITS); + + return vcombine_u16(blend_u16_lo, blend_u16_hi); +} + +static INLINE uint16x4_t alpha_blend_a64_u16x4(uint16x4_t m, uint16x4_t a, + uint16x4_t b) { + const uint16x4_t m_inv = vsub_u16(vdup_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); + + uint32x4_t blend_u16 = vmull_u16(m, a); + + blend_u16 = vmlal_u16(blend_u16, m_inv, b); + + return vrshrn_n_u32(blend_u16, AOM_BLEND_A64_ROUND_BITS); +} +#endif // CONFIG_AV1_HIGHBITDEPTH + +static INLINE uint8x8_t avg_blend_u8x8(uint8x8_t a, uint8x8_t b) { + return vrhadd_u8(a, b); +} + +static INLINE uint8x16_t avg_blend_u8x16(uint8x16_t a, uint8x16_t b) { + return vrhaddq_u8(a, b); +} + +static INLINE uint8x8_t avg_blend_pairwise_u8x8(uint8x8_t a, uint8x8_t b) { + return vrshr_n_u8(vpadd_u8(a, b), 1); +} + +static INLINE uint8x16_t avg_blend_pairwise_u8x16(uint8x16_t a, uint8x16_t b) { +#if AOM_ARCH_AARCH64 + return vrshrq_n_u8(vpaddq_u8(a, b), 1); +#else + uint8x8_t sum_pairwise_a = vpadd_u8(vget_low_u8(a), vget_high_u8(a)); + uint8x8_t sum_pairwise_b = vpadd_u8(vget_low_u8(b), vget_high_u8(b)); + return vrshrq_n_u8(vcombine_u8(sum_pairwise_a, sum_pairwise_b), 1); +#endif // AOM_ARCH_AARCH64 +} + +static INLINE uint8x8_t avg_blend_pairwise_u8x8_4(uint8x8_t a, uint8x8_t b, + uint8x8_t c, uint8x8_t d) { + uint8x8_t a_c = vpadd_u8(a, c); + uint8x8_t b_d = vpadd_u8(b, d); + return vrshr_n_u8(vqadd_u8(a_c, b_d), 2); +} + +static INLINE uint8x16_t avg_blend_pairwise_u8x16_4(uint8x16_t a, uint8x16_t b, + uint8x16_t c, + uint8x16_t d) { +#if AOM_ARCH_AARCH64 + uint8x16_t a_c = vpaddq_u8(a, c); + uint8x16_t b_d = vpaddq_u8(b, d); + return vrshrq_n_u8(vqaddq_u8(a_c, b_d), 2); +#else + uint8x8_t sum_pairwise_a = vpadd_u8(vget_low_u8(a), vget_high_u8(a)); + uint8x8_t sum_pairwise_b = vpadd_u8(vget_low_u8(b), vget_high_u8(b)); + uint8x8_t sum_pairwise_c = vpadd_u8(vget_low_u8(c), vget_high_u8(c)); + uint8x8_t sum_pairwise_d = vpadd_u8(vget_low_u8(d), vget_high_u8(d)); + uint8x16_t a_c = vcombine_u8(sum_pairwise_a, sum_pairwise_c); + uint8x16_t b_d = vcombine_u8(sum_pairwise_b, sum_pairwise_d); + return vrshrq_n_u8(vqaddq_u8(a_c, b_d), 2); +#endif // AOM_ARCH_AARCH64 +} + +#endif // AOM_AOM_DSP_ARM_BLEND_NEON_H_ diff --git a/third_party/aom/aom_dsp/arm/blk_sse_sum_neon.c b/third_party/aom/aom_dsp/arm/blk_sse_sum_neon.c new file mode 100644 index 0000000000..f2ada93e95 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/blk_sse_sum_neon.c @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void get_blk_sse_sum_4xh_neon(const int16_t *data, int stride, + int bh, int *x_sum, + int64_t *x2_sum) { + int i = bh; + int32x4_t sum = vdupq_n_s32(0); + int32x4_t sse = vdupq_n_s32(0); + + do { + int16x8_t d = vcombine_s16(vld1_s16(data), vld1_s16(data + stride)); + + sum = vpadalq_s16(sum, d); + + sse = vmlal_s16(sse, vget_low_s16(d), vget_low_s16(d)); + sse = vmlal_s16(sse, vget_high_s16(d), vget_high_s16(d)); + + data += 2 * stride; + i -= 2; + } while (i != 0); + + *x_sum = horizontal_add_s32x4(sum); + *x2_sum = horizontal_long_add_s32x4(sse); +} + +static INLINE void get_blk_sse_sum_8xh_neon(const int16_t *data, int stride, + int bh, int *x_sum, + int64_t *x2_sum) { + int i = bh; + int32x4_t sum = vdupq_n_s32(0); + int32x4_t sse = vdupq_n_s32(0); + + // Input is 12-bit wide, so we can add up to 127 squared elements in a signed + // 32-bits element. Since we're accumulating into an int32x4_t and the maximum + // value for bh is 32, we don't have to worry about sse overflowing. + + do { + int16x8_t d = vld1q_s16(data); + + sum = vpadalq_s16(sum, d); + + sse = vmlal_s16(sse, vget_low_s16(d), vget_low_s16(d)); + sse = vmlal_s16(sse, vget_high_s16(d), vget_high_s16(d)); + + data += stride; + } while (--i != 0); + + *x_sum = horizontal_add_s32x4(sum); + *x2_sum = horizontal_long_add_s32x4(sse); +} + +static INLINE void get_blk_sse_sum_large_neon(const int16_t *data, int stride, + int bw, int bh, int *x_sum, + int64_t *x2_sum) { + int32x4_t sum = vdupq_n_s32(0); + int64x2_t sse = vdupq_n_s64(0); + + // Input is 12-bit wide, so we can add up to 127 squared elements in a signed + // 32-bits element. Since we're accumulating into an int32x4_t vector that + // means we can process up to (127*4)/bw rows before we need to widen to + // 64 bits. + + int i_limit = (127 * 4) / bw; + int i_tmp = bh > i_limit ? i_limit : bh; + + int i = 0; + do { + int32x4_t sse_s32 = vdupq_n_s32(0); + do { + int j = bw; + const int16_t *data_ptr = data; + do { + int16x8_t d = vld1q_s16(data_ptr); + + sum = vpadalq_s16(sum, d); + + sse_s32 = vmlal_s16(sse_s32, vget_low_s16(d), vget_low_s16(d)); + sse_s32 = vmlal_s16(sse_s32, vget_high_s16(d), vget_high_s16(d)); + + data_ptr += 8; + j -= 8; + } while (j != 0); + + data += stride; + i++; + } while (i < i_tmp && i < bh); + + sse = vpadalq_s32(sse, sse_s32); + i_tmp += i_limit; + } while (i < bh); + + *x_sum = horizontal_add_s32x4(sum); + *x2_sum = horizontal_add_s64x2(sse); +} + +void aom_get_blk_sse_sum_neon(const int16_t *data, int stride, int bw, int bh, + int *x_sum, int64_t *x2_sum) { + if (bw == 4) { + get_blk_sse_sum_4xh_neon(data, stride, bh, x_sum, x2_sum); + } else if (bw == 8) { + get_blk_sse_sum_8xh_neon(data, stride, bh, x_sum, x2_sum); + } else { + assert(bw % 8 == 0); + get_blk_sse_sum_large_neon(data, stride, bw, bh, x_sum, x2_sum); + } +} diff --git a/third_party/aom/aom_dsp/arm/blk_sse_sum_sve.c b/third_party/aom/aom_dsp/arm/blk_sse_sum_sve.c new file mode 100644 index 0000000000..18bdc5dbfe --- /dev/null +++ b/third_party/aom/aom_dsp/arm/blk_sse_sum_sve.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +#include "aom_dsp/arm/dot_sve.h" +#include "aom_dsp/arm/mem_neon.h" + +static INLINE void get_blk_sse_sum_4xh_sve(const int16_t *data, int stride, + int bh, int *x_sum, + int64_t *x2_sum) { + int32x4_t sum = vdupq_n_s32(0); + int64x2_t sse = vdupq_n_s64(0); + + do { + int16x8_t d = vcombine_s16(vld1_s16(data), vld1_s16(data + stride)); + + sum = vpadalq_s16(sum, d); + + sse = aom_sdotq_s16(sse, d, d); + + data += 2 * stride; + bh -= 2; + } while (bh != 0); + + *x_sum = vaddvq_s32(sum); + *x2_sum = vaddvq_s64(sse); +} + +static INLINE void get_blk_sse_sum_8xh_sve(const int16_t *data, int stride, + int bh, int *x_sum, + int64_t *x2_sum) { + int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + do { + int16x8_t d0 = vld1q_s16(data); + int16x8_t d1 = vld1q_s16(data + stride); + + sum[0] = vpadalq_s16(sum[0], d0); + sum[1] = vpadalq_s16(sum[1], d1); + + sse[0] = aom_sdotq_s16(sse[0], d0, d0); + sse[1] = aom_sdotq_s16(sse[1], d1, d1); + + data += 2 * stride; + bh -= 2; + } while (bh != 0); + + *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1])); + *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1])); +} + +static INLINE void get_blk_sse_sum_large_sve(const int16_t *data, int stride, + int bw, int bh, int *x_sum, + int64_t *x2_sum) { + int32x4_t sum[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + do { + int j = bw; + const int16_t *data_ptr = data; + do { + int16x8_t d0 = vld1q_s16(data_ptr); + int16x8_t d1 = vld1q_s16(data_ptr + 8); + + sum[0] = vpadalq_s16(sum[0], d0); + sum[1] = vpadalq_s16(sum[1], d1); + + sse[0] = aom_sdotq_s16(sse[0], d0, d0); + sse[1] = aom_sdotq_s16(sse[1], d1, d1); + + data_ptr += 16; + j -= 16; + } while (j != 0); + + data += stride; + } while (--bh != 0); + + *x_sum = vaddvq_s32(vaddq_s32(sum[0], sum[1])); + *x2_sum = vaddvq_s64(vaddq_s64(sse[0], sse[1])); +} + +void aom_get_blk_sse_sum_sve(const int16_t *data, int stride, int bw, int bh, + int *x_sum, int64_t *x2_sum) { + if (bw == 4) { + get_blk_sse_sum_4xh_sve(data, stride, bh, x_sum, x2_sum); + } else if (bw == 8) { + get_blk_sse_sum_8xh_sve(data, stride, bh, x_sum, x2_sum); + } else { + assert(bw % 16 == 0); + get_blk_sse_sum_large_sve(data, stride, bw, bh, x_sum, x2_sum); + } +} diff --git a/third_party/aom/aom_dsp/arm/dist_wtd_avg_neon.h b/third_party/aom/aom_dsp/arm/dist_wtd_avg_neon.h new file mode 100644 index 0000000000..19c9b04c57 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/dist_wtd_avg_neon.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_ +#define AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_ + +#include + +#include "aom_dsp/aom_dsp_common.h" +#include "av1/common/enums.h" + +static INLINE uint8x8_t dist_wtd_avg_u8x8(uint8x8_t a, uint8x8_t b, + uint8x8_t wta, uint8x8_t wtb) { + uint16x8_t wtd_sum = vmull_u8(a, wta); + + wtd_sum = vmlal_u8(wtd_sum, b, wtb); + + return vrshrn_n_u16(wtd_sum, DIST_PRECISION_BITS); +} + +static INLINE uint16x4_t dist_wtd_avg_u16x4(uint16x4_t a, uint16x4_t b, + uint16x4_t wta, uint16x4_t wtb) { + uint32x4_t wtd_sum = vmull_u16(a, wta); + + wtd_sum = vmlal_u16(wtd_sum, b, wtb); + + return vrshrn_n_u32(wtd_sum, DIST_PRECISION_BITS); +} + +static INLINE uint8x16_t dist_wtd_avg_u8x16(uint8x16_t a, uint8x16_t b, + uint8x16_t wta, uint8x16_t wtb) { + uint16x8_t wtd_sum_lo = vmull_u8(vget_low_u8(a), vget_low_u8(wta)); + uint16x8_t wtd_sum_hi = vmull_u8(vget_high_u8(a), vget_high_u8(wta)); + + wtd_sum_lo = vmlal_u8(wtd_sum_lo, vget_low_u8(b), vget_low_u8(wtb)); + wtd_sum_hi = vmlal_u8(wtd_sum_hi, vget_high_u8(b), vget_high_u8(wtb)); + + uint8x8_t wtd_avg_lo = vrshrn_n_u16(wtd_sum_lo, DIST_PRECISION_BITS); + uint8x8_t wtd_avg_hi = vrshrn_n_u16(wtd_sum_hi, DIST_PRECISION_BITS); + + return vcombine_u8(wtd_avg_lo, wtd_avg_hi); +} + +static INLINE uint16x8_t dist_wtd_avg_u16x8(uint16x8_t a, uint16x8_t b, + uint16x8_t wta, uint16x8_t wtb) { + uint32x4_t wtd_sum_lo = vmull_u16(vget_low_u16(a), vget_low_u16(wta)); + uint32x4_t wtd_sum_hi = vmull_u16(vget_high_u16(a), vget_high_u16(wta)); + + wtd_sum_lo = vmlal_u16(wtd_sum_lo, vget_low_u16(b), vget_low_u16(wtb)); + wtd_sum_hi = vmlal_u16(wtd_sum_hi, vget_high_u16(b), vget_high_u16(wtb)); + + uint16x4_t wtd_avg_lo = vrshrn_n_u32(wtd_sum_lo, DIST_PRECISION_BITS); + uint16x4_t wtd_avg_hi = vrshrn_n_u32(wtd_sum_hi, DIST_PRECISION_BITS); + + return vcombine_u16(wtd_avg_lo, wtd_avg_hi); +} + +#endif // AOM_AOM_DSP_ARM_DIST_WTD_AVG_NEON_H_ diff --git a/third_party/aom/aom_dsp/arm/dot_sve.h b/third_party/aom/aom_dsp/arm/dot_sve.h new file mode 100644 index 0000000000..cf49f23606 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/dot_sve.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef AOM_AOM_DSP_ARM_DOT_SVE_H_ +#define AOM_AOM_DSP_ARM_DOT_SVE_H_ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +// Dot product instructions operating on 16-bit input elements are exclusive to +// the SVE instruction set. However, we can access these instructions from a +// predominantly Neon context by making use of the Neon-SVE bridge intrinsics +// to reinterpret Neon vectors as SVE vectors - with the high part of the SVE +// vector (if it's longer than 128 bits) being "don't care". + +// While sub-optimal on machines that have SVE vector length > 128-bit - as the +// remainder of the vector is unused - this approach is still beneficial when +// compared to a Neon-only solution. + +static INLINE uint64x2_t aom_udotq_u16(uint64x2_t acc, uint16x8_t x, + uint16x8_t y) { + return svget_neonq_u64(svdot_u64(svset_neonq_u64(svundef_u64(), acc), + svset_neonq_u16(svundef_u16(), x), + svset_neonq_u16(svundef_u16(), y))); +} + +static INLINE int64x2_t aom_sdotq_s16(int64x2_t acc, int16x8_t x, int16x8_t y) { + return svget_neonq_s64(svdot_s64(svset_neonq_s64(svundef_s64(), acc), + svset_neonq_s16(svundef_s16(), x), + svset_neonq_s16(svundef_s16(), y))); +} + +#endif // AOM_AOM_DSP_ARM_DOT_SVE_H_ diff --git a/third_party/aom/aom_dsp/arm/fwd_txfm_neon.c b/third_party/aom/aom_dsp/arm/fwd_txfm_neon.c new file mode 100644 index 0000000000..a4d6322f24 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/fwd_txfm_neon.c @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" + +#include "aom_dsp/txfm_common.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" + +static void aom_fdct4x4_helper(const int16_t *input, int stride, + int16x4_t *input_0, int16x4_t *input_1, + int16x4_t *input_2, int16x4_t *input_3) { + *input_0 = vshl_n_s16(vld1_s16(input + 0 * stride), 4); + *input_1 = vshl_n_s16(vld1_s16(input + 1 * stride), 4); + *input_2 = vshl_n_s16(vld1_s16(input + 2 * stride), 4); + *input_3 = vshl_n_s16(vld1_s16(input + 3 * stride), 4); + // If the very first value != 0, then add 1. + if (input[0] != 0) { + const int16x4_t one = vreinterpret_s16_s64(vdup_n_s64(1)); + *input_0 = vadd_s16(*input_0, one); + } + + for (int i = 0; i < 2; ++i) { + const int16x8_t input_01 = vcombine_s16(*input_0, *input_1); + const int16x8_t input_32 = vcombine_s16(*input_3, *input_2); + + // in_0 +/- in_3, in_1 +/- in_2 + const int16x8_t s_01 = vaddq_s16(input_01, input_32); + const int16x8_t s_32 = vsubq_s16(input_01, input_32); + + // step_0 +/- step_1, step_2 +/- step_3 + const int16x4_t s_0 = vget_low_s16(s_01); + const int16x4_t s_1 = vget_high_s16(s_01); + const int16x4_t s_2 = vget_high_s16(s_32); + const int16x4_t s_3 = vget_low_s16(s_32); + + // (s_0 +/- s_1) * cospi_16_64 + // Must expand all elements to s32. See 'needs32' comment in fwd_txfm.c. + const int32x4_t s_0_p_s_1 = vaddl_s16(s_0, s_1); + const int32x4_t s_0_m_s_1 = vsubl_s16(s_0, s_1); + const int32x4_t temp1 = vmulq_n_s32(s_0_p_s_1, (int32_t)cospi_16_64); + const int32x4_t temp2 = vmulq_n_s32(s_0_m_s_1, (int32_t)cospi_16_64); + + // fdct_round_shift + int16x4_t out_0 = vrshrn_n_s32(temp1, DCT_CONST_BITS); + int16x4_t out_2 = vrshrn_n_s32(temp2, DCT_CONST_BITS); + + // s_3 * cospi_8_64 + s_2 * cospi_24_64 + // s_3 * cospi_24_64 - s_2 * cospi_8_64 + const int32x4_t s_3_cospi_8_64 = vmull_n_s16(s_3, (int32_t)cospi_8_64); + const int32x4_t s_3_cospi_24_64 = vmull_n_s16(s_3, (int32_t)cospi_24_64); + + const int32x4_t temp3 = + vmlal_n_s16(s_3_cospi_8_64, s_2, (int32_t)cospi_24_64); + const int32x4_t temp4 = + vmlsl_n_s16(s_3_cospi_24_64, s_2, (int32_t)cospi_8_64); + + // fdct_round_shift + int16x4_t out_1 = vrshrn_n_s32(temp3, DCT_CONST_BITS); + int16x4_t out_3 = vrshrn_n_s32(temp4, DCT_CONST_BITS); + + // Only transpose the first pass + if (i == 0) { + transpose_elems_inplace_s16_4x4(&out_0, &out_1, &out_2, &out_3); + } + + *input_0 = out_0; + *input_1 = out_1; + *input_2 = out_2; + *input_3 = out_3; + } +} + +void aom_fdct4x4_neon(const int16_t *input, tran_low_t *final_output, + int stride) { + // input[M * stride] * 16 + int16x4_t input_0, input_1, input_2, input_3; + + aom_fdct4x4_helper(input, stride, &input_0, &input_1, &input_2, &input_3); + + // Not quite a rounding shift. Only add 1 despite shifting by 2. + const int16x8_t one = vdupq_n_s16(1); + int16x8_t out_01 = vcombine_s16(input_0, input_1); + int16x8_t out_23 = vcombine_s16(input_2, input_3); + out_01 = vshrq_n_s16(vaddq_s16(out_01, one), 2); + out_23 = vshrq_n_s16(vaddq_s16(out_23, one), 2); + store_s16q_to_tran_low(final_output + 0 * 8, out_01); + store_s16q_to_tran_low(final_output + 1 * 8, out_23); +} + +void aom_fdct4x4_lp_neon(const int16_t *input, int16_t *final_output, + int stride) { + // input[M * stride] * 16 + int16x4_t input_0, input_1, input_2, input_3; + + aom_fdct4x4_helper(input, stride, &input_0, &input_1, &input_2, &input_3); + + // Not quite a rounding shift. Only add 1 despite shifting by 2. + const int16x8_t one = vdupq_n_s16(1); + int16x8_t out_01 = vcombine_s16(input_0, input_1); + int16x8_t out_23 = vcombine_s16(input_2, input_3); + out_01 = vshrq_n_s16(vaddq_s16(out_01, one), 2); + out_23 = vshrq_n_s16(vaddq_s16(out_23, one), 2); + vst1q_s16(final_output + 0 * 8, out_01); + vst1q_s16(final_output + 1 * 8, out_23); +} + +void aom_fdct8x8_neon(const int16_t *input, int16_t *final_output, int stride) { + // stage 1 + int16x8_t input_0 = vshlq_n_s16(vld1q_s16(&input[0 * stride]), 2); + int16x8_t input_1 = vshlq_n_s16(vld1q_s16(&input[1 * stride]), 2); + int16x8_t input_2 = vshlq_n_s16(vld1q_s16(&input[2 * stride]), 2); + int16x8_t input_3 = vshlq_n_s16(vld1q_s16(&input[3 * stride]), 2); + int16x8_t input_4 = vshlq_n_s16(vld1q_s16(&input[4 * stride]), 2); + int16x8_t input_5 = vshlq_n_s16(vld1q_s16(&input[5 * stride]), 2); + int16x8_t input_6 = vshlq_n_s16(vld1q_s16(&input[6 * stride]), 2); + int16x8_t input_7 = vshlq_n_s16(vld1q_s16(&input[7 * stride]), 2); + for (int i = 0; i < 2; ++i) { + int16x8_t out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7; + const int16x8_t v_s0 = vaddq_s16(input_0, input_7); + const int16x8_t v_s1 = vaddq_s16(input_1, input_6); + const int16x8_t v_s2 = vaddq_s16(input_2, input_5); + const int16x8_t v_s3 = vaddq_s16(input_3, input_4); + const int16x8_t v_s4 = vsubq_s16(input_3, input_4); + const int16x8_t v_s5 = vsubq_s16(input_2, input_5); + const int16x8_t v_s6 = vsubq_s16(input_1, input_6); + const int16x8_t v_s7 = vsubq_s16(input_0, input_7); + // fdct4(step, step); + int16x8_t v_x0 = vaddq_s16(v_s0, v_s3); + int16x8_t v_x1 = vaddq_s16(v_s1, v_s2); + int16x8_t v_x2 = vsubq_s16(v_s1, v_s2); + int16x8_t v_x3 = vsubq_s16(v_s0, v_s3); + // fdct4(step, step); + int32x4_t v_t0_lo = vaddl_s16(vget_low_s16(v_x0), vget_low_s16(v_x1)); + int32x4_t v_t0_hi = vaddl_s16(vget_high_s16(v_x0), vget_high_s16(v_x1)); + int32x4_t v_t1_lo = vsubl_s16(vget_low_s16(v_x0), vget_low_s16(v_x1)); + int32x4_t v_t1_hi = vsubl_s16(vget_high_s16(v_x0), vget_high_s16(v_x1)); + int32x4_t v_t2_lo = vmull_n_s16(vget_low_s16(v_x2), (int16_t)cospi_24_64); + int32x4_t v_t2_hi = vmull_n_s16(vget_high_s16(v_x2), (int16_t)cospi_24_64); + int32x4_t v_t3_lo = vmull_n_s16(vget_low_s16(v_x3), (int16_t)cospi_24_64); + int32x4_t v_t3_hi = vmull_n_s16(vget_high_s16(v_x3), (int16_t)cospi_24_64); + v_t2_lo = vmlal_n_s16(v_t2_lo, vget_low_s16(v_x3), (int16_t)cospi_8_64); + v_t2_hi = vmlal_n_s16(v_t2_hi, vget_high_s16(v_x3), (int16_t)cospi_8_64); + v_t3_lo = vmlsl_n_s16(v_t3_lo, vget_low_s16(v_x2), (int16_t)cospi_8_64); + v_t3_hi = vmlsl_n_s16(v_t3_hi, vget_high_s16(v_x2), (int16_t)cospi_8_64); + v_t0_lo = vmulq_n_s32(v_t0_lo, (int32_t)cospi_16_64); + v_t0_hi = vmulq_n_s32(v_t0_hi, (int32_t)cospi_16_64); + v_t1_lo = vmulq_n_s32(v_t1_lo, (int32_t)cospi_16_64); + v_t1_hi = vmulq_n_s32(v_t1_hi, (int32_t)cospi_16_64); + { + const int16x4_t a = vrshrn_n_s32(v_t0_lo, DCT_CONST_BITS); + const int16x4_t b = vrshrn_n_s32(v_t0_hi, DCT_CONST_BITS); + const int16x4_t c = vrshrn_n_s32(v_t1_lo, DCT_CONST_BITS); + const int16x4_t d = vrshrn_n_s32(v_t1_hi, DCT_CONST_BITS); + const int16x4_t e = vrshrn_n_s32(v_t2_lo, DCT_CONST_BITS); + const int16x4_t f = vrshrn_n_s32(v_t2_hi, DCT_CONST_BITS); + const int16x4_t g = vrshrn_n_s32(v_t3_lo, DCT_CONST_BITS); + const int16x4_t h = vrshrn_n_s32(v_t3_hi, DCT_CONST_BITS); + out_0 = vcombine_s16(a, c); // 00 01 02 03 40 41 42 43 + out_2 = vcombine_s16(e, g); // 20 21 22 23 60 61 62 63 + out_4 = vcombine_s16(b, d); // 04 05 06 07 44 45 46 47 + out_6 = vcombine_s16(f, h); // 24 25 26 27 64 65 66 67 + } + // Stage 2 + v_x0 = vsubq_s16(v_s6, v_s5); + v_x1 = vaddq_s16(v_s6, v_s5); + v_t0_lo = vmull_n_s16(vget_low_s16(v_x0), (int16_t)cospi_16_64); + v_t0_hi = vmull_n_s16(vget_high_s16(v_x0), (int16_t)cospi_16_64); + v_t1_lo = vmull_n_s16(vget_low_s16(v_x1), (int16_t)cospi_16_64); + v_t1_hi = vmull_n_s16(vget_high_s16(v_x1), (int16_t)cospi_16_64); + { + const int16x4_t a = vrshrn_n_s32(v_t0_lo, DCT_CONST_BITS); + const int16x4_t b = vrshrn_n_s32(v_t0_hi, DCT_CONST_BITS); + const int16x4_t c = vrshrn_n_s32(v_t1_lo, DCT_CONST_BITS); + const int16x4_t d = vrshrn_n_s32(v_t1_hi, DCT_CONST_BITS); + const int16x8_t ab = vcombine_s16(a, b); + const int16x8_t cd = vcombine_s16(c, d); + // Stage 3 + v_x0 = vaddq_s16(v_s4, ab); + v_x1 = vsubq_s16(v_s4, ab); + v_x2 = vsubq_s16(v_s7, cd); + v_x3 = vaddq_s16(v_s7, cd); + } + // Stage 4 + v_t0_lo = vmull_n_s16(vget_low_s16(v_x3), (int16_t)cospi_4_64); + v_t0_hi = vmull_n_s16(vget_high_s16(v_x3), (int16_t)cospi_4_64); + v_t0_lo = vmlal_n_s16(v_t0_lo, vget_low_s16(v_x0), (int16_t)cospi_28_64); + v_t0_hi = vmlal_n_s16(v_t0_hi, vget_high_s16(v_x0), (int16_t)cospi_28_64); + v_t1_lo = vmull_n_s16(vget_low_s16(v_x1), (int16_t)cospi_12_64); + v_t1_hi = vmull_n_s16(vget_high_s16(v_x1), (int16_t)cospi_12_64); + v_t1_lo = vmlal_n_s16(v_t1_lo, vget_low_s16(v_x2), (int16_t)cospi_20_64); + v_t1_hi = vmlal_n_s16(v_t1_hi, vget_high_s16(v_x2), (int16_t)cospi_20_64); + v_t2_lo = vmull_n_s16(vget_low_s16(v_x2), (int16_t)cospi_12_64); + v_t2_hi = vmull_n_s16(vget_high_s16(v_x2), (int16_t)cospi_12_64); + v_t2_lo = vmlsl_n_s16(v_t2_lo, vget_low_s16(v_x1), (int16_t)cospi_20_64); + v_t2_hi = vmlsl_n_s16(v_t2_hi, vget_high_s16(v_x1), (int16_t)cospi_20_64); + v_t3_lo = vmull_n_s16(vget_low_s16(v_x3), (int16_t)cospi_28_64); + v_t3_hi = vmull_n_s16(vget_high_s16(v_x3), (int16_t)cospi_28_64); + v_t3_lo = vmlsl_n_s16(v_t3_lo, vget_low_s16(v_x0), (int16_t)cospi_4_64); + v_t3_hi = vmlsl_n_s16(v_t3_hi, vget_high_s16(v_x0), (int16_t)cospi_4_64); + { + const int16x4_t a = vrshrn_n_s32(v_t0_lo, DCT_CONST_BITS); + const int16x4_t b = vrshrn_n_s32(v_t0_hi, DCT_CONST_BITS); + const int16x4_t c = vrshrn_n_s32(v_t1_lo, DCT_CONST_BITS); + const int16x4_t d = vrshrn_n_s32(v_t1_hi, DCT_CONST_BITS); + const int16x4_t e = vrshrn_n_s32(v_t2_lo, DCT_CONST_BITS); + const int16x4_t f = vrshrn_n_s32(v_t2_hi, DCT_CONST_BITS); + const int16x4_t g = vrshrn_n_s32(v_t3_lo, DCT_CONST_BITS); + const int16x4_t h = vrshrn_n_s32(v_t3_hi, DCT_CONST_BITS); + out_1 = vcombine_s16(a, c); // 10 11 12 13 50 51 52 53 + out_3 = vcombine_s16(e, g); // 30 31 32 33 70 71 72 73 + out_5 = vcombine_s16(b, d); // 14 15 16 17 54 55 56 57 + out_7 = vcombine_s16(f, h); // 34 35 36 37 74 75 76 77 + } + // transpose 8x8 + { + // 00 01 02 03 40 41 42 43 + // 10 11 12 13 50 51 52 53 + // 20 21 22 23 60 61 62 63 + // 30 31 32 33 70 71 72 73 + // 04 05 06 07 44 45 46 47 + // 14 15 16 17 54 55 56 57 + // 24 25 26 27 64 65 66 67 + // 34 35 36 37 74 75 76 77 + const int32x4x2_t r02_s32 = + vtrnq_s32(vreinterpretq_s32_s16(out_0), vreinterpretq_s32_s16(out_2)); + const int32x4x2_t r13_s32 = + vtrnq_s32(vreinterpretq_s32_s16(out_1), vreinterpretq_s32_s16(out_3)); + const int32x4x2_t r46_s32 = + vtrnq_s32(vreinterpretq_s32_s16(out_4), vreinterpretq_s32_s16(out_6)); + const int32x4x2_t r57_s32 = + vtrnq_s32(vreinterpretq_s32_s16(out_5), vreinterpretq_s32_s16(out_7)); + const int16x8x2_t r01_s16 = + vtrnq_s16(vreinterpretq_s16_s32(r02_s32.val[0]), + vreinterpretq_s16_s32(r13_s32.val[0])); + const int16x8x2_t r23_s16 = + vtrnq_s16(vreinterpretq_s16_s32(r02_s32.val[1]), + vreinterpretq_s16_s32(r13_s32.val[1])); + const int16x8x2_t r45_s16 = + vtrnq_s16(vreinterpretq_s16_s32(r46_s32.val[0]), + vreinterpretq_s16_s32(r57_s32.val[0])); + const int16x8x2_t r67_s16 = + vtrnq_s16(vreinterpretq_s16_s32(r46_s32.val[1]), + vreinterpretq_s16_s32(r57_s32.val[1])); + input_0 = r01_s16.val[0]; + input_1 = r01_s16.val[1]; + input_2 = r23_s16.val[0]; + input_3 = r23_s16.val[1]; + input_4 = r45_s16.val[0]; + input_5 = r45_s16.val[1]; + input_6 = r67_s16.val[0]; + input_7 = r67_s16.val[1]; + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + } + } // for + { + // from aom_dct_sse2.c + // Post-condition (division by two) + // division of two 16 bits signed numbers using shifts + // n / 2 = (n - (n >> 15)) >> 1 + const int16x8_t sign_in0 = vshrq_n_s16(input_0, 15); + const int16x8_t sign_in1 = vshrq_n_s16(input_1, 15); + const int16x8_t sign_in2 = vshrq_n_s16(input_2, 15); + const int16x8_t sign_in3 = vshrq_n_s16(input_3, 15); + const int16x8_t sign_in4 = vshrq_n_s16(input_4, 15); + const int16x8_t sign_in5 = vshrq_n_s16(input_5, 15); + const int16x8_t sign_in6 = vshrq_n_s16(input_6, 15); + const int16x8_t sign_in7 = vshrq_n_s16(input_7, 15); + input_0 = vhsubq_s16(input_0, sign_in0); + input_1 = vhsubq_s16(input_1, sign_in1); + input_2 = vhsubq_s16(input_2, sign_in2); + input_3 = vhsubq_s16(input_3, sign_in3); + input_4 = vhsubq_s16(input_4, sign_in4); + input_5 = vhsubq_s16(input_5, sign_in5); + input_6 = vhsubq_s16(input_6, sign_in6); + input_7 = vhsubq_s16(input_7, sign_in7); + // store results + vst1q_s16(&final_output[0 * 8], input_0); + vst1q_s16(&final_output[1 * 8], input_1); + vst1q_s16(&final_output[2 * 8], input_2); + vst1q_s16(&final_output[3 * 8], input_3); + vst1q_s16(&final_output[4 * 8], input_4); + vst1q_s16(&final_output[5 * 8], input_5); + vst1q_s16(&final_output[6 * 8], input_6); + vst1q_s16(&final_output[7 * 8], input_7); + } +} diff --git a/third_party/aom/aom_dsp/arm/hadamard_neon.c b/third_party/aom/aom_dsp/arm/hadamard_neon.c new file mode 100644 index 0000000000..d0f59227db --- /dev/null +++ b/third_party/aom/aom_dsp/arm/hadamard_neon.c @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2019, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" + +static INLINE void hadamard_4x4_one_pass(int16x4_t *a0, int16x4_t *a1, + int16x4_t *a2, int16x4_t *a3) { + const int16x4_t b0 = vhadd_s16(*a0, *a1); + const int16x4_t b1 = vhsub_s16(*a0, *a1); + const int16x4_t b2 = vhadd_s16(*a2, *a3); + const int16x4_t b3 = vhsub_s16(*a2, *a3); + + *a0 = vadd_s16(b0, b2); + *a1 = vadd_s16(b1, b3); + *a2 = vsub_s16(b0, b2); + *a3 = vsub_s16(b1, b3); +} + +void aom_hadamard_4x4_neon(const int16_t *src_diff, ptrdiff_t src_stride, + tran_low_t *coeff) { + int16x4_t a0 = vld1_s16(src_diff); + int16x4_t a1 = vld1_s16(src_diff + src_stride); + int16x4_t a2 = vld1_s16(src_diff + 2 * src_stride); + int16x4_t a3 = vld1_s16(src_diff + 3 * src_stride); + + hadamard_4x4_one_pass(&a0, &a1, &a2, &a3); + + transpose_elems_inplace_s16_4x4(&a0, &a1, &a2, &a3); + + hadamard_4x4_one_pass(&a0, &a1, &a2, &a3); + + store_s16_to_tran_low(coeff, a0); + store_s16_to_tran_low(coeff + 4, a1); + store_s16_to_tran_low(coeff + 8, a2); + store_s16_to_tran_low(coeff + 12, a3); +} + +static void hadamard8x8_one_pass(int16x8_t *a0, int16x8_t *a1, int16x8_t *a2, + int16x8_t *a3, int16x8_t *a4, int16x8_t *a5, + int16x8_t *a6, int16x8_t *a7) { + const int16x8_t b0 = vaddq_s16(*a0, *a1); + const int16x8_t b1 = vsubq_s16(*a0, *a1); + const int16x8_t b2 = vaddq_s16(*a2, *a3); + const int16x8_t b3 = vsubq_s16(*a2, *a3); + const int16x8_t b4 = vaddq_s16(*a4, *a5); + const int16x8_t b5 = vsubq_s16(*a4, *a5); + const int16x8_t b6 = vaddq_s16(*a6, *a7); + const int16x8_t b7 = vsubq_s16(*a6, *a7); + + const int16x8_t c0 = vaddq_s16(b0, b2); + const int16x8_t c1 = vaddq_s16(b1, b3); + const int16x8_t c2 = vsubq_s16(b0, b2); + const int16x8_t c3 = vsubq_s16(b1, b3); + const int16x8_t c4 = vaddq_s16(b4, b6); + const int16x8_t c5 = vaddq_s16(b5, b7); + const int16x8_t c6 = vsubq_s16(b4, b6); + const int16x8_t c7 = vsubq_s16(b5, b7); + + *a0 = vaddq_s16(c0, c4); + *a1 = vsubq_s16(c2, c6); + *a2 = vsubq_s16(c0, c4); + *a3 = vaddq_s16(c2, c6); + *a4 = vaddq_s16(c3, c7); + *a5 = vsubq_s16(c3, c7); + *a6 = vsubq_s16(c1, c5); + *a7 = vaddq_s16(c1, c5); +} + +void aom_hadamard_8x8_neon(const int16_t *src_diff, ptrdiff_t src_stride, + tran_low_t *coeff) { + int16x8_t a0 = vld1q_s16(src_diff); + int16x8_t a1 = vld1q_s16(src_diff + src_stride); + int16x8_t a2 = vld1q_s16(src_diff + 2 * src_stride); + int16x8_t a3 = vld1q_s16(src_diff + 3 * src_stride); + int16x8_t a4 = vld1q_s16(src_diff + 4 * src_stride); + int16x8_t a5 = vld1q_s16(src_diff + 5 * src_stride); + int16x8_t a6 = vld1q_s16(src_diff + 6 * src_stride); + int16x8_t a7 = vld1q_s16(src_diff + 7 * src_stride); + + hadamard8x8_one_pass(&a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + transpose_elems_inplace_s16_8x8(&a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + hadamard8x8_one_pass(&a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + // Skip the second transpose because it is not required. + + store_s16q_to_tran_low(coeff + 0, a0); + store_s16q_to_tran_low(coeff + 8, a1); + store_s16q_to_tran_low(coeff + 16, a2); + store_s16q_to_tran_low(coeff + 24, a3); + store_s16q_to_tran_low(coeff + 32, a4); + store_s16q_to_tran_low(coeff + 40, a5); + store_s16q_to_tran_low(coeff + 48, a6); + store_s16q_to_tran_low(coeff + 56, a7); +} + +void aom_hadamard_lp_8x8_neon(const int16_t *src_diff, ptrdiff_t src_stride, + int16_t *coeff) { + int16x8_t a0 = vld1q_s16(src_diff); + int16x8_t a1 = vld1q_s16(src_diff + src_stride); + int16x8_t a2 = vld1q_s16(src_diff + 2 * src_stride); + int16x8_t a3 = vld1q_s16(src_diff + 3 * src_stride); + int16x8_t a4 = vld1q_s16(src_diff + 4 * src_stride); + int16x8_t a5 = vld1q_s16(src_diff + 5 * src_stride); + int16x8_t a6 = vld1q_s16(src_diff + 6 * src_stride); + int16x8_t a7 = vld1q_s16(src_diff + 7 * src_stride); + + hadamard8x8_one_pass(&a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + transpose_elems_inplace_s16_8x8(&a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + hadamard8x8_one_pass(&a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + // Skip the second transpose because it is not required. + + vst1q_s16(coeff + 0, a0); + vst1q_s16(coeff + 8, a1); + vst1q_s16(coeff + 16, a2); + vst1q_s16(coeff + 24, a3); + vst1q_s16(coeff + 32, a4); + vst1q_s16(coeff + 40, a5); + vst1q_s16(coeff + 48, a6); + vst1q_s16(coeff + 56, a7); +} + +void aom_hadamard_lp_8x8_dual_neon(const int16_t *src_diff, + ptrdiff_t src_stride, int16_t *coeff) { + for (int i = 0; i < 2; i++) { + aom_hadamard_lp_8x8_neon(src_diff + (i * 8), src_stride, coeff + (i * 64)); + } +} + +void aom_hadamard_lp_16x16_neon(const int16_t *src_diff, ptrdiff_t src_stride, + int16_t *coeff) { + /* Rearrange 16x16 to 8x32 and remove stride. + * Top left first. */ + aom_hadamard_lp_8x8_neon(src_diff + 0 + 0 * src_stride, src_stride, + coeff + 0); + /* Top right. */ + aom_hadamard_lp_8x8_neon(src_diff + 8 + 0 * src_stride, src_stride, + coeff + 64); + /* Bottom left. */ + aom_hadamard_lp_8x8_neon(src_diff + 0 + 8 * src_stride, src_stride, + coeff + 128); + /* Bottom right. */ + aom_hadamard_lp_8x8_neon(src_diff + 8 + 8 * src_stride, src_stride, + coeff + 192); + + for (int i = 0; i < 64; i += 8) { + const int16x8_t a0 = vld1q_s16(coeff + 0); + const int16x8_t a1 = vld1q_s16(coeff + 64); + const int16x8_t a2 = vld1q_s16(coeff + 128); + const int16x8_t a3 = vld1q_s16(coeff + 192); + + const int16x8_t b0 = vhaddq_s16(a0, a1); + const int16x8_t b1 = vhsubq_s16(a0, a1); + const int16x8_t b2 = vhaddq_s16(a2, a3); + const int16x8_t b3 = vhsubq_s16(a2, a3); + + const int16x8_t c0 = vaddq_s16(b0, b2); + const int16x8_t c1 = vaddq_s16(b1, b3); + const int16x8_t c2 = vsubq_s16(b0, b2); + const int16x8_t c3 = vsubq_s16(b1, b3); + + vst1q_s16(coeff + 0, c0); + vst1q_s16(coeff + 64, c1); + vst1q_s16(coeff + 128, c2); + vst1q_s16(coeff + 192, c3); + + coeff += 8; + } +} + +void aom_hadamard_16x16_neon(const int16_t *src_diff, ptrdiff_t src_stride, + tran_low_t *coeff) { + /* Rearrange 16x16 to 8x32 and remove stride. + * Top left first. */ + aom_hadamard_8x8_neon(src_diff + 0 + 0 * src_stride, src_stride, coeff + 0); + /* Top right. */ + aom_hadamard_8x8_neon(src_diff + 8 + 0 * src_stride, src_stride, coeff + 64); + /* Bottom left. */ + aom_hadamard_8x8_neon(src_diff + 0 + 8 * src_stride, src_stride, coeff + 128); + /* Bottom right. */ + aom_hadamard_8x8_neon(src_diff + 8 + 8 * src_stride, src_stride, coeff + 192); + + // Each iteration of the loop operates on entire rows (16 samples each) + // because we need to swap the second and third quarters of every row in the + // output to match AVX2 output (i.e., aom_hadamard_16x16_avx2). See the for + // loop at the end of aom_hadamard_16x16_c. + for (int i = 0; i < 64; i += 16) { + const int32x4_t a00 = vld1q_s32(coeff + 0); + const int32x4_t a01 = vld1q_s32(coeff + 64); + const int32x4_t a02 = vld1q_s32(coeff + 128); + const int32x4_t a03 = vld1q_s32(coeff + 192); + + const int32x4_t b00 = vhaddq_s32(a00, a01); + const int32x4_t b01 = vhsubq_s32(a00, a01); + const int32x4_t b02 = vhaddq_s32(a02, a03); + const int32x4_t b03 = vhsubq_s32(a02, a03); + + const int32x4_t c00 = vaddq_s32(b00, b02); + const int32x4_t c01 = vaddq_s32(b01, b03); + const int32x4_t c02 = vsubq_s32(b00, b02); + const int32x4_t c03 = vsubq_s32(b01, b03); + + const int32x4_t a10 = vld1q_s32(coeff + 4 + 0); + const int32x4_t a11 = vld1q_s32(coeff + 4 + 64); + const int32x4_t a12 = vld1q_s32(coeff + 4 + 128); + const int32x4_t a13 = vld1q_s32(coeff + 4 + 192); + + const int32x4_t b10 = vhaddq_s32(a10, a11); + const int32x4_t b11 = vhsubq_s32(a10, a11); + const int32x4_t b12 = vhaddq_s32(a12, a13); + const int32x4_t b13 = vhsubq_s32(a12, a13); + + const int32x4_t c10 = vaddq_s32(b10, b12); + const int32x4_t c11 = vaddq_s32(b11, b13); + const int32x4_t c12 = vsubq_s32(b10, b12); + const int32x4_t c13 = vsubq_s32(b11, b13); + + const int32x4_t a20 = vld1q_s32(coeff + 8 + 0); + const int32x4_t a21 = vld1q_s32(coeff + 8 + 64); + const int32x4_t a22 = vld1q_s32(coeff + 8 + 128); + const int32x4_t a23 = vld1q_s32(coeff + 8 + 192); + + const int32x4_t b20 = vhaddq_s32(a20, a21); + const int32x4_t b21 = vhsubq_s32(a20, a21); + const int32x4_t b22 = vhaddq_s32(a22, a23); + const int32x4_t b23 = vhsubq_s32(a22, a23); + + const int32x4_t c20 = vaddq_s32(b20, b22); + const int32x4_t c21 = vaddq_s32(b21, b23); + const int32x4_t c22 = vsubq_s32(b20, b22); + const int32x4_t c23 = vsubq_s32(b21, b23); + + const int32x4_t a30 = vld1q_s32(coeff + 12 + 0); + const int32x4_t a31 = vld1q_s32(coeff + 12 + 64); + const int32x4_t a32 = vld1q_s32(coeff + 12 + 128); + const int32x4_t a33 = vld1q_s32(coeff + 12 + 192); + + const int32x4_t b30 = vhaddq_s32(a30, a31); + const int32x4_t b31 = vhsubq_s32(a30, a31); + const int32x4_t b32 = vhaddq_s32(a32, a33); + const int32x4_t b33 = vhsubq_s32(a32, a33); + + const int32x4_t c30 = vaddq_s32(b30, b32); + const int32x4_t c31 = vaddq_s32(b31, b33); + const int32x4_t c32 = vsubq_s32(b30, b32); + const int32x4_t c33 = vsubq_s32(b31, b33); + + vst1q_s32(coeff + 0 + 0, c00); + vst1q_s32(coeff + 0 + 4, c20); + vst1q_s32(coeff + 0 + 8, c10); + vst1q_s32(coeff + 0 + 12, c30); + + vst1q_s32(coeff + 64 + 0, c01); + vst1q_s32(coeff + 64 + 4, c21); + vst1q_s32(coeff + 64 + 8, c11); + vst1q_s32(coeff + 64 + 12, c31); + + vst1q_s32(coeff + 128 + 0, c02); + vst1q_s32(coeff + 128 + 4, c22); + vst1q_s32(coeff + 128 + 8, c12); + vst1q_s32(coeff + 128 + 12, c32); + + vst1q_s32(coeff + 192 + 0, c03); + vst1q_s32(coeff + 192 + 4, c23); + vst1q_s32(coeff + 192 + 8, c13); + vst1q_s32(coeff + 192 + 12, c33); + + coeff += 16; + } +} + +void aom_hadamard_32x32_neon(const int16_t *src_diff, ptrdiff_t src_stride, + tran_low_t *coeff) { + /* Top left first. */ + aom_hadamard_16x16_neon(src_diff + 0 + 0 * src_stride, src_stride, coeff + 0); + /* Top right. */ + aom_hadamard_16x16_neon(src_diff + 16 + 0 * src_stride, src_stride, + coeff + 256); + /* Bottom left. */ + aom_hadamard_16x16_neon(src_diff + 0 + 16 * src_stride, src_stride, + coeff + 512); + /* Bottom right. */ + aom_hadamard_16x16_neon(src_diff + 16 + 16 * src_stride, src_stride, + coeff + 768); + + for (int i = 0; i < 256; i += 4) { + const int32x4_t a0 = vld1q_s32(coeff); + const int32x4_t a1 = vld1q_s32(coeff + 256); + const int32x4_t a2 = vld1q_s32(coeff + 512); + const int32x4_t a3 = vld1q_s32(coeff + 768); + + const int32x4_t b0 = vshrq_n_s32(vaddq_s32(a0, a1), 2); + const int32x4_t b1 = vshrq_n_s32(vsubq_s32(a0, a1), 2); + const int32x4_t b2 = vshrq_n_s32(vaddq_s32(a2, a3), 2); + const int32x4_t b3 = vshrq_n_s32(vsubq_s32(a2, a3), 2); + + const int32x4_t c0 = vaddq_s32(b0, b2); + const int32x4_t c1 = vaddq_s32(b1, b3); + const int32x4_t c2 = vsubq_s32(b0, b2); + const int32x4_t c3 = vsubq_s32(b1, b3); + + vst1q_s32(coeff + 0, c0); + vst1q_s32(coeff + 256, c1); + vst1q_s32(coeff + 512, c2); + vst1q_s32(coeff + 768, c3); + + coeff += 4; + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_avg_neon.c b/third_party/aom/aom_dsp/arm/highbd_avg_neon.c new file mode 100644 index 0000000000..47d5dae012 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_avg_neon.c @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_ports/mem.h" + +uint32_t aom_highbd_avg_4x4_neon(const uint8_t *a, int a_stride) { + const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a); + uint16x4_t sum, a0, a1, a2, a3; + + load_u16_4x4(a_ptr, a_stride, &a0, &a1, &a2, &a3); + + sum = vadd_u16(a0, a1); + sum = vadd_u16(sum, a2); + sum = vadd_u16(sum, a3); + + return (horizontal_add_u16x4(sum) + (1 << 3)) >> 4; +} + +uint32_t aom_highbd_avg_8x8_neon(const uint8_t *a, int a_stride) { + const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a); + uint16x8_t sum, a0, a1, a2, a3, a4, a5, a6, a7; + + load_u16_8x8(a_ptr, a_stride, &a0, &a1, &a2, &a3, &a4, &a5, &a6, &a7); + + sum = vaddq_u16(a0, a1); + sum = vaddq_u16(sum, a2); + sum = vaddq_u16(sum, a3); + sum = vaddq_u16(sum, a4); + sum = vaddq_u16(sum, a5); + sum = vaddq_u16(sum, a6); + sum = vaddq_u16(sum, a7); + + return (horizontal_add_u16x8(sum) + (1 << 5)) >> 6; +} + +void aom_highbd_minmax_8x8_neon(const uint8_t *s8, int p, const uint8_t *d8, + int dp, int *min, int *max) { + const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(s8); + const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(d8); + + const uint16x8_t a0 = vld1q_u16(a_ptr + 0 * p); + const uint16x8_t a1 = vld1q_u16(a_ptr + 1 * p); + const uint16x8_t a2 = vld1q_u16(a_ptr + 2 * p); + const uint16x8_t a3 = vld1q_u16(a_ptr + 3 * p); + const uint16x8_t a4 = vld1q_u16(a_ptr + 4 * p); + const uint16x8_t a5 = vld1q_u16(a_ptr + 5 * p); + const uint16x8_t a6 = vld1q_u16(a_ptr + 6 * p); + const uint16x8_t a7 = vld1q_u16(a_ptr + 7 * p); + + const uint16x8_t b0 = vld1q_u16(b_ptr + 0 * dp); + const uint16x8_t b1 = vld1q_u16(b_ptr + 1 * dp); + const uint16x8_t b2 = vld1q_u16(b_ptr + 2 * dp); + const uint16x8_t b3 = vld1q_u16(b_ptr + 3 * dp); + const uint16x8_t b4 = vld1q_u16(b_ptr + 4 * dp); + const uint16x8_t b5 = vld1q_u16(b_ptr + 5 * dp); + const uint16x8_t b6 = vld1q_u16(b_ptr + 6 * dp); + const uint16x8_t b7 = vld1q_u16(b_ptr + 7 * dp); + + const uint16x8_t abs_diff0 = vabdq_u16(a0, b0); + const uint16x8_t abs_diff1 = vabdq_u16(a1, b1); + const uint16x8_t abs_diff2 = vabdq_u16(a2, b2); + const uint16x8_t abs_diff3 = vabdq_u16(a3, b3); + const uint16x8_t abs_diff4 = vabdq_u16(a4, b4); + const uint16x8_t abs_diff5 = vabdq_u16(a5, b5); + const uint16x8_t abs_diff6 = vabdq_u16(a6, b6); + const uint16x8_t abs_diff7 = vabdq_u16(a7, b7); + + const uint16x8_t max01 = vmaxq_u16(abs_diff0, abs_diff1); + const uint16x8_t max23 = vmaxq_u16(abs_diff2, abs_diff3); + const uint16x8_t max45 = vmaxq_u16(abs_diff4, abs_diff5); + const uint16x8_t max67 = vmaxq_u16(abs_diff6, abs_diff7); + + const uint16x8_t max0123 = vmaxq_u16(max01, max23); + const uint16x8_t max4567 = vmaxq_u16(max45, max67); + const uint16x8_t max07 = vmaxq_u16(max0123, max4567); + + const uint16x8_t min01 = vminq_u16(abs_diff0, abs_diff1); + const uint16x8_t min23 = vminq_u16(abs_diff2, abs_diff3); + const uint16x8_t min45 = vminq_u16(abs_diff4, abs_diff5); + const uint16x8_t min67 = vminq_u16(abs_diff6, abs_diff7); + + const uint16x8_t min0123 = vminq_u16(min01, min23); + const uint16x8_t min4567 = vminq_u16(min45, min67); + const uint16x8_t min07 = vminq_u16(min0123, min4567); + +#if AOM_ARCH_AARCH64 + *max = (int)vmaxvq_u16(max07); + *min = (int)vminvq_u16(min07); +#else + // Split into 64-bit vectors and execute pairwise min/max. + uint16x4_t ab_max = vmax_u16(vget_high_u16(max07), vget_low_u16(max07)); + uint16x4_t ab_min = vmin_u16(vget_high_u16(min07), vget_low_u16(min07)); + + // Enough runs of vpmax/min propagate the max/min values to every position. + ab_max = vpmax_u16(ab_max, ab_max); + ab_min = vpmin_u16(ab_min, ab_min); + + ab_max = vpmax_u16(ab_max, ab_max); + ab_min = vpmin_u16(ab_min, ab_min); + + ab_max = vpmax_u16(ab_max, ab_max); + ab_min = vpmin_u16(ab_min, ab_min); + + *min = *max = 0; // Clear high bits + // Store directly to avoid costly neon->gpr transfer. + vst1_lane_u16((uint16_t *)max, ab_max, 0); + vst1_lane_u16((uint16_t *)min, ab_min, 0); +#endif +} diff --git a/third_party/aom/aom_dsp/arm/highbd_avg_pred_neon.c b/third_party/aom/aom_dsp/arm/highbd_avg_pred_neon.c new file mode 100644 index 0000000000..531309b025 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_avg_pred_neon.c @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/dist_wtd_avg_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/blend.h" + +void aom_highbd_comp_avg_pred_neon(uint8_t *comp_pred8, const uint8_t *pred8, + int width, int height, const uint8_t *ref8, + int ref_stride) { + const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); + const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); + uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8); + + int i = height; + if (width > 8) { + do { + int j = 0; + do { + const uint16x8_t p = vld1q_u16(pred + j); + const uint16x8_t r = vld1q_u16(ref + j); + + uint16x8_t avg = vrhaddq_u16(p, r); + vst1q_u16(comp_pred + j, avg); + + j += 8; + } while (j < width); + + comp_pred += width; + pred += width; + ref += ref_stride; + } while (--i != 0); + } else if (width == 8) { + do { + const uint16x8_t p = vld1q_u16(pred); + const uint16x8_t r = vld1q_u16(ref); + + uint16x8_t avg = vrhaddq_u16(p, r); + vst1q_u16(comp_pred, avg); + + comp_pred += width; + pred += width; + ref += ref_stride; + } while (--i != 0); + } else { + assert(width == 4); + do { + const uint16x4_t p = vld1_u16(pred); + const uint16x4_t r = vld1_u16(ref); + + uint16x4_t avg = vrhadd_u16(p, r); + vst1_u16(comp_pred, avg); + + comp_pred += width; + pred += width; + ref += ref_stride; + } while (--i != 0); + } +} + +void aom_highbd_comp_mask_pred_neon(uint8_t *comp_pred8, const uint8_t *pred8, + int width, int height, const uint8_t *ref8, + int ref_stride, const uint8_t *mask, + int mask_stride, int invert_mask) { + uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); + uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); + uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8); + + const uint16_t *src0 = invert_mask ? pred : ref; + const uint16_t *src1 = invert_mask ? ref : pred; + const int src_stride0 = invert_mask ? width : ref_stride; + const int src_stride1 = invert_mask ? ref_stride : width; + + if (width >= 8) { + do { + int j = 0; + + do { + const uint16x8_t s0 = vld1q_u16(src0 + j); + const uint16x8_t s1 = vld1q_u16(src1 + j); + const uint16x8_t m0 = vmovl_u8(vld1_u8(mask + j)); + + uint16x8_t blend_u16 = alpha_blend_a64_u16x8(m0, s0, s1); + + vst1q_u16(comp_pred + j, blend_u16); + + j += 8; + } while (j < width); + + src0 += src_stride0; + src1 += src_stride1; + mask += mask_stride; + comp_pred += width; + } while (--height != 0); + } else { + assert(width == 4); + + do { + const uint16x4_t s0 = vld1_u16(src0); + const uint16x4_t s1 = vld1_u16(src1); + const uint16x4_t m0 = vget_low_u16(vmovl_u8(load_unaligned_u8_4x1(mask))); + + uint16x4_t blend_u16 = alpha_blend_a64_u16x4(m0, s0, s1); + + vst1_u16(comp_pred, blend_u16); + + src0 += src_stride0; + src1 += src_stride1; + mask += mask_stride; + comp_pred += 4; + } while (--height != 0); + } +} + +void aom_highbd_dist_wtd_comp_avg_pred_neon( + uint8_t *comp_pred8, const uint8_t *pred8, int width, int height, + const uint8_t *ref8, int ref_stride, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint16x8_t fwd_offset_u16 = vdupq_n_u16(jcp_param->fwd_offset); + const uint16x8_t bck_offset_u16 = vdupq_n_u16(jcp_param->bck_offset); + const uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); + const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); + uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8); + + if (width > 8) { + do { + int j = 0; + do { + const uint16x8_t p = vld1q_u16(pred + j); + const uint16x8_t r = vld1q_u16(ref + j); + + const uint16x8_t avg = + dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16); + + vst1q_u16(comp_pred + j, avg); + + j += 8; + } while (j < width); + + comp_pred += width; + pred += width; + ref += ref_stride; + } while (--height != 0); + } else if (width == 8) { + do { + const uint16x8_t p = vld1q_u16(pred); + const uint16x8_t r = vld1q_u16(ref); + + const uint16x8_t avg = + dist_wtd_avg_u16x8(r, p, fwd_offset_u16, bck_offset_u16); + + vst1q_u16(comp_pred, avg); + + comp_pred += width; + pred += width; + ref += ref_stride; + } while (--height != 0); + } else { + assert(width == 4); + do { + const uint16x4_t p = vld1_u16(pred); + const uint16x4_t r = vld1_u16(ref); + + const uint16x4_t avg = dist_wtd_avg_u16x4( + r, p, vget_low_u16(fwd_offset_u16), vget_low_u16(bck_offset_u16)); + + vst1_u16(comp_pred, avg); + + comp_pred += width; + pred += width; + ref += ref_stride; + } while (--height != 0); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_blend_a64_hmask_neon.c b/third_party/aom/aom_dsp/arm/highbd_blend_a64_hmask_neon.c new file mode 100644 index 0000000000..8b03e91ac3 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_blend_a64_hmask_neon.c @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/blend.h" + +void aom_highbd_blend_a64_hmask_neon(uint8_t *dst_8, uint32_t dst_stride, + const uint8_t *src0_8, + uint32_t src0_stride, + const uint8_t *src1_8, + uint32_t src1_stride, const uint8_t *mask, + int w, int h, int bd) { + (void)bd; + + const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8); + const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8); + uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8); + + assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); + assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); + + assert(h >= 1); + assert(w >= 1); + assert(IS_POWER_OF_TWO(h)); + assert(IS_POWER_OF_TWO(w)); + + assert(bd == 8 || bd == 10 || bd == 12); + + if (w >= 8) { + do { + int i = 0; + do { + uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1); + + vst1q_u16(dst + i, blend); + i += 8; + } while (i < w); + + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else if (w == 4) { + const uint16x8_t m0 = vmovl_u8(load_unaligned_dup_u8_4x2(mask)); + do { + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1); + + store_u16x4_strided_x2(dst, dst_stride, blend); + + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } else if (w == 2 && h >= 8) { + const uint16x4_t m0 = + vget_low_u16(vmovl_u8(load_unaligned_dup_u8_2x4(mask))); + do { + uint16x4_t s0 = load_unaligned_u16_2x2(src0, src0_stride); + uint16x4_t s1 = load_unaligned_u16_2x2(src1, src1_stride); + + uint16x4_t blend = alpha_blend_a64_u16x4(m0, s0, s1); + + store_u16x2_strided_x2(dst, dst_stride, blend); + + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } else { + aom_highbd_blend_a64_hmask_c(dst_8, dst_stride, src0_8, src0_stride, src1_8, + src1_stride, mask, w, h, bd); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_blend_a64_mask_neon.c b/third_party/aom/aom_dsp/arm/highbd_blend_a64_mask_neon.c new file mode 100644 index 0000000000..90b44fcc5e --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_blend_a64_mask_neon.c @@ -0,0 +1,473 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/blend.h" + +#define HBD_BLEND_A64_D16_MASK(bd, round0_bits) \ + static INLINE uint16x8_t alpha_##bd##_blend_a64_d16_u16x8( \ + uint16x8_t m, uint16x8_t a, uint16x8_t b, int32x4_t round_offset) { \ + const uint16x8_t m_inv = \ + vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); \ + \ + uint32x4_t blend_u32_lo = vmlal_u16(vreinterpretq_u32_s32(round_offset), \ + vget_low_u16(m), vget_low_u16(a)); \ + uint32x4_t blend_u32_hi = vmlal_u16(vreinterpretq_u32_s32(round_offset), \ + vget_high_u16(m), vget_high_u16(a)); \ + \ + blend_u32_lo = \ + vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b)); \ + blend_u32_hi = \ + vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b)); \ + \ + uint16x4_t blend_u16_lo = \ + vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_lo), \ + AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS - \ + round0_bits - COMPOUND_ROUND1_BITS); \ + uint16x4_t blend_u16_hi = \ + vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_hi), \ + AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS - \ + round0_bits - COMPOUND_ROUND1_BITS); \ + \ + uint16x8_t blend_u16 = vcombine_u16(blend_u16_lo, blend_u16_hi); \ + blend_u16 = vminq_u16(blend_u16, vdupq_n_u16((1 << bd) - 1)); \ + \ + return blend_u16; \ + } \ + \ + static INLINE void highbd_##bd##_blend_a64_d16_mask_neon( \ + uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0, \ + uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, \ + const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, \ + int subh) { \ + const int offset_bits = bd + 2 * FILTER_BITS - round0_bits; \ + int32_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + \ + (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); \ + int32x4_t offset = \ + vdupq_n_s32(-(round_offset << AOM_BLEND_A64_ROUND_BITS)); \ + \ + if ((subw | subh) == 0) { \ + if (w >= 8) { \ + do { \ + int i = 0; \ + do { \ + uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); \ + uint16x8_t s0 = vld1q_u16(src0 + i); \ + uint16x8_t s1 = vld1q_u16(src1 + i); \ + \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset); \ + \ + vst1q_u16(dst + i, blend); \ + i += 8; \ + } while (i < w); \ + \ + mask += mask_stride; \ + src0 += src0_stride; \ + src1 += src1_stride; \ + dst += dst_stride; \ + } while (--h != 0); \ + } else { \ + do { \ + uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); \ + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ + \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset); \ + \ + store_u16x4_strided_x2(dst, dst_stride, blend); \ + \ + mask += 2 * mask_stride; \ + src0 += 2 * src0_stride; \ + src1 += 2 * src1_stride; \ + dst += 2 * dst_stride; \ + h -= 2; \ + } while (h != 0); \ + } \ + } else if ((subw & subh) == 1) { \ + if (w >= 8) { \ + do { \ + int i = 0; \ + do { \ + uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i); \ + uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i); \ + uint16x8_t s0 = vld1q_u16(src0 + i); \ + uint16x8_t s1 = vld1q_u16(src1 + i); \ + \ + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4( \ + vget_low_u8(m0), vget_low_u8(m1), vget_high_u8(m0), \ + vget_high_u8(m1))); \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ + \ + vst1q_u16(dst + i, blend); \ + i += 8; \ + } while (i < w); \ + \ + mask += 2 * mask_stride; \ + src0 += src0_stride; \ + src1 += src1_stride; \ + dst += dst_stride; \ + } while (--h != 0); \ + } else { \ + do { \ + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); \ + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); \ + uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); \ + uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); \ + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ + \ + uint16x8_t m_avg = \ + vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ + \ + store_u16x4_strided_x2(dst, dst_stride, blend); \ + \ + mask += 4 * mask_stride; \ + src0 += 2 * src0_stride; \ + src1 += 2 * src1_stride; \ + dst += 2 * dst_stride; \ + h -= 2; \ + } while (h != 0); \ + } \ + } else if (subw == 1 && subh == 0) { \ + if (w >= 8) { \ + do { \ + int i = 0; \ + do { \ + uint8x8_t m0 = vld1_u8(mask + 2 * i); \ + uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); \ + uint16x8_t s0 = vld1q_u16(src0 + i); \ + uint16x8_t s1 = vld1q_u16(src1 + i); \ + \ + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ + \ + vst1q_u16(dst + i, blend); \ + i += 8; \ + } while (i < w); \ + \ + mask += mask_stride; \ + src0 += src0_stride; \ + src1 += src1_stride; \ + dst += dst_stride; \ + } while (--h != 0); \ + } else { \ + do { \ + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); \ + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); \ + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ + \ + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ + \ + store_u16x4_strided_x2(dst, dst_stride, blend); \ + \ + mask += 2 * mask_stride; \ + src0 += 2 * src0_stride; \ + src1 += 2 * src1_stride; \ + dst += 2 * dst_stride; \ + h -= 2; \ + } while (h != 0); \ + } \ + } else { \ + if (w >= 8) { \ + do { \ + int i = 0; \ + do { \ + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); \ + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); \ + uint16x8_t s0 = vld1q_u16(src0 + i); \ + uint16x8_t s1 = vld1q_u16(src1 + i); \ + \ + uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ + \ + vst1q_u16(dst + i, blend); \ + i += 8; \ + } while (i < w); \ + \ + mask += 2 * mask_stride; \ + src0 += src0_stride; \ + src1 += src1_stride; \ + dst += dst_stride; \ + } while (--h != 0); \ + } else { \ + do { \ + uint8x8_t m0_2 = \ + load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); \ + uint8x8_t m1_3 = \ + load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); \ + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \ + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \ + \ + uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); \ + uint16x8_t blend = \ + alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \ + \ + store_u16x4_strided_x2(dst, dst_stride, blend); \ + \ + mask += 4 * mask_stride; \ + src0 += 2 * src0_stride; \ + src1 += 2 * src1_stride; \ + dst += 2 * dst_stride; \ + h -= 2; \ + } while (h != 0); \ + } \ + } \ + } + +// 12 bitdepth +HBD_BLEND_A64_D16_MASK(12, (ROUND0_BITS + 2)) +// 10 bitdepth +HBD_BLEND_A64_D16_MASK(10, ROUND0_BITS) +// 8 bitdepth +HBD_BLEND_A64_D16_MASK(8, ROUND0_BITS) + +void aom_highbd_blend_a64_d16_mask_neon( + uint8_t *dst_8, uint32_t dst_stride, const CONV_BUF_TYPE *src0, + uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, + const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh, + ConvolveParams *conv_params, const int bd) { + (void)conv_params; + assert(h >= 1); + assert(w >= 1); + assert(IS_POWER_OF_TWO(h)); + assert(IS_POWER_OF_TWO(w)); + + uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8); + assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); + assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); + + if (bd == 12) { + highbd_12_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1, + src1_stride, mask, mask_stride, w, h, + subw, subh); + } else if (bd == 10) { + highbd_10_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1, + src1_stride, mask, mask_stride, w, h, + subw, subh); + } else { + highbd_8_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1, + src1_stride, mask, mask_stride, w, h, subw, + subh); + } +} + +void aom_highbd_blend_a64_mask_neon(uint8_t *dst_8, uint32_t dst_stride, + const uint8_t *src0_8, uint32_t src0_stride, + const uint8_t *src1_8, uint32_t src1_stride, + const uint8_t *mask, uint32_t mask_stride, + int w, int h, int subw, int subh, int bd) { + (void)bd; + + const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8); + const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8); + uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8); + + assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); + assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); + + assert(h >= 1); + assert(w >= 1); + assert(IS_POWER_OF_TWO(h)); + assert(IS_POWER_OF_TWO(w)); + + assert(bd == 8 || bd == 10 || bd == 12); + + if ((subw | subh) == 0) { + if (w >= 8) { + do { + int i = 0; + do { + uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1); + + vst1q_u16(dst + i, blend); + i += 8; + } while (i < w); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1); + + store_u16x4_strided_x2(dst, dst_stride, blend); + + mask += 2 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else if ((subw & subh) == 1) { + if (w >= 8) { + do { + int i = 0; + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i); + uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8); + uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t m_avg = + vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); + + uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); + + vst1q_u16(dst + i, blend); + + i += 8; + } while (i < w); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); + uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); + uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); + + store_u16x4_strided_x2(dst, dst_stride, blend); + + mask += 4 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else if (subw == 1 && subh == 0) { + if (w >= 8) { + do { + int i = 0; + + do { + uint8x8_t m0 = vld1_u8(mask + 2 * i); + uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); + uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); + + vst1q_u16(dst + i, blend); + + i += 8; + } while (i < w); + + mask += mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); + uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); + + store_u16x4_strided_x2(dst, dst_stride, blend); + + mask += 2 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } else { + if (w >= 8) { + do { + int i = 0; + do { + uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); + uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); + uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); + + vst1q_u16(dst + i, blend); + + i += 8; + } while (i < w); + + mask += 2 * mask_stride; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else { + do { + uint8x8_t m0_2 = + load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); + uint8x8_t m1_3 = + load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); + uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1); + + store_u16x4_strided_x2(dst, dst_stride, blend); + + mask += 4 * mask_stride; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_blend_a64_vmask_neon.c b/third_party/aom/aom_dsp/arm/highbd_blend_a64_vmask_neon.c new file mode 100644 index 0000000000..1292e20342 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_blend_a64_vmask_neon.c @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/blend.h" + +void aom_highbd_blend_a64_vmask_neon(uint8_t *dst_8, uint32_t dst_stride, + const uint8_t *src0_8, + uint32_t src0_stride, + const uint8_t *src1_8, + uint32_t src1_stride, const uint8_t *mask, + int w, int h, int bd) { + (void)bd; + + const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8); + const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8); + uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8); + + assert(IMPLIES(src0 == dst, src0_stride == dst_stride)); + assert(IMPLIES(src1 == dst, src1_stride == dst_stride)); + + assert(h >= 1); + assert(w >= 1); + assert(IS_POWER_OF_TWO(h)); + assert(IS_POWER_OF_TWO(w)); + + assert(bd == 8 || bd == 10 || bd == 12); + + if (w >= 8) { + do { + uint16x8_t m = vmovl_u8(vdup_n_u8(mask[0])); + int i = 0; + do { + uint16x8_t s0 = vld1q_u16(src0 + i); + uint16x8_t s1 = vld1q_u16(src1 + i); + + uint16x8_t blend = alpha_blend_a64_u16x8(m, s0, s1); + + vst1q_u16(dst + i, blend); + i += 8; + } while (i < w); + + mask += 1; + src0 += src0_stride; + src1 += src1_stride; + dst += dst_stride; + } while (--h != 0); + } else if (w == 4) { + do { + uint16x4_t m1 = vdup_n_u16((uint16_t)mask[0]); + uint16x4_t m2 = vdup_n_u16((uint16_t)mask[1]); + uint16x8_t m = vcombine_u16(m1, m2); + uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); + uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); + + uint16x8_t blend = alpha_blend_a64_u16x8(m, s0, s1); + + store_u16x4_strided_x2(dst, dst_stride, blend); + + mask += 2; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } else if (w == 2 && h >= 8) { + do { + uint16x4_t m0 = vdup_n_u16(0); + m0 = vld1_lane_u16((uint16_t *)mask, m0, 0); + uint8x8_t m0_zip = + vzip_u8(vreinterpret_u8_u16(m0), vreinterpret_u8_u16(m0)).val[0]; + m0 = vget_low_u16(vmovl_u8(m0_zip)); + uint16x4_t s0 = load_unaligned_u16_2x2(src0, src0_stride); + uint16x4_t s1 = load_unaligned_u16_2x2(src1, src1_stride); + + uint16x4_t blend = alpha_blend_a64_u16x4(m0, s0, s1); + + store_u16x2_strided_x2(dst, dst_stride, blend); + + mask += 2; + src0 += 2 * src0_stride; + src1 += 2 * src1_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h != 0); + } else { + aom_highbd_blend_a64_vmask_c(dst_8, dst_stride, src0_8, src0_stride, src1_8, + src1_stride, mask, w, h, bd); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_convolve8_neon.c b/third_party/aom/aom_dsp/arm/highbd_convolve8_neon.c new file mode 100644 index 0000000000..e25438c9b4 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_convolve8_neon.c @@ -0,0 +1,363 @@ +/* + * Copyright (c) 2014 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_ports/mem.h" + +static INLINE int32x4_t highbd_convolve8_4_s32( + const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, + const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, + const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter) { + const int16x4_t y_filter_lo = vget_low_s16(y_filter); + const int16x4_t y_filter_hi = vget_high_s16(y_filter); + + int32x4_t sum = vmull_lane_s16(s0, y_filter_lo, 0); + sum = vmlal_lane_s16(sum, s1, y_filter_lo, 1); + sum = vmlal_lane_s16(sum, s2, y_filter_lo, 2); + sum = vmlal_lane_s16(sum, s3, y_filter_lo, 3); + sum = vmlal_lane_s16(sum, s4, y_filter_hi, 0); + sum = vmlal_lane_s16(sum, s5, y_filter_hi, 1); + sum = vmlal_lane_s16(sum, s6, y_filter_hi, 2); + sum = vmlal_lane_s16(sum, s7, y_filter_hi, 3); + + return sum; +} + +static INLINE uint16x4_t highbd_convolve8_4_s32_s16( + const int16x4_t s0, const int16x4_t s1, const int16x4_t s2, + const int16x4_t s3, const int16x4_t s4, const int16x4_t s5, + const int16x4_t s6, const int16x4_t s7, const int16x8_t y_filter) { + int32x4_t sum = + highbd_convolve8_4_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter); + + return vqrshrun_n_s32(sum, FILTER_BITS); +} + +static INLINE int32x4_t highbd_convolve8_horiz4_s32( + const int16x8_t s0, const int16x8_t s1, const int16x8_t x_filter_0_7) { + const int16x8_t s2 = vextq_s16(s0, s1, 1); + const int16x8_t s3 = vextq_s16(s0, s1, 2); + const int16x8_t s4 = vextq_s16(s0, s1, 3); + const int16x4_t s0_lo = vget_low_s16(s0); + const int16x4_t s1_lo = vget_low_s16(s2); + const int16x4_t s2_lo = vget_low_s16(s3); + const int16x4_t s3_lo = vget_low_s16(s4); + const int16x4_t s4_lo = vget_high_s16(s0); + const int16x4_t s5_lo = vget_high_s16(s2); + const int16x4_t s6_lo = vget_high_s16(s3); + const int16x4_t s7_lo = vget_high_s16(s4); + + return highbd_convolve8_4_s32(s0_lo, s1_lo, s2_lo, s3_lo, s4_lo, s5_lo, s6_lo, + s7_lo, x_filter_0_7); +} + +static INLINE uint16x4_t highbd_convolve8_horiz4_s32_s16( + const int16x8_t s0, const int16x8_t s1, const int16x8_t x_filter_0_7) { + int32x4_t sum = highbd_convolve8_horiz4_s32(s0, s1, x_filter_0_7); + + return vqrshrun_n_s32(sum, FILTER_BITS); +} + +static INLINE void highbd_convolve8_8_s32( + const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, + const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, + const int16x8_t s6, const int16x8_t s7, const int16x8_t y_filter, + int32x4_t *sum0, int32x4_t *sum1) { + const int16x4_t y_filter_lo = vget_low_s16(y_filter); + const int16x4_t y_filter_hi = vget_high_s16(y_filter); + + *sum0 = vmull_lane_s16(vget_low_s16(s0), y_filter_lo, 0); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s1), y_filter_lo, 1); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s2), y_filter_lo, 2); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s3), y_filter_lo, 3); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s4), y_filter_hi, 0); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s5), y_filter_hi, 1); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s6), y_filter_hi, 2); + *sum0 = vmlal_lane_s16(*sum0, vget_low_s16(s7), y_filter_hi, 3); + + *sum1 = vmull_lane_s16(vget_high_s16(s0), y_filter_lo, 0); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s1), y_filter_lo, 1); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s2), y_filter_lo, 2); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s3), y_filter_lo, 3); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s4), y_filter_hi, 0); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s5), y_filter_hi, 1); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s6), y_filter_hi, 2); + *sum1 = vmlal_lane_s16(*sum1, vget_high_s16(s7), y_filter_hi, 3); +} + +static INLINE void highbd_convolve8_horiz8_s32(const int16x8_t s0, + const int16x8_t s0_hi, + const int16x8_t x_filter_0_7, + int32x4_t *sum0, + int32x4_t *sum1) { + const int16x8_t s1 = vextq_s16(s0, s0_hi, 1); + const int16x8_t s2 = vextq_s16(s0, s0_hi, 2); + const int16x8_t s3 = vextq_s16(s0, s0_hi, 3); + const int16x8_t s4 = vextq_s16(s0, s0_hi, 4); + const int16x8_t s5 = vextq_s16(s0, s0_hi, 5); + const int16x8_t s6 = vextq_s16(s0, s0_hi, 6); + const int16x8_t s7 = vextq_s16(s0, s0_hi, 7); + + highbd_convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, x_filter_0_7, sum0, + sum1); +} + +static INLINE uint16x8_t highbd_convolve8_horiz8_s32_s16( + const int16x8_t s0, const int16x8_t s1, const int16x8_t x_filter_0_7) { + int32x4_t sum0, sum1; + highbd_convolve8_horiz8_s32(s0, s1, x_filter_0_7, &sum0, &sum1); + + return vcombine_u16(vqrshrun_n_s32(sum0, FILTER_BITS), + vqrshrun_n_s32(sum1, FILTER_BITS)); +} + +static INLINE uint16x8_t highbd_convolve8_8_s32_s16( + const int16x8_t s0, const int16x8_t s1, const int16x8_t s2, + const int16x8_t s3, const int16x8_t s4, const int16x8_t s5, + const int16x8_t s6, const int16x8_t s7, const int16x8_t y_filter) { + int32x4_t sum0; + int32x4_t sum1; + highbd_convolve8_8_s32(s0, s1, s2, s3, s4, s5, s6, s7, y_filter, &sum0, + &sum1); + + return vcombine_u16(vqrshrun_n_s32(sum0, FILTER_BITS), + vqrshrun_n_s32(sum1, FILTER_BITS)); +} + +static void highbd_convolve_horiz_neon(const uint16_t *src_ptr, + ptrdiff_t src_stride, uint16_t *dst_ptr, + ptrdiff_t dst_stride, + const int16_t *x_filter_ptr, + int x_step_q4, int w, int h, int bd) { + assert(w >= 4 && h >= 4); + const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); + const int16x8_t x_filter = vld1q_s16(x_filter_ptr); + + if (w == 4) { + const int16_t *s = (const int16_t *)src_ptr; + uint16_t *d = dst_ptr; + + do { + int16x8_t s0, s1, s2, s3; + load_s16_8x2(s, src_stride, &s0, &s2); + load_s16_8x2(s + 8, src_stride, &s1, &s3); + + uint16x4_t d0 = highbd_convolve8_horiz4_s32_s16(s0, s1, x_filter); + uint16x4_t d1 = highbd_convolve8_horiz4_s32_s16(s2, s3, x_filter); + + uint16x8_t d01 = vcombine_u16(d0, d1); + d01 = vminq_u16(d01, max); + + vst1_u16(d + 0 * dst_stride, vget_low_u16(d01)); + vst1_u16(d + 1 * dst_stride, vget_high_u16(d01)); + + s += 2 * src_stride; + d += 2 * dst_stride; + h -= 2; + } while (h > 0); + } else { + int height = h; + + do { + int width = w; + const int16_t *s = (const int16_t *)src_ptr; + uint16_t *d = dst_ptr; + int x_q4 = 0; + + const int16_t *src_x = &s[x_q4 >> SUBPEL_BITS]; + int16x8_t s0, s2, s4, s6; + load_s16_8x4(src_x, src_stride, &s0, &s2, &s4, &s6); + src_x += 8; + + do { + int16x8_t s1, s3, s5, s7; + load_s16_8x4(src_x, src_stride, &s1, &s3, &s5, &s7); + + uint16x8_t d0 = highbd_convolve8_horiz8_s32_s16(s0, s1, x_filter); + uint16x8_t d1 = highbd_convolve8_horiz8_s32_s16(s2, s3, x_filter); + uint16x8_t d2 = highbd_convolve8_horiz8_s32_s16(s4, s5, x_filter); + uint16x8_t d3 = highbd_convolve8_horiz8_s32_s16(s6, s7, x_filter); + + d0 = vminq_u16(d0, max); + d1 = vminq_u16(d1, max); + d2 = vminq_u16(d2, max); + d3 = vminq_u16(d3, max); + + store_u16_8x4(d, dst_stride, d0, d1, d2, d3); + + s0 = s1; + s2 = s3; + s4 = s5; + s6 = s7; + src_x += 8; + d += 8; + width -= 8; + x_q4 += 8 * x_step_q4; + } while (width > 0); + src_ptr += 4 * src_stride; + dst_ptr += 4 * dst_stride; + height -= 4; + } while (height > 0); + } +} + +void aom_highbd_convolve8_horiz_neon(const uint8_t *src8, ptrdiff_t src_stride, + uint8_t *dst8, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, + int w, int h, int bd) { + if (x_step_q4 != 16) { + aom_highbd_convolve8_horiz_c(src8, src_stride, dst8, dst_stride, filter_x, + x_step_q4, filter_y, y_step_q4, w, h, bd); + } else { + (void)filter_y; + (void)y_step_q4; + + uint16_t *src = CONVERT_TO_SHORTPTR(src8); + uint16_t *dst = CONVERT_TO_SHORTPTR(dst8); + + src -= SUBPEL_TAPS / 2 - 1; + highbd_convolve_horiz_neon(src, src_stride, dst, dst_stride, filter_x, + x_step_q4, w, h, bd); + } +} + +static void highbd_convolve_vert_neon(const uint16_t *src_ptr, + ptrdiff_t src_stride, uint16_t *dst_ptr, + ptrdiff_t dst_stride, + const int16_t *y_filter_ptr, int w, int h, + int bd) { + assert(w >= 4 && h >= 4); + const int16x8_t y_filter = vld1q_s16(y_filter_ptr); + const uint16x8_t max = vdupq_n_u16((1 << bd) - 1); + + if (w == 4) { + const int16_t *s = (const int16_t *)src_ptr; + uint16_t *d = dst_ptr; + + int16x4_t s0, s1, s2, s3, s4, s5, s6; + load_s16_4x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); + s += 7 * src_stride; + + do { + int16x4_t s7, s8, s9, s10; + load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10); + + uint16x4_t d0 = + highbd_convolve8_4_s32_s16(s0, s1, s2, s3, s4, s5, s6, s7, y_filter); + uint16x4_t d1 = + highbd_convolve8_4_s32_s16(s1, s2, s3, s4, s5, s6, s7, s8, y_filter); + uint16x4_t d2 = + highbd_convolve8_4_s32_s16(s2, s3, s4, s5, s6, s7, s8, s9, y_filter); + uint16x4_t d3 = + highbd_convolve8_4_s32_s16(s3, s4, s5, s6, s7, s8, s9, s10, y_filter); + + uint16x8_t d01 = vcombine_u16(d0, d1); + uint16x8_t d23 = vcombine_u16(d2, d3); + + d01 = vminq_u16(d01, max); + d23 = vminq_u16(d23, max); + + vst1_u16(d + 0 * dst_stride, vget_low_u16(d01)); + vst1_u16(d + 1 * dst_stride, vget_high_u16(d01)); + vst1_u16(d + 2 * dst_stride, vget_low_u16(d23)); + vst1_u16(d + 3 * dst_stride, vget_high_u16(d23)); + + s0 = s4; + s1 = s5; + s2 = s6; + s3 = s7; + s4 = s8; + s5 = s9; + s6 = s10; + s += 4 * src_stride; + d += 4 * dst_stride; + h -= 4; + } while (h > 0); + } else { + do { + int height = h; + const int16_t *s = (const int16_t *)src_ptr; + uint16_t *d = dst_ptr; + + int16x8_t s0, s1, s2, s3, s4, s5, s6; + load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); + s += 7 * src_stride; + + do { + int16x8_t s7, s8, s9, s10; + load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10); + + uint16x8_t d0 = highbd_convolve8_8_s32_s16(s0, s1, s2, s3, s4, s5, s6, + s7, y_filter); + uint16x8_t d1 = highbd_convolve8_8_s32_s16(s1, s2, s3, s4, s5, s6, s7, + s8, y_filter); + uint16x8_t d2 = highbd_convolve8_8_s32_s16(s2, s3, s4, s5, s6, s7, s8, + s9, y_filter); + uint16x8_t d3 = highbd_convolve8_8_s32_s16(s3, s4, s5, s6, s7, s8, s9, + s10, y_filter); + + d0 = vminq_u16(d0, max); + d1 = vminq_u16(d1, max); + d2 = vminq_u16(d2, max); + d3 = vminq_u16(d3, max); + + store_u16_8x4(d, dst_stride, d0, d1, d2, d3); + + s0 = s4; + s1 = s5; + s2 = s6; + s3 = s7; + s4 = s8; + s5 = s9; + s6 = s10; + s += 4 * src_stride; + d += 4 * dst_stride; + height -= 4; + } while (height > 0); + src_ptr += 8; + dst_ptr += 8; + w -= 8; + } while (w > 0); + } +} + +void aom_highbd_convolve8_vert_neon(const uint8_t *src8, ptrdiff_t src_stride, + uint8_t *dst8, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, + int w, int h, int bd) { + if (y_step_q4 != 16) { + aom_highbd_convolve8_vert_c(src8, src_stride, dst8, dst_stride, filter_x, + x_step_q4, filter_y, y_step_q4, w, h, bd); + } else { + (void)filter_x; + (void)x_step_q4; + + uint16_t *src = CONVERT_TO_SHORTPTR(src8); + uint16_t *dst = CONVERT_TO_SHORTPTR(dst8); + + src -= (SUBPEL_TAPS / 2 - 1) * src_stride; + highbd_convolve_vert_neon(src, src_stride, dst, dst_stride, filter_y, w, h, + bd); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_hadamard_neon.c b/third_party/aom/aom_dsp/arm/highbd_hadamard_neon.c new file mode 100644 index 0000000000..d28617c67e --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_hadamard_neon.c @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include "config/aom_dsp_rtcd.h" +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_ports/mem.h" + +static INLINE void hadamard_highbd_col8_first_pass(int16x8_t *a0, int16x8_t *a1, + int16x8_t *a2, int16x8_t *a3, + int16x8_t *a4, int16x8_t *a5, + int16x8_t *a6, + int16x8_t *a7) { + int16x8_t b0 = vaddq_s16(*a0, *a1); + int16x8_t b1 = vsubq_s16(*a0, *a1); + int16x8_t b2 = vaddq_s16(*a2, *a3); + int16x8_t b3 = vsubq_s16(*a2, *a3); + int16x8_t b4 = vaddq_s16(*a4, *a5); + int16x8_t b5 = vsubq_s16(*a4, *a5); + int16x8_t b6 = vaddq_s16(*a6, *a7); + int16x8_t b7 = vsubq_s16(*a6, *a7); + + int16x8_t c0 = vaddq_s16(b0, b2); + int16x8_t c2 = vsubq_s16(b0, b2); + int16x8_t c1 = vaddq_s16(b1, b3); + int16x8_t c3 = vsubq_s16(b1, b3); + int16x8_t c4 = vaddq_s16(b4, b6); + int16x8_t c6 = vsubq_s16(b4, b6); + int16x8_t c5 = vaddq_s16(b5, b7); + int16x8_t c7 = vsubq_s16(b5, b7); + + *a0 = vaddq_s16(c0, c4); + *a2 = vsubq_s16(c0, c4); + *a7 = vaddq_s16(c1, c5); + *a6 = vsubq_s16(c1, c5); + *a3 = vaddq_s16(c2, c6); + *a1 = vsubq_s16(c2, c6); + *a4 = vaddq_s16(c3, c7); + *a5 = vsubq_s16(c3, c7); +} + +static INLINE void hadamard_highbd_col4_second_pass(int16x4_t a0, int16x4_t a1, + int16x4_t a2, int16x4_t a3, + int16x4_t a4, int16x4_t a5, + int16x4_t a6, int16x4_t a7, + tran_low_t *coeff) { + int32x4_t b0 = vaddl_s16(a0, a1); + int32x4_t b1 = vsubl_s16(a0, a1); + int32x4_t b2 = vaddl_s16(a2, a3); + int32x4_t b3 = vsubl_s16(a2, a3); + int32x4_t b4 = vaddl_s16(a4, a5); + int32x4_t b5 = vsubl_s16(a4, a5); + int32x4_t b6 = vaddl_s16(a6, a7); + int32x4_t b7 = vsubl_s16(a6, a7); + + int32x4_t c0 = vaddq_s32(b0, b2); + int32x4_t c2 = vsubq_s32(b0, b2); + int32x4_t c1 = vaddq_s32(b1, b3); + int32x4_t c3 = vsubq_s32(b1, b3); + int32x4_t c4 = vaddq_s32(b4, b6); + int32x4_t c6 = vsubq_s32(b4, b6); + int32x4_t c5 = vaddq_s32(b5, b7); + int32x4_t c7 = vsubq_s32(b5, b7); + + int32x4_t d0 = vaddq_s32(c0, c4); + int32x4_t d2 = vsubq_s32(c0, c4); + int32x4_t d7 = vaddq_s32(c1, c5); + int32x4_t d6 = vsubq_s32(c1, c5); + int32x4_t d3 = vaddq_s32(c2, c6); + int32x4_t d1 = vsubq_s32(c2, c6); + int32x4_t d4 = vaddq_s32(c3, c7); + int32x4_t d5 = vsubq_s32(c3, c7); + + vst1q_s32(coeff + 0, d0); + vst1q_s32(coeff + 4, d1); + vst1q_s32(coeff + 8, d2); + vst1q_s32(coeff + 12, d3); + vst1q_s32(coeff + 16, d4); + vst1q_s32(coeff + 20, d5); + vst1q_s32(coeff + 24, d6); + vst1q_s32(coeff + 28, d7); +} + +void aom_highbd_hadamard_8x8_neon(const int16_t *src_diff, ptrdiff_t src_stride, + tran_low_t *coeff) { + int16x4_t b0, b1, b2, b3, b4, b5, b6, b7; + + int16x8_t s0 = vld1q_s16(src_diff + 0 * src_stride); + int16x8_t s1 = vld1q_s16(src_diff + 1 * src_stride); + int16x8_t s2 = vld1q_s16(src_diff + 2 * src_stride); + int16x8_t s3 = vld1q_s16(src_diff + 3 * src_stride); + int16x8_t s4 = vld1q_s16(src_diff + 4 * src_stride); + int16x8_t s5 = vld1q_s16(src_diff + 5 * src_stride); + int16x8_t s6 = vld1q_s16(src_diff + 6 * src_stride); + int16x8_t s7 = vld1q_s16(src_diff + 7 * src_stride); + + // For the first pass we can stay in 16-bit elements (4095*8 = 32760). + hadamard_highbd_col8_first_pass(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); + + transpose_elems_inplace_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); + + // For the second pass we need to widen to 32-bit elements, so we're + // processing 4 columns at a time. + // Skip the second transpose because it is not required. + + b0 = vget_low_s16(s0); + b1 = vget_low_s16(s1); + b2 = vget_low_s16(s2); + b3 = vget_low_s16(s3); + b4 = vget_low_s16(s4); + b5 = vget_low_s16(s5); + b6 = vget_low_s16(s6); + b7 = vget_low_s16(s7); + + hadamard_highbd_col4_second_pass(b0, b1, b2, b3, b4, b5, b6, b7, coeff); + + b0 = vget_high_s16(s0); + b1 = vget_high_s16(s1); + b2 = vget_high_s16(s2); + b3 = vget_high_s16(s3); + b4 = vget_high_s16(s4); + b5 = vget_high_s16(s5); + b6 = vget_high_s16(s6); + b7 = vget_high_s16(s7); + + hadamard_highbd_col4_second_pass(b0, b1, b2, b3, b4, b5, b6, b7, coeff + 32); +} + +void aom_highbd_hadamard_16x16_neon(const int16_t *src_diff, + ptrdiff_t src_stride, tran_low_t *coeff) { + // Rearrange 16x16 to 8x32 and remove stride. + // Top left first. + aom_highbd_hadamard_8x8_neon(src_diff, src_stride, coeff); + // Top right. + aom_highbd_hadamard_8x8_neon(src_diff + 8, src_stride, coeff + 64); + // Bottom left. + aom_highbd_hadamard_8x8_neon(src_diff + 8 * src_stride, src_stride, + coeff + 128); + // Bottom right. + aom_highbd_hadamard_8x8_neon(src_diff + 8 * src_stride + 8, src_stride, + coeff + 192); + + for (int i = 0; i < 16; i++) { + int32x4_t a0 = vld1q_s32(coeff + 4 * i); + int32x4_t a1 = vld1q_s32(coeff + 4 * i + 64); + int32x4_t a2 = vld1q_s32(coeff + 4 * i + 128); + int32x4_t a3 = vld1q_s32(coeff + 4 * i + 192); + + int32x4_t b0 = vhaddq_s32(a0, a1); + int32x4_t b1 = vhsubq_s32(a0, a1); + int32x4_t b2 = vhaddq_s32(a2, a3); + int32x4_t b3 = vhsubq_s32(a2, a3); + + int32x4_t c0 = vaddq_s32(b0, b2); + int32x4_t c1 = vaddq_s32(b1, b3); + int32x4_t c2 = vsubq_s32(b0, b2); + int32x4_t c3 = vsubq_s32(b1, b3); + + vst1q_s32(coeff + 4 * i, c0); + vst1q_s32(coeff + 4 * i + 64, c1); + vst1q_s32(coeff + 4 * i + 128, c2); + vst1q_s32(coeff + 4 * i + 192, c3); + } +} + +void aom_highbd_hadamard_32x32_neon(const int16_t *src_diff, + ptrdiff_t src_stride, tran_low_t *coeff) { + // Rearrange 32x32 to 16x64 and remove stride. + // Top left first. + aom_highbd_hadamard_16x16_neon(src_diff, src_stride, coeff); + // Top right. + aom_highbd_hadamard_16x16_neon(src_diff + 16, src_stride, coeff + 256); + // Bottom left. + aom_highbd_hadamard_16x16_neon(src_diff + 16 * src_stride, src_stride, + coeff + 512); + // Bottom right. + aom_highbd_hadamard_16x16_neon(src_diff + 16 * src_stride + 16, src_stride, + coeff + 768); + + for (int i = 0; i < 64; i++) { + int32x4_t a0 = vld1q_s32(coeff + 4 * i); + int32x4_t a1 = vld1q_s32(coeff + 4 * i + 256); + int32x4_t a2 = vld1q_s32(coeff + 4 * i + 512); + int32x4_t a3 = vld1q_s32(coeff + 4 * i + 768); + + int32x4_t b0 = vshrq_n_s32(vaddq_s32(a0, a1), 2); + int32x4_t b1 = vshrq_n_s32(vsubq_s32(a0, a1), 2); + int32x4_t b2 = vshrq_n_s32(vaddq_s32(a2, a3), 2); + int32x4_t b3 = vshrq_n_s32(vsubq_s32(a2, a3), 2); + + int32x4_t c0 = vaddq_s32(b0, b2); + int32x4_t c1 = vaddq_s32(b1, b3); + int32x4_t c2 = vsubq_s32(b0, b2); + int32x4_t c3 = vsubq_s32(b1, b3); + + vst1q_s32(coeff + 4 * i, c0); + vst1q_s32(coeff + 4 * i + 256, c1); + vst1q_s32(coeff + 4 * i + 512, c2); + vst1q_s32(coeff + 4 * i + 768, c3); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_intrapred_neon.c b/third_party/aom/aom_dsp/arm/highbd_intrapred_neon.c new file mode 100644 index 0000000000..dc47974c68 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_intrapred_neon.c @@ -0,0 +1,2730 @@ +/* + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "config/av1_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_dsp/intrapred_common.h" + +// ----------------------------------------------------------------------------- +// DC + +static INLINE void highbd_dc_store_4xh(uint16_t *dst, ptrdiff_t stride, int h, + uint16x4_t dc) { + for (int i = 0; i < h; ++i) { + vst1_u16(dst + i * stride, dc); + } +} + +static INLINE void highbd_dc_store_8xh(uint16_t *dst, ptrdiff_t stride, int h, + uint16x8_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u16(dst + i * stride, dc); + } +} + +static INLINE void highbd_dc_store_16xh(uint16_t *dst, ptrdiff_t stride, int h, + uint16x8_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u16(dst + i * stride, dc); + vst1q_u16(dst + i * stride + 8, dc); + } +} + +static INLINE void highbd_dc_store_32xh(uint16_t *dst, ptrdiff_t stride, int h, + uint16x8_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u16(dst + i * stride, dc); + vst1q_u16(dst + i * stride + 8, dc); + vst1q_u16(dst + i * stride + 16, dc); + vst1q_u16(dst + i * stride + 24, dc); + } +} + +static INLINE void highbd_dc_store_64xh(uint16_t *dst, ptrdiff_t stride, int h, + uint16x8_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u16(dst + i * stride, dc); + vst1q_u16(dst + i * stride + 8, dc); + vst1q_u16(dst + i * stride + 16, dc); + vst1q_u16(dst + i * stride + 24, dc); + vst1q_u16(dst + i * stride + 32, dc); + vst1q_u16(dst + i * stride + 40, dc); + vst1q_u16(dst + i * stride + 48, dc); + vst1q_u16(dst + i * stride + 56, dc); + } +} + +static INLINE uint32x4_t horizontal_add_and_broadcast_long_u16x8(uint16x8_t a) { + // Need to assume input is up to 16 bits wide from dc 64x64 partial sum, so + // promote first. + const uint32x4_t b = vpaddlq_u16(a); +#if AOM_ARCH_AARCH64 + const uint32x4_t c = vpaddq_u32(b, b); + return vpaddq_u32(c, c); +#else + const uint32x2_t c = vadd_u32(vget_low_u32(b), vget_high_u32(b)); + const uint32x2_t d = vpadd_u32(c, c); + return vcombine_u32(d, d); +#endif +} + +static INLINE uint16x8_t highbd_dc_load_partial_sum_4(const uint16_t *left) { + // Nothing to do since sum is already one vector, but saves needing to + // special case w=4 or h=4 cases. The combine will be zero cost for a sane + // compiler since vld1 already sets the top half of a vector to zero as part + // of the operation. + return vcombine_u16(vld1_u16(left), vdup_n_u16(0)); +} + +static INLINE uint16x8_t highbd_dc_load_partial_sum_8(const uint16_t *left) { + // Nothing to do since sum is already one vector, but saves needing to + // special case w=8 or h=8 cases. + return vld1q_u16(left); +} + +static INLINE uint16x8_t highbd_dc_load_partial_sum_16(const uint16_t *left) { + const uint16x8_t a0 = vld1q_u16(left + 0); // up to 12 bits + const uint16x8_t a1 = vld1q_u16(left + 8); + return vaddq_u16(a0, a1); // up to 13 bits +} + +static INLINE uint16x8_t highbd_dc_load_partial_sum_32(const uint16_t *left) { + const uint16x8_t a0 = vld1q_u16(left + 0); // up to 12 bits + const uint16x8_t a1 = vld1q_u16(left + 8); + const uint16x8_t a2 = vld1q_u16(left + 16); + const uint16x8_t a3 = vld1q_u16(left + 24); + const uint16x8_t b0 = vaddq_u16(a0, a1); // up to 13 bits + const uint16x8_t b1 = vaddq_u16(a2, a3); + return vaddq_u16(b0, b1); // up to 14 bits +} + +static INLINE uint16x8_t highbd_dc_load_partial_sum_64(const uint16_t *left) { + const uint16x8_t a0 = vld1q_u16(left + 0); // up to 12 bits + const uint16x8_t a1 = vld1q_u16(left + 8); + const uint16x8_t a2 = vld1q_u16(left + 16); + const uint16x8_t a3 = vld1q_u16(left + 24); + const uint16x8_t a4 = vld1q_u16(left + 32); + const uint16x8_t a5 = vld1q_u16(left + 40); + const uint16x8_t a6 = vld1q_u16(left + 48); + const uint16x8_t a7 = vld1q_u16(left + 56); + const uint16x8_t b0 = vaddq_u16(a0, a1); // up to 13 bits + const uint16x8_t b1 = vaddq_u16(a2, a3); + const uint16x8_t b2 = vaddq_u16(a4, a5); + const uint16x8_t b3 = vaddq_u16(a6, a7); + const uint16x8_t c0 = vaddq_u16(b0, b1); // up to 14 bits + const uint16x8_t c1 = vaddq_u16(b2, b3); + return vaddq_u16(c0, c1); // up to 15 bits +} + +#define HIGHBD_DC_PREDICTOR(w, h, shift) \ + void aom_highbd_dc_predictor_##w##x##h##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + const uint16x8_t a = highbd_dc_load_partial_sum_##w(above); \ + const uint16x8_t l = highbd_dc_load_partial_sum_##h(left); \ + const uint32x4_t sum = \ + horizontal_add_and_broadcast_long_u16x8(vaddq_u16(a, l)); \ + const uint16x4_t dc0 = vrshrn_n_u32(sum, shift); \ + highbd_dc_store_##w##xh(dst, stride, (h), vdupq_lane_u16(dc0, 0)); \ + } + +void aom_highbd_dc_predictor_4x4_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + // In the rectangular cases we simply extend the shorter vector to uint16x8 + // in order to accumulate, however in the 4x4 case there is no shorter vector + // to extend so it is beneficial to do the whole calculation in uint16x4 + // instead. + (void)bd; + const uint16x4_t a = vld1_u16(above); // up to 12 bits + const uint16x4_t l = vld1_u16(left); + uint16x4_t sum = vpadd_u16(a, l); // up to 13 bits + sum = vpadd_u16(sum, sum); // up to 14 bits + sum = vpadd_u16(sum, sum); + const uint16x4_t dc = vrshr_n_u16(sum, 3); + highbd_dc_store_4xh(dst, stride, 4, dc); +} + +HIGHBD_DC_PREDICTOR(8, 8, 4) +HIGHBD_DC_PREDICTOR(16, 16, 5) +HIGHBD_DC_PREDICTOR(32, 32, 6) +HIGHBD_DC_PREDICTOR(64, 64, 7) + +#undef HIGHBD_DC_PREDICTOR + +static INLINE int divide_using_multiply_shift(int num, int shift1, + int multiplier, int shift2) { + const int interm = num >> shift1; + return interm * multiplier >> shift2; +} + +#define HIGHBD_DC_MULTIPLIER_1X2 0xAAAB +#define HIGHBD_DC_MULTIPLIER_1X4 0x6667 +#define HIGHBD_DC_SHIFT2 17 + +static INLINE int highbd_dc_predictor_rect(int bw, int bh, int sum, int shift1, + uint32_t multiplier) { + return divide_using_multiply_shift(sum + ((bw + bh) >> 1), shift1, multiplier, + HIGHBD_DC_SHIFT2); +} + +#undef HIGHBD_DC_SHIFT2 + +#define HIGHBD_DC_PREDICTOR_RECT(w, h, q, shift, mult) \ + void aom_highbd_dc_predictor_##w##x##h##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + uint16x8_t sum_above = highbd_dc_load_partial_sum_##w(above); \ + uint16x8_t sum_left = highbd_dc_load_partial_sum_##h(left); \ + uint16x8_t sum_vec = vaddq_u16(sum_left, sum_above); \ + int sum = horizontal_add_u16x8(sum_vec); \ + int dc0 = highbd_dc_predictor_rect((w), (h), sum, (shift), (mult)); \ + highbd_dc_store_##w##xh(dst, stride, (h), vdup##q##_n_u16(dc0)); \ + } + +HIGHBD_DC_PREDICTOR_RECT(4, 8, , 2, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(4, 16, , 2, HIGHBD_DC_MULTIPLIER_1X4) +HIGHBD_DC_PREDICTOR_RECT(8, 4, q, 2, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(8, 16, q, 3, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(8, 32, q, 3, HIGHBD_DC_MULTIPLIER_1X4) +HIGHBD_DC_PREDICTOR_RECT(16, 4, q, 2, HIGHBD_DC_MULTIPLIER_1X4) +HIGHBD_DC_PREDICTOR_RECT(16, 8, q, 3, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(16, 32, q, 4, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(16, 64, q, 4, HIGHBD_DC_MULTIPLIER_1X4) +HIGHBD_DC_PREDICTOR_RECT(32, 8, q, 3, HIGHBD_DC_MULTIPLIER_1X4) +HIGHBD_DC_PREDICTOR_RECT(32, 16, q, 4, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(32, 64, q, 5, HIGHBD_DC_MULTIPLIER_1X2) +HIGHBD_DC_PREDICTOR_RECT(64, 16, q, 4, HIGHBD_DC_MULTIPLIER_1X4) +HIGHBD_DC_PREDICTOR_RECT(64, 32, q, 5, HIGHBD_DC_MULTIPLIER_1X2) + +#undef HIGHBD_DC_PREDICTOR_RECT +#undef HIGHBD_DC_MULTIPLIER_1X2 +#undef HIGHBD_DC_MULTIPLIER_1X4 + +// ----------------------------------------------------------------------------- +// DC_128 + +#define HIGHBD_DC_PREDICTOR_128(w, h, q) \ + void aom_highbd_dc_128_predictor_##w##x##h##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)above; \ + (void)bd; \ + (void)left; \ + highbd_dc_store_##w##xh(dst, stride, (h), \ + vdup##q##_n_u16(0x80 << (bd - 8))); \ + } + +HIGHBD_DC_PREDICTOR_128(4, 4, ) +HIGHBD_DC_PREDICTOR_128(4, 8, ) +HIGHBD_DC_PREDICTOR_128(4, 16, ) +HIGHBD_DC_PREDICTOR_128(8, 4, q) +HIGHBD_DC_PREDICTOR_128(8, 8, q) +HIGHBD_DC_PREDICTOR_128(8, 16, q) +HIGHBD_DC_PREDICTOR_128(8, 32, q) +HIGHBD_DC_PREDICTOR_128(16, 4, q) +HIGHBD_DC_PREDICTOR_128(16, 8, q) +HIGHBD_DC_PREDICTOR_128(16, 16, q) +HIGHBD_DC_PREDICTOR_128(16, 32, q) +HIGHBD_DC_PREDICTOR_128(16, 64, q) +HIGHBD_DC_PREDICTOR_128(32, 8, q) +HIGHBD_DC_PREDICTOR_128(32, 16, q) +HIGHBD_DC_PREDICTOR_128(32, 32, q) +HIGHBD_DC_PREDICTOR_128(32, 64, q) +HIGHBD_DC_PREDICTOR_128(64, 16, q) +HIGHBD_DC_PREDICTOR_128(64, 32, q) +HIGHBD_DC_PREDICTOR_128(64, 64, q) + +#undef HIGHBD_DC_PREDICTOR_128 + +// ----------------------------------------------------------------------------- +// DC_LEFT + +static INLINE uint32x4_t highbd_dc_load_sum_4(const uint16_t *left) { + const uint16x4_t a = vld1_u16(left); // up to 12 bits + const uint16x4_t b = vpadd_u16(a, a); // up to 13 bits + return vcombine_u32(vpaddl_u16(b), vdup_n_u32(0)); +} + +static INLINE uint32x4_t highbd_dc_load_sum_8(const uint16_t *left) { + return horizontal_add_and_broadcast_long_u16x8(vld1q_u16(left)); +} + +static INLINE uint32x4_t highbd_dc_load_sum_16(const uint16_t *left) { + return horizontal_add_and_broadcast_long_u16x8( + highbd_dc_load_partial_sum_16(left)); +} + +static INLINE uint32x4_t highbd_dc_load_sum_32(const uint16_t *left) { + return horizontal_add_and_broadcast_long_u16x8( + highbd_dc_load_partial_sum_32(left)); +} + +static INLINE uint32x4_t highbd_dc_load_sum_64(const uint16_t *left) { + return horizontal_add_and_broadcast_long_u16x8( + highbd_dc_load_partial_sum_64(left)); +} + +#define DC_PREDICTOR_LEFT(w, h, shift, q) \ + void aom_highbd_dc_left_predictor_##w##x##h##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)above; \ + (void)bd; \ + const uint32x4_t sum = highbd_dc_load_sum_##h(left); \ + const uint16x4_t dc0 = vrshrn_n_u32(sum, (shift)); \ + highbd_dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u16(dc0, 0)); \ + } + +DC_PREDICTOR_LEFT(4, 4, 2, ) +DC_PREDICTOR_LEFT(4, 8, 3, ) +DC_PREDICTOR_LEFT(4, 16, 4, ) +DC_PREDICTOR_LEFT(8, 4, 2, q) +DC_PREDICTOR_LEFT(8, 8, 3, q) +DC_PREDICTOR_LEFT(8, 16, 4, q) +DC_PREDICTOR_LEFT(8, 32, 5, q) +DC_PREDICTOR_LEFT(16, 4, 2, q) +DC_PREDICTOR_LEFT(16, 8, 3, q) +DC_PREDICTOR_LEFT(16, 16, 4, q) +DC_PREDICTOR_LEFT(16, 32, 5, q) +DC_PREDICTOR_LEFT(16, 64, 6, q) +DC_PREDICTOR_LEFT(32, 8, 3, q) +DC_PREDICTOR_LEFT(32, 16, 4, q) +DC_PREDICTOR_LEFT(32, 32, 5, q) +DC_PREDICTOR_LEFT(32, 64, 6, q) +DC_PREDICTOR_LEFT(64, 16, 4, q) +DC_PREDICTOR_LEFT(64, 32, 5, q) +DC_PREDICTOR_LEFT(64, 64, 6, q) + +#undef DC_PREDICTOR_LEFT + +// ----------------------------------------------------------------------------- +// DC_TOP + +#define DC_PREDICTOR_TOP(w, h, shift, q) \ + void aom_highbd_dc_top_predictor_##w##x##h##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + (void)left; \ + const uint32x4_t sum = highbd_dc_load_sum_##w(above); \ + const uint16x4_t dc0 = vrshrn_n_u32(sum, (shift)); \ + highbd_dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u16(dc0, 0)); \ + } + +DC_PREDICTOR_TOP(4, 4, 2, ) +DC_PREDICTOR_TOP(4, 8, 2, ) +DC_PREDICTOR_TOP(4, 16, 2, ) +DC_PREDICTOR_TOP(8, 4, 3, q) +DC_PREDICTOR_TOP(8, 8, 3, q) +DC_PREDICTOR_TOP(8, 16, 3, q) +DC_PREDICTOR_TOP(8, 32, 3, q) +DC_PREDICTOR_TOP(16, 4, 4, q) +DC_PREDICTOR_TOP(16, 8, 4, q) +DC_PREDICTOR_TOP(16, 16, 4, q) +DC_PREDICTOR_TOP(16, 32, 4, q) +DC_PREDICTOR_TOP(16, 64, 4, q) +DC_PREDICTOR_TOP(32, 8, 5, q) +DC_PREDICTOR_TOP(32, 16, 5, q) +DC_PREDICTOR_TOP(32, 32, 5, q) +DC_PREDICTOR_TOP(32, 64, 5, q) +DC_PREDICTOR_TOP(64, 16, 6, q) +DC_PREDICTOR_TOP(64, 32, 6, q) +DC_PREDICTOR_TOP(64, 64, 6, q) + +#undef DC_PREDICTOR_TOP + +// ----------------------------------------------------------------------------- +// V_PRED + +#define HIGHBD_V_NXM(W, H) \ + void aom_highbd_v_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)left; \ + (void)bd; \ + vertical##W##xh_neon(dst, stride, above, H); \ + } + +static INLINE uint16x8x2_t load_uint16x8x2(uint16_t const *ptr) { + uint16x8x2_t x; + // Clang/gcc uses ldp here. + x.val[0] = vld1q_u16(ptr); + x.val[1] = vld1q_u16(ptr + 8); + return x; +} + +static INLINE void store_uint16x8x2(uint16_t *ptr, uint16x8x2_t x) { + vst1q_u16(ptr, x.val[0]); + vst1q_u16(ptr + 8, x.val[1]); +} + +static INLINE void vertical4xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const above, int height) { + const uint16x4_t row = vld1_u16(above); + int y = height; + do { + vst1_u16(dst, row); + vst1_u16(dst + stride, row); + dst += stride << 1; + y -= 2; + } while (y != 0); +} + +static INLINE void vertical8xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const above, int height) { + const uint16x8_t row = vld1q_u16(above); + int y = height; + do { + vst1q_u16(dst, row); + vst1q_u16(dst + stride, row); + dst += stride << 1; + y -= 2; + } while (y != 0); +} + +static INLINE void vertical16xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const above, int height) { + const uint16x8x2_t row = load_uint16x8x2(above); + int y = height; + do { + store_uint16x8x2(dst, row); + store_uint16x8x2(dst + stride, row); + dst += stride << 1; + y -= 2; + } while (y != 0); +} + +static INLINE uint16x8x4_t load_uint16x8x4(uint16_t const *ptr) { + uint16x8x4_t x; + // Clang/gcc uses ldp here. + x.val[0] = vld1q_u16(ptr); + x.val[1] = vld1q_u16(ptr + 8); + x.val[2] = vld1q_u16(ptr + 16); + x.val[3] = vld1q_u16(ptr + 24); + return x; +} + +static INLINE void store_uint16x8x4(uint16_t *ptr, uint16x8x4_t x) { + vst1q_u16(ptr, x.val[0]); + vst1q_u16(ptr + 8, x.val[1]); + vst1q_u16(ptr + 16, x.val[2]); + vst1q_u16(ptr + 24, x.val[3]); +} + +static INLINE void vertical32xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const above, int height) { + const uint16x8x4_t row = load_uint16x8x4(above); + int y = height; + do { + store_uint16x8x4(dst, row); + store_uint16x8x4(dst + stride, row); + dst += stride << 1; + y -= 2; + } while (y != 0); +} + +static INLINE void vertical64xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const above, int height) { + uint16_t *dst32 = dst + 32; + const uint16x8x4_t row = load_uint16x8x4(above); + const uint16x8x4_t row32 = load_uint16x8x4(above + 32); + int y = height; + do { + store_uint16x8x4(dst, row); + store_uint16x8x4(dst32, row32); + store_uint16x8x4(dst + stride, row); + store_uint16x8x4(dst32 + stride, row32); + dst += stride << 1; + dst32 += stride << 1; + y -= 2; + } while (y != 0); +} + +HIGHBD_V_NXM(4, 4) +HIGHBD_V_NXM(4, 8) +HIGHBD_V_NXM(4, 16) + +HIGHBD_V_NXM(8, 4) +HIGHBD_V_NXM(8, 8) +HIGHBD_V_NXM(8, 16) +HIGHBD_V_NXM(8, 32) + +HIGHBD_V_NXM(16, 4) +HIGHBD_V_NXM(16, 8) +HIGHBD_V_NXM(16, 16) +HIGHBD_V_NXM(16, 32) +HIGHBD_V_NXM(16, 64) + +HIGHBD_V_NXM(32, 8) +HIGHBD_V_NXM(32, 16) +HIGHBD_V_NXM(32, 32) +HIGHBD_V_NXM(32, 64) + +HIGHBD_V_NXM(64, 16) +HIGHBD_V_NXM(64, 32) +HIGHBD_V_NXM(64, 64) + +// ----------------------------------------------------------------------------- +// H_PRED + +static INLINE void highbd_h_store_4x4(uint16_t *dst, ptrdiff_t stride, + uint16x4_t left) { + vst1_u16(dst + 0 * stride, vdup_lane_u16(left, 0)); + vst1_u16(dst + 1 * stride, vdup_lane_u16(left, 1)); + vst1_u16(dst + 2 * stride, vdup_lane_u16(left, 2)); + vst1_u16(dst + 3 * stride, vdup_lane_u16(left, 3)); +} + +static INLINE void highbd_h_store_8x4(uint16_t *dst, ptrdiff_t stride, + uint16x4_t left) { + vst1q_u16(dst + 0 * stride, vdupq_lane_u16(left, 0)); + vst1q_u16(dst + 1 * stride, vdupq_lane_u16(left, 1)); + vst1q_u16(dst + 2 * stride, vdupq_lane_u16(left, 2)); + vst1q_u16(dst + 3 * stride, vdupq_lane_u16(left, 3)); +} + +static INLINE void highbd_h_store_16x1(uint16_t *dst, uint16x8_t left) { + vst1q_u16(dst + 0, left); + vst1q_u16(dst + 8, left); +} + +static INLINE void highbd_h_store_16x4(uint16_t *dst, ptrdiff_t stride, + uint16x4_t left) { + highbd_h_store_16x1(dst + 0 * stride, vdupq_lane_u16(left, 0)); + highbd_h_store_16x1(dst + 1 * stride, vdupq_lane_u16(left, 1)); + highbd_h_store_16x1(dst + 2 * stride, vdupq_lane_u16(left, 2)); + highbd_h_store_16x1(dst + 3 * stride, vdupq_lane_u16(left, 3)); +} + +static INLINE void highbd_h_store_32x1(uint16_t *dst, uint16x8_t left) { + vst1q_u16(dst + 0, left); + vst1q_u16(dst + 8, left); + vst1q_u16(dst + 16, left); + vst1q_u16(dst + 24, left); +} + +static INLINE void highbd_h_store_32x4(uint16_t *dst, ptrdiff_t stride, + uint16x4_t left) { + highbd_h_store_32x1(dst + 0 * stride, vdupq_lane_u16(left, 0)); + highbd_h_store_32x1(dst + 1 * stride, vdupq_lane_u16(left, 1)); + highbd_h_store_32x1(dst + 2 * stride, vdupq_lane_u16(left, 2)); + highbd_h_store_32x1(dst + 3 * stride, vdupq_lane_u16(left, 3)); +} + +static INLINE void highbd_h_store_64x1(uint16_t *dst, uint16x8_t left) { + vst1q_u16(dst + 0, left); + vst1q_u16(dst + 8, left); + vst1q_u16(dst + 16, left); + vst1q_u16(dst + 24, left); + vst1q_u16(dst + 32, left); + vst1q_u16(dst + 40, left); + vst1q_u16(dst + 48, left); + vst1q_u16(dst + 56, left); +} + +static INLINE void highbd_h_store_64x4(uint16_t *dst, ptrdiff_t stride, + uint16x4_t left) { + highbd_h_store_64x1(dst + 0 * stride, vdupq_lane_u16(left, 0)); + highbd_h_store_64x1(dst + 1 * stride, vdupq_lane_u16(left, 1)); + highbd_h_store_64x1(dst + 2 * stride, vdupq_lane_u16(left, 2)); + highbd_h_store_64x1(dst + 3 * stride, vdupq_lane_u16(left, 3)); +} + +void aom_highbd_h_predictor_4x4_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + highbd_h_store_4x4(dst, stride, vld1_u16(left)); +} + +void aom_highbd_h_predictor_4x8_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + uint16x8_t l = vld1q_u16(left); + highbd_h_store_4x4(dst + 0 * stride, stride, vget_low_u16(l)); + highbd_h_store_4x4(dst + 4 * stride, stride, vget_high_u16(l)); +} + +void aom_highbd_h_predictor_8x4_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + highbd_h_store_8x4(dst, stride, vld1_u16(left)); +} + +void aom_highbd_h_predictor_8x8_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + uint16x8_t l = vld1q_u16(left); + highbd_h_store_8x4(dst + 0 * stride, stride, vget_low_u16(l)); + highbd_h_store_8x4(dst + 4 * stride, stride, vget_high_u16(l)); +} + +void aom_highbd_h_predictor_16x4_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + highbd_h_store_16x4(dst, stride, vld1_u16(left)); +} + +void aom_highbd_h_predictor_16x8_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + uint16x8_t l = vld1q_u16(left); + highbd_h_store_16x4(dst + 0 * stride, stride, vget_low_u16(l)); + highbd_h_store_16x4(dst + 4 * stride, stride, vget_high_u16(l)); +} + +void aom_highbd_h_predictor_32x8_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, int bd) { + (void)above; + (void)bd; + uint16x8_t l = vld1q_u16(left); + highbd_h_store_32x4(dst + 0 * stride, stride, vget_low_u16(l)); + highbd_h_store_32x4(dst + 4 * stride, stride, vget_high_u16(l)); +} + +// For cases where height >= 16 we use pairs of loads to get LDP instructions. +#define HIGHBD_H_WXH_LARGE(w, h) \ + void aom_highbd_h_predictor_##w##x##h##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)above; \ + (void)bd; \ + for (int i = 0; i < (h) / 16; ++i) { \ + uint16x8_t l0 = vld1q_u16(left + 0); \ + uint16x8_t l1 = vld1q_u16(left + 8); \ + highbd_h_store_##w##x4(dst + 0 * stride, stride, vget_low_u16(l0)); \ + highbd_h_store_##w##x4(dst + 4 * stride, stride, vget_high_u16(l0)); \ + highbd_h_store_##w##x4(dst + 8 * stride, stride, vget_low_u16(l1)); \ + highbd_h_store_##w##x4(dst + 12 * stride, stride, vget_high_u16(l1)); \ + left += 16; \ + dst += 16 * stride; \ + } \ + } + +HIGHBD_H_WXH_LARGE(4, 16) +HIGHBD_H_WXH_LARGE(8, 16) +HIGHBD_H_WXH_LARGE(8, 32) +HIGHBD_H_WXH_LARGE(16, 16) +HIGHBD_H_WXH_LARGE(16, 32) +HIGHBD_H_WXH_LARGE(16, 64) +HIGHBD_H_WXH_LARGE(32, 16) +HIGHBD_H_WXH_LARGE(32, 32) +HIGHBD_H_WXH_LARGE(32, 64) +HIGHBD_H_WXH_LARGE(64, 16) +HIGHBD_H_WXH_LARGE(64, 32) +HIGHBD_H_WXH_LARGE(64, 64) + +#undef HIGHBD_H_WXH_LARGE + +// ----------------------------------------------------------------------------- +// PAETH + +static INLINE void highbd_paeth_4or8_x_h_neon(uint16_t *dest, ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + int width, int height) { + const uint16x8_t top_left = vdupq_n_u16(top_row[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]); + uint16x8_t top; + if (width == 4) { + top = vcombine_u16(vld1_u16(top_row), vdup_n_u16(0)); + } else { // width == 8 + top = vld1q_u16(top_row); + } + + for (int y = 0; y < height; ++y) { + const uint16x8_t left = vdupq_n_u16(left_column[y]); + + const uint16x8_t left_dist = vabdq_u16(top, top_left); + const uint16x8_t top_dist = vabdq_u16(left, top_left); + const uint16x8_t top_left_dist = + vabdq_u16(vaddq_u16(top, left), top_left_x2); + + const uint16x8_t left_le_top = vcleq_u16(left_dist, top_dist); + const uint16x8_t left_le_top_left = vcleq_u16(left_dist, top_left_dist); + const uint16x8_t top_le_top_left = vcleq_u16(top_dist, top_left_dist); + + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint16x8_t left_mask = vandq_u16(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint16x8_t result = vbslq_u16(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint16x8_t left_or_top_mask = vorrq_u16(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + result = vbslq_u16(left_or_top_mask, result, top_left); + + if (width == 4) { + vst1_u16(dest, vget_low_u16(result)); + } else { // width == 8 + vst1q_u16(dest, result); + } + dest += stride; + } +} + +#define HIGHBD_PAETH_NXM(W, H) \ + void aom_highbd_paeth_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_paeth_4or8_x_h_neon(dst, stride, above, left, W, H); \ + } + +HIGHBD_PAETH_NXM(4, 4) +HIGHBD_PAETH_NXM(4, 8) +HIGHBD_PAETH_NXM(4, 16) +HIGHBD_PAETH_NXM(8, 4) +HIGHBD_PAETH_NXM(8, 8) +HIGHBD_PAETH_NXM(8, 16) +HIGHBD_PAETH_NXM(8, 32) + +// Select the closest values and collect them. +static INLINE uint16x8_t select_paeth(const uint16x8_t top, + const uint16x8_t left, + const uint16x8_t top_left, + const uint16x8_t left_le_top, + const uint16x8_t left_le_top_left, + const uint16x8_t top_le_top_left) { + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint16x8_t left_mask = vandq_u16(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + const uint16x8_t result = vbslq_u16(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint16x8_t left_or_top_mask = vorrq_u16(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + return vbslq_u16(left_or_top_mask, result, top_left); +} + +#define PAETH_PREDICTOR(num) \ + do { \ + const uint16x8_t left_dist = vabdq_u16(top[num], top_left); \ + const uint16x8_t top_left_dist = \ + vabdq_u16(vaddq_u16(top[num], left), top_left_x2); \ + const uint16x8_t left_le_top = vcleq_u16(left_dist, top_dist); \ + const uint16x8_t left_le_top_left = vcleq_u16(left_dist, top_left_dist); \ + const uint16x8_t top_le_top_left = vcleq_u16(top_dist, top_left_dist); \ + const uint16x8_t result = \ + select_paeth(top[num], left, top_left, left_le_top, left_le_top_left, \ + top_le_top_left); \ + vst1q_u16(dest + (num * 8), result); \ + } while (0) + +#define LOAD_TOP_ROW(num) vld1q_u16(top_row + (num * 8)) + +static INLINE void highbd_paeth16_plus_x_h_neon( + uint16_t *dest, ptrdiff_t stride, const uint16_t *const top_row, + const uint16_t *const left_column, int width, int height) { + const uint16x8_t top_left = vdupq_n_u16(top_row[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]); + uint16x8_t top[8]; + top[0] = LOAD_TOP_ROW(0); + top[1] = LOAD_TOP_ROW(1); + if (width > 16) { + top[2] = LOAD_TOP_ROW(2); + top[3] = LOAD_TOP_ROW(3); + if (width == 64) { + top[4] = LOAD_TOP_ROW(4); + top[5] = LOAD_TOP_ROW(5); + top[6] = LOAD_TOP_ROW(6); + top[7] = LOAD_TOP_ROW(7); + } + } + + for (int y = 0; y < height; ++y) { + const uint16x8_t left = vdupq_n_u16(left_column[y]); + const uint16x8_t top_dist = vabdq_u16(left, top_left); + PAETH_PREDICTOR(0); + PAETH_PREDICTOR(1); + if (width > 16) { + PAETH_PREDICTOR(2); + PAETH_PREDICTOR(3); + if (width == 64) { + PAETH_PREDICTOR(4); + PAETH_PREDICTOR(5); + PAETH_PREDICTOR(6); + PAETH_PREDICTOR(7); + } + } + dest += stride; + } +} + +#define HIGHBD_PAETH_NXM_WIDE(W, H) \ + void aom_highbd_paeth_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_paeth16_plus_x_h_neon(dst, stride, above, left, W, H); \ + } + +HIGHBD_PAETH_NXM_WIDE(16, 4) +HIGHBD_PAETH_NXM_WIDE(16, 8) +HIGHBD_PAETH_NXM_WIDE(16, 16) +HIGHBD_PAETH_NXM_WIDE(16, 32) +HIGHBD_PAETH_NXM_WIDE(16, 64) +HIGHBD_PAETH_NXM_WIDE(32, 8) +HIGHBD_PAETH_NXM_WIDE(32, 16) +HIGHBD_PAETH_NXM_WIDE(32, 32) +HIGHBD_PAETH_NXM_WIDE(32, 64) +HIGHBD_PAETH_NXM_WIDE(64, 16) +HIGHBD_PAETH_NXM_WIDE(64, 32) +HIGHBD_PAETH_NXM_WIDE(64, 64) + +// ----------------------------------------------------------------------------- +// SMOOTH + +// 256 - v = vneg_s8(v) +static INLINE uint16x4_t negate_s8(const uint16x4_t v) { + return vreinterpret_u16_s8(vneg_s8(vreinterpret_s8_u16(v))); +} + +static INLINE void highbd_smooth_4xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + const int height) { + const uint16_t top_right = top_row[3]; + const uint16_t bottom_left = left_column[height - 1]; + const uint16_t *const weights_y = smooth_weights_u16 + height - 4; + + const uint16x4_t top_v = vld1_u16(top_row); + const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left); + const uint16x4_t weights_x_v = vld1_u16(smooth_weights_u16); + const uint16x4_t scaled_weights_x = negate_s8(weights_x_v); + const uint32x4_t weighted_tr = vmull_n_u16(scaled_weights_x, top_right); + + for (int y = 0; y < height; ++y) { + // Each variable in the running summation is named for the last item to be + // accumulated. + const uint32x4_t weighted_top = + vmlal_n_u16(weighted_tr, top_v, weights_y[y]); + const uint32x4_t weighted_left = + vmlal_n_u16(weighted_top, weights_x_v, left_column[y]); + const uint32x4_t weighted_bl = + vmlal_n_u16(weighted_left, bottom_left_v, 256 - weights_y[y]); + + const uint16x4_t pred = + vrshrn_n_u32(weighted_bl, SMOOTH_WEIGHT_LOG2_SCALE + 1); + vst1_u16(dst, pred); + dst += stride; + } +} + +// Common code between 8xH and [16|32|64]xH. +static INLINE void highbd_calculate_pred8( + uint16_t *dst, const uint32x4_t weighted_corners_low, + const uint32x4_t weighted_corners_high, const uint16x4x2_t top_vals, + const uint16x4x2_t weights_x, const uint16_t left_y, + const uint16_t weight_y) { + // Each variable in the running summation is named for the last item to be + // accumulated. + const uint32x4_t weighted_top_low = + vmlal_n_u16(weighted_corners_low, top_vals.val[0], weight_y); + const uint32x4_t weighted_edges_low = + vmlal_n_u16(weighted_top_low, weights_x.val[0], left_y); + + const uint16x4_t pred_low = + vrshrn_n_u32(weighted_edges_low, SMOOTH_WEIGHT_LOG2_SCALE + 1); + vst1_u16(dst, pred_low); + + const uint32x4_t weighted_top_high = + vmlal_n_u16(weighted_corners_high, top_vals.val[1], weight_y); + const uint32x4_t weighted_edges_high = + vmlal_n_u16(weighted_top_high, weights_x.val[1], left_y); + + const uint16x4_t pred_high = + vrshrn_n_u32(weighted_edges_high, SMOOTH_WEIGHT_LOG2_SCALE + 1); + vst1_u16(dst + 4, pred_high); +} + +static void highbd_smooth_8xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + const int height) { + const uint16_t top_right = top_row[7]; + const uint16_t bottom_left = left_column[height - 1]; + const uint16_t *const weights_y = smooth_weights_u16 + height - 4; + + const uint16x4x2_t top_vals = { { vld1_u16(top_row), + vld1_u16(top_row + 4) } }; + const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left); + const uint16x4x2_t weights_x = { { vld1_u16(smooth_weights_u16 + 4), + vld1_u16(smooth_weights_u16 + 8) } }; + const uint32x4_t weighted_tr_low = + vmull_n_u16(negate_s8(weights_x.val[0]), top_right); + const uint32x4_t weighted_tr_high = + vmull_n_u16(negate_s8(weights_x.val[1]), top_right); + + for (int y = 0; y < height; ++y) { + const uint32x4_t weighted_bl = + vmull_n_u16(bottom_left_v, 256 - weights_y[y]); + const uint32x4_t weighted_corners_low = + vaddq_u32(weighted_bl, weighted_tr_low); + const uint32x4_t weighted_corners_high = + vaddq_u32(weighted_bl, weighted_tr_high); + highbd_calculate_pred8(dst, weighted_corners_low, weighted_corners_high, + top_vals, weights_x, left_column[y], weights_y[y]); + dst += stride; + } +} + +#define HIGHBD_SMOOTH_NXM(W, H) \ + void aom_highbd_smooth_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_smooth_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +HIGHBD_SMOOTH_NXM(4, 4) +HIGHBD_SMOOTH_NXM(4, 8) +HIGHBD_SMOOTH_NXM(8, 4) +HIGHBD_SMOOTH_NXM(8, 8) +HIGHBD_SMOOTH_NXM(4, 16) +HIGHBD_SMOOTH_NXM(8, 16) +HIGHBD_SMOOTH_NXM(8, 32) + +#undef HIGHBD_SMOOTH_NXM + +// For width 16 and above. +#define HIGHBD_SMOOTH_PREDICTOR(W) \ + static void highbd_smooth_##W##xh_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *const top_row, \ + const uint16_t *const left_column, const int height) { \ + const uint16_t top_right = top_row[(W)-1]; \ + const uint16_t bottom_left = left_column[height - 1]; \ + const uint16_t *const weights_y = smooth_weights_u16 + height - 4; \ + \ + /* Precompute weighted values that don't vary with |y|. */ \ + uint32x4_t weighted_tr_low[(W) >> 3]; \ + uint32x4_t weighted_tr_high[(W) >> 3]; \ + for (int i = 0; i < (W) >> 3; ++i) { \ + const int x = i << 3; \ + const uint16x4_t weights_x_low = \ + vld1_u16(smooth_weights_u16 + (W)-4 + x); \ + weighted_tr_low[i] = vmull_n_u16(negate_s8(weights_x_low), top_right); \ + const uint16x4_t weights_x_high = \ + vld1_u16(smooth_weights_u16 + (W) + x); \ + weighted_tr_high[i] = vmull_n_u16(negate_s8(weights_x_high), top_right); \ + } \ + \ + const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left); \ + for (int y = 0; y < height; ++y) { \ + const uint32x4_t weighted_bl = \ + vmull_n_u16(bottom_left_v, 256 - weights_y[y]); \ + uint16_t *dst_x = dst; \ + for (int i = 0; i < (W) >> 3; ++i) { \ + const int x = i << 3; \ + const uint16x4x2_t top_vals = { { vld1_u16(top_row + x), \ + vld1_u16(top_row + x + 4) } }; \ + const uint32x4_t weighted_corners_low = \ + vaddq_u32(weighted_bl, weighted_tr_low[i]); \ + const uint32x4_t weighted_corners_high = \ + vaddq_u32(weighted_bl, weighted_tr_high[i]); \ + /* Accumulate weighted edge values and store. */ \ + const uint16x4x2_t weights_x = { \ + { vld1_u16(smooth_weights_u16 + (W)-4 + x), \ + vld1_u16(smooth_weights_u16 + (W) + x) } \ + }; \ + highbd_calculate_pred8(dst_x, weighted_corners_low, \ + weighted_corners_high, top_vals, weights_x, \ + left_column[y], weights_y[y]); \ + dst_x += 8; \ + } \ + dst += stride; \ + } \ + } + +HIGHBD_SMOOTH_PREDICTOR(16) +HIGHBD_SMOOTH_PREDICTOR(32) +HIGHBD_SMOOTH_PREDICTOR(64) + +#undef HIGHBD_SMOOTH_PREDICTOR + +#define HIGHBD_SMOOTH_NXM_WIDE(W, H) \ + void aom_highbd_smooth_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_smooth_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +HIGHBD_SMOOTH_NXM_WIDE(16, 4) +HIGHBD_SMOOTH_NXM_WIDE(16, 8) +HIGHBD_SMOOTH_NXM_WIDE(16, 16) +HIGHBD_SMOOTH_NXM_WIDE(16, 32) +HIGHBD_SMOOTH_NXM_WIDE(16, 64) +HIGHBD_SMOOTH_NXM_WIDE(32, 8) +HIGHBD_SMOOTH_NXM_WIDE(32, 16) +HIGHBD_SMOOTH_NXM_WIDE(32, 32) +HIGHBD_SMOOTH_NXM_WIDE(32, 64) +HIGHBD_SMOOTH_NXM_WIDE(64, 16) +HIGHBD_SMOOTH_NXM_WIDE(64, 32) +HIGHBD_SMOOTH_NXM_WIDE(64, 64) + +#undef HIGHBD_SMOOTH_NXM_WIDE + +static void highbd_smooth_v_4xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + const int height) { + const uint16_t bottom_left = left_column[height - 1]; + const uint16_t *const weights_y = smooth_weights_u16 + height - 4; + + const uint16x4_t top_v = vld1_u16(top_row); + const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left); + + for (int y = 0; y < height; ++y) { + const uint32x4_t weighted_bl = + vmull_n_u16(bottom_left_v, 256 - weights_y[y]); + const uint32x4_t weighted_top = + vmlal_n_u16(weighted_bl, top_v, weights_y[y]); + vst1_u16(dst, vrshrn_n_u32(weighted_top, SMOOTH_WEIGHT_LOG2_SCALE)); + + dst += stride; + } +} + +static void highbd_smooth_v_8xh_neon(uint16_t *dst, const ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + const int height) { + const uint16_t bottom_left = left_column[height - 1]; + const uint16_t *const weights_y = smooth_weights_u16 + height - 4; + + const uint16x4_t top_low = vld1_u16(top_row); + const uint16x4_t top_high = vld1_u16(top_row + 4); + const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left); + + for (int y = 0; y < height; ++y) { + const uint32x4_t weighted_bl = + vmull_n_u16(bottom_left_v, 256 - weights_y[y]); + + const uint32x4_t weighted_top_low = + vmlal_n_u16(weighted_bl, top_low, weights_y[y]); + vst1_u16(dst, vrshrn_n_u32(weighted_top_low, SMOOTH_WEIGHT_LOG2_SCALE)); + + const uint32x4_t weighted_top_high = + vmlal_n_u16(weighted_bl, top_high, weights_y[y]); + vst1_u16(dst + 4, + vrshrn_n_u32(weighted_top_high, SMOOTH_WEIGHT_LOG2_SCALE)); + dst += stride; + } +} + +#define HIGHBD_SMOOTH_V_NXM(W, H) \ + void aom_highbd_smooth_v_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_smooth_v_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +HIGHBD_SMOOTH_V_NXM(4, 4) +HIGHBD_SMOOTH_V_NXM(4, 8) +HIGHBD_SMOOTH_V_NXM(4, 16) +HIGHBD_SMOOTH_V_NXM(8, 4) +HIGHBD_SMOOTH_V_NXM(8, 8) +HIGHBD_SMOOTH_V_NXM(8, 16) +HIGHBD_SMOOTH_V_NXM(8, 32) + +#undef HIGHBD_SMOOTH_V_NXM + +// For width 16 and above. +#define HIGHBD_SMOOTH_V_PREDICTOR(W) \ + static void highbd_smooth_v_##W##xh_neon( \ + uint16_t *dst, const ptrdiff_t stride, const uint16_t *const top_row, \ + const uint16_t *const left_column, const int height) { \ + const uint16_t bottom_left = left_column[height - 1]; \ + const uint16_t *const weights_y = smooth_weights_u16 + height - 4; \ + \ + uint16x4x2_t top_vals[(W) >> 3]; \ + for (int i = 0; i < (W) >> 3; ++i) { \ + const int x = i << 3; \ + top_vals[i].val[0] = vld1_u16(top_row + x); \ + top_vals[i].val[1] = vld1_u16(top_row + x + 4); \ + } \ + \ + const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left); \ + for (int y = 0; y < height; ++y) { \ + const uint32x4_t weighted_bl = \ + vmull_n_u16(bottom_left_v, 256 - weights_y[y]); \ + \ + uint16_t *dst_x = dst; \ + for (int i = 0; i < (W) >> 3; ++i) { \ + const uint32x4_t weighted_top_low = \ + vmlal_n_u16(weighted_bl, top_vals[i].val[0], weights_y[y]); \ + vst1_u16(dst_x, \ + vrshrn_n_u32(weighted_top_low, SMOOTH_WEIGHT_LOG2_SCALE)); \ + \ + const uint32x4_t weighted_top_high = \ + vmlal_n_u16(weighted_bl, top_vals[i].val[1], weights_y[y]); \ + vst1_u16(dst_x + 4, \ + vrshrn_n_u32(weighted_top_high, SMOOTH_WEIGHT_LOG2_SCALE)); \ + dst_x += 8; \ + } \ + dst += stride; \ + } \ + } + +HIGHBD_SMOOTH_V_PREDICTOR(16) +HIGHBD_SMOOTH_V_PREDICTOR(32) +HIGHBD_SMOOTH_V_PREDICTOR(64) + +#undef HIGHBD_SMOOTH_V_PREDICTOR + +#define HIGHBD_SMOOTH_V_NXM_WIDE(W, H) \ + void aom_highbd_smooth_v_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_smooth_v_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +HIGHBD_SMOOTH_V_NXM_WIDE(16, 4) +HIGHBD_SMOOTH_V_NXM_WIDE(16, 8) +HIGHBD_SMOOTH_V_NXM_WIDE(16, 16) +HIGHBD_SMOOTH_V_NXM_WIDE(16, 32) +HIGHBD_SMOOTH_V_NXM_WIDE(16, 64) +HIGHBD_SMOOTH_V_NXM_WIDE(32, 8) +HIGHBD_SMOOTH_V_NXM_WIDE(32, 16) +HIGHBD_SMOOTH_V_NXM_WIDE(32, 32) +HIGHBD_SMOOTH_V_NXM_WIDE(32, 64) +HIGHBD_SMOOTH_V_NXM_WIDE(64, 16) +HIGHBD_SMOOTH_V_NXM_WIDE(64, 32) +HIGHBD_SMOOTH_V_NXM_WIDE(64, 64) + +#undef HIGHBD_SMOOTH_V_NXM_WIDE + +static INLINE void highbd_smooth_h_4xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + const int height) { + const uint16_t top_right = top_row[3]; + + const uint16x4_t weights_x = vld1_u16(smooth_weights_u16); + const uint16x4_t scaled_weights_x = negate_s8(weights_x); + + const uint32x4_t weighted_tr = vmull_n_u16(scaled_weights_x, top_right); + for (int y = 0; y < height; ++y) { + const uint32x4_t weighted_left = + vmlal_n_u16(weighted_tr, weights_x, left_column[y]); + vst1_u16(dst, vrshrn_n_u32(weighted_left, SMOOTH_WEIGHT_LOG2_SCALE)); + dst += stride; + } +} + +static INLINE void highbd_smooth_h_8xh_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *const top_row, + const uint16_t *const left_column, + const int height) { + const uint16_t top_right = top_row[7]; + + const uint16x4x2_t weights_x = { { vld1_u16(smooth_weights_u16 + 4), + vld1_u16(smooth_weights_u16 + 8) } }; + + const uint32x4_t weighted_tr_low = + vmull_n_u16(negate_s8(weights_x.val[0]), top_right); + const uint32x4_t weighted_tr_high = + vmull_n_u16(negate_s8(weights_x.val[1]), top_right); + + for (int y = 0; y < height; ++y) { + const uint16_t left_y = left_column[y]; + const uint32x4_t weighted_left_low = + vmlal_n_u16(weighted_tr_low, weights_x.val[0], left_y); + vst1_u16(dst, vrshrn_n_u32(weighted_left_low, SMOOTH_WEIGHT_LOG2_SCALE)); + + const uint32x4_t weighted_left_high = + vmlal_n_u16(weighted_tr_high, weights_x.val[1], left_y); + vst1_u16(dst + 4, + vrshrn_n_u32(weighted_left_high, SMOOTH_WEIGHT_LOG2_SCALE)); + dst += stride; + } +} + +#define HIGHBD_SMOOTH_H_NXM(W, H) \ + void aom_highbd_smooth_h_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_smooth_h_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +HIGHBD_SMOOTH_H_NXM(4, 4) +HIGHBD_SMOOTH_H_NXM(4, 8) +HIGHBD_SMOOTH_H_NXM(4, 16) +HIGHBD_SMOOTH_H_NXM(8, 4) +HIGHBD_SMOOTH_H_NXM(8, 8) +HIGHBD_SMOOTH_H_NXM(8, 16) +HIGHBD_SMOOTH_H_NXM(8, 32) + +#undef HIGHBD_SMOOTH_H_NXM + +// For width 16 and above. +#define HIGHBD_SMOOTH_H_PREDICTOR(W) \ + void highbd_smooth_h_##W##xh_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *const top_row, \ + const uint16_t *const left_column, const int height) { \ + const uint16_t top_right = top_row[(W)-1]; \ + \ + uint16x4_t weights_x_low[(W) >> 3]; \ + uint16x4_t weights_x_high[(W) >> 3]; \ + uint32x4_t weighted_tr_low[(W) >> 3]; \ + uint32x4_t weighted_tr_high[(W) >> 3]; \ + for (int i = 0; i < (W) >> 3; ++i) { \ + const int x = i << 3; \ + weights_x_low[i] = vld1_u16(smooth_weights_u16 + (W)-4 + x); \ + weighted_tr_low[i] = \ + vmull_n_u16(negate_s8(weights_x_low[i]), top_right); \ + weights_x_high[i] = vld1_u16(smooth_weights_u16 + (W) + x); \ + weighted_tr_high[i] = \ + vmull_n_u16(negate_s8(weights_x_high[i]), top_right); \ + } \ + \ + for (int y = 0; y < height; ++y) { \ + uint16_t *dst_x = dst; \ + const uint16_t left_y = left_column[y]; \ + for (int i = 0; i < (W) >> 3; ++i) { \ + const uint32x4_t weighted_left_low = \ + vmlal_n_u16(weighted_tr_low[i], weights_x_low[i], left_y); \ + vst1_u16(dst_x, \ + vrshrn_n_u32(weighted_left_low, SMOOTH_WEIGHT_LOG2_SCALE)); \ + \ + const uint32x4_t weighted_left_high = \ + vmlal_n_u16(weighted_tr_high[i], weights_x_high[i], left_y); \ + vst1_u16(dst_x + 4, \ + vrshrn_n_u32(weighted_left_high, SMOOTH_WEIGHT_LOG2_SCALE)); \ + dst_x += 8; \ + } \ + dst += stride; \ + } \ + } + +HIGHBD_SMOOTH_H_PREDICTOR(16) +HIGHBD_SMOOTH_H_PREDICTOR(32) +HIGHBD_SMOOTH_H_PREDICTOR(64) + +#undef HIGHBD_SMOOTH_H_PREDICTOR + +#define HIGHBD_SMOOTH_H_NXM_WIDE(W, H) \ + void aom_highbd_smooth_h_predictor_##W##x##H##_neon( \ + uint16_t *dst, ptrdiff_t y_stride, const uint16_t *above, \ + const uint16_t *left, int bd) { \ + (void)bd; \ + highbd_smooth_h_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +HIGHBD_SMOOTH_H_NXM_WIDE(16, 4) +HIGHBD_SMOOTH_H_NXM_WIDE(16, 8) +HIGHBD_SMOOTH_H_NXM_WIDE(16, 16) +HIGHBD_SMOOTH_H_NXM_WIDE(16, 32) +HIGHBD_SMOOTH_H_NXM_WIDE(16, 64) +HIGHBD_SMOOTH_H_NXM_WIDE(32, 8) +HIGHBD_SMOOTH_H_NXM_WIDE(32, 16) +HIGHBD_SMOOTH_H_NXM_WIDE(32, 32) +HIGHBD_SMOOTH_H_NXM_WIDE(32, 64) +HIGHBD_SMOOTH_H_NXM_WIDE(64, 16) +HIGHBD_SMOOTH_H_NXM_WIDE(64, 32) +HIGHBD_SMOOTH_H_NXM_WIDE(64, 64) + +#undef HIGHBD_SMOOTH_H_NXM_WIDE + +// ----------------------------------------------------------------------------- +// Z1 + +static int16_t iota1_s16[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8 }; +static int16_t iota2_s16[] = { 0, 2, 4, 6, 8, 10, 12, 14 }; + +static AOM_FORCE_INLINE uint16x4_t highbd_dr_z1_apply_shift_x4(uint16x4_t a0, + uint16x4_t a1, + int shift) { + // The C implementation of the z1 predictor uses (32 - shift) and a right + // shift by 5, however we instead double shift to avoid an unnecessary right + // shift by 1. + uint32x4_t res = vmull_n_u16(a1, shift); + res = vmlal_n_u16(res, a0, 64 - shift); + return vrshrn_n_u32(res, 6); +} + +static AOM_FORCE_INLINE uint16x8_t highbd_dr_z1_apply_shift_x8(uint16x8_t a0, + uint16x8_t a1, + int shift) { + return vcombine_u16( + highbd_dr_z1_apply_shift_x4(vget_low_u16(a0), vget_low_u16(a1), shift), + highbd_dr_z1_apply_shift_x4(vget_high_u16(a0), vget_high_u16(a1), shift)); +} + +static void highbd_dr_prediction_z1_upsample0_neon(uint16_t *dst, + ptrdiff_t stride, int bw, + int bh, + const uint16_t *above, + int dx) { + assert(bw % 4 == 0); + assert(bh % 4 == 0); + assert(dx > 0); + + const int max_base_x = (bw + bh) - 1; + const int above_max = above[max_base_x]; + + const int16x8_t iota1x8 = vld1q_s16(iota1_s16); + const int16x4_t iota1x4 = vget_low_s16(iota1x8); + + int x = dx; + int r = 0; + do { + const int base = x >> 6; + if (base >= max_base_x) { + for (int i = r; i < bh; ++i) { + aom_memset16(dst, above_max, bw); + dst += stride; + } + return; + } + + // The C implementation of the z1 predictor when not upsampling uses: + // ((x & 0x3f) >> 1) + // The right shift is unnecessary here since we instead shift by +1 later, + // so adjust the mask to 0x3e to ensure we don't consider the extra bit. + const int shift = x & 0x3e; + + if (bw == 4) { + const uint16x4_t a0 = vld1_u16(&above[base]); + const uint16x4_t a1 = vld1_u16(&above[base + 1]); + const uint16x4_t val = highbd_dr_z1_apply_shift_x4(a0, a1, shift); + const uint16x4_t cmp = vcgt_s16(vdup_n_s16(max_base_x - base), iota1x4); + const uint16x4_t res = vbsl_u16(cmp, val, vdup_n_u16(above_max)); + vst1_u16(dst, res); + } else { + int c = 0; + do { + const uint16x8_t a0 = vld1q_u16(&above[base + c]); + const uint16x8_t a1 = vld1q_u16(&above[base + c + 1]); + const uint16x8_t val = highbd_dr_z1_apply_shift_x8(a0, a1, shift); + const uint16x8_t cmp = + vcgtq_s16(vdupq_n_s16(max_base_x - base - c), iota1x8); + const uint16x8_t res = vbslq_u16(cmp, val, vdupq_n_u16(above_max)); + vst1q_u16(dst + c, res); + c += 8; + } while (c < bw); + } + + dst += stride; + x += dx; + } while (++r < bh); +} + +static void highbd_dr_prediction_z1_upsample1_neon(uint16_t *dst, + ptrdiff_t stride, int bw, + int bh, + const uint16_t *above, + int dx) { + assert(bw % 4 == 0); + assert(bh % 4 == 0); + assert(dx > 0); + + const int max_base_x = ((bw + bh) - 1) << 1; + const int above_max = above[max_base_x]; + + const int16x8_t iota2x8 = vld1q_s16(iota2_s16); + const int16x4_t iota2x4 = vget_low_s16(iota2x8); + + int x = dx; + int r = 0; + do { + const int base = x >> 5; + if (base >= max_base_x) { + for (int i = r; i < bh; ++i) { + aom_memset16(dst, above_max, bw); + dst += stride; + } + return; + } + + // The C implementation of the z1 predictor when upsampling uses: + // (((x << 1) & 0x3f) >> 1) + // The right shift is unnecessary here since we instead shift by +1 later, + // so adjust the mask to 0x3e to ensure we don't consider the extra bit. + const int shift = (x << 1) & 0x3e; + + if (bw == 4) { + const uint16x4x2_t a01 = vld2_u16(&above[base]); + const uint16x4_t val = + highbd_dr_z1_apply_shift_x4(a01.val[0], a01.val[1], shift); + const uint16x4_t cmp = vcgt_s16(vdup_n_s16(max_base_x - base), iota2x4); + const uint16x4_t res = vbsl_u16(cmp, val, vdup_n_u16(above_max)); + vst1_u16(dst, res); + } else { + int c = 0; + do { + const uint16x8x2_t a01 = vld2q_u16(&above[base + 2 * c]); + const uint16x8_t val = + highbd_dr_z1_apply_shift_x8(a01.val[0], a01.val[1], shift); + const uint16x8_t cmp = + vcgtq_s16(vdupq_n_s16(max_base_x - base - 2 * c), iota2x8); + const uint16x8_t res = vbslq_u16(cmp, val, vdupq_n_u16(above_max)); + vst1q_u16(dst + c, res); + c += 8; + } while (c < bw); + } + + dst += stride; + x += dx; + } while (++r < bh); +} + +// Directional prediction, zone 1: 0 < angle < 90 +void av1_highbd_dr_prediction_z1_neon(uint16_t *dst, ptrdiff_t stride, int bw, + int bh, const uint16_t *above, + const uint16_t *left, int upsample_above, + int dx, int dy, int bd) { + (void)left; + (void)dy; + (void)bd; + assert(dy == 1); + + if (upsample_above) { + highbd_dr_prediction_z1_upsample1_neon(dst, stride, bw, bh, above, dx); + } else { + highbd_dr_prediction_z1_upsample0_neon(dst, stride, bw, bh, above, dx); + } +} + +// ----------------------------------------------------------------------------- +// Z2 + +#if AOM_ARCH_AARCH64 +// Incrementally shift more elements from `above` into the result, merging with +// existing `left` elements. +// X0, X1, X2, X3 +// Y0, X0, X1, X2 +// Y0, Y1, X0, X1 +// Y0, Y1, Y2, X0 +// Y0, Y1, Y2, Y3 +// clang-format off +static const uint8_t z2_merge_shuffles_u16x4[5][8] = { + { 8, 9, 10, 11, 12, 13, 14, 15 }, + { 0, 1, 8, 9, 10, 11, 12, 13 }, + { 0, 1, 2, 3, 8, 9, 10, 11 }, + { 0, 1, 2, 3, 4, 5, 8, 9 }, + { 0, 1, 2, 3, 4, 5, 6, 7 }, +}; +// clang-format on + +// Incrementally shift more elements from `above` into the result, merging with +// existing `left` elements. +// X0, X1, X2, X3, X4, X5, X6, X7 +// Y0, X0, X1, X2, X3, X4, X5, X6 +// Y0, Y1, X0, X1, X2, X3, X4, X5 +// Y0, Y1, Y2, X0, X1, X2, X3, X4 +// Y0, Y1, Y2, Y3, X0, X1, X2, X3 +// Y0, Y1, Y2, Y3, Y4, X0, X1, X2 +// Y0, Y1, Y2, Y3, Y4, Y5, X0, X1 +// Y0, Y1, Y2, Y3, Y4, Y5, Y6, X0 +// Y0, Y1, Y2, Y3, Y4, Y5, Y6, Y7 +// clang-format off +static const uint8_t z2_merge_shuffles_u16x8[9][16] = { + { 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }, + { 0, 1, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29 }, + { 0, 1, 2, 3, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27 }, + { 0, 1, 2, 3, 4, 5, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 }, + { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }, + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17, 18, 19, 20, 21 }, + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19 }, + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17 }, + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, +}; +// clang-format on + +// clang-format off +static const uint16_t z2_y_iter_masks_u16x4[5][4] = { + { 0U, 0U, 0U, 0U }, + { 0xffffU, 0U, 0U, 0U }, + { 0xffffU, 0xffffU, 0U, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0xffffU }, +}; +// clang-format on + +// clang-format off +static const uint16_t z2_y_iter_masks_u16x8[9][8] = { + { 0U, 0U, 0U, 0U, 0U, 0U, 0U, 0U }, + { 0xffffU, 0U, 0U, 0U, 0U, 0U, 0U, 0U }, + { 0xffffU, 0xffffU, 0U, 0U, 0U, 0U, 0U, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0U, 0U, 0U, 0U, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0U, 0U, 0U, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0U, 0U, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0U, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0U }, + { 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU, 0xffffU }, +}; +// clang-format on + +static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_tbl_left_x4_from_x8( + const uint16x8_t left_data, const int16x4_t indices, int base, int n) { + // Need to adjust indices to operate on 0-based indices rather than + // `base`-based indices and then adjust from uint16x4 indices to uint8x8 + // indices so we can use a tbl instruction (which only operates on bytes). + uint8x8_t left_indices = + vreinterpret_u8_s16(vsub_s16(indices, vdup_n_s16(base))); + left_indices = vtrn1_u8(left_indices, left_indices); + left_indices = vadd_u8(left_indices, left_indices); + left_indices = vadd_u8(left_indices, vreinterpret_u8_u16(vdup_n_u16(0x0100))); + const uint16x4_t ret = vreinterpret_u16_u8( + vqtbl1_u8(vreinterpretq_u8_u16(left_data), left_indices)); + return vand_u16(ret, vld1_u16(z2_y_iter_masks_u16x4[n])); +} + +static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_tbl_left_x4_from_x16( + const uint16x8x2_t left_data, const int16x4_t indices, int base, int n) { + // Need to adjust indices to operate on 0-based indices rather than + // `base`-based indices and then adjust from uint16x4 indices to uint8x8 + // indices so we can use a tbl instruction (which only operates on bytes). + uint8x8_t left_indices = + vreinterpret_u8_s16(vsub_s16(indices, vdup_n_s16(base))); + left_indices = vtrn1_u8(left_indices, left_indices); + left_indices = vadd_u8(left_indices, left_indices); + left_indices = vadd_u8(left_indices, vreinterpret_u8_u16(vdup_n_u16(0x0100))); + uint8x16x2_t data_u8 = { { vreinterpretq_u8_u16(left_data.val[0]), + vreinterpretq_u8_u16(left_data.val[1]) } }; + const uint16x4_t ret = vreinterpret_u16_u8(vqtbl2_u8(data_u8, left_indices)); + return vand_u16(ret, vld1_u16(z2_y_iter_masks_u16x4[n])); +} + +static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_tbl_left_x8_from_x8( + const uint16x8_t left_data, const int16x8_t indices, int base, int n) { + // Need to adjust indices to operate on 0-based indices rather than + // `base`-based indices and then adjust from uint16x4 indices to uint8x8 + // indices so we can use a tbl instruction (which only operates on bytes). + uint8x16_t left_indices = + vreinterpretq_u8_s16(vsubq_s16(indices, vdupq_n_s16(base))); + left_indices = vtrn1q_u8(left_indices, left_indices); + left_indices = vaddq_u8(left_indices, left_indices); + left_indices = + vaddq_u8(left_indices, vreinterpretq_u8_u16(vdupq_n_u16(0x0100))); + const uint16x8_t ret = vreinterpretq_u16_u8( + vqtbl1q_u8(vreinterpretq_u8_u16(left_data), left_indices)); + return vandq_u16(ret, vld1q_u16(z2_y_iter_masks_u16x8[n])); +} + +static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_tbl_left_x8_from_x16( + const uint16x8x2_t left_data, const int16x8_t indices, int base, int n) { + // Need to adjust indices to operate on 0-based indices rather than + // `base`-based indices and then adjust from uint16x4 indices to uint8x8 + // indices so we can use a tbl instruction (which only operates on bytes). + uint8x16_t left_indices = + vreinterpretq_u8_s16(vsubq_s16(indices, vdupq_n_s16(base))); + left_indices = vtrn1q_u8(left_indices, left_indices); + left_indices = vaddq_u8(left_indices, left_indices); + left_indices = + vaddq_u8(left_indices, vreinterpretq_u8_u16(vdupq_n_u16(0x0100))); + uint8x16x2_t data_u8 = { { vreinterpretq_u8_u16(left_data.val[0]), + vreinterpretq_u8_u16(left_data.val[1]) } }; + const uint16x8_t ret = + vreinterpretq_u16_u8(vqtbl2q_u8(data_u8, left_indices)); + return vandq_u16(ret, vld1q_u16(z2_y_iter_masks_u16x8[n])); +} +#endif // AOM_ARCH_AARCH64 + +static AOM_FORCE_INLINE uint16x4x2_t highbd_dr_prediction_z2_gather_left_x4( + const uint16_t *left, const int16x4_t indices, int n) { + assert(n > 0); + assert(n <= 4); + // Load two elements at a time and then uzp them into separate vectors, to + // reduce the number of memory accesses. + uint32x2_t ret0_u32 = vdup_n_u32(0); + uint32x2_t ret1_u32 = vdup_n_u32(0); + + // Use a single vget_lane_u64 to minimize vector to general purpose register + // transfers and then mask off the bits we actually want. + const uint64_t indices0123 = vget_lane_u64(vreinterpret_u64_s16(indices), 0); + const int idx0 = (int16_t)((indices0123 >> 0) & 0xffffU); + const int idx1 = (int16_t)((indices0123 >> 16) & 0xffffU); + const int idx2 = (int16_t)((indices0123 >> 32) & 0xffffU); + const int idx3 = (int16_t)((indices0123 >> 48) & 0xffffU); + + // At time of writing both Clang and GCC produced better code with these + // nested if-statements compared to a switch statement with fallthrough. + ret0_u32 = vld1_lane_u32((const uint32_t *)(left + idx0), ret0_u32, 0); + if (n > 1) { + ret0_u32 = vld1_lane_u32((const uint32_t *)(left + idx1), ret0_u32, 1); + if (n > 2) { + ret1_u32 = vld1_lane_u32((const uint32_t *)(left + idx2), ret1_u32, 0); + if (n > 3) { + ret1_u32 = vld1_lane_u32((const uint32_t *)(left + idx3), ret1_u32, 1); + } + } + } + return vuzp_u16(vreinterpret_u16_u32(ret0_u32), + vreinterpret_u16_u32(ret1_u32)); +} + +static AOM_FORCE_INLINE uint16x8x2_t highbd_dr_prediction_z2_gather_left_x8( + const uint16_t *left, const int16x8_t indices, int n) { + assert(n > 0); + assert(n <= 8); + // Load two elements at a time and then uzp them into separate vectors, to + // reduce the number of memory accesses. + uint32x4_t ret0_u32 = vdupq_n_u32(0); + uint32x4_t ret1_u32 = vdupq_n_u32(0); + + // Use a pair of vget_lane_u64 to minimize vector to general purpose register + // transfers and then mask off the bits we actually want. + const uint64_t indices0123 = + vgetq_lane_u64(vreinterpretq_u64_s16(indices), 0); + const uint64_t indices4567 = + vgetq_lane_u64(vreinterpretq_u64_s16(indices), 1); + const int idx0 = (int16_t)((indices0123 >> 0) & 0xffffU); + const int idx1 = (int16_t)((indices0123 >> 16) & 0xffffU); + const int idx2 = (int16_t)((indices0123 >> 32) & 0xffffU); + const int idx3 = (int16_t)((indices0123 >> 48) & 0xffffU); + const int idx4 = (int16_t)((indices4567 >> 0) & 0xffffU); + const int idx5 = (int16_t)((indices4567 >> 16) & 0xffffU); + const int idx6 = (int16_t)((indices4567 >> 32) & 0xffffU); + const int idx7 = (int16_t)((indices4567 >> 48) & 0xffffU); + + // At time of writing both Clang and GCC produced better code with these + // nested if-statements compared to a switch statement with fallthrough. + ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx0), ret0_u32, 0); + if (n > 1) { + ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx1), ret0_u32, 1); + if (n > 2) { + ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx2), ret0_u32, 2); + if (n > 3) { + ret0_u32 = vld1q_lane_u32((const uint32_t *)(left + idx3), ret0_u32, 3); + if (n > 4) { + ret1_u32 = + vld1q_lane_u32((const uint32_t *)(left + idx4), ret1_u32, 0); + if (n > 5) { + ret1_u32 = + vld1q_lane_u32((const uint32_t *)(left + idx5), ret1_u32, 1); + if (n > 6) { + ret1_u32 = + vld1q_lane_u32((const uint32_t *)(left + idx6), ret1_u32, 2); + if (n > 7) { + ret1_u32 = vld1q_lane_u32((const uint32_t *)(left + idx7), + ret1_u32, 3); + } + } + } + } + } + } + } + return vuzpq_u16(vreinterpretq_u16_u32(ret0_u32), + vreinterpretq_u16_u32(ret1_u32)); +} + +static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_merge_x4( + uint16x4_t out_x, uint16x4_t out_y, int base_shift) { + assert(base_shift >= 0); + assert(base_shift <= 4); + // On AArch64 we can permute the data from the `above` and `left` vectors + // into a single vector in a single load (of the permute vector) + tbl. +#if AOM_ARCH_AARCH64 + const uint8x8x2_t out_yx = { { vreinterpret_u8_u16(out_y), + vreinterpret_u8_u16(out_x) } }; + return vreinterpret_u16_u8( + vtbl2_u8(out_yx, vld1_u8(z2_merge_shuffles_u16x4[base_shift]))); +#else + uint16x4_t out = out_y; + for (int c2 = base_shift, x_idx = 0; c2 < 4; ++c2, ++x_idx) { + out[c2] = out_x[x_idx]; + } + return out; +#endif +} + +static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_merge_x8( + uint16x8_t out_x, uint16x8_t out_y, int base_shift) { + assert(base_shift >= 0); + assert(base_shift <= 8); + // On AArch64 we can permute the data from the `above` and `left` vectors + // into a single vector in a single load (of the permute vector) + tbl. +#if AOM_ARCH_AARCH64 + const uint8x16x2_t out_yx = { { vreinterpretq_u8_u16(out_y), + vreinterpretq_u8_u16(out_x) } }; + return vreinterpretq_u16_u8( + vqtbl2q_u8(out_yx, vld1q_u8(z2_merge_shuffles_u16x8[base_shift]))); +#else + uint16x8_t out = out_y; + for (int c2 = base_shift, x_idx = 0; c2 < 8; ++c2, ++x_idx) { + out[c2] = out_x[x_idx]; + } + return out; +#endif +} + +static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_apply_shift_x4( + uint16x4_t a0, uint16x4_t a1, int16x4_t shift) { + uint32x4_t res = vmull_u16(a1, vreinterpret_u16_s16(shift)); + res = + vmlal_u16(res, a0, vsub_u16(vdup_n_u16(32), vreinterpret_u16_s16(shift))); + return vrshrn_n_u32(res, 5); +} + +static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_apply_shift_x8( + uint16x8_t a0, uint16x8_t a1, int16x8_t shift) { + return vcombine_u16( + highbd_dr_prediction_z2_apply_shift_x4(vget_low_u16(a0), vget_low_u16(a1), + vget_low_s16(shift)), + highbd_dr_prediction_z2_apply_shift_x4( + vget_high_u16(a0), vget_high_u16(a1), vget_high_s16(shift))); +} + +static AOM_FORCE_INLINE uint16x4_t highbd_dr_prediction_z2_step_x4( + const uint16_t *above, const uint16x4_t above0, const uint16x4_t above1, + const uint16_t *left, int dx, int dy, int r, int c) { + const int16x4_t iota = vld1_s16(iota1_s16); + + const int x0 = (c << 6) - (r + 1) * dx; + const int y0 = (r << 6) - (c + 1) * dy; + + const int16x4_t x0123 = vadd_s16(vdup_n_s16(x0), vshl_n_s16(iota, 6)); + const int16x4_t y0123 = vsub_s16(vdup_n_s16(y0), vmul_n_s16(iota, dy)); + const int16x4_t shift_x0123 = + vshr_n_s16(vand_s16(x0123, vdup_n_s16(0x3F)), 1); + const int16x4_t shift_y0123 = + vshr_n_s16(vand_s16(y0123, vdup_n_s16(0x3F)), 1); + const int16x4_t base_y0123 = vshr_n_s16(y0123, 6); + + const int base_shift = ((((r + 1) * dx) - 1) >> 6) - c; + + // Based on the value of `base_shift` there are three possible cases to + // compute the result: + // 1) base_shift <= 0: We can load and operate entirely on data from the + // `above` input vector. + // 2) base_shift < vl: We can load from `above[-1]` and shift + // `vl - base_shift` elements across to the end of the + // vector, then compute the remainder from `left`. + // 3) base_shift >= vl: We can load and operate entirely on data from the + // `left` input vector. + + if (base_shift <= 0) { + const int base_x = x0 >> 6; + const uint16x4_t a0 = vld1_u16(above + base_x); + const uint16x4_t a1 = vld1_u16(above + base_x + 1); + return highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123); + } else if (base_shift < 4) { + const uint16x4x2_t l01 = highbd_dr_prediction_z2_gather_left_x4( + left + 1, base_y0123, base_shift); + const uint16x4_t out16_y = highbd_dr_prediction_z2_apply_shift_x4( + l01.val[0], l01.val[1], shift_y0123); + + // No need to reload from above in the loop, just use pre-loaded constants. + const uint16x4_t out16_x = + highbd_dr_prediction_z2_apply_shift_x4(above0, above1, shift_x0123); + + return highbd_dr_prediction_z2_merge_x4(out16_x, out16_y, base_shift); + } else { + const uint16x4x2_t l01 = + highbd_dr_prediction_z2_gather_left_x4(left + 1, base_y0123, 4); + return highbd_dr_prediction_z2_apply_shift_x4(l01.val[0], l01.val[1], + shift_y0123); + } +} + +static AOM_FORCE_INLINE uint16x8_t highbd_dr_prediction_z2_step_x8( + const uint16_t *above, const uint16x8_t above0, const uint16x8_t above1, + const uint16_t *left, int dx, int dy, int r, int c) { + const int16x8_t iota = vld1q_s16(iota1_s16); + + const int x0 = (c << 6) - (r + 1) * dx; + const int y0 = (r << 6) - (c + 1) * dy; + + const int16x8_t x01234567 = vaddq_s16(vdupq_n_s16(x0), vshlq_n_s16(iota, 6)); + const int16x8_t y01234567 = vsubq_s16(vdupq_n_s16(y0), vmulq_n_s16(iota, dy)); + const int16x8_t shift_x01234567 = + vshrq_n_s16(vandq_s16(x01234567, vdupq_n_s16(0x3F)), 1); + const int16x8_t shift_y01234567 = + vshrq_n_s16(vandq_s16(y01234567, vdupq_n_s16(0x3F)), 1); + const int16x8_t base_y01234567 = vshrq_n_s16(y01234567, 6); + + const int base_shift = ((((r + 1) * dx) - 1) >> 6) - c; + + // Based on the value of `base_shift` there are three possible cases to + // compute the result: + // 1) base_shift <= 0: We can load and operate entirely on data from the + // `above` input vector. + // 2) base_shift < vl: We can load from `above[-1]` and shift + // `vl - base_shift` elements across to the end of the + // vector, then compute the remainder from `left`. + // 3) base_shift >= vl: We can load and operate entirely on data from the + // `left` input vector. + + if (base_shift <= 0) { + const int base_x = x0 >> 6; + const uint16x8_t a0 = vld1q_u16(above + base_x); + const uint16x8_t a1 = vld1q_u16(above + base_x + 1); + return highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567); + } else if (base_shift < 8) { + const uint16x8x2_t l01 = highbd_dr_prediction_z2_gather_left_x8( + left + 1, base_y01234567, base_shift); + const uint16x8_t out16_y = highbd_dr_prediction_z2_apply_shift_x8( + l01.val[0], l01.val[1], shift_y01234567); + + // No need to reload from above in the loop, just use pre-loaded constants. + const uint16x8_t out16_x = + highbd_dr_prediction_z2_apply_shift_x8(above0, above1, shift_x01234567); + + return highbd_dr_prediction_z2_merge_x8(out16_x, out16_y, base_shift); + } else { + const uint16x8x2_t l01 = + highbd_dr_prediction_z2_gather_left_x8(left + 1, base_y01234567, 8); + return highbd_dr_prediction_z2_apply_shift_x8(l01.val[0], l01.val[1], + shift_y01234567); + } +} + +// Left array is accessed from -1 through `bh - 1` inclusive. +// Above array is accessed from -1 through `bw - 1` inclusive. +#define HIGHBD_DR_PREDICTOR_Z2_WXH(bw, bh) \ + static void highbd_dr_prediction_z2_##bw##x##bh##_neon( \ + uint16_t *dst, ptrdiff_t stride, const uint16_t *above, \ + const uint16_t *left, int upsample_above, int upsample_left, int dx, \ + int dy, int bd) { \ + (void)bd; \ + (void)upsample_above; \ + (void)upsample_left; \ + assert(!upsample_above); \ + assert(!upsample_left); \ + assert(bw % 4 == 0); \ + assert(bh % 4 == 0); \ + assert(dx > 0); \ + assert(dy > 0); \ + \ + uint16_t left_data[bh + 1]; \ + memcpy(left_data, left - 1, (bh + 1) * sizeof(uint16_t)); \ + \ + uint16x8_t a0, a1; \ + if (bw == 4) { \ + a0 = vcombine_u16(vld1_u16(above - 1), vdup_n_u16(0)); \ + a1 = vcombine_u16(vld1_u16(above + 0), vdup_n_u16(0)); \ + } else { \ + a0 = vld1q_u16(above - 1); \ + a1 = vld1q_u16(above + 0); \ + } \ + \ + int r = 0; \ + do { \ + if (bw == 4) { \ + vst1_u16(dst, highbd_dr_prediction_z2_step_x4( \ + above, vget_low_u16(a0), vget_low_u16(a1), \ + left_data, dx, dy, r, 0)); \ + } else { \ + int c = 0; \ + do { \ + vst1q_u16(dst + c, highbd_dr_prediction_z2_step_x8( \ + above, a0, a1, left_data, dx, dy, r, c)); \ + c += 8; \ + } while (c < bw); \ + } \ + dst += stride; \ + } while (++r < bh); \ + } + +HIGHBD_DR_PREDICTOR_Z2_WXH(4, 16) +HIGHBD_DR_PREDICTOR_Z2_WXH(8, 16) +HIGHBD_DR_PREDICTOR_Z2_WXH(8, 32) +HIGHBD_DR_PREDICTOR_Z2_WXH(16, 4) +HIGHBD_DR_PREDICTOR_Z2_WXH(16, 8) +HIGHBD_DR_PREDICTOR_Z2_WXH(16, 16) +HIGHBD_DR_PREDICTOR_Z2_WXH(16, 32) +HIGHBD_DR_PREDICTOR_Z2_WXH(16, 64) +HIGHBD_DR_PREDICTOR_Z2_WXH(32, 8) +HIGHBD_DR_PREDICTOR_Z2_WXH(32, 16) +HIGHBD_DR_PREDICTOR_Z2_WXH(32, 32) +HIGHBD_DR_PREDICTOR_Z2_WXH(32, 64) +HIGHBD_DR_PREDICTOR_Z2_WXH(64, 16) +HIGHBD_DR_PREDICTOR_Z2_WXH(64, 32) +HIGHBD_DR_PREDICTOR_Z2_WXH(64, 64) + +#undef HIGHBD_DR_PREDICTOR_Z2_WXH + +typedef void (*highbd_dr_prediction_z2_ptr)(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, + int upsample_above, + int upsample_left, int dx, int dy, + int bd); + +static void highbd_dr_prediction_z2_4x4_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, + int upsample_above, + int upsample_left, int dx, int dy, + int bd) { + (void)bd; + assert(dx > 0); + assert(dy > 0); + + const int frac_bits_x = 6 - upsample_above; + const int frac_bits_y = 6 - upsample_left; + const int min_base_x = -(1 << (upsample_above + frac_bits_x)); + + // if `upsample_left` then we need -2 through 6 inclusive from `left`. + // else we only need -1 through 3 inclusive. + +#if AOM_ARCH_AARCH64 + uint16x8_t left_data0, left_data1; + if (upsample_left) { + left_data0 = vld1q_u16(left - 2); + left_data1 = vld1q_u16(left - 1); + } else { + left_data0 = vcombine_u16(vld1_u16(left - 1), vdup_n_u16(0)); + left_data1 = vcombine_u16(vld1_u16(left + 0), vdup_n_u16(0)); + } +#endif + + const int16x4_t iota0123 = vld1_s16(iota1_s16); + const int16x4_t iota1234 = vld1_s16(iota1_s16 + 1); + + for (int r = 0; r < 4; ++r) { + const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6; + const int x0 = (r + 1) * dx; + const int16x4_t x0123 = vsub_s16(vshl_n_s16(iota0123, 6), vdup_n_s16(x0)); + const int base_x0 = (-x0) >> frac_bits_x; + if (base_shift <= 0) { + uint16x4_t a0, a1; + int16x4_t shift_x0123; + if (upsample_above) { + const uint16x4x2_t a01 = vld2_u16(above + base_x0); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x0123 = vand_s16(x0123, vdup_n_s16(0x1F)); + } else { + a0 = vld1_u16(above + base_x0); + a1 = vld1_u16(above + base_x0 + 1); + shift_x0123 = vshr_n_s16(vand_s16(x0123, vdup_n_s16(0x3F)), 1); + } + vst1_u16(dst, + highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123)); + } else if (base_shift < 4) { + // Calculate Y component from `left`. + const int y_iters = base_shift; + const int16x4_t y0123 = + vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy)); + const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y)); + const int16x4_t shift_y0123 = vshr_n_s16( + vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1); + uint16x4_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data0, base_y0123, + left_data_base, y_iters); + l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data1, base_y0123, + left_data_base, y_iters); +#else + const uint16x4x2_t l01 = + highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, y_iters); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + const uint16x4_t out_y = + highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123); + + // Calculate X component from `above`. + const int16x4_t shift_x0123 = vshr_n_s16( + vand_s16(vmul_n_s16(x0123, 1 << upsample_above), vdup_n_s16(0x3F)), + 1); + uint16x4_t a0, a1; + if (upsample_above) { + const uint16x4x2_t a01 = vld2_u16(above + (base_x0 % 2 == 0 ? -2 : -1)); + a0 = a01.val[0]; + a1 = a01.val[1]; + } else { + a0 = vld1_u16(above - 1); + a1 = vld1_u16(above + 0); + } + const uint16x4_t out_x = + highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123); + + // Combine X and Y vectors. + const uint16x4_t out = + highbd_dr_prediction_z2_merge_x4(out_x, out_y, base_shift); + vst1_u16(dst, out); + } else { + const int16x4_t y0123 = + vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy)); + const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y)); + const int16x4_t shift_y0123 = vshr_n_s16( + vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1); + uint16x4_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data0, base_y0123, + left_data_base, 4); + l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x8(left_data1, base_y0123, + left_data_base, 4); +#else + const uint16x4x2_t l01 = + highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, 4); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + vst1_u16(dst, + highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123)); + } + dst += stride; + } +} + +static void highbd_dr_prediction_z2_4x8_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, + int upsample_above, + int upsample_left, int dx, int dy, + int bd) { + (void)bd; + assert(dx > 0); + assert(dy > 0); + + const int frac_bits_x = 6 - upsample_above; + const int frac_bits_y = 6 - upsample_left; + const int min_base_x = -(1 << (upsample_above + frac_bits_x)); + + // if `upsample_left` then we need -2 through 14 inclusive from `left`. + // else we only need -1 through 6 inclusive. + +#if AOM_ARCH_AARCH64 + uint16x8x2_t left_data0, left_data1; + if (upsample_left) { + left_data0 = vld1q_u16_x2(left - 2); + left_data1 = vld1q_u16_x2(left - 1); + } else { + left_data0 = (uint16x8x2_t){ { vld1q_u16(left - 1), vdupq_n_u16(0) } }; + left_data1 = (uint16x8x2_t){ { vld1q_u16(left + 0), vdupq_n_u16(0) } }; + } +#endif + + const int16x4_t iota0123 = vld1_s16(iota1_s16); + const int16x4_t iota1234 = vld1_s16(iota1_s16 + 1); + + for (int r = 0; r < 8; ++r) { + const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6; + const int x0 = (r + 1) * dx; + const int16x4_t x0123 = vsub_s16(vshl_n_s16(iota0123, 6), vdup_n_s16(x0)); + const int base_x0 = (-x0) >> frac_bits_x; + if (base_shift <= 0) { + uint16x4_t a0, a1; + int16x4_t shift_x0123; + if (upsample_above) { + const uint16x4x2_t a01 = vld2_u16(above + base_x0); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x0123 = vand_s16(x0123, vdup_n_s16(0x1F)); + } else { + a0 = vld1_u16(above + base_x0); + a1 = vld1_u16(above + base_x0 + 1); + shift_x0123 = vand_s16(vshr_n_s16(x0123, 1), vdup_n_s16(0x1F)); + } + vst1_u16(dst, + highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123)); + } else if (base_shift < 4) { + // Calculate Y component from `left`. + const int y_iters = base_shift; + const int16x4_t y0123 = + vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy)); + const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y)); + const int16x4_t shift_y0123 = vshr_n_s16( + vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1); + + uint16x4_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x16( + left_data0, base_y0123, left_data_base, y_iters); + l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x16( + left_data1, base_y0123, left_data_base, y_iters); +#else + const uint16x4x2_t l01 = + highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, y_iters); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + const uint16x4_t out_y = + highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123); + + // Calculate X component from `above`. + uint16x4_t a0, a1; + int16x4_t shift_x0123; + if (upsample_above) { + const uint16x4x2_t a01 = vld2_u16(above + (base_x0 % 2 == 0 ? -2 : -1)); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x0123 = vand_s16(x0123, vdup_n_s16(0x1F)); + } else { + a0 = vld1_u16(above - 1); + a1 = vld1_u16(above + 0); + shift_x0123 = vand_s16(vshr_n_s16(x0123, 1), vdup_n_s16(0x1F)); + } + const uint16x4_t out_x = + highbd_dr_prediction_z2_apply_shift_x4(a0, a1, shift_x0123); + + // Combine X and Y vectors. + const uint16x4_t out = + highbd_dr_prediction_z2_merge_x4(out_x, out_y, base_shift); + vst1_u16(dst, out); + } else { + const int16x4_t y0123 = + vsub_s16(vdup_n_s16(r << 6), vmul_n_s16(iota1234, dy)); + const int16x4_t base_y0123 = vshl_s16(y0123, vdup_n_s16(-frac_bits_y)); + const int16x4_t shift_y0123 = vshr_n_s16( + vand_s16(vmul_n_s16(y0123, 1 << upsample_left), vdup_n_s16(0x3F)), 1); + + uint16x4_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x4_from_x16(left_data0, base_y0123, + left_data_base, 4); + l1 = highbd_dr_prediction_z2_tbl_left_x4_from_x16(left_data1, base_y0123, + left_data_base, 4); +#else + const uint16x4x2_t l01 = + highbd_dr_prediction_z2_gather_left_x4(left, base_y0123, 4); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + vst1_u16(dst, + highbd_dr_prediction_z2_apply_shift_x4(l0, l1, shift_y0123)); + } + dst += stride; + } +} + +static void highbd_dr_prediction_z2_8x4_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, + int upsample_above, + int upsample_left, int dx, int dy, + int bd) { + (void)bd; + assert(dx > 0); + assert(dy > 0); + + const int frac_bits_x = 6 - upsample_above; + const int frac_bits_y = 6 - upsample_left; + const int min_base_x = -(1 << (upsample_above + frac_bits_x)); + + // if `upsample_left` then we need -2 through 6 inclusive from `left`. + // else we only need -1 through 3 inclusive. + +#if AOM_ARCH_AARCH64 + uint16x8_t left_data0, left_data1; + if (upsample_left) { + left_data0 = vld1q_u16(left - 2); + left_data1 = vld1q_u16(left - 1); + } else { + left_data0 = vcombine_u16(vld1_u16(left - 1), vdup_n_u16(0)); + left_data1 = vcombine_u16(vld1_u16(left + 0), vdup_n_u16(0)); + } +#endif + + const int16x8_t iota01234567 = vld1q_s16(iota1_s16); + const int16x8_t iota12345678 = vld1q_s16(iota1_s16 + 1); + + for (int r = 0; r < 4; ++r) { + const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6; + const int x0 = (r + 1) * dx; + const int16x8_t x01234567 = + vsubq_s16(vshlq_n_s16(iota01234567, 6), vdupq_n_s16(x0)); + const int base_x0 = (-x0) >> frac_bits_x; + if (base_shift <= 0) { + uint16x8_t a0, a1; + int16x8_t shift_x01234567; + if (upsample_above) { + const uint16x8x2_t a01 = vld2q_u16(above + base_x0); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F)); + } else { + a0 = vld1q_u16(above + base_x0); + a1 = vld1q_u16(above + base_x0 + 1); + shift_x01234567 = + vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F)); + } + vst1q_u16( + dst, highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567)); + } else if (base_shift < 8) { + // Calculate Y component from `left`. + const int y_iters = base_shift; + const int16x8_t y01234567 = + vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy)); + const int16x8_t base_y01234567 = + vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y)); + const int16x8_t shift_y01234567 = + vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left), + vdupq_n_s16(0x3F)), + 1); + + uint16x8_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x8( + left_data0, base_y01234567, left_data_base, y_iters); + l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x8( + left_data1, base_y01234567, left_data_base, y_iters); +#else + const uint16x8x2_t l01 = + highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, y_iters); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + const uint16x8_t out_y = + highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567); + + // Calculate X component from `above`. + uint16x8_t a0, a1; + int16x8_t shift_x01234567; + if (upsample_above) { + const uint16x8x2_t a01 = + vld2q_u16(above + (base_x0 % 2 == 0 ? -2 : -1)); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F)); + } else { + a0 = vld1q_u16(above - 1); + a1 = vld1q_u16(above + 0); + shift_x01234567 = + vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F)); + } + const uint16x8_t out_x = + highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567); + + // Combine X and Y vectors. + const uint16x8_t out = + highbd_dr_prediction_z2_merge_x8(out_x, out_y, base_shift); + vst1q_u16(dst, out); + } else { + const int16x8_t y01234567 = + vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy)); + const int16x8_t base_y01234567 = + vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y)); + const int16x8_t shift_y01234567 = + vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left), + vdupq_n_s16(0x3F)), + 1); + + uint16x8_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x8( + left_data0, base_y01234567, left_data_base, 8); + l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x8( + left_data1, base_y01234567, left_data_base, 8); +#else + const uint16x8x2_t l01 = + highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, 8); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + vst1q_u16( + dst, highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567)); + } + dst += stride; + } +} + +static void highbd_dr_prediction_z2_8x8_neon(uint16_t *dst, ptrdiff_t stride, + const uint16_t *above, + const uint16_t *left, + int upsample_above, + int upsample_left, int dx, int dy, + int bd) { + (void)bd; + assert(dx > 0); + assert(dy > 0); + + const int frac_bits_x = 6 - upsample_above; + const int frac_bits_y = 6 - upsample_left; + const int min_base_x = -(1 << (upsample_above + frac_bits_x)); + + // if `upsample_left` then we need -2 through 14 inclusive from `left`. + // else we only need -1 through 6 inclusive. + +#if AOM_ARCH_AARCH64 + uint16x8x2_t left_data0, left_data1; + if (upsample_left) { + left_data0 = vld1q_u16_x2(left - 2); + left_data1 = vld1q_u16_x2(left - 1); + } else { + left_data0 = (uint16x8x2_t){ { vld1q_u16(left - 1), vdupq_n_u16(0) } }; + left_data1 = (uint16x8x2_t){ { vld1q_u16(left + 0), vdupq_n_u16(0) } }; + } +#endif + + const int16x8_t iota01234567 = vld1q_s16(iota1_s16); + const int16x8_t iota12345678 = vld1q_s16(iota1_s16 + 1); + + for (int r = 0; r < 8; ++r) { + const int base_shift = (min_base_x + (r + 1) * dx + 63) >> 6; + const int x0 = (r + 1) * dx; + const int16x8_t x01234567 = + vsubq_s16(vshlq_n_s16(iota01234567, 6), vdupq_n_s16(x0)); + const int base_x0 = (-x0) >> frac_bits_x; + if (base_shift <= 0) { + uint16x8_t a0, a1; + int16x8_t shift_x01234567; + if (upsample_above) { + const uint16x8x2_t a01 = vld2q_u16(above + base_x0); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F)); + } else { + a0 = vld1q_u16(above + base_x0); + a1 = vld1q_u16(above + base_x0 + 1); + shift_x01234567 = + vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F)); + } + vst1q_u16( + dst, highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567)); + } else if (base_shift < 8) { + // Calculate Y component from `left`. + const int y_iters = base_shift; + const int16x8_t y01234567 = + vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy)); + const int16x8_t base_y01234567 = + vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y)); + const int16x8_t shift_y01234567 = + vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left), + vdupq_n_s16(0x3F)), + 1); + + uint16x8_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x16( + left_data0, base_y01234567, left_data_base, y_iters); + l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x16( + left_data1, base_y01234567, left_data_base, y_iters); +#else + const uint16x8x2_t l01 = + highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, y_iters); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + const uint16x8_t out_y = + highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567); + + // Calculate X component from `above`. + uint16x8_t a0, a1; + int16x8_t shift_x01234567; + if (upsample_above) { + const uint16x8x2_t a01 = + vld2q_u16(above + (base_x0 % 2 == 0 ? -2 : -1)); + a0 = a01.val[0]; + a1 = a01.val[1]; + shift_x01234567 = vandq_s16(x01234567, vdupq_n_s16(0x1F)); + } else { + a0 = vld1q_u16(above - 1); + a1 = vld1q_u16(above + 0); + shift_x01234567 = + vandq_s16(vshrq_n_s16(x01234567, 1), vdupq_n_s16(0x1F)); + } + const uint16x8_t out_x = + highbd_dr_prediction_z2_apply_shift_x8(a0, a1, shift_x01234567); + + // Combine X and Y vectors. + const uint16x8_t out = + highbd_dr_prediction_z2_merge_x8(out_x, out_y, base_shift); + vst1q_u16(dst, out); + } else { + const int16x8_t y01234567 = + vsubq_s16(vdupq_n_s16(r << 6), vmulq_n_s16(iota12345678, dy)); + const int16x8_t base_y01234567 = + vshlq_s16(y01234567, vdupq_n_s16(-frac_bits_y)); + const int16x8_t shift_y01234567 = + vshrq_n_s16(vandq_s16(vmulq_n_s16(y01234567, 1 << upsample_left), + vdupq_n_s16(0x3F)), + 1); + + uint16x8_t l0, l1; +#if AOM_ARCH_AARCH64 + const int left_data_base = upsample_left ? -2 : -1; + l0 = highbd_dr_prediction_z2_tbl_left_x8_from_x16( + left_data0, base_y01234567, left_data_base, 8); + l1 = highbd_dr_prediction_z2_tbl_left_x8_from_x16( + left_data1, base_y01234567, left_data_base, 8); +#else + const uint16x8x2_t l01 = + highbd_dr_prediction_z2_gather_left_x8(left, base_y01234567, 8); + l0 = l01.val[0]; + l1 = l01.val[1]; +#endif + + vst1q_u16( + dst, highbd_dr_prediction_z2_apply_shift_x8(l0, l1, shift_y01234567)); + } + dst += stride; + } +} + +static highbd_dr_prediction_z2_ptr dr_predictor_z2_arr_neon[7][7] = { + { NULL, NULL, NULL, NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL, NULL, NULL, NULL }, + { NULL, NULL, &highbd_dr_prediction_z2_4x4_neon, + &highbd_dr_prediction_z2_4x8_neon, &highbd_dr_prediction_z2_4x16_neon, NULL, + NULL }, + { NULL, NULL, &highbd_dr_prediction_z2_8x4_neon, + &highbd_dr_prediction_z2_8x8_neon, &highbd_dr_prediction_z2_8x16_neon, + &highbd_dr_prediction_z2_8x32_neon, NULL }, + { NULL, NULL, &highbd_dr_prediction_z2_16x4_neon, + &highbd_dr_prediction_z2_16x8_neon, &highbd_dr_prediction_z2_16x16_neon, + &highbd_dr_prediction_z2_16x32_neon, &highbd_dr_prediction_z2_16x64_neon }, + { NULL, NULL, NULL, &highbd_dr_prediction_z2_32x8_neon, + &highbd_dr_prediction_z2_32x16_neon, &highbd_dr_prediction_z2_32x32_neon, + &highbd_dr_prediction_z2_32x64_neon }, + { NULL, NULL, NULL, NULL, &highbd_dr_prediction_z2_64x16_neon, + &highbd_dr_prediction_z2_64x32_neon, &highbd_dr_prediction_z2_64x64_neon }, +}; + +// Directional prediction, zone 2: 90 < angle < 180 +void av1_highbd_dr_prediction_z2_neon(uint16_t *dst, ptrdiff_t stride, int bw, + int bh, const uint16_t *above, + const uint16_t *left, int upsample_above, + int upsample_left, int dx, int dy, + int bd) { + highbd_dr_prediction_z2_ptr f = + dr_predictor_z2_arr_neon[get_msb(bw)][get_msb(bh)]; + assert(f != NULL); + f(dst, stride, above, left, upsample_above, upsample_left, dx, dy, bd); +} + +// ----------------------------------------------------------------------------- +// Z3 + +// Both the lane to the use and the shift amount must be immediates. +#define HIGHBD_DR_PREDICTOR_Z3_STEP_X4(out, iota, base, in0, in1, s0, s1, \ + lane, shift) \ + do { \ + uint32x4_t val = vmull_lane_u16((in0), (s0), (lane)); \ + val = vmlal_lane_u16(val, (in1), (s1), (lane)); \ + const uint16x4_t cmp = vadd_u16((iota), vdup_n_u16(base)); \ + const uint16x4_t res = vrshrn_n_u32(val, (shift)); \ + *(out) = vbsl_u16(vclt_u16(cmp, vdup_n_u16(max_base_y)), res, \ + vdup_n_u16(left_max)); \ + } while (0) + +#define HIGHBD_DR_PREDICTOR_Z3_STEP_X8(out, iota, base, in0, in1, s0, s1, \ + lane, shift) \ + do { \ + uint32x4_t val_lo = vmull_lane_u16(vget_low_u16(in0), (s0), (lane)); \ + val_lo = vmlal_lane_u16(val_lo, vget_low_u16(in1), (s1), (lane)); \ + uint32x4_t val_hi = vmull_lane_u16(vget_high_u16(in0), (s0), (lane)); \ + val_hi = vmlal_lane_u16(val_hi, vget_high_u16(in1), (s1), (lane)); \ + const uint16x8_t cmp = vaddq_u16((iota), vdupq_n_u16(base)); \ + const uint16x8_t res = vcombine_u16(vrshrn_n_u32(val_lo, (shift)), \ + vrshrn_n_u32(val_hi, (shift))); \ + *(out) = vbslq_u16(vcltq_u16(cmp, vdupq_n_u16(max_base_y)), res, \ + vdupq_n_u16(left_max)); \ + } while (0) + +static void highbd_dr_prediction_z3_upsample0_neon(uint16_t *dst, + ptrdiff_t stride, int bw, + int bh, const uint16_t *left, + int dy) { + assert(bw % 4 == 0); + assert(bh % 4 == 0); + assert(dy > 0); + + // Factor out left + 1 to give the compiler a better chance of recognising + // that the offsets used for the loads from left and left + 1 are otherwise + // identical. + const uint16_t *left1 = left + 1; + + const int max_base_y = (bw + bh - 1); + const int left_max = left[max_base_y]; + const int frac_bits = 6; + + const uint16x8_t iota1x8 = vreinterpretq_u16_s16(vld1q_s16(iota1_s16)); + const uint16x4_t iota1x4 = vget_low_u16(iota1x8); + + // The C implementation of the z3 predictor when not upsampling uses: + // ((y & 0x3f) >> 1) + // The right shift is unnecessary here since we instead shift by +1 later, + // so adjust the mask to 0x3e to ensure we don't consider the extra bit. + const uint16x4_t shift_mask = vdup_n_u16(0x3e); + + if (bh == 4) { + int y = dy; + int c = 0; + do { + // Fully unroll the 4x4 block to allow us to use immediate lane-indexed + // multiply instructions. + const uint16x4_t shifts1 = + vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask); + const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(64), shifts1); + const int base0 = (y + 0 * dy) >> frac_bits; + const int base1 = (y + 1 * dy) >> frac_bits; + const int base2 = (y + 2 * dy) >> frac_bits; + const int base3 = (y + 3 * dy) >> frac_bits; + uint16x4_t out[4]; + if (base0 >= max_base_y) { + out[0] = vdup_n_u16(left_max); + } else { + const uint16x4_t l00 = vld1_u16(left + base0); + const uint16x4_t l01 = vld1_u16(left1 + base0); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[0], iota1x4, base0, l00, l01, + shifts0, shifts1, 0, 6); + } + if (base1 >= max_base_y) { + out[1] = vdup_n_u16(left_max); + } else { + const uint16x4_t l10 = vld1_u16(left + base1); + const uint16x4_t l11 = vld1_u16(left1 + base1); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[1], iota1x4, base1, l10, l11, + shifts0, shifts1, 1, 6); + } + if (base2 >= max_base_y) { + out[2] = vdup_n_u16(left_max); + } else { + const uint16x4_t l20 = vld1_u16(left + base2); + const uint16x4_t l21 = vld1_u16(left1 + base2); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[2], iota1x4, base2, l20, l21, + shifts0, shifts1, 2, 6); + } + if (base3 >= max_base_y) { + out[3] = vdup_n_u16(left_max); + } else { + const uint16x4_t l30 = vld1_u16(left + base3); + const uint16x4_t l31 = vld1_u16(left1 + base3); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[3], iota1x4, base3, l30, l31, + shifts0, shifts1, 3, 6); + } + transpose_array_inplace_u16_4x4(out); + for (int r2 = 0; r2 < 4; ++r2) { + vst1_u16(dst + r2 * stride + c, out[r2]); + } + y += 4 * dy; + c += 4; + } while (c < bw); + } else { + int y = dy; + int c = 0; + do { + int r = 0; + do { + // Fully unroll the 4x4 block to allow us to use immediate lane-indexed + // multiply instructions. + const uint16x4_t shifts1 = + vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask); + const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(64), shifts1); + const int base0 = ((y + 0 * dy) >> frac_bits) + r; + const int base1 = ((y + 1 * dy) >> frac_bits) + r; + const int base2 = ((y + 2 * dy) >> frac_bits) + r; + const int base3 = ((y + 3 * dy) >> frac_bits) + r; + uint16x8_t out[4]; + if (base0 >= max_base_y) { + out[0] = vdupq_n_u16(left_max); + } else { + const uint16x8_t l00 = vld1q_u16(left + base0); + const uint16x8_t l01 = vld1q_u16(left1 + base0); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota1x8, base0, l00, l01, + shifts0, shifts1, 0, 6); + } + if (base1 >= max_base_y) { + out[1] = vdupq_n_u16(left_max); + } else { + const uint16x8_t l10 = vld1q_u16(left + base1); + const uint16x8_t l11 = vld1q_u16(left1 + base1); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota1x8, base1, l10, l11, + shifts0, shifts1, 1, 6); + } + if (base2 >= max_base_y) { + out[2] = vdupq_n_u16(left_max); + } else { + const uint16x8_t l20 = vld1q_u16(left + base2); + const uint16x8_t l21 = vld1q_u16(left1 + base2); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota1x8, base2, l20, l21, + shifts0, shifts1, 2, 6); + } + if (base3 >= max_base_y) { + out[3] = vdupq_n_u16(left_max); + } else { + const uint16x8_t l30 = vld1q_u16(left + base3); + const uint16x8_t l31 = vld1q_u16(left1 + base3); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota1x8, base3, l30, l31, + shifts0, shifts1, 3, 6); + } + transpose_array_inplace_u16_4x8(out); + for (int r2 = 0; r2 < 4; ++r2) { + vst1_u16(dst + (r + r2) * stride + c, vget_low_u16(out[r2])); + } + for (int r2 = 0; r2 < 4; ++r2) { + vst1_u16(dst + (r + r2 + 4) * stride + c, vget_high_u16(out[r2])); + } + r += 8; + } while (r < bh); + y += 4 * dy; + c += 4; + } while (c < bw); + } +} + +static void highbd_dr_prediction_z3_upsample1_neon(uint16_t *dst, + ptrdiff_t stride, int bw, + int bh, const uint16_t *left, + int dy) { + assert(bw % 4 == 0); + assert(bh % 4 == 0); + assert(dy > 0); + + const int max_base_y = (bw + bh - 1) << 1; + const int left_max = left[max_base_y]; + const int frac_bits = 5; + + const uint16x4_t iota1x4 = vreinterpret_u16_s16(vld1_s16(iota1_s16)); + const uint16x8_t iota2x8 = vreinterpretq_u16_s16(vld1q_s16(iota2_s16)); + const uint16x4_t iota2x4 = vget_low_u16(iota2x8); + + // The C implementation of the z3 predictor when upsampling uses: + // (((x << 1) & 0x3f) >> 1) + // The two shifts are unnecessary here since the lowest bit is guaranteed to + // be zero when the mask is applied, so adjust the mask to 0x1f to avoid + // needing the shifts at all. + const uint16x4_t shift_mask = vdup_n_u16(0x1F); + + if (bh == 4) { + int y = dy; + int c = 0; + do { + // Fully unroll the 4x4 block to allow us to use immediate lane-indexed + // multiply instructions. + const uint16x4_t shifts1 = + vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask); + const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(32), shifts1); + const int base0 = (y + 0 * dy) >> frac_bits; + const int base1 = (y + 1 * dy) >> frac_bits; + const int base2 = (y + 2 * dy) >> frac_bits; + const int base3 = (y + 3 * dy) >> frac_bits; + const uint16x4x2_t l0 = vld2_u16(left + base0); + const uint16x4x2_t l1 = vld2_u16(left + base1); + const uint16x4x2_t l2 = vld2_u16(left + base2); + const uint16x4x2_t l3 = vld2_u16(left + base3); + uint16x4_t out[4]; + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[0], iota2x4, base0, l0.val[0], + l0.val[1], shifts0, shifts1, 0, 5); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[1], iota2x4, base1, l1.val[0], + l1.val[1], shifts0, shifts1, 1, 5); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[2], iota2x4, base2, l2.val[0], + l2.val[1], shifts0, shifts1, 2, 5); + HIGHBD_DR_PREDICTOR_Z3_STEP_X4(&out[3], iota2x4, base3, l3.val[0], + l3.val[1], shifts0, shifts1, 3, 5); + transpose_array_inplace_u16_4x4(out); + for (int r2 = 0; r2 < 4; ++r2) { + vst1_u16(dst + r2 * stride + c, out[r2]); + } + y += 4 * dy; + c += 4; + } while (c < bw); + } else { + assert(bh % 8 == 0); + + int y = dy; + int c = 0; + do { + int r = 0; + do { + // Fully unroll the 4x8 block to allow us to use immediate lane-indexed + // multiply instructions. + const uint16x4_t shifts1 = + vand_u16(vmla_n_u16(vdup_n_u16(y), iota1x4, dy), shift_mask); + const uint16x4_t shifts0 = vsub_u16(vdup_n_u16(32), shifts1); + const int base0 = ((y + 0 * dy) >> frac_bits) + (r * 2); + const int base1 = ((y + 1 * dy) >> frac_bits) + (r * 2); + const int base2 = ((y + 2 * dy) >> frac_bits) + (r * 2); + const int base3 = ((y + 3 * dy) >> frac_bits) + (r * 2); + const uint16x8x2_t l0 = vld2q_u16(left + base0); + const uint16x8x2_t l1 = vld2q_u16(left + base1); + const uint16x8x2_t l2 = vld2q_u16(left + base2); + const uint16x8x2_t l3 = vld2q_u16(left + base3); + uint16x8_t out[4]; + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota2x8, base0, l0.val[0], + l0.val[1], shifts0, shifts1, 0, 5); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota2x8, base1, l1.val[0], + l1.val[1], shifts0, shifts1, 1, 5); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota2x8, base2, l2.val[0], + l2.val[1], shifts0, shifts1, 2, 5); + HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota2x8, base3, l3.val[0], + l3.val[1], shifts0, shifts1, 3, 5); + transpose_array_inplace_u16_4x8(out); + for (int r2 = 0; r2 < 4; ++r2) { + vst1_u16(dst + (r + r2) * stride + c, vget_low_u16(out[r2])); + } + for (int r2 = 0; r2 < 4; ++r2) { + vst1_u16(dst + (r + r2 + 4) * stride + c, vget_high_u16(out[r2])); + } + r += 8; + } while (r < bh); + y += 4 * dy; + c += 4; + } while (c < bw); + } +} + +// Directional prediction, zone 3: 180 < angle < 270 +void av1_highbd_dr_prediction_z3_neon(uint16_t *dst, ptrdiff_t stride, int bw, + int bh, const uint16_t *above, + const uint16_t *left, int upsample_left, + int dx, int dy, int bd) { + (void)above; + (void)dx; + (void)bd; + assert(bw % 4 == 0); + assert(bh % 4 == 0); + assert(dx == 1); + assert(dy > 0); + + if (upsample_left) { + highbd_dr_prediction_z3_upsample1_neon(dst, stride, bw, bh, left, dy); + } else { + highbd_dr_prediction_z3_upsample0_neon(dst, stride, bw, bh, left, dy); + } +} + +#undef HIGHBD_DR_PREDICTOR_Z3_STEP_X4 +#undef HIGHBD_DR_PREDICTOR_Z3_STEP_X8 diff --git a/third_party/aom/aom_dsp/arm/highbd_loopfilter_neon.c b/third_party/aom/aom_dsp/arm/highbd_loopfilter_neon.c new file mode 100644 index 0000000000..77727b7665 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_loopfilter_neon.c @@ -0,0 +1,1265 @@ +/* + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ +#include + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/transpose_neon.h" + +static INLINE int16x4_t clip3_s16(const int16x4_t val, const int16x4_t low, + const int16x4_t high) { + return vmin_s16(vmax_s16(val, low), high); +} + +static INLINE uint16x8_t convert_to_unsigned_pixel_u16(int16x8_t val, + int bitdepth) { + const int16x8_t low = vdupq_n_s16(0); + const uint16x8_t high = vdupq_n_u16((1 << bitdepth) - 1); + + return vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(val, low)), high); +} + +// (abs(p1 - p0) > thresh) || (abs(q1 - q0) > thresh) +static INLINE uint16x4_t hev(const uint16x8_t abd_p0p1_q0q1, + const uint16_t thresh) { + const uint16x8_t a = vcgtq_u16(abd_p0p1_q0q1, vdupq_n_u16(thresh)); + return vorr_u16(vget_low_u16(a), vget_high_u16(a)); +} + +// abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh +static INLINE uint16x4_t outer_threshold(const uint16x4_t p1, + const uint16x4_t p0, + const uint16x4_t q0, + const uint16x4_t q1, + const uint16_t outer_thresh) { + const uint16x4_t abd_p0q0 = vabd_u16(p0, q0); + const uint16x4_t abd_p1q1 = vabd_u16(p1, q1); + const uint16x4_t p0q0_double = vshl_n_u16(abd_p0q0, 1); + const uint16x4_t p1q1_half = vshr_n_u16(abd_p1q1, 1); + const uint16x4_t sum = vadd_u16(p0q0_double, p1q1_half); + return vcle_u16(sum, vdup_n_u16(outer_thresh)); +} + +// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh && +// outer_threshold() +static INLINE uint16x4_t needs_filter4(const uint16x8_t abd_p0p1_q0q1, + const uint16_t inner_thresh, + const uint16x4_t outer_mask) { + const uint16x8_t a = vcleq_u16(abd_p0p1_q0q1, vdupq_n_u16(inner_thresh)); + const uint16x4_t inner_mask = vand_u16(vget_low_u16(a), vget_high_u16(a)); + return vand_u16(inner_mask, outer_mask); +} + +// abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh && +// abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh && +// outer_threshold() +static INLINE uint16x4_t needs_filter6(const uint16x8_t abd_p0p1_q0q1, + const uint16x8_t abd_p1p2_q1q2, + const uint16_t inner_thresh, + const uint16x4_t outer_mask) { + const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p1p2_q1q2); + const uint16x8_t b = vcleq_u16(a, vdupq_n_u16(inner_thresh)); + const uint16x4_t inner_mask = vand_u16(vget_low_u16(b), vget_high_u16(b)); + return vand_u16(inner_mask, outer_mask); +} + +// abs(p3 - p2) <= inner_thresh && abs(p2 - p1) <= inner_thresh && +// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh && +// abs(q2 - q1) <= inner_thresh && abs(q3 - q2) <= inner_thresh +// outer_threshold() +static INLINE uint16x4_t needs_filter8(const uint16x8_t abd_p0p1_q0q1, + const uint16x8_t abd_p1p2_q1q2, + const uint16x8_t abd_p2p3_q2q3, + const uint16_t inner_thresh, + const uint16x4_t outer_mask) { + const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p1p2_q1q2); + const uint16x8_t b = vmaxq_u16(a, abd_p2p3_q2q3); + const uint16x8_t c = vcleq_u16(b, vdupq_n_u16(inner_thresh)); + const uint16x4_t inner_mask = vand_u16(vget_low_u16(c), vget_high_u16(c)); + return vand_u16(inner_mask, outer_mask); +} + +// ----------------------------------------------------------------------------- +// filterN_masks functions. + +static INLINE void filter4_masks(const uint16x8_t p0q0, const uint16x8_t p1q1, + const uint16_t hev_thresh, + const uint16x4_t outer_mask, + const uint16_t inner_thresh, + uint16x4_t *const hev_mask, + uint16x4_t *const needs_filter4_mask) { + const uint16x8_t p0p1_q0q1 = vabdq_u16(p0q0, p1q1); + // This includes cases where needs_filter4() is not true and so filter2() will + // not be applied. + const uint16x4_t hev_tmp_mask = hev(p0p1_q0q1, hev_thresh); + + *needs_filter4_mask = needs_filter4(p0p1_q0q1, inner_thresh, outer_mask); + + // filter2() will only be applied if both needs_filter4() and hev() are true. + *hev_mask = vand_u16(hev_tmp_mask, *needs_filter4_mask); +} + +// abs(p1 - p0) <= flat_thresh && abs(q1 - q0) <= flat_thresh && +// abs(p2 - p0) <= flat_thresh && abs(q2 - q0) <= flat_thresh +// |flat_thresh| == 4 for 10 bit decode. +static INLINE uint16x4_t is_flat3(const uint16x8_t abd_p0p1_q0q1, + const uint16x8_t abd_p0p2_q0q2, + const int bitdepth) { + const int flat_thresh = 1 << (bitdepth - 8); + const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p0p2_q0q2); + const uint16x8_t b = vcleq_u16(a, vdupq_n_u16(flat_thresh)); + return vand_u16(vget_low_u16(b), vget_high_u16(b)); +} + +static INLINE void filter6_masks( + const uint16x8_t p2q2, const uint16x8_t p1q1, const uint16x8_t p0q0, + const uint16_t hev_thresh, const uint16x4_t outer_mask, + const uint16_t inner_thresh, const int bitdepth, + uint16x4_t *const needs_filter6_mask, uint16x4_t *const is_flat3_mask, + uint16x4_t *const hev_mask) { + const uint16x8_t abd_p0p1_q0q1 = vabdq_u16(p0q0, p1q1); + *hev_mask = hev(abd_p0p1_q0q1, hev_thresh); + *is_flat3_mask = is_flat3(abd_p0p1_q0q1, vabdq_u16(p0q0, p2q2), bitdepth); + *needs_filter6_mask = needs_filter6(abd_p0p1_q0q1, vabdq_u16(p1q1, p2q2), + inner_thresh, outer_mask); +} + +// is_flat4 uses N=1, IsFlatOuter4 uses N=4. +// abs(p[N] - p0) <= flat_thresh && abs(q[N] - q0) <= flat_thresh && +// abs(p[N+1] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh && +// abs(p[N+2] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh +// |flat_thresh| == 4 for 10 bit decode. +static INLINE uint16x4_t is_flat4(const uint16x8_t abd_pnp0_qnq0, + const uint16x8_t abd_pn1p0_qn1q0, + const uint16x8_t abd_pn2p0_qn2q0, + const int bitdepth) { + const int flat_thresh = 1 << (bitdepth - 8); + const uint16x8_t a = vmaxq_u16(abd_pnp0_qnq0, abd_pn1p0_qn1q0); + const uint16x8_t b = vmaxq_u16(a, abd_pn2p0_qn2q0); + const uint16x8_t c = vcleq_u16(b, vdupq_n_u16(flat_thresh)); + return vand_u16(vget_low_u16(c), vget_high_u16(c)); +} + +static INLINE void filter8_masks( + const uint16x8_t p3q3, const uint16x8_t p2q2, const uint16x8_t p1q1, + const uint16x8_t p0q0, const uint16_t hev_thresh, + const uint16x4_t outer_mask, const uint16_t inner_thresh, + const int bitdepth, uint16x4_t *const needs_filter8_mask, + uint16x4_t *const is_flat4_mask, uint16x4_t *const hev_mask) { + const uint16x8_t abd_p0p1_q0q1 = vabdq_u16(p0q0, p1q1); + *hev_mask = hev(abd_p0p1_q0q1, hev_thresh); + const uint16x4_t v_is_flat4 = is_flat4(abd_p0p1_q0q1, vabdq_u16(p0q0, p2q2), + vabdq_u16(p0q0, p3q3), bitdepth); + *needs_filter8_mask = + needs_filter8(abd_p0p1_q0q1, vabdq_u16(p1q1, p2q2), vabdq_u16(p2q2, p3q3), + inner_thresh, outer_mask); + // |is_flat4_mask| is used to decide where to use the result of filter8. + // In rare cases, |is_flat4| can be true where |needs_filter8_mask| is false, + // overriding the question of whether to use filter8. Because filter4 doesn't + // apply to p2q2, |is_flat4_mask| chooses directly between filter8 and the + // source value. To be correct, the mask must account for this override. + *is_flat4_mask = vand_u16(v_is_flat4, *needs_filter8_mask); +} + +// ----------------------------------------------------------------------------- +// filterN functions. + +// Calculate filter4() or filter2() based on |hev_mask|. +static INLINE void filter4(const uint16x8_t p0q0, const uint16x8_t p0q1, + const uint16x8_t p1q1, const uint16x4_t hev_mask, + int bitdepth, uint16x8_t *const p1q1_result, + uint16x8_t *const p0q0_result) { + const uint16x8_t q0p1 = vextq_u16(p0q0, p1q1, 4); + // a = 3 * (q0 - p0) + Clip3(p1 - q1, min_signed_val, max_signed_val); + // q0mp0 means "q0 minus p0". + const int16x8_t q0mp0_p1mq1 = vreinterpretq_s16_u16(vsubq_u16(q0p1, p0q1)); + const int16x4_t q0mp0_3 = vmul_n_s16(vget_low_s16(q0mp0_p1mq1), 3); + + // If this is for filter2() then include |p1mq1|. Otherwise zero it. + const int16x4_t min_signed_pixel = vdup_n_s16(-(1 << (bitdepth - 1))); + const int16x4_t max_signed_pixel = vdup_n_s16((1 << (bitdepth - 1)) - 1); + const int16x4_t p1mq1 = vget_high_s16(q0mp0_p1mq1); + const int16x4_t p1mq1_saturated = + clip3_s16(p1mq1, min_signed_pixel, max_signed_pixel); + const int16x4_t hev_option = + vand_s16(vreinterpret_s16_u16(hev_mask), p1mq1_saturated); + + const int16x4_t a = vadd_s16(q0mp0_3, hev_option); + + // Need to figure out what's going on here because there are some unnecessary + // tricks to accommodate 8x8 as smallest 8bpp vector + + // We can not shift with rounding because the clamp comes *before* the + // shifting. a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; a2 = + // Clip3(a + 3, min_signed_val, max_signed_val) >> 3; + const int16x4_t plus_four = + clip3_s16(vadd_s16(a, vdup_n_s16(4)), min_signed_pixel, max_signed_pixel); + const int16x4_t plus_three = + clip3_s16(vadd_s16(a, vdup_n_s16(3)), min_signed_pixel, max_signed_pixel); + const int16x4_t a1 = vshr_n_s16(plus_four, 3); + const int16x4_t a2 = vshr_n_s16(plus_three, 3); + + // a3 = (a1 + 1) >> 1; + const int16x4_t a3 = vrshr_n_s16(a1, 1); + + const int16x8_t a3_ma3 = vcombine_s16(a3, vneg_s16(a3)); + const int16x8_t p1q1_a3 = vaddq_s16(vreinterpretq_s16_u16(p1q1), a3_ma3); + + // Need to shift the second term or we end up with a2_ma2. + const int16x8_t a2_ma1 = vcombine_s16(a2, vneg_s16(a1)); + const int16x8_t p0q0_a = vaddq_s16(vreinterpretq_s16_u16(p0q0), a2_ma1); + *p1q1_result = convert_to_unsigned_pixel_u16(p1q1_a3, bitdepth); + *p0q0_result = convert_to_unsigned_pixel_u16(p0q0_a, bitdepth); +} + +void aom_highbd_lpf_horizontal_4_neon(uint16_t *s, int pitch, + const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh, int bd) { + uint16_t *const dst_p1 = (uint16_t *)(s - 2 * pitch); + uint16_t *const dst_p0 = (uint16_t *)(s - pitch); + uint16_t *const dst_q0 = (uint16_t *)(s); + uint16_t *const dst_q1 = (uint16_t *)(s + pitch); + + const uint16x4_t src[4] = { vld1_u16(dst_p1), vld1_u16(dst_p0), + vld1_u16(dst_q0), vld1_u16(dst_q1) }; + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = + outer_threshold(src[0], src[1], src[2], src[3], outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter4_mask; + const uint16x8_t p0q0 = vcombine_u16(src[1], src[2]); + const uint16x8_t p1q1 = vcombine_u16(src[0], src[3]); + filter4_masks(p0q0, p1q1, hev_thresh, outer_mask, inner_thresh, &hev_mask, + &needs_filter4_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter4_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t needs_filter4_mask_8 = + vcombine_u16(needs_filter4_mask, needs_filter4_mask); + + uint16x8_t f_p1q1; + uint16x8_t f_p0q0; + const uint16x8_t p0q1 = vcombine_u16(src[1], src[3]); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f_p1q1, &f_p0q0); + + // Already integrated the hev mask when calculating the filtered values. + const uint16x8_t p0q0_output = vbslq_u16(needs_filter4_mask_8, f_p0q0, p0q0); + + // p1/q1 are unmodified if only hev() is true. This works because it was and'd + // with |needs_filter4_mask| previously. + const uint16x8_t p1q1_mask = veorq_u16(hev_mask_8, needs_filter4_mask_8); + const uint16x8_t p1q1_output = vbslq_u16(p1q1_mask, f_p1q1, p1q1); + + vst1_u16(dst_p1, vget_low_u16(p1q1_output)); + vst1_u16(dst_p0, vget_low_u16(p0q0_output)); + vst1_u16(dst_q0, vget_high_u16(p0q0_output)); + vst1_u16(dst_q1, vget_high_u16(p1q1_output)); +} + +void aom_highbd_lpf_horizontal_4_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_horizontal_4_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_horizontal_4_neon(s + 4, pitch, blimit1, limit1, thresh1, bd); +} + +void aom_highbd_lpf_vertical_4_neon(uint16_t *s, int pitch, + const uint8_t *blimit, const uint8_t *limit, + const uint8_t *thresh, int bd) { + // Offset by 2 uint16_t values to load from first p1 position. + uint16_t *dst = s - 2; + uint16_t *dst_p1 = dst; + uint16_t *dst_p0 = dst + pitch; + uint16_t *dst_q0 = dst + pitch * 2; + uint16_t *dst_q1 = dst + pitch * 3; + + uint16x4_t src[4] = { vld1_u16(dst_p1), vld1_u16(dst_p0), vld1_u16(dst_q0), + vld1_u16(dst_q1) }; + transpose_array_inplace_u16_4x4(src); + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = + outer_threshold(src[0], src[1], src[2], src[3], outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter4_mask; + const uint16x8_t p0q0 = vcombine_u16(src[1], src[2]); + const uint16x8_t p1q1 = vcombine_u16(src[0], src[3]); + filter4_masks(p0q0, p1q1, hev_thresh, outer_mask, inner_thresh, &hev_mask, + &needs_filter4_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter4_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t needs_filter4_mask_8 = + vcombine_u16(needs_filter4_mask, needs_filter4_mask); + + uint16x8_t f_p1q1; + uint16x8_t f_p0q0; + const uint16x8_t p0q1 = vcombine_u16(src[1], src[3]); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f_p1q1, &f_p0q0); + + // Already integrated the hev mask when calculating the filtered values. + const uint16x8_t p0q0_output = vbslq_u16(needs_filter4_mask_8, f_p0q0, p0q0); + + // p1/q1 are unmodified if only hev() is true. This works because it was and'd + // with |needs_filter4_mask| previously. + const uint16x8_t p1q1_mask = veorq_u16(hev_mask_8, needs_filter4_mask_8); + const uint16x8_t p1q1_output = vbslq_u16(p1q1_mask, f_p1q1, p1q1); + + uint16x4_t output[4] = { + vget_low_u16(p1q1_output), + vget_low_u16(p0q0_output), + vget_high_u16(p0q0_output), + vget_high_u16(p1q1_output), + }; + transpose_array_inplace_u16_4x4(output); + + vst1_u16(dst_p1, output[0]); + vst1_u16(dst_p0, output[1]); + vst1_u16(dst_q0, output[2]); + vst1_u16(dst_q1, output[3]); +} + +void aom_highbd_lpf_vertical_4_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_vertical_4_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_vertical_4_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1, + bd); +} + +static INLINE void filter6(const uint16x8_t p2q2, const uint16x8_t p1q1, + const uint16x8_t p0q0, uint16x8_t *const p1q1_output, + uint16x8_t *const p0q0_output) { + // Sum p1 and q1 output from opposite directions. + // The formula is regrouped to allow 3 doubling operations to be combined. + // + // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0 + // ^^^^^^^^ + // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2) + // ^^^^^^^^ + // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0 + // ^^^^^^^^^^^ + uint16x8_t sum = vaddq_u16(p2q2, p1q1); + + // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0 + // ^^^^^^ + sum = vaddq_u16(sum, p0q0); + + // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0 + // ^^^^^ + sum = vshlq_n_u16(sum, 1); + + // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0 + // ^^^^^^ ^^^^^^ + // Should dual issue with the left shift. + const uint16x8_t q0p0 = vextq_u16(p0q0, p0q0, 4); + const uint16x8_t outer_sum = vaddq_u16(p2q2, q0p0); + sum = vaddq_u16(sum, outer_sum); + + *p1q1_output = vrshrq_n_u16(sum, 3); + + // Convert to p0 and q0 output: + // p0 = p1 - (2 * p2) + q0 + q1 + // q0 = q1 - (2 * q2) + p0 + p1 + // p0q0 = p1q1 - (2 * p2q2) + q0p0 + q1p1 + // ^^^^^^^^ + const uint16x8_t p2q2_double = vshlq_n_u16(p2q2, 1); + // p0q0 = p1q1 - (2 * p2q2) + q0p0 + q1p1 + // ^^^^^^^^ + sum = vsubq_u16(sum, p2q2_double); + const uint16x8_t q1p1 = vextq_u16(p1q1, p1q1, 4); + sum = vaddq_u16(sum, vaddq_u16(q0p0, q1p1)); + + *p0q0_output = vrshrq_n_u16(sum, 3); +} + +void aom_highbd_lpf_horizontal_6_neon(uint16_t *s, int pitch, + const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh, int bd) { + uint16_t *const dst_p2 = s - 3 * pitch; + uint16_t *const dst_p1 = s - 2 * pitch; + uint16_t *const dst_p0 = s - pitch; + uint16_t *const dst_q0 = s; + uint16_t *const dst_q1 = s + pitch; + uint16_t *const dst_q2 = s + 2 * pitch; + + const uint16x4_t src[6] = { vld1_u16(dst_p2), vld1_u16(dst_p1), + vld1_u16(dst_p0), vld1_u16(dst_q0), + vld1_u16(dst_q1), vld1_u16(dst_q2) }; + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = + outer_threshold(src[1], src[2], src[3], src[4], outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter_mask; + uint16x4_t is_flat3_mask; + const uint16x8_t p0q0 = vcombine_u16(src[2], src[3]); + const uint16x8_t p1q1 = vcombine_u16(src[1], src[4]); + const uint16x8_t p2q2 = vcombine_u16(src[0], src[5]); + filter6_masks(p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh, bd, + &needs_filter_mask, &is_flat3_mask, &hev_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t is_flat3_mask_8 = vcombine_u16(is_flat3_mask, is_flat3_mask); + const uint16x8_t needs_filter_mask_8 = + vcombine_u16(needs_filter_mask, needs_filter_mask); + + uint16x8_t f4_p1q1; + uint16x8_t f4_p0q0; + // ZIP1 p0q0, p1q1 may perform better here. + const uint16x8_t p0q1 = vcombine_u16(src[2], src[4]); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f4_p1q1, &f4_p0q0); + f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1); + + uint16x8_t p0q0_output, p1q1_output; + // Because we did not return after testing |needs_filter_mask| we know it is + // nonzero. |is_flat3_mask| controls whether the needed filter is filter4 or + // filter6. Therefore if it is false when |needs_filter_mask| is true, filter6 + // output is not used. + uint16x8_t f6_p1q1, f6_p0q0; + const uint64x1_t need_filter6 = vreinterpret_u64_u16(is_flat3_mask); + if (vget_lane_u64(need_filter6, 0) == 0) { + // filter6() does not apply, but filter4() applies to one or more values. + p0q0_output = p0q0; + p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1); + p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0); + } else { + filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0); + p1q1_output = vbslq_u16(is_flat3_mask_8, f6_p1q1, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(is_flat3_mask_8, f6_p0q0, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } + + vst1_u16(dst_p1, vget_low_u16(p1q1_output)); + vst1_u16(dst_p0, vget_low_u16(p0q0_output)); + vst1_u16(dst_q0, vget_high_u16(p0q0_output)); + vst1_u16(dst_q1, vget_high_u16(p1q1_output)); +} + +void aom_highbd_lpf_horizontal_6_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_horizontal_6_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_horizontal_6_neon(s + 4, pitch, blimit1, limit1, thresh1, bd); +} + +void aom_highbd_lpf_vertical_6_neon(uint16_t *s, int pitch, + const uint8_t *blimit, const uint8_t *limit, + const uint8_t *thresh, int bd) { + // Left side of the filter window. + uint16_t *const dst = s - 3; + uint16_t *const dst_0 = dst; + uint16_t *const dst_1 = dst + pitch; + uint16_t *const dst_2 = dst + 2 * pitch; + uint16_t *const dst_3 = dst + 3 * pitch; + + // Overread by 2 values. These overreads become the high halves of src_raw[2] + // and src_raw[3] after transpose. + uint16x8_t src_raw[4] = { vld1q_u16(dst_0), vld1q_u16(dst_1), + vld1q_u16(dst_2), vld1q_u16(dst_3) }; + transpose_array_inplace_u16_4x8(src_raw); + // p2, p1, p0, q0, q1, q2 + const uint16x4_t src[6] = { + vget_low_u16(src_raw[0]), vget_low_u16(src_raw[1]), + vget_low_u16(src_raw[2]), vget_low_u16(src_raw[3]), + vget_high_u16(src_raw[0]), vget_high_u16(src_raw[1]), + }; + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = + outer_threshold(src[1], src[2], src[3], src[4], outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter_mask; + uint16x4_t is_flat3_mask; + const uint16x8_t p0q0 = vcombine_u16(src[2], src[3]); + const uint16x8_t p1q1 = vcombine_u16(src[1], src[4]); + const uint16x8_t p2q2 = vcombine_u16(src[0], src[5]); + filter6_masks(p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh, bd, + &needs_filter_mask, &is_flat3_mask, &hev_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t is_flat3_mask_8 = vcombine_u16(is_flat3_mask, is_flat3_mask); + const uint16x8_t needs_filter_mask_8 = + vcombine_u16(needs_filter_mask, needs_filter_mask); + + uint16x8_t f4_p1q1; + uint16x8_t f4_p0q0; + // ZIP1 p0q0, p1q1 may perform better here. + const uint16x8_t p0q1 = vcombine_u16(src[2], src[4]); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f4_p1q1, &f4_p0q0); + f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1); + + uint16x8_t p0q0_output, p1q1_output; + // Because we did not return after testing |needs_filter_mask| we know it is + // nonzero. |is_flat3_mask| controls whether the needed filter is filter4 or + // filter6. Therefore if it is false when |needs_filter_mask| is true, filter6 + // output is not used. + uint16x8_t f6_p1q1, f6_p0q0; + const uint64x1_t need_filter6 = vreinterpret_u64_u16(is_flat3_mask); + if (vget_lane_u64(need_filter6, 0) == 0) { + // filter6() does not apply, but filter4() applies to one or more values. + p0q0_output = p0q0; + p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1); + p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0); + } else { + filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0); + p1q1_output = vbslq_u16(is_flat3_mask_8, f6_p1q1, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(is_flat3_mask_8, f6_p0q0, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } + + uint16x4_t output[4] = { + vget_low_u16(p1q1_output), + vget_low_u16(p0q0_output), + vget_high_u16(p0q0_output), + vget_high_u16(p1q1_output), + }; + transpose_array_inplace_u16_4x4(output); + + // dst_n starts at p2, so adjust to p1. + vst1_u16(dst_0 + 1, output[0]); + vst1_u16(dst_1 + 1, output[1]); + vst1_u16(dst_2 + 1, output[2]); + vst1_u16(dst_3 + 1, output[3]); +} + +void aom_highbd_lpf_vertical_6_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_vertical_6_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_vertical_6_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1, + bd); +} + +static INLINE void filter8(const uint16x8_t p3q3, const uint16x8_t p2q2, + const uint16x8_t p1q1, const uint16x8_t p0q0, + uint16x8_t *const p2q2_output, + uint16x8_t *const p1q1_output, + uint16x8_t *const p0q0_output) { + // Sum p2 and q2 output from opposite directions. + // The formula is regrouped to allow 2 doubling operations to be combined. + // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0 + // ^^^^^^^^ + // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3) + // ^^^^^^^^ + // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0 + // ^^^^^^^^^^^ + const uint16x8_t p23q23 = vaddq_u16(p3q3, p2q2); + + // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0 + // ^^^^^ + uint16x8_t sum = vshlq_n_u16(p23q23, 1); + + // Add two other terms to make dual issue with shift more likely. + // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0 + // ^^^^^^^^^^^ + const uint16x8_t p01q01 = vaddq_u16(p0q0, p1q1); + + // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0 + // ^^^^^^^^^^^^^ + sum = vaddq_u16(sum, p01q01); + + // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0 + // ^^^^^^ + sum = vaddq_u16(sum, p3q3); + + // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0 + // ^^^^^^ + const uint16x8_t q0p0 = vextq_u16(p0q0, p0q0, 4); + sum = vaddq_u16(sum, q0p0); + + *p2q2_output = vrshrq_n_u16(sum, 3); + + // Convert to p1 and q1 output: + // p1 = p2 - p3 - p2 + p1 + q1 + // q1 = q2 - q3 - q2 + q0 + p1 + sum = vsubq_u16(sum, p23q23); + const uint16x8_t q1p1 = vextq_u16(p1q1, p1q1, 4); + sum = vaddq_u16(sum, vaddq_u16(p1q1, q1p1)); + + *p1q1_output = vrshrq_n_u16(sum, 3); + + // Convert to p0 and q0 output: + // p0 = p1 - p3 - p1 + p0 + q2 + // q0 = q1 - q3 - q1 + q0 + p2 + sum = vsubq_u16(sum, vaddq_u16(p3q3, p1q1)); + const uint16x8_t q2p2 = vextq_u16(p2q2, p2q2, 4); + sum = vaddq_u16(sum, vaddq_u16(p0q0, q2p2)); + + *p0q0_output = vrshrq_n_u16(sum, 3); +} + +void aom_highbd_lpf_horizontal_8_neon(uint16_t *s, int pitch, + const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh, int bd) { + uint16_t *const dst_p3 = s - 4 * pitch; + uint16_t *const dst_p2 = s - 3 * pitch; + uint16_t *const dst_p1 = s - 2 * pitch; + uint16_t *const dst_p0 = s - pitch; + uint16_t *const dst_q0 = s; + uint16_t *const dst_q1 = s + pitch; + uint16_t *const dst_q2 = s + 2 * pitch; + uint16_t *const dst_q3 = s + 3 * pitch; + + const uint16x4_t src[8] = { vld1_u16(dst_p3), vld1_u16(dst_p2), + vld1_u16(dst_p1), vld1_u16(dst_p0), + vld1_u16(dst_q0), vld1_u16(dst_q1), + vld1_u16(dst_q2), vld1_u16(dst_q3) }; + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = + outer_threshold(src[2], src[3], src[4], src[5], outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter_mask; + uint16x4_t is_flat4_mask; + const uint16x8_t p0q0 = vcombine_u16(src[3], src[4]); + const uint16x8_t p1q1 = vcombine_u16(src[2], src[5]); + const uint16x8_t p2q2 = vcombine_u16(src[1], src[6]); + const uint16x8_t p3q3 = vcombine_u16(src[0], src[7]); + filter8_masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh, + bd, &needs_filter_mask, &is_flat4_mask, &hev_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t needs_filter_mask_8 = + vcombine_u16(needs_filter_mask, needs_filter_mask); + + uint16x8_t f4_p1q1; + uint16x8_t f4_p0q0; + // ZIP1 p0q0, p1q1 may perform better here. + const uint16x8_t p0q1 = vcombine_u16(src[3], src[5]); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f4_p1q1, &f4_p0q0); + f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1); + + uint16x8_t p0q0_output, p1q1_output, p2q2_output; + // Because we did not return after testing |needs_filter_mask| we know it is + // nonzero. |is_flat4_mask| controls whether the needed filter is filter4 or + // filter8. Therefore if it is false when |needs_filter_mask| is true, filter8 + // output is not used. + uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0; + const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask); + if (vget_lane_u64(need_filter8, 0) == 0) { + // filter8() does not apply, but filter4() applies to one or more values. + p2q2_output = p2q2; + p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1); + p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0); + } else { + const uint16x8_t is_flat4_mask_8 = + vcombine_u16(is_flat4_mask, is_flat4_mask); + filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + p2q2_output = vbslq_u16(is_flat4_mask_8, f8_p2q2, p2q2); + p1q1_output = vbslq_u16(is_flat4_mask_8, f8_p1q1, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(is_flat4_mask_8, f8_p0q0, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } + + vst1_u16(dst_p2, vget_low_u16(p2q2_output)); + vst1_u16(dst_p1, vget_low_u16(p1q1_output)); + vst1_u16(dst_p0, vget_low_u16(p0q0_output)); + vst1_u16(dst_q0, vget_high_u16(p0q0_output)); + vst1_u16(dst_q1, vget_high_u16(p1q1_output)); + vst1_u16(dst_q2, vget_high_u16(p2q2_output)); +} + +void aom_highbd_lpf_horizontal_8_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_horizontal_8_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_horizontal_8_neon(s + 4, pitch, blimit1, limit1, thresh1, bd); +} + +static INLINE uint16x8_t reverse_low_half(const uint16x8_t a) { + return vcombine_u16(vrev64_u16(vget_low_u16(a)), vget_high_u16(a)); +} + +void aom_highbd_lpf_vertical_8_neon(uint16_t *s, int pitch, + const uint8_t *blimit, const uint8_t *limit, + const uint8_t *thresh, int bd) { + uint16_t *const dst = s - 4; + uint16_t *const dst_0 = dst; + uint16_t *const dst_1 = dst + pitch; + uint16_t *const dst_2 = dst + 2 * pitch; + uint16_t *const dst_3 = dst + 3 * pitch; + + // src_raw[n] contains p3, p2, p1, p0, q0, q1, q2, q3 for row n. + // To get desired pairs after transpose, one half should be reversed. + uint16x8_t src[4] = { vld1q_u16(dst_0), vld1q_u16(dst_1), vld1q_u16(dst_2), + vld1q_u16(dst_3) }; + + // src[0] = p0q0 + // src[1] = p1q1 + // src[2] = p2q2 + // src[3] = p3q3 + loop_filter_transpose_u16_4x8q(src); + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = outer_threshold( + vget_low_u16(src[1]), vget_low_u16(src[0]), vget_high_u16(src[0]), + vget_high_u16(src[1]), outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter_mask; + uint16x4_t is_flat4_mask; + const uint16x8_t p0q0 = src[0]; + const uint16x8_t p1q1 = src[1]; + const uint16x8_t p2q2 = src[2]; + const uint16x8_t p3q3 = src[3]; + filter8_masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh, + bd, &needs_filter_mask, &is_flat4_mask, &hev_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t needs_filter_mask_8 = + vcombine_u16(needs_filter_mask, needs_filter_mask); + + uint16x8_t f4_p1q1; + uint16x8_t f4_p0q0; + const uint16x8_t p0q1 = vcombine_u16(vget_low_u16(p0q0), vget_high_u16(p1q1)); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f4_p1q1, &f4_p0q0); + f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1); + + uint16x8_t p0q0_output, p1q1_output, p2q2_output; + // Because we did not return after testing |needs_filter_mask| we know it is + // nonzero. |is_flat4_mask| controls whether the needed filter is filter4 or + // filter8. Therefore if it is false when |needs_filter_mask| is true, filter8 + // output is not used. + const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask); + if (vget_lane_u64(need_filter8, 0) == 0) { + // filter8() does not apply, but filter4() applies to one or more values. + p2q2_output = p2q2; + p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1); + p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0); + } else { + const uint16x8_t is_flat4_mask_8 = + vcombine_u16(is_flat4_mask, is_flat4_mask); + uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0; + filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + p2q2_output = vbslq_u16(is_flat4_mask_8, f8_p2q2, p2q2); + p1q1_output = vbslq_u16(is_flat4_mask_8, f8_p1q1, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(is_flat4_mask_8, f8_p0q0, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } + + uint16x8_t output[4] = { p0q0_output, p1q1_output, p2q2_output, p3q3 }; + // After transpose, |output| will contain rows of the form: + // p0 p1 p2 p3 q0 q1 q2 q3 + transpose_array_inplace_u16_4x8(output); + + // Reverse p values to produce original order: + // p3 p2 p1 p0 q0 q1 q2 q3 + vst1q_u16(dst_0, reverse_low_half(output[0])); + vst1q_u16(dst_1, reverse_low_half(output[1])); + vst1q_u16(dst_2, reverse_low_half(output[2])); + vst1q_u16(dst_3, reverse_low_half(output[3])); +} + +void aom_highbd_lpf_vertical_8_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_vertical_8_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_vertical_8_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1, + bd); +} + +static INLINE void filter14( + const uint16x8_t p6q6, const uint16x8_t p5q5, const uint16x8_t p4q4, + const uint16x8_t p3q3, const uint16x8_t p2q2, const uint16x8_t p1q1, + const uint16x8_t p0q0, uint16x8_t *const p5q5_output, + uint16x8_t *const p4q4_output, uint16x8_t *const p3q3_output, + uint16x8_t *const p2q2_output, uint16x8_t *const p1q1_output, + uint16x8_t *const p0q0_output) { + // Sum p5 and q5 output from opposite directions. + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^ + const uint16x8_t p6q6_x7 = vsubq_u16(vshlq_n_u16(p6q6, 3), p6q6); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^^^^^^^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^^^^^^^^^^^^^ + uint16x8_t sum = vshlq_n_u16(vaddq_u16(p5q5, p4q4), 1); + sum = vaddq_u16(sum, p6q6_x7); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^ + sum = vaddq_u16(vaddq_u16(p3q3, p2q2), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^^^^^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^^^^^^ + sum = vaddq_u16(vaddq_u16(p1q1, p0q0), sum); + + // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0 + // ^^ + // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6) + // ^^ + const uint16x8_t q0p0 = vextq_u16(p0q0, p0q0, 4); + sum = vaddq_u16(sum, q0p0); + + *p5q5_output = vrshrq_n_u16(sum, 4); + + // Convert to p4 and q4 output: + // p4 = p5 - (2 * p6) + p3 + q1 + // q4 = q5 - (2 * q6) + q3 + p1 + sum = vsubq_u16(sum, vshlq_n_u16(p6q6, 1)); + const uint16x8_t q1p1 = vextq_u16(p1q1, p1q1, 4); + sum = vaddq_u16(vaddq_u16(p3q3, q1p1), sum); + + *p4q4_output = vrshrq_n_u16(sum, 4); + + // Convert to p3 and q3 output: + // p3 = p4 - p6 - p5 + p2 + q2 + // q3 = q4 - q6 - q5 + q2 + p2 + sum = vsubq_u16(sum, vaddq_u16(p6q6, p5q5)); + const uint16x8_t q2p2 = vextq_u16(p2q2, p2q2, 4); + sum = vaddq_u16(vaddq_u16(p2q2, q2p2), sum); + + *p3q3_output = vrshrq_n_u16(sum, 4); + + // Convert to p2 and q2 output: + // p2 = p3 - p6 - p4 + p1 + q3 + // q2 = q3 - q6 - q4 + q1 + p3 + sum = vsubq_u16(sum, vaddq_u16(p6q6, p4q4)); + const uint16x8_t q3p3 = vextq_u16(p3q3, p3q3, 4); + sum = vaddq_u16(vaddq_u16(p1q1, q3p3), sum); + + *p2q2_output = vrshrq_n_u16(sum, 4); + + // Convert to p1 and q1 output: + // p1 = p2 - p6 - p3 + p0 + q4 + // q1 = q2 - q6 - q3 + q0 + p4 + sum = vsubq_u16(sum, vaddq_u16(p6q6, p3q3)); + const uint16x8_t q4p4 = vextq_u16(p4q4, p4q4, 4); + sum = vaddq_u16(vaddq_u16(p0q0, q4p4), sum); + + *p1q1_output = vrshrq_n_u16(sum, 4); + + // Convert to p0 and q0 output: + // p0 = p1 - p6 - p2 + q0 + q5 + // q0 = q1 - q6 - q2 + p0 + p5 + sum = vsubq_u16(sum, vaddq_u16(p6q6, p2q2)); + const uint16x8_t q5p5 = vextq_u16(p5q5, p5q5, 4); + sum = vaddq_u16(vaddq_u16(q0p0, q5p5), sum); + + *p0q0_output = vrshrq_n_u16(sum, 4); +} + +void aom_highbd_lpf_horizontal_14_neon(uint16_t *s, int pitch, + const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh, int bd) { + uint16_t *const dst_p6 = s - 7 * pitch; + uint16_t *const dst_p5 = s - 6 * pitch; + uint16_t *const dst_p4 = s - 5 * pitch; + uint16_t *const dst_p3 = s - 4 * pitch; + uint16_t *const dst_p2 = s - 3 * pitch; + uint16_t *const dst_p1 = s - 2 * pitch; + uint16_t *const dst_p0 = s - pitch; + uint16_t *const dst_q0 = s; + uint16_t *const dst_q1 = s + pitch; + uint16_t *const dst_q2 = s + 2 * pitch; + uint16_t *const dst_q3 = s + 3 * pitch; + uint16_t *const dst_q4 = s + 4 * pitch; + uint16_t *const dst_q5 = s + 5 * pitch; + uint16_t *const dst_q6 = s + 6 * pitch; + + const uint16x4_t src[14] = { + vld1_u16(dst_p6), vld1_u16(dst_p5), vld1_u16(dst_p4), vld1_u16(dst_p3), + vld1_u16(dst_p2), vld1_u16(dst_p1), vld1_u16(dst_p0), vld1_u16(dst_q0), + vld1_u16(dst_q1), vld1_u16(dst_q2), vld1_u16(dst_q3), vld1_u16(dst_q4), + vld1_u16(dst_q5), vld1_u16(dst_q6) + }; + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = + outer_threshold(src[5], src[6], src[7], src[8], outer_thresh); + uint16x4_t hev_mask; + uint16x4_t needs_filter_mask; + uint16x4_t is_flat4_mask; + const uint16x8_t p0q0 = vcombine_u16(src[6], src[7]); + const uint16x8_t p1q1 = vcombine_u16(src[5], src[8]); + const uint16x8_t p2q2 = vcombine_u16(src[4], src[9]); + const uint16x8_t p3q3 = vcombine_u16(src[3], src[10]); + filter8_masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh, + bd, &needs_filter_mask, &is_flat4_mask, &hev_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + const uint16x8_t p4q4 = vcombine_u16(src[2], src[11]); + const uint16x8_t p5q5 = vcombine_u16(src[1], src[12]); + const uint16x8_t p6q6 = vcombine_u16(src[0], src[13]); + // Mask to choose between the outputs of filter8 and filter14. + // As with the derivation of |is_flat4_mask|, the question of whether to use + // filter14 is only raised where |is_flat4_mask| is true. + const uint16x4_t is_flat4_outer_mask = vand_u16( + is_flat4_mask, is_flat4(vabdq_u16(p0q0, p4q4), vabdq_u16(p0q0, p5q5), + vabdq_u16(p0q0, p6q6), bd)); + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t needs_filter_mask_8 = + vcombine_u16(needs_filter_mask, needs_filter_mask); + + uint16x8_t f4_p1q1; + uint16x8_t f4_p0q0; + // ZIP1 p0q0, p1q1 may perform better here. + const uint16x8_t p0q1 = vcombine_u16(src[6], src[8]); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f4_p1q1, &f4_p0q0); + f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1); + + uint16x8_t p0q0_output, p1q1_output, p2q2_output, p3q3_output, p4q4_output, + p5q5_output; + // Because we did not return after testing |needs_filter_mask| we know it is + // nonzero. |is_flat4_mask| controls whether the needed filter is filter4 or + // filter8. Therefore if it is false when |needs_filter_mask| is true, filter8 + // output is not used. + uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0; + const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask); + if (vget_lane_u64(need_filter8, 0) == 0) { + // filter8() and filter14() do not apply, but filter4() applies to one or + // more values. + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = p2q2; + p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1); + p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0); + } else { + const uint16x8_t use_filter8_mask = + vcombine_u16(is_flat4_mask, is_flat4_mask); + filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + const uint64x1_t need_filter14 = vreinterpret_u64_u16(is_flat4_outer_mask); + if (vget_lane_u64(need_filter14, 0) == 0) { + // filter14() does not apply, but filter8() and filter4() apply to one or + // more values. + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = vbslq_u16(use_filter8_mask, f8_p2q2, p2q2); + p1q1_output = vbslq_u16(use_filter8_mask, f8_p1q1, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(use_filter8_mask, f8_p0q0, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } else { + // All filters may contribute values to final outputs. + const uint16x8_t use_filter14_mask = + vcombine_u16(is_flat4_outer_mask, is_flat4_outer_mask); + uint16x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0; + filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4, + &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0); + p5q5_output = vbslq_u16(use_filter14_mask, f14_p5q5, p5q5); + p4q4_output = vbslq_u16(use_filter14_mask, f14_p4q4, p4q4); + p3q3_output = vbslq_u16(use_filter14_mask, f14_p3q3, p3q3); + p2q2_output = vbslq_u16(use_filter14_mask, f14_p2q2, f8_p2q2); + p2q2_output = vbslq_u16(use_filter8_mask, p2q2_output, p2q2); + p2q2_output = vbslq_u16(needs_filter_mask_8, p2q2_output, p2q2); + p1q1_output = vbslq_u16(use_filter14_mask, f14_p1q1, f8_p1q1); + p1q1_output = vbslq_u16(use_filter8_mask, p1q1_output, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(use_filter14_mask, f14_p0q0, f8_p0q0); + p0q0_output = vbslq_u16(use_filter8_mask, p0q0_output, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } + } + + vst1_u16(dst_p5, vget_low_u16(p5q5_output)); + vst1_u16(dst_p4, vget_low_u16(p4q4_output)); + vst1_u16(dst_p3, vget_low_u16(p3q3_output)); + vst1_u16(dst_p2, vget_low_u16(p2q2_output)); + vst1_u16(dst_p1, vget_low_u16(p1q1_output)); + vst1_u16(dst_p0, vget_low_u16(p0q0_output)); + vst1_u16(dst_q0, vget_high_u16(p0q0_output)); + vst1_u16(dst_q1, vget_high_u16(p1q1_output)); + vst1_u16(dst_q2, vget_high_u16(p2q2_output)); + vst1_u16(dst_q3, vget_high_u16(p3q3_output)); + vst1_u16(dst_q4, vget_high_u16(p4q4_output)); + vst1_u16(dst_q5, vget_high_u16(p5q5_output)); +} + +void aom_highbd_lpf_horizontal_14_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_horizontal_14_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_horizontal_14_neon(s + 4, pitch, blimit1, limit1, thresh1, bd); +} + +static INLINE uint16x8x2_t permute_acdb64(const uint16x8_t ab, + const uint16x8_t cd) { + uint16x8x2_t acdb; +#if AOM_ARCH_AARCH64 + // a[b] <- [c]d + acdb.val[0] = vreinterpretq_u16_u64( + vtrn1q_u64(vreinterpretq_u64_u16(ab), vreinterpretq_u64_u16(cd))); + // [a]b <- c[d] + acdb.val[1] = vreinterpretq_u16_u64( + vtrn2q_u64(vreinterpretq_u64_u16(cd), vreinterpretq_u64_u16(ab))); +#else + // a[b] <- [c]d + acdb.val[0] = vreinterpretq_u16_u64( + vsetq_lane_u64(vgetq_lane_u64(vreinterpretq_u64_u16(cd), 0), + vreinterpretq_u64_u16(ab), 1)); + // [a]b <- c[d] + acdb.val[1] = vreinterpretq_u16_u64( + vsetq_lane_u64(vgetq_lane_u64(vreinterpretq_u64_u16(cd), 1), + vreinterpretq_u64_u16(ab), 0)); +#endif // AOM_ARCH_AARCH64 + return acdb; +} + +void aom_highbd_lpf_vertical_14_neon(uint16_t *s, int pitch, + const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh, int bd) { + uint16_t *const dst = s - 8; + uint16_t *const dst_0 = dst; + uint16_t *const dst_1 = dst + pitch; + uint16_t *const dst_2 = dst + 2 * pitch; + uint16_t *const dst_3 = dst + 3 * pitch; + + // Low halves: p7 p6 p5 p4 + // High halves: p3 p2 p1 p0 + uint16x8_t src_p[4] = { vld1q_u16(dst_0), vld1q_u16(dst_1), vld1q_u16(dst_2), + vld1q_u16(dst_3) }; + // p7 will be the low half of src_p[0]. Not used until the end. + transpose_array_inplace_u16_4x8(src_p); + + // Low halves: q0 q1 q2 q3 + // High halves: q4 q5 q6 q7 + uint16x8_t src_q[4] = { vld1q_u16(dst_0 + 8), vld1q_u16(dst_1 + 8), + vld1q_u16(dst_2 + 8), vld1q_u16(dst_3 + 8) }; + // q7 will be the high half of src_q[3]. Not used until the end. + transpose_array_inplace_u16_4x8(src_q); + + // Adjust thresholds to bitdepth. + const int outer_thresh = *blimit << (bd - 8); + const int inner_thresh = *limit << (bd - 8); + const int hev_thresh = *thresh << (bd - 8); + const uint16x4_t outer_mask = outer_threshold( + vget_high_u16(src_p[2]), vget_high_u16(src_p[3]), vget_low_u16(src_q[0]), + vget_low_u16(src_q[1]), outer_thresh); + const uint16x8_t p0q0 = vextq_u16(src_p[3], src_q[0], 4); + const uint16x8_t p1q1 = vextq_u16(src_p[2], src_q[1], 4); + const uint16x8_t p2q2 = vextq_u16(src_p[1], src_q[2], 4); + const uint16x8_t p3q3 = vextq_u16(src_p[0], src_q[3], 4); + uint16x4_t hev_mask; + uint16x4_t needs_filter_mask; + uint16x4_t is_flat4_mask; + filter8_masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh, + bd, &needs_filter_mask, &is_flat4_mask, &hev_mask); + +#if AOM_ARCH_AARCH64 + if (vaddv_u16(needs_filter_mask) == 0) { + // None of the values will be filtered. + return; + } +#endif // AOM_ARCH_AARCH64 + const uint16x8_t p4q4 = + vcombine_u16(vget_low_u16(src_p[3]), vget_high_u16(src_q[0])); + const uint16x8_t p5q5 = + vcombine_u16(vget_low_u16(src_p[2]), vget_high_u16(src_q[1])); + const uint16x8_t p6q6 = + vcombine_u16(vget_low_u16(src_p[1]), vget_high_u16(src_q[2])); + const uint16x8_t p7q7 = + vcombine_u16(vget_low_u16(src_p[0]), vget_high_u16(src_q[3])); + // Mask to choose between the outputs of filter8 and filter14. + // As with the derivation of |is_flat4_mask|, the question of whether to use + // filter14 is only raised where |is_flat4_mask| is true. + const uint16x4_t is_flat4_outer_mask = vand_u16( + is_flat4_mask, is_flat4(vabdq_u16(p0q0, p4q4), vabdq_u16(p0q0, p5q5), + vabdq_u16(p0q0, p6q6), bd)); + // Copy the masks to the high bits for packed comparisons later. + const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask); + const uint16x8_t needs_filter_mask_8 = + vcombine_u16(needs_filter_mask, needs_filter_mask); + + uint16x8_t f4_p1q1; + uint16x8_t f4_p0q0; + const uint16x8_t p0q1 = vcombine_u16(vget_low_u16(p0q0), vget_high_u16(p1q1)); + filter4(p0q0, p0q1, p1q1, hev_mask, bd, &f4_p1q1, &f4_p0q0); + f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1); + + uint16x8_t p0q0_output, p1q1_output, p2q2_output, p3q3_output, p4q4_output, + p5q5_output; + // Because we did not return after testing |needs_filter_mask| we know it is + // nonzero. |is_flat4_mask| controls whether the needed filter is filter4 or + // filter8. Therefore if it is false when |needs_filter_mask| is true, filter8 + // output is not used. + uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0; + const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask); + if (vget_lane_u64(need_filter8, 0) == 0) { + // filter8() and filter14() do not apply, but filter4() applies to one or + // more values. + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = p2q2; + p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1); + p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0); + } else { + const uint16x8_t use_filter8_mask = + vcombine_u16(is_flat4_mask, is_flat4_mask); + filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0); + const uint64x1_t need_filter14 = vreinterpret_u64_u16(is_flat4_outer_mask); + if (vget_lane_u64(need_filter14, 0) == 0) { + // filter14() does not apply, but filter8() and filter4() apply to one or + // more values. + p5q5_output = p5q5; + p4q4_output = p4q4; + p3q3_output = p3q3; + p2q2_output = vbslq_u16(use_filter8_mask, f8_p2q2, p2q2); + p1q1_output = vbslq_u16(use_filter8_mask, f8_p1q1, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(use_filter8_mask, f8_p0q0, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } else { + // All filters may contribute values to final outputs. + const uint16x8_t use_filter14_mask = + vcombine_u16(is_flat4_outer_mask, is_flat4_outer_mask); + uint16x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0; + filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4, + &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0); + p5q5_output = vbslq_u16(use_filter14_mask, f14_p5q5, p5q5); + p4q4_output = vbslq_u16(use_filter14_mask, f14_p4q4, p4q4); + p3q3_output = vbslq_u16(use_filter14_mask, f14_p3q3, p3q3); + p2q2_output = vbslq_u16(use_filter14_mask, f14_p2q2, f8_p2q2); + p2q2_output = vbslq_u16(use_filter8_mask, p2q2_output, p2q2); + p2q2_output = vbslq_u16(needs_filter_mask_8, p2q2_output, p2q2); + p1q1_output = vbslq_u16(use_filter14_mask, f14_p1q1, f8_p1q1); + p1q1_output = vbslq_u16(use_filter8_mask, p1q1_output, f4_p1q1); + p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1); + p0q0_output = vbslq_u16(use_filter14_mask, f14_p0q0, f8_p0q0); + p0q0_output = vbslq_u16(use_filter8_mask, p0q0_output, f4_p0q0); + p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0); + } + } + // To get the correctly ordered rows from the transpose, we need: + // p7p3 p6p2 p5p1 p4p0 + // q0q4 q1q5 q2q6 q3q7 + const uint16x8x2_t p7p3_q3q7 = permute_acdb64(p7q7, p3q3_output); + const uint16x8x2_t p6p2_q2q6 = permute_acdb64(p6q6, p2q2_output); + const uint16x8x2_t p5p1_q1q5 = permute_acdb64(p5q5_output, p1q1_output); + const uint16x8x2_t p4p0_q0q4 = permute_acdb64(p4q4_output, p0q0_output); + uint16x8_t output_p[4] = { p7p3_q3q7.val[0], p6p2_q2q6.val[0], + p5p1_q1q5.val[0], p4p0_q0q4.val[0] }; + transpose_array_inplace_u16_4x8(output_p); + uint16x8_t output_q[4] = { p4p0_q0q4.val[1], p5p1_q1q5.val[1], + p6p2_q2q6.val[1], p7p3_q3q7.val[1] }; + transpose_array_inplace_u16_4x8(output_q); + + // Reverse p values to produce original order: + // p3 p2 p1 p0 q0 q1 q2 q3 + vst1q_u16(dst_0, output_p[0]); + vst1q_u16(dst_0 + 8, output_q[0]); + vst1q_u16(dst_1, output_p[1]); + vst1q_u16(dst_1 + 8, output_q[1]); + vst1q_u16(dst_2, output_p[2]); + vst1q_u16(dst_2 + 8, output_q[2]); + vst1q_u16(dst_3, output_p[3]); + vst1q_u16(dst_3 + 8, output_q[3]); +} + +void aom_highbd_lpf_vertical_14_dual_neon( + uint16_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1, int bd) { + aom_highbd_lpf_vertical_14_neon(s, pitch, blimit0, limit0, thresh0, bd); + aom_highbd_lpf_vertical_14_neon(s + 4 * pitch, pitch, blimit1, limit1, + thresh1, bd); +} diff --git a/third_party/aom/aom_dsp/arm/highbd_masked_sad_neon.c b/third_party/aom/aom_dsp/arm/highbd_masked_sad_neon.c new file mode 100644 index 0000000000..9262d818e9 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_masked_sad_neon.c @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/blend.h" + +static INLINE uint16x8_t masked_sad_8x1_neon(uint16x8_t sad, + const uint16_t *src, + const uint16_t *a, + const uint16_t *b, + const uint8_t *m) { + const uint16x8_t s0 = vld1q_u16(src); + const uint16x8_t a0 = vld1q_u16(a); + const uint16x8_t b0 = vld1q_u16(b); + const uint16x8_t m0 = vmovl_u8(vld1_u8(m)); + + uint16x8_t blend_u16 = alpha_blend_a64_u16x8(m0, a0, b0); + + return vaddq_u16(sad, vabdq_u16(blend_u16, s0)); +} + +static INLINE uint16x8_t masked_sad_16x1_neon(uint16x8_t sad, + const uint16_t *src, + const uint16_t *a, + const uint16_t *b, + const uint8_t *m) { + sad = masked_sad_8x1_neon(sad, src, a, b, m); + return masked_sad_8x1_neon(sad, &src[8], &a[8], &b[8], &m[8]); +} + +static INLINE uint16x8_t masked_sad_32x1_neon(uint16x8_t sad, + const uint16_t *src, + const uint16_t *a, + const uint16_t *b, + const uint8_t *m) { + sad = masked_sad_16x1_neon(sad, src, a, b, m); + return masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]); +} + +static INLINE unsigned int masked_sad_128xh_large_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint32x4_t sad_u32[] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + do { + uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + for (int h = 0; h < 4; ++h) { + sad[0] = masked_sad_32x1_neon(sad[0], src, a, b, m); + sad[1] = masked_sad_32x1_neon(sad[1], &src[32], &a[32], &b[32], &m[32]); + sad[2] = masked_sad_32x1_neon(sad[2], &src[64], &a[64], &b[64], &m[64]); + sad[3] = masked_sad_32x1_neon(sad[3], &src[96], &a[96], &b[96], &m[96]); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } + + sad_u32[0] = vpadalq_u16(sad_u32[0], sad[0]); + sad_u32[1] = vpadalq_u16(sad_u32[1], sad[1]); + sad_u32[2] = vpadalq_u16(sad_u32[2], sad[2]); + sad_u32[3] = vpadalq_u16(sad_u32[3], sad[3]); + height -= 4; + } while (height != 0); + + sad_u32[0] = vaddq_u32(sad_u32[0], sad_u32[1]); + sad_u32[2] = vaddq_u32(sad_u32[2], sad_u32[3]); + sad_u32[0] = vaddq_u32(sad_u32[0], sad_u32[2]); + + return horizontal_add_u32x4(sad_u32[0]); +} + +static INLINE unsigned int masked_sad_64xh_large_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint32x4_t sad_u32[] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + do { + uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0) }; + for (int h = 0; h < 4; ++h) { + sad[0] = masked_sad_32x1_neon(sad[0], src, a, b, m); + sad[1] = masked_sad_32x1_neon(sad[1], &src[32], &a[32], &b[32], &m[32]); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } + + sad_u32[0] = vpadalq_u16(sad_u32[0], sad[0]); + sad_u32[1] = vpadalq_u16(sad_u32[1], sad[1]); + height -= 4; + } while (height != 0); + + return horizontal_add_u32x4(vaddq_u32(sad_u32[0], sad_u32[1])); +} + +static INLINE unsigned int masked_sad_32xh_large_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint32x4_t sad_u32 = vdupq_n_u32(0); + + do { + uint16x8_t sad = vdupq_n_u16(0); + for (int h = 0; h < 4; ++h) { + sad = masked_sad_32x1_neon(sad, src, a, b, m); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } + + sad_u32 = vpadalq_u16(sad_u32, sad); + height -= 4; + } while (height != 0); + + return horizontal_add_u32x4(sad_u32); +} + +static INLINE unsigned int masked_sad_16xh_large_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint32x4_t sad_u32 = vdupq_n_u32(0); + + do { + uint16x8_t sad_u16 = vdupq_n_u16(0); + + for (int h = 0; h < 8; ++h) { + sad_u16 = masked_sad_16x1_neon(sad_u16, src, a, b, m); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } + + sad_u32 = vpadalq_u16(sad_u32, sad_u16); + height -= 8; + } while (height != 0); + + return horizontal_add_u32x4(sad_u32); +} + +#if !CONFIG_REALTIME_ONLY +static INLINE unsigned int masked_sad_8xh_large_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint32x4_t sad_u32 = vdupq_n_u32(0); + + do { + uint16x8_t sad_u16 = vdupq_n_u16(0); + + for (int h = 0; h < 16; ++h) { + sad_u16 = masked_sad_8x1_neon(sad_u16, src, a, b, m); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } + + sad_u32 = vpadalq_u16(sad_u32, sad_u16); + height -= 16; + } while (height != 0); + + return horizontal_add_u32x4(sad_u32); +} +#endif // !CONFIG_REALTIME_ONLY + +static INLINE unsigned int masked_sad_16xh_small_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + // For 12-bit data, we can only accumulate up to 128 elements in the + // uint16x8_t type sad accumulator, so we can only process up to 8 rows + // before we have to accumulate into 32-bit elements. + assert(height <= 8); + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint16x8_t sad = vdupq_n_u16(0); + + do { + sad = masked_sad_16x1_neon(sad, src, a, b, m); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } while (--height != 0); + + return horizontal_add_u16x8(sad); +} + +static INLINE unsigned int masked_sad_8xh_small_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + // For 12-bit data, we can only accumulate up to 128 elements in the + // uint16x8_t type sad accumulator, so we can only process up to 16 rows + // before we have to accumulate into 32-bit elements. + assert(height <= 16); + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + uint16x8_t sad = vdupq_n_u16(0); + + do { + sad = masked_sad_8x1_neon(sad, src, a, b, m); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } while (--height != 0); + + return horizontal_add_u16x8(sad); +} + +static INLINE unsigned int masked_sad_4xh_small_neon( + const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride, + const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride, + int height) { + // For 12-bit data, we can only accumulate up to 64 elements in the + // uint16x4_t type sad accumulator, so we can only process up to 16 rows + // before we have to accumulate into 32-bit elements. + assert(height <= 16); + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *a = CONVERT_TO_SHORTPTR(a8); + const uint16_t *b = CONVERT_TO_SHORTPTR(b8); + + uint16x4_t sad = vdup_n_u16(0); + do { + uint16x4_t m0 = vget_low_u16(vmovl_u8(load_unaligned_u8_4x1(m))); + uint16x4_t a0 = load_unaligned_u16_4x1(a); + uint16x4_t b0 = load_unaligned_u16_4x1(b); + uint16x4_t s0 = load_unaligned_u16_4x1(src); + + uint16x4_t blend_u16 = alpha_blend_a64_u16x4(m0, a0, b0); + + sad = vadd_u16(sad, vabd_u16(blend_u16, s0)); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + } while (--height != 0); + + return horizontal_add_u16x4(sad); +} + +#define HIGHBD_MASKED_SAD_WXH_SMALL_NEON(w, h) \ + unsigned int aom_highbd_masked_sad##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \ + int invert_mask) { \ + if (!invert_mask) \ + return masked_sad_##w##xh_small_neon(src, src_stride, ref, ref_stride, \ + second_pred, w, msk, msk_stride, \ + h); \ + else \ + return masked_sad_##w##xh_small_neon(src, src_stride, second_pred, w, \ + ref, ref_stride, msk, msk_stride, \ + h); \ + } + +#define HIGHBD_MASKED_SAD_WXH_LARGE_NEON(w, h) \ + unsigned int aom_highbd_masked_sad##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \ + int invert_mask) { \ + if (!invert_mask) \ + return masked_sad_##w##xh_large_neon(src, src_stride, ref, ref_stride, \ + second_pred, w, msk, msk_stride, \ + h); \ + else \ + return masked_sad_##w##xh_large_neon(src, src_stride, second_pred, w, \ + ref, ref_stride, msk, msk_stride, \ + h); \ + } + +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 4) +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 8) + +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 4) +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 8) +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 16) + +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(16, 8) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 16) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 32) + +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 16) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 32) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 64) + +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 32) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 64) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 128) + +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(128, 64) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 16) + +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(8, 32) + +HIGHBD_MASKED_SAD_WXH_SMALL_NEON(16, 4) +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 64) + +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 8) + +HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY diff --git a/third_party/aom/aom_dsp/arm/highbd_obmc_sad_neon.c b/third_party/aom/aom_dsp/arm/highbd_obmc_sad_neon.c new file mode 100644 index 0000000000..28699e6f41 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_obmc_sad_neon.c @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void highbd_obmc_sad_8x1_s16_neon(uint16x8_t ref, + const int32_t *mask, + const int32_t *wsrc, + uint32x4_t *sum) { + int16x8_t ref_s16 = vreinterpretq_s16_u16(ref); + + int32x4_t wsrc_lo = vld1q_s32(wsrc); + int32x4_t wsrc_hi = vld1q_s32(wsrc + 4); + + int32x4_t mask_lo = vld1q_s32(mask); + int32x4_t mask_hi = vld1q_s32(mask + 4); + + int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi)); + + int32x4_t pre_lo = vmull_s16(vget_low_s16(ref_s16), vget_low_s16(mask_s16)); + int32x4_t pre_hi = vmull_s16(vget_high_s16(ref_s16), vget_high_s16(mask_s16)); + + uint32x4_t abs_lo = vreinterpretq_u32_s32(vabdq_s32(wsrc_lo, pre_lo)); + uint32x4_t abs_hi = vreinterpretq_u32_s32(vabdq_s32(wsrc_hi, pre_hi)); + + *sum = vrsraq_n_u32(*sum, abs_lo, 12); + *sum = vrsraq_n_u32(*sum, abs_hi, 12); +} + +static INLINE unsigned int highbd_obmc_sad_4xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int height) { + const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref); + uint32x4_t sum = vdupq_n_u32(0); + + int h = height / 2; + do { + uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride); + + highbd_obmc_sad_8x1_s16_neon(r, mask, wsrc, &sum); + + ref_ptr += 2 * ref_stride; + wsrc += 8; + mask += 8; + } while (--h != 0); + + return horizontal_add_u32x4(sum); +} + +static INLINE unsigned int highbd_obmc_sad_8xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int height) { + const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref); + uint32x4_t sum = vdupq_n_u32(0); + + do { + uint16x8_t r = vld1q_u16(ref_ptr); + + highbd_obmc_sad_8x1_s16_neon(r, mask, wsrc, &sum); + + ref_ptr += ref_stride; + wsrc += 8; + mask += 8; + } while (--height != 0); + + return horizontal_add_u32x4(sum); +} + +static INLINE unsigned int highbd_obmc_sad_large_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int width, int height) { + const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref); + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + do { + int i = 0; + do { + uint16x8_t r0 = vld1q_u16(ref_ptr + i); + highbd_obmc_sad_8x1_s16_neon(r0, mask, wsrc, &sum[0]); + + uint16x8_t r1 = vld1q_u16(ref_ptr + i + 8); + highbd_obmc_sad_8x1_s16_neon(r1, mask + 8, wsrc + 8, &sum[1]); + + wsrc += 16; + mask += 16; + i += 16; + } while (i < width); + + ref_ptr += ref_stride; + } while (--height != 0); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1])); +} + +static INLINE unsigned int highbd_obmc_sad_16xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int h) { + return highbd_obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 16, h); +} + +static INLINE unsigned int highbd_obmc_sad_32xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int height) { + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref); + + do { + uint16x8_t r0 = vld1q_u16(ref_ptr); + uint16x8_t r1 = vld1q_u16(ref_ptr + 8); + uint16x8_t r2 = vld1q_u16(ref_ptr + 16); + uint16x8_t r3 = vld1q_u16(ref_ptr + 24); + + highbd_obmc_sad_8x1_s16_neon(r0, mask, wsrc, &sum[0]); + highbd_obmc_sad_8x1_s16_neon(r1, mask + 8, wsrc + 8, &sum[1]); + highbd_obmc_sad_8x1_s16_neon(r2, mask + 16, wsrc + 16, &sum[2]); + highbd_obmc_sad_8x1_s16_neon(r3, mask + 24, wsrc + 24, &sum[3]); + + wsrc += 32; + mask += 32; + ref_ptr += ref_stride; + } while (--height != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + sum[2] = vaddq_u32(sum[2], sum[3]); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[2])); +} + +static INLINE unsigned int highbd_obmc_sad_64xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int h) { + return highbd_obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 64, h); +} + +static INLINE unsigned int highbd_obmc_sad_128xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, + int h) { + return highbd_obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 128, h); +} + +#define HIGHBD_OBMC_SAD_WXH_NEON(w, h) \ + unsigned int aom_highbd_obmc_sad##w##x##h##_neon( \ + const uint8_t *ref, int ref_stride, const int32_t *wsrc, \ + const int32_t *mask) { \ + return highbd_obmc_sad_##w##xh_neon(ref, ref_stride, wsrc, mask, h); \ + } + +HIGHBD_OBMC_SAD_WXH_NEON(4, 4) +HIGHBD_OBMC_SAD_WXH_NEON(4, 8) + +HIGHBD_OBMC_SAD_WXH_NEON(8, 4) +HIGHBD_OBMC_SAD_WXH_NEON(8, 8) +HIGHBD_OBMC_SAD_WXH_NEON(8, 16) + +HIGHBD_OBMC_SAD_WXH_NEON(16, 8) +HIGHBD_OBMC_SAD_WXH_NEON(16, 16) +HIGHBD_OBMC_SAD_WXH_NEON(16, 32) + +HIGHBD_OBMC_SAD_WXH_NEON(32, 16) +HIGHBD_OBMC_SAD_WXH_NEON(32, 32) +HIGHBD_OBMC_SAD_WXH_NEON(32, 64) + +HIGHBD_OBMC_SAD_WXH_NEON(64, 32) +HIGHBD_OBMC_SAD_WXH_NEON(64, 64) +HIGHBD_OBMC_SAD_WXH_NEON(64, 128) + +HIGHBD_OBMC_SAD_WXH_NEON(128, 64) +HIGHBD_OBMC_SAD_WXH_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HIGHBD_OBMC_SAD_WXH_NEON(4, 16) + +HIGHBD_OBMC_SAD_WXH_NEON(8, 32) + +HIGHBD_OBMC_SAD_WXH_NEON(16, 4) +HIGHBD_OBMC_SAD_WXH_NEON(16, 64) + +HIGHBD_OBMC_SAD_WXH_NEON(32, 8) + +HIGHBD_OBMC_SAD_WXH_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY diff --git a/third_party/aom/aom_dsp/arm/highbd_obmc_variance_neon.c b/third_party/aom/aom_dsp/arm/highbd_obmc_variance_neon.c new file mode 100644 index 0000000000..d59224619b --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_obmc_variance_neon.c @@ -0,0 +1,369 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre, + const int32_t *wsrc, + const int32_t *mask, + uint32x4_t *sse, + int32x4_t *sum) { + int16x8_t pre_s16 = vreinterpretq_s16_u16(pre); + int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]); + int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]); + + int32x4_t mask_lo = vld1q_s32(&mask[0]); + int32x4_t mask_hi = vld1q_s32(&mask[4]); + + int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi)); + + int32x4_t diff_lo = vmull_s16(vget_low_s16(pre_s16), vget_low_s16(mask_s16)); + int32x4_t diff_hi = + vmull_s16(vget_high_s16(pre_s16), vget_high_s16(mask_s16)); + + diff_lo = vsubq_s32(wsrc_lo, diff_lo); + diff_hi = vsubq_s32(wsrc_hi, diff_hi); + + // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away + // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. + // This difference only affects the bit patterns at the rounding breakpoints + // exactly, so we can add -1 to all negative numbers to move the breakpoint + // one value across and into the correct rounding region. + diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31); + diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31); + int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12); + int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12); + + *sum = vaddq_s32(*sum, round_lo); + *sum = vaddq_s32(*sum, round_hi); + *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_lo), + vreinterpretq_u32_s32(round_lo)); + *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_hi), + vreinterpretq_u32_s32(round_hi)); +} + +// For 12-bit data, we can only accumulate up to 256 elements in the unsigned +// 32-bit elements (4095*4095*256 = 4292870400) before we have to accumulate +// into 64-bit elements. Therefore blocks of size 32x64, 64x32, 64x64, 64x128, +// 128x64, 128x128 are processed in a different helper function. +static INLINE void highbd_obmc_variance_xlarge_neon( + const uint8_t *pre, int pre_stride, const int32_t *wsrc, + const int32_t *mask, int width, int h, int h_limit, uint64_t *sse, + int64_t *sum) { + uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre); + int32x4_t sum_s32 = vdupq_n_s32(0); + uint64x2_t sse_u64 = vdupq_n_u64(0); + + // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit + // accumulator overflows. After hitting this limit we accumulate into 64-bit + // elements. + int h_tmp = h > h_limit ? h_limit : h; + + do { + uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + int j = 0; + + do { + int i = 0; + + do { + uint16x8_t pre0 = vld1q_u16(pre_ptr + i); + highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32[0], + &sum_s32); + + uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8); + highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32[1], + &sum_s32); + + i += 16; + wsrc += 16; + mask += 16; + } while (i < width); + + pre_ptr += pre_stride; + j++; + } while (j < h_tmp); + + sse_u64 = vpadalq_u32(sse_u64, sse_u32[0]); + sse_u64 = vpadalq_u32(sse_u64, sse_u32[1]); + h -= h_tmp; + } while (h != 0); + + *sse = horizontal_add_u64x2(sse_u64); + *sum = horizontal_long_add_s32x4(sum_s32); +} + +static INLINE void highbd_obmc_variance_xlarge_neon_128xh( + const uint8_t *pre, int pre_stride, const int32_t *wsrc, + const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 128, h, 16, sse, + sum); +} + +static INLINE void highbd_obmc_variance_xlarge_neon_64xh( + const uint8_t *pre, int pre_stride, const int32_t *wsrc, + const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 64, h, 32, sse, + sum); +} + +static INLINE void highbd_obmc_variance_xlarge_neon_32xh( + const uint8_t *pre, int pre_stride, const int32_t *wsrc, + const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 32, h, 64, sse, + sum); +} + +static INLINE void highbd_obmc_variance_large_neon( + const uint8_t *pre, int pre_stride, const int32_t *wsrc, + const int32_t *mask, int width, int h, uint64_t *sse, int64_t *sum) { + uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre); + uint32x4_t sse_u32 = vdupq_n_u32(0); + int32x4_t sum_s32 = vdupq_n_s32(0); + + do { + int i = 0; + do { + uint16x8_t pre0 = vld1q_u16(pre_ptr + i); + highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32, &sum_s32); + + uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8); + highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32, + &sum_s32); + + i += 16; + wsrc += 16; + mask += 16; + } while (i < width); + + pre_ptr += pre_stride; + } while (--h != 0); + + *sse = horizontal_long_add_u32x4(sse_u32); + *sum = horizontal_long_add_s32x4(sum_s32); +} + +static INLINE void highbd_obmc_variance_neon_128xh( + const uint8_t *pre, int pre_stride, const int32_t *wsrc, + const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, + sum); +} + +static INLINE void highbd_obmc_variance_neon_64xh(const uint8_t *pre, + int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum); +} + +static INLINE void highbd_obmc_variance_neon_32xh(const uint8_t *pre, + int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum); +} + +static INLINE void highbd_obmc_variance_neon_16xh(const uint8_t *pre, + int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + uint64_t *sse, int64_t *sum) { + highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum); +} + +static INLINE void highbd_obmc_variance_neon_8xh(const uint8_t *pre8, + int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + uint64_t *sse, int64_t *sum) { + uint16_t *pre = CONVERT_TO_SHORTPTR(pre8); + uint32x4_t sse_u32 = vdupq_n_u32(0); + int32x4_t sum_s32 = vdupq_n_s32(0); + + do { + uint16x8_t pre_u16 = vld1q_u16(pre); + + highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32); + + pre += pre_stride; + wsrc += 8; + mask += 8; + } while (--h != 0); + + *sse = horizontal_long_add_u32x4(sse_u32); + *sum = horizontal_long_add_s32x4(sum_s32); +} + +static INLINE void highbd_obmc_variance_neon_4xh(const uint8_t *pre8, + int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + uint64_t *sse, int64_t *sum) { + assert(h % 2 == 0); + uint16_t *pre = CONVERT_TO_SHORTPTR(pre8); + uint32x4_t sse_u32 = vdupq_n_u32(0); + int32x4_t sum_s32 = vdupq_n_s32(0); + + do { + uint16x8_t pre_u16 = load_unaligned_u16_4x2(pre, pre_stride); + + highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32); + + pre += 2 * pre_stride; + wsrc += 8; + mask += 8; + h -= 2; + } while (h != 0); + + *sse = horizontal_long_add_u32x4(sse_u32); + *sum = horizontal_long_add_s32x4(sum_s32); +} + +static INLINE void highbd_8_obmc_variance_cast(int64_t sum64, uint64_t sse64, + int *sum, unsigned int *sse) { + *sum = (int)sum64; + *sse = (unsigned int)sse64; +} + +static INLINE void highbd_10_obmc_variance_cast(int64_t sum64, uint64_t sse64, + int *sum, unsigned int *sse) { + *sum = (int)ROUND_POWER_OF_TWO(sum64, 2); + *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4); +} + +static INLINE void highbd_12_obmc_variance_cast(int64_t sum64, uint64_t sse64, + int *sum, unsigned int *sse) { + *sum = (int)ROUND_POWER_OF_TWO(sum64, 4); + *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8); +} + +#define HIGHBD_OBMC_VARIANCE_WXH_NEON(w, h, bitdepth) \ + unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + const uint8_t *pre, int pre_stride, const int32_t *wsrc, \ + const int32_t *mask, unsigned int *sse) { \ + int sum; \ + int64_t sum64; \ + uint64_t sse64; \ + highbd_obmc_variance_neon_##w##xh(pre, pre_stride, wsrc, mask, h, &sse64, \ + &sum64); \ + highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse); \ + return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h)); \ + } + +#define HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(w, h, bitdepth) \ + unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + const uint8_t *pre, int pre_stride, const int32_t *wsrc, \ + const int32_t *mask, unsigned int *sse) { \ + int sum; \ + int64_t sum64; \ + uint64_t sse64; \ + highbd_obmc_variance_xlarge_neon_##w##xh(pre, pre_stride, wsrc, mask, h, \ + &sse64, &sum64); \ + highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse); \ + return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h)); \ + } + +// 8-bit +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 8) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 8) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 8) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 8) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 8) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 8) +HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 8) + +// 10-bit +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 10) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 10) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 10) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 10) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 10) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 10) +HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 10) + +// 12-bit +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 12) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 12) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 12) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 12) +HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 12) +HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(32, 64, 12) + +HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 12) +HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 32, 12) +HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 64, 12) +HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 128, 12) + +HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 64, 12) +HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 128, 12) diff --git a/third_party/aom/aom_dsp/arm/highbd_quantize_neon.c b/third_party/aom/aom_dsp/arm/highbd_quantize_neon.c new file mode 100644 index 0000000000..6149c9f13e --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_quantize_neon.c @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include +#include + +#include "config/aom_config.h" + +#include "aom_dsp/quantize.h" + +static INLINE uint32_t sum_abs_coeff(const uint32x4_t a) { +#if AOM_ARCH_AARCH64 + return vaddvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + const uint64x1_t c = vadd_u64(vget_low_u64(b), vget_high_u64(b)); + return (uint32_t)vget_lane_u64(c, 0); +#endif +} + +static INLINE uint16x4_t +quantize_4(const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, int32x4_t v_quant_s32, + int32x4_t v_dequant_s32, int32x4_t v_round_s32, int32x4_t v_zbin_s32, + int32x4_t v_quant_shift_s32, int log_scale) { + const int32x4_t v_coeff = vld1q_s32(coeff_ptr); + const int32x4_t v_coeff_sign = + vreinterpretq_s32_u32(vcltq_s32(v_coeff, vdupq_n_s32(0))); + const int32x4_t v_abs_coeff = vabsq_s32(v_coeff); + // if (abs_coeff < zbins[rc != 0]), + const uint32x4_t v_zbin_mask = vcgeq_s32(v_abs_coeff, v_zbin_s32); + const int32x4_t v_log_scale = vdupq_n_s32(log_scale); + // const int64_t tmp = (int64_t)abs_coeff + log_scaled_round; + const int32x4_t v_tmp = vaddq_s32(v_abs_coeff, v_round_s32); + // const int32_t tmpw32 = tmp * wt; + const int32x4_t v_tmpw32 = vmulq_s32(v_tmp, vdupq_n_s32((1 << AOM_QM_BITS))); + // const int32_t tmp2 = (int32_t)((tmpw32 * quant64) >> 16); + const int32x4_t v_tmp2 = vqdmulhq_s32(v_tmpw32, v_quant_s32); + // const int32_t tmp3 = + // ((((tmp2 + tmpw32)<< log_scale) * (int64_t)(quant_shift << 15)) >> 32); + const int32x4_t v_tmp3 = vqdmulhq_s32( + vshlq_s32(vaddq_s32(v_tmp2, v_tmpw32), v_log_scale), v_quant_shift_s32); + // const int abs_qcoeff = vmask ? (int)tmp3 >> AOM_QM_BITS : 0; + const int32x4_t v_abs_qcoeff = vandq_s32(vreinterpretq_s32_u32(v_zbin_mask), + vshrq_n_s32(v_tmp3, AOM_QM_BITS)); + // const tran_low_t abs_dqcoeff = (abs_qcoeff * dequant_iwt) >> log_scale; + // vshlq_s32 will shift right if shift value is negative. + const int32x4_t v_abs_dqcoeff = + vshlq_s32(vmulq_s32(v_abs_qcoeff, v_dequant_s32), vnegq_s32(v_log_scale)); + // qcoeff_ptr[rc] = (tran_low_t)((abs_qcoeff ^ coeff_sign) - coeff_sign); + const int32x4_t v_qcoeff = + vsubq_s32(veorq_s32(v_abs_qcoeff, v_coeff_sign), v_coeff_sign); + // dqcoeff_ptr[rc] = (tran_low_t)((abs_dqcoeff ^ coeff_sign) - coeff_sign); + const int32x4_t v_dqcoeff = + vsubq_s32(veorq_s32(v_abs_dqcoeff, v_coeff_sign), v_coeff_sign); + + vst1q_s32(qcoeff_ptr, v_qcoeff); + vst1q_s32(dqcoeff_ptr, v_dqcoeff); + + // Used to find eob. + const uint32x4_t nz_qcoeff_mask = vcgtq_s32(v_abs_qcoeff, vdupq_n_s32(0)); + return vmovn_u32(nz_qcoeff_mask); +} + +static INLINE int16x8_t get_max_lane_eob(const int16_t *iscan, + int16x8_t v_eobmax, + uint16x8_t v_mask) { + const int16x8_t v_iscan = vld1q_s16(&iscan[0]); + const int16x8_t v_iscan_plus1 = vaddq_s16(v_iscan, vdupq_n_s16(1)); + const int16x8_t v_nz_iscan = vbslq_s16(v_mask, v_iscan_plus1, vdupq_n_s16(0)); + return vmaxq_s16(v_eobmax, v_nz_iscan); +} + +#if !CONFIG_REALTIME_ONLY +static INLINE void get_min_max_lane_eob(const int16_t *iscan, + int16x8_t *v_eobmin, + int16x8_t *v_eobmax, uint16x8_t v_mask, + intptr_t n_coeffs) { + const int16x8_t v_iscan = vld1q_s16(&iscan[0]); + const int16x8_t v_nz_iscan_max = vbslq_s16(v_mask, v_iscan, vdupq_n_s16(-1)); +#if SKIP_EOB_FACTOR_ADJUST + const int16x8_t v_nz_iscan_min = + vbslq_s16(v_mask, v_iscan, vdupq_n_s16((int16_t)n_coeffs)); + *v_eobmin = vminq_s16(*v_eobmin, v_nz_iscan_min); +#else + (void)v_eobmin; +#endif + *v_eobmax = vmaxq_s16(*v_eobmax, v_nz_iscan_max); +} +#endif // !CONFIG_REALTIME_ONLY + +static INLINE uint16_t get_max_eob(int16x8_t v_eobmax) { +#if AOM_ARCH_AARCH64 + return (uint16_t)vmaxvq_s16(v_eobmax); +#else + const int16x4_t v_eobmax_3210 = + vmax_s16(vget_low_s16(v_eobmax), vget_high_s16(v_eobmax)); + const int64x1_t v_eobmax_xx32 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmax_3210), 32); + const int16x4_t v_eobmax_tmp = + vmax_s16(v_eobmax_3210, vreinterpret_s16_s64(v_eobmax_xx32)); + const int64x1_t v_eobmax_xxx3 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmax_tmp), 16); + const int16x4_t v_eobmax_final = + vmax_s16(v_eobmax_tmp, vreinterpret_s16_s64(v_eobmax_xxx3)); + return (uint16_t)vget_lane_s16(v_eobmax_final, 0); +#endif +} + +#if SKIP_EOB_FACTOR_ADJUST && !CONFIG_REALTIME_ONLY +static INLINE uint16_t get_min_eob(int16x8_t v_eobmin) { +#if AOM_ARCH_AARCH64 + return (uint16_t)vminvq_s16(v_eobmin); +#else + const int16x4_t v_eobmin_3210 = + vmin_s16(vget_low_s16(v_eobmin), vget_high_s16(v_eobmin)); + const int64x1_t v_eobmin_xx32 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmin_3210), 32); + const int16x4_t v_eobmin_tmp = + vmin_s16(v_eobmin_3210, vreinterpret_s16_s64(v_eobmin_xx32)); + const int64x1_t v_eobmin_xxx3 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmin_tmp), 16); + const int16x4_t v_eobmin_final = + vmin_s16(v_eobmin_tmp, vreinterpret_s16_s64(v_eobmin_xxx3)); + return (uint16_t)vget_lane_s16(v_eobmin_final, 0); +#endif +} +#endif // SKIP_EOB_FACTOR_ADJUST && !CONFIG_REALTIME_ONLY + +static void highbd_quantize_b_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, const int log_scale) { + (void)scan; + const int16x4_t v_quant = vld1_s16(quant_ptr); + const int16x4_t v_dequant = vld1_s16(dequant_ptr); + const int16x4_t v_zero = vdup_n_s16(0); + const uint16x4_t v_round_select = vcgt_s16(vdup_n_s16(log_scale), v_zero); + const int16x4_t v_round_no_scale = vld1_s16(round_ptr); + const int16x4_t v_round_log_scale = + vqrdmulh_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale))); + const int16x4_t v_round = + vbsl_s16(v_round_select, v_round_log_scale, v_round_no_scale); + const int16x4_t v_quant_shift = vld1_s16(quant_shift_ptr); + const int16x4_t v_zbin_no_scale = vld1_s16(zbin_ptr); + const int16x4_t v_zbin_log_scale = + vqrdmulh_n_s16(v_zbin_no_scale, (int16_t)(1 << (15 - log_scale))); + const int16x4_t v_zbin = + vbsl_s16(v_round_select, v_zbin_log_scale, v_zbin_no_scale); + int32x4_t v_round_s32 = vmovl_s16(v_round); + int32x4_t v_quant_s32 = vshlq_n_s32(vmovl_s16(v_quant), 15); + int32x4_t v_dequant_s32 = vmovl_s16(v_dequant); + int32x4_t v_quant_shift_s32 = vshlq_n_s32(vmovl_s16(v_quant_shift), 15); + int32x4_t v_zbin_s32 = vmovl_s16(v_zbin); + uint16x4_t v_mask_lo, v_mask_hi; + int16x8_t v_eobmax = vdupq_n_s16(-1); + + intptr_t non_zero_count = n_coeffs; + + assert(n_coeffs > 8); + // Pre-scan pass + const int32x4_t v_zbin_s32x = vdupq_lane_s32(vget_low_s32(v_zbin_s32), 1); + intptr_t i = n_coeffs; + do { + const int32x4_t v_coeff_a = vld1q_s32(coeff_ptr + i - 4); + const int32x4_t v_coeff_b = vld1q_s32(coeff_ptr + i - 8); + const int32x4_t v_abs_coeff_a = vabsq_s32(v_coeff_a); + const int32x4_t v_abs_coeff_b = vabsq_s32(v_coeff_b); + const uint32x4_t v_mask_a = vcgeq_s32(v_abs_coeff_a, v_zbin_s32x); + const uint32x4_t v_mask_b = vcgeq_s32(v_abs_coeff_b, v_zbin_s32x); + // If the coefficient is in the base ZBIN range, then discard. + if (sum_abs_coeff(v_mask_a) + sum_abs_coeff(v_mask_b) == 0) { + non_zero_count -= 8; + } else { + break; + } + i -= 8; + } while (i > 0); + + const intptr_t remaining_zcoeffs = n_coeffs - non_zero_count; + memset(qcoeff_ptr + non_zero_count, 0, + remaining_zcoeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr + non_zero_count, 0, + remaining_zcoeffs * sizeof(*dqcoeff_ptr)); + + // DC and first 3 AC + v_mask_lo = + quantize_4(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant_s32, v_dequant_s32, + v_round_s32, v_zbin_s32, v_quant_shift_s32, log_scale); + + // overwrite the DC constants with AC constants + v_round_s32 = vdupq_lane_s32(vget_low_s32(v_round_s32), 1); + v_quant_s32 = vdupq_lane_s32(vget_low_s32(v_quant_s32), 1); + v_dequant_s32 = vdupq_lane_s32(vget_low_s32(v_dequant_s32), 1); + v_quant_shift_s32 = vdupq_lane_s32(vget_low_s32(v_quant_shift_s32), 1); + v_zbin_s32 = vdupq_lane_s32(vget_low_s32(v_zbin_s32), 1); + + // 4 more AC + v_mask_hi = quantize_4(coeff_ptr + 4, qcoeff_ptr + 4, dqcoeff_ptr + 4, + v_quant_s32, v_dequant_s32, v_round_s32, v_zbin_s32, + v_quant_shift_s32, log_scale); + + v_eobmax = + get_max_lane_eob(iscan, v_eobmax, vcombine_u16(v_mask_lo, v_mask_hi)); + + intptr_t count = non_zero_count - 8; + for (; count > 0; count -= 8) { + coeff_ptr += 8; + qcoeff_ptr += 8; + dqcoeff_ptr += 8; + iscan += 8; + v_mask_lo = quantize_4(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant_s32, + v_dequant_s32, v_round_s32, v_zbin_s32, + v_quant_shift_s32, log_scale); + v_mask_hi = quantize_4(coeff_ptr + 4, qcoeff_ptr + 4, dqcoeff_ptr + 4, + v_quant_s32, v_dequant_s32, v_round_s32, v_zbin_s32, + v_quant_shift_s32, log_scale); + // Find the max lane eob for 8 coeffs. + v_eobmax = + get_max_lane_eob(iscan, v_eobmax, vcombine_u16(v_mask_lo, v_mask_hi)); + } + + *eob_ptr = get_max_eob(v_eobmax); +} + +void aom_highbd_quantize_b_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *zbin_ptr, + const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + highbd_quantize_b_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, quant_ptr, + quant_shift_ptr, qcoeff_ptr, dqcoeff_ptr, dequant_ptr, + eob_ptr, scan, iscan, 0); +} + +void aom_highbd_quantize_b_32x32_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + highbd_quantize_b_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, quant_ptr, + quant_shift_ptr, qcoeff_ptr, dqcoeff_ptr, dequant_ptr, + eob_ptr, scan, iscan, 1); +} + +void aom_highbd_quantize_b_64x64_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + highbd_quantize_b_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, quant_ptr, + quant_shift_ptr, qcoeff_ptr, dqcoeff_ptr, dequant_ptr, + eob_ptr, scan, iscan, 2); +} + +#if !CONFIG_REALTIME_ONLY +static void highbd_quantize_b_adaptive_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, const int log_scale) { + (void)scan; + const int16x4_t v_quant = vld1_s16(quant_ptr); + const int16x4_t v_dequant = vld1_s16(dequant_ptr); + const int16x4_t v_zero = vdup_n_s16(0); + const uint16x4_t v_round_select = vcgt_s16(vdup_n_s16(log_scale), v_zero); + const int16x4_t v_round_no_scale = vld1_s16(round_ptr); + const int16x4_t v_round_log_scale = + vqrdmulh_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale))); + const int16x4_t v_round = + vbsl_s16(v_round_select, v_round_log_scale, v_round_no_scale); + const int16x4_t v_quant_shift = vld1_s16(quant_shift_ptr); + const int16x4_t v_zbin_no_scale = vld1_s16(zbin_ptr); + const int16x4_t v_zbin_log_scale = + vqrdmulh_n_s16(v_zbin_no_scale, (int16_t)(1 << (15 - log_scale))); + const int16x4_t v_zbin = + vbsl_s16(v_round_select, v_zbin_log_scale, v_zbin_no_scale); + int32x4_t v_round_s32 = vmovl_s16(v_round); + int32x4_t v_quant_s32 = vshlq_n_s32(vmovl_s16(v_quant), 15); + int32x4_t v_dequant_s32 = vmovl_s16(v_dequant); + int32x4_t v_quant_shift_s32 = vshlq_n_s32(vmovl_s16(v_quant_shift), 15); + int32x4_t v_zbin_s32 = vmovl_s16(v_zbin); + uint16x4_t v_mask_lo, v_mask_hi; + int16x8_t v_eobmax = vdupq_n_s16(-1); + int16x8_t v_eobmin = vdupq_n_s16((int16_t)n_coeffs); + + assert(n_coeffs > 8); + // Pre-scan pass + const int32x4_t v_zbin_s32x = vdupq_lane_s32(vget_low_s32(v_zbin_s32), 1); + const int prescan_add_1 = + ROUND_POWER_OF_TWO(dequant_ptr[1] * EOB_FACTOR, 7 + AOM_QM_BITS); + const int32x4_t v_zbin_prescan = + vaddq_s32(v_zbin_s32x, vdupq_n_s32(prescan_add_1)); + intptr_t non_zero_count = n_coeffs; + intptr_t i = n_coeffs; + do { + const int32x4_t v_coeff_a = vld1q_s32(coeff_ptr + i - 4); + const int32x4_t v_coeff_b = vld1q_s32(coeff_ptr + i - 8); + const int32x4_t v_abs_coeff_a = vabsq_s32(v_coeff_a); + const int32x4_t v_abs_coeff_b = vabsq_s32(v_coeff_b); + const uint32x4_t v_mask_a = vcgeq_s32(v_abs_coeff_a, v_zbin_prescan); + const uint32x4_t v_mask_b = vcgeq_s32(v_abs_coeff_b, v_zbin_prescan); + // If the coefficient is in the base ZBIN range, then discard. + if (sum_abs_coeff(v_mask_a) + sum_abs_coeff(v_mask_b) == 0) { + non_zero_count -= 8; + } else { + break; + } + i -= 8; + } while (i > 0); + + const intptr_t remaining_zcoeffs = n_coeffs - non_zero_count; + memset(qcoeff_ptr + non_zero_count, 0, + remaining_zcoeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr + non_zero_count, 0, + remaining_zcoeffs * sizeof(*dqcoeff_ptr)); + + // DC and first 3 AC + v_mask_lo = + quantize_4(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant_s32, v_dequant_s32, + v_round_s32, v_zbin_s32, v_quant_shift_s32, log_scale); + + // overwrite the DC constants with AC constants + v_round_s32 = vdupq_lane_s32(vget_low_s32(v_round_s32), 1); + v_quant_s32 = vdupq_lane_s32(vget_low_s32(v_quant_s32), 1); + v_dequant_s32 = vdupq_lane_s32(vget_low_s32(v_dequant_s32), 1); + v_quant_shift_s32 = vdupq_lane_s32(vget_low_s32(v_quant_shift_s32), 1); + v_zbin_s32 = vdupq_lane_s32(vget_low_s32(v_zbin_s32), 1); + + // 4 more AC + v_mask_hi = quantize_4(coeff_ptr + 4, qcoeff_ptr + 4, dqcoeff_ptr + 4, + v_quant_s32, v_dequant_s32, v_round_s32, v_zbin_s32, + v_quant_shift_s32, log_scale); + + get_min_max_lane_eob(iscan, &v_eobmin, &v_eobmax, + vcombine_u16(v_mask_lo, v_mask_hi), n_coeffs); + + intptr_t count = non_zero_count - 8; + for (; count > 0; count -= 8) { + coeff_ptr += 8; + qcoeff_ptr += 8; + dqcoeff_ptr += 8; + iscan += 8; + v_mask_lo = quantize_4(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant_s32, + v_dequant_s32, v_round_s32, v_zbin_s32, + v_quant_shift_s32, log_scale); + v_mask_hi = quantize_4(coeff_ptr + 4, qcoeff_ptr + 4, dqcoeff_ptr + 4, + v_quant_s32, v_dequant_s32, v_round_s32, v_zbin_s32, + v_quant_shift_s32, log_scale); + + get_min_max_lane_eob(iscan, &v_eobmin, &v_eobmax, + vcombine_u16(v_mask_lo, v_mask_hi), n_coeffs); + } + + int eob = get_max_eob(v_eobmax); + +#if SKIP_EOB_FACTOR_ADJUST + const int first = get_min_eob(v_eobmin); + if (eob >= 0 && first == eob) { + const int rc = scan[eob]; + if (qcoeff_ptr[rc] == 1 || qcoeff_ptr[rc] == -1) { + const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale), + ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) }; + const int nzbins[2] = { zbins[0] * -1, zbins[1] * -1 }; + const qm_val_t wt = (1 << AOM_QM_BITS); + const int coeff = coeff_ptr[rc] * wt; + const int factor = EOB_FACTOR + SKIP_EOB_FACTOR_ADJUST; + const int prescan_add_val = + ROUND_POWER_OF_TWO(dequant_ptr[rc != 0] * factor, 7); + if (coeff < (zbins[rc != 0] * (1 << AOM_QM_BITS) + prescan_add_val) && + coeff > (nzbins[rc != 0] * (1 << AOM_QM_BITS) - prescan_add_val)) { + qcoeff_ptr[rc] = 0; + dqcoeff_ptr[rc] = 0; + eob = -1; + } + } + } +#endif // SKIP_EOB_FACTOR_ADJUST + *eob_ptr = eob + 1; +} + +void aom_highbd_quantize_b_adaptive_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + highbd_quantize_b_adaptive_neon( + coeff_ptr, n_coeffs, zbin_ptr, round_ptr, quant_ptr, quant_shift_ptr, + qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, 0); +} + +void aom_highbd_quantize_b_32x32_adaptive_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + highbd_quantize_b_adaptive_neon( + coeff_ptr, n_coeffs, zbin_ptr, round_ptr, quant_ptr, quant_shift_ptr, + qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, 1); +} + +void aom_highbd_quantize_b_64x64_adaptive_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + highbd_quantize_b_adaptive_neon( + coeff_ptr, n_coeffs, zbin_ptr, round_ptr, quant_ptr, quant_shift_ptr, + qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, 2); +} +#endif // !CONFIG_REALTIME_ONLY diff --git a/third_party/aom/aom_dsp/arm/highbd_sad_neon.c b/third_party/aom/aom_dsp/arm/highbd_sad_neon.c new file mode 100644 index 0000000000..d51f639de6 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_sad_neon.c @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE uint32_t highbd_sad4xh_small_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + uint32x4_t sum = vdupq_n_u32(0); + + int i = h; + do { + uint16x4_t s = vld1_u16(src16_ptr); + uint16x4_t r = vld1_u16(ref16_ptr); + sum = vabal_u16(sum, s, r); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(sum); +} + +static INLINE uint32_t highbd_sad8xh_small_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + uint16x8_t sum = vdupq_n_u16(0); + + int i = h; + do { + uint16x8_t s = vld1q_u16(src16_ptr); + uint16x8_t r = vld1q_u16(ref16_ptr); + sum = vabaq_u16(sum, s, r); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +#if !CONFIG_REALTIME_ONLY +static INLINE uint32_t highbd_sad8xh_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + uint32x4_t sum_u32 = vdupq_n_u32(0); + + int i = h; + do { + uint16x8_t s = vld1q_u16(src16_ptr); + uint16x8_t r = vld1q_u16(ref16_ptr); + uint16x8_t sum_u16 = vabdq_u16(s, r); + sum_u32 = vpadalq_u16(sum_u32, sum_u16); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(sum_u32); +} +#endif // !CONFIG_REALTIME_ONLY + +static INLINE uint32_t highbd_sad16xh_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h; + do { + uint16x8_t s0 = vld1q_u16(src16_ptr); + uint16x8_t r0 = vld1q_u16(ref16_ptr); + uint16x8_t diff0 = vabdq_u16(s0, r0); + sum[0] = vpadalq_u16(sum[0], diff0); + + uint16x8_t s1 = vld1q_u16(src16_ptr + 8); + uint16x8_t r1 = vld1q_u16(ref16_ptr + 8); + uint16x8_t diff1 = vabdq_u16(s1, r1); + sum[1] = vpadalq_u16(sum[1], diff1); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + } while (--i != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + return horizontal_add_u32x4(sum[0]); +} + +static INLINE uint32_t highbd_sadwxh_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int w, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + int i = h; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src16_ptr + j); + uint16x8_t r0 = vld1q_u16(ref16_ptr + j); + uint16x8_t diff0 = vabdq_u16(s0, r0); + sum[0] = vpadalq_u16(sum[0], diff0); + + uint16x8_t s1 = vld1q_u16(src16_ptr + j + 8); + uint16x8_t r1 = vld1q_u16(ref16_ptr + j + 8); + uint16x8_t diff1 = vabdq_u16(s1, r1); + sum[1] = vpadalq_u16(sum[1], diff1); + + uint16x8_t s2 = vld1q_u16(src16_ptr + j + 16); + uint16x8_t r2 = vld1q_u16(ref16_ptr + j + 16); + uint16x8_t diff2 = vabdq_u16(s2, r2); + sum[2] = vpadalq_u16(sum[2], diff2); + + uint16x8_t s3 = vld1q_u16(src16_ptr + j + 24); + uint16x8_t r3 = vld1q_u16(ref16_ptr + j + 24); + uint16x8_t diff3 = vabdq_u16(s3, r3); + sum[3] = vpadalq_u16(sum[3], diff3); + + j += 32; + } while (j < w); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + } while (--i != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + sum[2] = vaddq_u32(sum[2], sum[3]); + sum[0] = vaddq_u32(sum[0], sum[2]); + + return horizontal_add_u32x4(sum[0]); +} + +static INLINE unsigned int highbd_sad128xh_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + return highbd_sadwxh_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, 128, + h); +} + +static INLINE unsigned int highbd_sad64xh_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + return highbd_sadwxh_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, + h); +} + +static INLINE unsigned int highbd_sad32xh_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + return highbd_sadwxh_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, + h); +} + +#define HBD_SAD_WXH_SMALL_NEON(w, h) \ + unsigned int aom_highbd_sad##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return highbd_sad##w##xh_small_neon(src, src_stride, ref, ref_stride, \ + (h)); \ + } + +#define HBD_SAD_WXH_LARGE_NEON(w, h) \ + unsigned int aom_highbd_sad##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return highbd_sad##w##xh_large_neon(src, src_stride, ref, ref_stride, \ + (h)); \ + } + +HBD_SAD_WXH_SMALL_NEON(4, 4) +HBD_SAD_WXH_SMALL_NEON(4, 8) + +HBD_SAD_WXH_SMALL_NEON(8, 4) +HBD_SAD_WXH_SMALL_NEON(8, 8) +HBD_SAD_WXH_SMALL_NEON(8, 16) + +HBD_SAD_WXH_LARGE_NEON(16, 8) +HBD_SAD_WXH_LARGE_NEON(16, 16) +HBD_SAD_WXH_LARGE_NEON(16, 32) + +HBD_SAD_WXH_LARGE_NEON(32, 16) +HBD_SAD_WXH_LARGE_NEON(32, 32) +HBD_SAD_WXH_LARGE_NEON(32, 64) + +HBD_SAD_WXH_LARGE_NEON(64, 32) +HBD_SAD_WXH_LARGE_NEON(64, 64) +HBD_SAD_WXH_LARGE_NEON(64, 128) + +HBD_SAD_WXH_LARGE_NEON(128, 64) +HBD_SAD_WXH_LARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SAD_WXH_SMALL_NEON(4, 16) + +HBD_SAD_WXH_LARGE_NEON(8, 32) + +HBD_SAD_WXH_LARGE_NEON(16, 4) +HBD_SAD_WXH_LARGE_NEON(16, 64) + +HBD_SAD_WXH_LARGE_NEON(32, 8) + +HBD_SAD_WXH_LARGE_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#define HBD_SAD_SKIP_WXH_SMALL_NEON(w, h) \ + unsigned int aom_highbd_sad_skip_##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return 2 * highbd_sad##w##xh_small_neon(src, 2 * src_stride, ref, \ + 2 * ref_stride, (h) / 2); \ + } + +#define HBD_SAD_SKIP_WXH_LARGE_NEON(w, h) \ + unsigned int aom_highbd_sad_skip_##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return 2 * highbd_sad##w##xh_large_neon(src, 2 * src_stride, ref, \ + 2 * ref_stride, (h) / 2); \ + } + +HBD_SAD_SKIP_WXH_SMALL_NEON(4, 4) +HBD_SAD_SKIP_WXH_SMALL_NEON(4, 8) + +HBD_SAD_SKIP_WXH_SMALL_NEON(8, 4) +HBD_SAD_SKIP_WXH_SMALL_NEON(8, 8) +HBD_SAD_SKIP_WXH_SMALL_NEON(8, 16) + +HBD_SAD_SKIP_WXH_LARGE_NEON(16, 8) +HBD_SAD_SKIP_WXH_LARGE_NEON(16, 16) +HBD_SAD_SKIP_WXH_LARGE_NEON(16, 32) + +HBD_SAD_SKIP_WXH_LARGE_NEON(32, 16) +HBD_SAD_SKIP_WXH_LARGE_NEON(32, 32) +HBD_SAD_SKIP_WXH_LARGE_NEON(32, 64) + +HBD_SAD_SKIP_WXH_LARGE_NEON(64, 32) +HBD_SAD_SKIP_WXH_LARGE_NEON(64, 64) +HBD_SAD_SKIP_WXH_LARGE_NEON(64, 128) + +HBD_SAD_SKIP_WXH_LARGE_NEON(128, 64) +HBD_SAD_SKIP_WXH_LARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SAD_SKIP_WXH_SMALL_NEON(4, 16) + +HBD_SAD_SKIP_WXH_SMALL_NEON(8, 32) + +HBD_SAD_SKIP_WXH_LARGE_NEON(16, 4) +HBD_SAD_SKIP_WXH_LARGE_NEON(16, 64) + +HBD_SAD_SKIP_WXH_LARGE_NEON(32, 8) + +HBD_SAD_SKIP_WXH_LARGE_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +static INLINE uint32_t highbd_sad4xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred); + uint32x4_t sum = vdupq_n_u32(0); + + int i = h; + do { + uint16x4_t s = vld1_u16(src16_ptr); + uint16x4_t r = vld1_u16(ref16_ptr); + uint16x4_t p = vld1_u16(pred16_ptr); + + uint16x4_t avg = vrhadd_u16(r, p); + sum = vabal_u16(sum, s, avg); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + pred16_ptr += 4; + } while (--i != 0); + + return horizontal_add_u32x4(sum); +} + +static INLINE uint32_t highbd_sad8xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred); + uint32x4_t sum = vdupq_n_u32(0); + + int i = h; + do { + uint16x8_t s = vld1q_u16(src16_ptr); + uint16x8_t r = vld1q_u16(ref16_ptr); + uint16x8_t p = vld1q_u16(pred16_ptr); + + uint16x8_t avg = vrhaddq_u16(r, p); + uint16x8_t diff = vabdq_u16(s, avg); + sum = vpadalq_u16(sum, diff); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + pred16_ptr += 8; + } while (--i != 0); + + return horizontal_add_u32x4(sum); +} + +static INLINE uint32_t highbd_sad16xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred); + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h; + do { + uint16x8_t s0, s1, r0, r1, p0, p1; + uint16x8_t avg0, avg1, diff0, diff1; + + s0 = vld1q_u16(src16_ptr); + r0 = vld1q_u16(ref16_ptr); + p0 = vld1q_u16(pred16_ptr); + avg0 = vrhaddq_u16(r0, p0); + diff0 = vabdq_u16(s0, avg0); + sum[0] = vpadalq_u16(sum[0], diff0); + + s1 = vld1q_u16(src16_ptr + 8); + r1 = vld1q_u16(ref16_ptr + 8); + p1 = vld1q_u16(pred16_ptr + 8); + avg1 = vrhaddq_u16(r1, p1); + diff1 = vabdq_u16(s1, avg1); + sum[1] = vpadalq_u16(sum[1], diff1); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + pred16_ptr += 16; + } while (--i != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + return horizontal_add_u32x4(sum[0]); +} + +static INLINE uint32_t highbd_sadwxh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int w, int h, + const uint8_t *second_pred) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr = CONVERT_TO_SHORTPTR(ref_ptr); + const uint16_t *pred16_ptr = CONVERT_TO_SHORTPTR(second_pred); + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + int i = h; + do { + int j = 0; + do { + uint16x8_t s0, s1, s2, s3, r0, r1, r2, r3, p0, p1, p2, p3; + uint16x8_t avg0, avg1, avg2, avg3, diff0, diff1, diff2, diff3; + + s0 = vld1q_u16(src16_ptr + j); + r0 = vld1q_u16(ref16_ptr + j); + p0 = vld1q_u16(pred16_ptr + j); + avg0 = vrhaddq_u16(r0, p0); + diff0 = vabdq_u16(s0, avg0); + sum[0] = vpadalq_u16(sum[0], diff0); + + s1 = vld1q_u16(src16_ptr + j + 8); + r1 = vld1q_u16(ref16_ptr + j + 8); + p1 = vld1q_u16(pred16_ptr + j + 8); + avg1 = vrhaddq_u16(r1, p1); + diff1 = vabdq_u16(s1, avg1); + sum[1] = vpadalq_u16(sum[1], diff1); + + s2 = vld1q_u16(src16_ptr + j + 16); + r2 = vld1q_u16(ref16_ptr + j + 16); + p2 = vld1q_u16(pred16_ptr + j + 16); + avg2 = vrhaddq_u16(r2, p2); + diff2 = vabdq_u16(s2, avg2); + sum[2] = vpadalq_u16(sum[2], diff2); + + s3 = vld1q_u16(src16_ptr + j + 24); + r3 = vld1q_u16(ref16_ptr + j + 24); + p3 = vld1q_u16(pred16_ptr + j + 24); + avg3 = vrhaddq_u16(r3, p3); + diff3 = vabdq_u16(s3, avg3); + sum[3] = vpadalq_u16(sum[3], diff3); + + j += 32; + } while (j < w); + + src16_ptr += src_stride; + ref16_ptr += ref_stride; + pred16_ptr += w; + } while (--i != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + sum[2] = vaddq_u32(sum[2], sum[3]); + sum[0] = vaddq_u32(sum[0], sum[2]); + + return horizontal_add_u32x4(sum[0]); +} + +static INLINE unsigned int highbd_sad128xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred) { + return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 128, + h, second_pred); +} + +static INLINE unsigned int highbd_sad64xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, h, + second_pred); +} + +static INLINE unsigned int highbd_sad32xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + return highbd_sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, h, + second_pred); +} + +#define HBD_SAD_WXH_AVG_NEON(w, h) \ + uint32_t aom_highbd_sad##w##x##h##_avg_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred) { \ + return highbd_sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h), \ + second_pred); \ + } + +HBD_SAD_WXH_AVG_NEON(4, 4) +HBD_SAD_WXH_AVG_NEON(4, 8) + +HBD_SAD_WXH_AVG_NEON(8, 4) +HBD_SAD_WXH_AVG_NEON(8, 8) +HBD_SAD_WXH_AVG_NEON(8, 16) + +HBD_SAD_WXH_AVG_NEON(16, 8) +HBD_SAD_WXH_AVG_NEON(16, 16) +HBD_SAD_WXH_AVG_NEON(16, 32) + +HBD_SAD_WXH_AVG_NEON(32, 16) +HBD_SAD_WXH_AVG_NEON(32, 32) +HBD_SAD_WXH_AVG_NEON(32, 64) + +HBD_SAD_WXH_AVG_NEON(64, 32) +HBD_SAD_WXH_AVG_NEON(64, 64) +HBD_SAD_WXH_AVG_NEON(64, 128) + +HBD_SAD_WXH_AVG_NEON(128, 64) +HBD_SAD_WXH_AVG_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SAD_WXH_AVG_NEON(4, 16) + +HBD_SAD_WXH_AVG_NEON(8, 32) + +HBD_SAD_WXH_AVG_NEON(16, 4) +HBD_SAD_WXH_AVG_NEON(16, 64) + +HBD_SAD_WXH_AVG_NEON(32, 8) + +HBD_SAD_WXH_AVG_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY diff --git a/third_party/aom/aom_dsp/arm/highbd_sadxd_neon.c b/third_party/aom/aom_dsp/arm/highbd_sadxd_neon.c new file mode 100644 index 0000000000..85ca6732a8 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_sadxd_neon.c @@ -0,0 +1,617 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void highbd_sad4xhx4d_small_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]); + + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + int i = 0; + do { + uint16x4_t s = vld1_u16(src16_ptr + i * src_stride); + uint16x4_t r0 = vld1_u16(ref16_ptr0 + i * ref_stride); + uint16x4_t r1 = vld1_u16(ref16_ptr1 + i * ref_stride); + uint16x4_t r2 = vld1_u16(ref16_ptr2 + i * ref_stride); + uint16x4_t r3 = vld1_u16(ref16_ptr3 + i * ref_stride); + + sum[0] = vabal_u16(sum[0], s, r0); + sum[1] = vabal_u16(sum[1], s, r1); + sum[2] = vabal_u16(sum[2], s, r2); + sum[3] = vabal_u16(sum[3], s, r3); + + } while (++i < h); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void highbd_sad8xhx4d_small_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]); + + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint32x4_t sum_u32[4]; + + int i = 0; + do { + uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride); + + sum[0] = vabaq_u16(sum[0], s, vld1q_u16(ref16_ptr0 + i * ref_stride)); + sum[1] = vabaq_u16(sum[1], s, vld1q_u16(ref16_ptr1 + i * ref_stride)); + sum[2] = vabaq_u16(sum[2], s, vld1q_u16(ref16_ptr2 + i * ref_stride)); + sum[3] = vabaq_u16(sum[3], s, vld1q_u16(ref16_ptr3 + i * ref_stride)); + + } while (++i < h); + + sum_u32[0] = vpaddlq_u16(sum[0]); + sum_u32[1] = vpaddlq_u16(sum[1]); + sum_u32[2] = vpaddlq_u16(sum[2]); + sum_u32[3] = vpaddlq_u16(sum[3]); + vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32)); +} + +static INLINE void sad8_neon(uint16x8_t src, uint16x8_t ref, + uint32x4_t *const sad_sum) { + uint16x8_t abs_diff = vabdq_u16(src, ref); + *sad_sum = vpadalq_u16(*sad_sum, abs_diff); +} + +#if !CONFIG_REALTIME_ONLY +static INLINE void highbd_sad8xhx4d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]); + + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + int i = 0; + do { + uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride); + sad8_neon(s, vld1q_u16(ref16_ptr0 + i * ref_stride), &sum[0]); + sad8_neon(s, vld1q_u16(ref16_ptr1 + i * ref_stride), &sum[1]); + sad8_neon(s, vld1q_u16(ref16_ptr2 + i * ref_stride), &sum[2]); + sad8_neon(s, vld1q_u16(ref16_ptr3 + i * ref_stride), &sum[3]); + + } while (++i < h); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} +#endif // !CONFIG_REALTIME_ONLY + +static INLINE void highbd_sad16xhx4d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]); + + uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint32x4_t sum[4]; + + int i = 0; + do { + uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride); + sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride), &sum_lo[0]); + sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride), &sum_lo[1]); + sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride), &sum_lo[2]); + sad8_neon(s0, vld1q_u16(ref16_ptr3 + i * ref_stride), &sum_lo[3]); + + uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + 8); + sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + 8), &sum_hi[0]); + sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + 8), &sum_hi[1]); + sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + 8), &sum_hi[2]); + sad8_neon(s1, vld1q_u16(ref16_ptr3 + i * ref_stride + 8), &sum_hi[3]); + + } while (++i < h); + + sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]); + sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]); + sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]); + sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void highbd_sadwxhx4d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int w, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]); + + uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint32x4_t sum[4]; + + int i = 0; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride + j); + sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride + j), &sum_lo[0]); + sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride + j), &sum_lo[1]); + sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride + j), &sum_lo[2]); + sad8_neon(s0, vld1q_u16(ref16_ptr3 + i * ref_stride + j), &sum_lo[3]); + + uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + j + 8); + sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 8), &sum_hi[0]); + sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 8), &sum_hi[1]); + sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 8), &sum_hi[2]); + sad8_neon(s1, vld1q_u16(ref16_ptr3 + i * ref_stride + j + 8), &sum_hi[3]); + + uint16x8_t s2 = vld1q_u16(src16_ptr + i * src_stride + j + 16); + sad8_neon(s2, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 16), + &sum_lo[0]); + sad8_neon(s2, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 16), + &sum_lo[1]); + sad8_neon(s2, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 16), + &sum_lo[2]); + sad8_neon(s2, vld1q_u16(ref16_ptr3 + i * ref_stride + j + 16), + &sum_lo[3]); + + uint16x8_t s3 = vld1q_u16(src16_ptr + i * src_stride + j + 24); + sad8_neon(s3, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 24), + &sum_hi[0]); + sad8_neon(s3, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 24), + &sum_hi[1]); + sad8_neon(s3, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 24), + &sum_hi[2]); + sad8_neon(s3, vld1q_u16(ref16_ptr3 + i * ref_stride + j + 24), + &sum_hi[3]); + + j += 32; + } while (j < w); + + } while (++i < h); + + sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]); + sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]); + sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]); + sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void highbd_sad128xhx4d_large_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], int h) { + highbd_sadwxhx4d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, + 128, h); +} + +static INLINE void highbd_sad64xhx4d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + highbd_sadwxhx4d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 64, + h); +} + +static INLINE void highbd_sad32xhx4d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + highbd_sadwxhx4d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 32, + h); +} + +#define HBD_SAD_WXH_4D_SMALL_NEON(w, h) \ + void aom_highbd_sad##w##x##h##x4d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \ + int ref_stride, uint32_t sad_array[4]) { \ + highbd_sad##w##xhx4d_small_neon(src, src_stride, ref_array, ref_stride, \ + sad_array, (h)); \ + } + +#define HBD_SAD_WXH_4D_LARGE_NEON(w, h) \ + void aom_highbd_sad##w##x##h##x4d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \ + int ref_stride, uint32_t sad_array[4]) { \ + highbd_sad##w##xhx4d_large_neon(src, src_stride, ref_array, ref_stride, \ + sad_array, (h)); \ + } + +HBD_SAD_WXH_4D_SMALL_NEON(4, 4) +HBD_SAD_WXH_4D_SMALL_NEON(4, 8) + +HBD_SAD_WXH_4D_SMALL_NEON(8, 4) +HBD_SAD_WXH_4D_SMALL_NEON(8, 8) +HBD_SAD_WXH_4D_SMALL_NEON(8, 16) + +HBD_SAD_WXH_4D_LARGE_NEON(16, 8) +HBD_SAD_WXH_4D_LARGE_NEON(16, 16) +HBD_SAD_WXH_4D_LARGE_NEON(16, 32) + +HBD_SAD_WXH_4D_LARGE_NEON(32, 16) +HBD_SAD_WXH_4D_LARGE_NEON(32, 32) +HBD_SAD_WXH_4D_LARGE_NEON(32, 64) + +HBD_SAD_WXH_4D_LARGE_NEON(64, 32) +HBD_SAD_WXH_4D_LARGE_NEON(64, 64) +HBD_SAD_WXH_4D_LARGE_NEON(64, 128) + +HBD_SAD_WXH_4D_LARGE_NEON(128, 64) +HBD_SAD_WXH_4D_LARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SAD_WXH_4D_SMALL_NEON(4, 16) + +HBD_SAD_WXH_4D_LARGE_NEON(8, 32) + +HBD_SAD_WXH_4D_LARGE_NEON(16, 4) +HBD_SAD_WXH_4D_LARGE_NEON(16, 64) + +HBD_SAD_WXH_4D_LARGE_NEON(32, 8) + +HBD_SAD_WXH_4D_LARGE_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#define HBD_SAD_SKIP_WXH_4D_SMALL_NEON(w, h) \ + void aom_highbd_sad_skip_##w##x##h##x4d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \ + int ref_stride, uint32_t sad_array[4]) { \ + highbd_sad##w##xhx4d_small_neon(src, 2 * src_stride, ref_array, \ + 2 * ref_stride, sad_array, ((h) >> 1)); \ + sad_array[0] <<= 1; \ + sad_array[1] <<= 1; \ + sad_array[2] <<= 1; \ + sad_array[3] <<= 1; \ + } + +#define HBD_SAD_SKIP_WXH_4D_LARGE_NEON(w, h) \ + void aom_highbd_sad_skip_##w##x##h##x4d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \ + int ref_stride, uint32_t sad_array[4]) { \ + highbd_sad##w##xhx4d_large_neon(src, 2 * src_stride, ref_array, \ + 2 * ref_stride, sad_array, ((h) >> 1)); \ + sad_array[0] <<= 1; \ + sad_array[1] <<= 1; \ + sad_array[2] <<= 1; \ + sad_array[3] <<= 1; \ + } + +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(4, 4) +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(4, 8) + +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 4) +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 8) +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 16) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 8) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 16) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 32) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 16) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 32) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 64) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 32) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 64) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 128) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(128, 64) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(4, 16) + +HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 32) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 4) +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 64) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 8) + +HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +static INLINE void highbd_sad4xhx3d_small_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + + uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = 0; + do { + uint16x4_t s = vld1_u16(src16_ptr + i * src_stride); + uint16x4_t r0 = vld1_u16(ref16_ptr0 + i * ref_stride); + uint16x4_t r1 = vld1_u16(ref16_ptr1 + i * ref_stride); + uint16x4_t r2 = vld1_u16(ref16_ptr2 + i * ref_stride); + + sum[0] = vabal_u16(sum[0], s, r0); + sum[1] = vabal_u16(sum[1], s, r1); + sum[2] = vabal_u16(sum[2], s, r2); + + } while (++i < h); + + res[0] = horizontal_add_u32x4(sum[0]); + res[1] = horizontal_add_u32x4(sum[1]); + res[2] = horizontal_add_u32x4(sum[2]); +} + +static INLINE void highbd_sad8xhx3d_small_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + + uint16x8_t sum[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) }; + + int i = 0; + do { + uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride); + + sum[0] = vabaq_u16(sum[0], s, vld1q_u16(ref16_ptr0 + i * ref_stride)); + sum[1] = vabaq_u16(sum[1], s, vld1q_u16(ref16_ptr1 + i * ref_stride)); + sum[2] = vabaq_u16(sum[2], s, vld1q_u16(ref16_ptr2 + i * ref_stride)); + + } while (++i < h); + + res[0] = horizontal_add_u32x4(vpaddlq_u16(sum[0])); + res[1] = horizontal_add_u32x4(vpaddlq_u16(sum[1])); + res[2] = horizontal_add_u32x4(vpaddlq_u16(sum[2])); +} + +#if !CONFIG_REALTIME_ONLY +static INLINE void highbd_sad8xhx3d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + + uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = 0; + do { + uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride); + uint16x8_t r0 = vld1q_u16(ref16_ptr0 + i * ref_stride); + uint16x8_t r1 = vld1q_u16(ref16_ptr1 + i * ref_stride); + uint16x8_t r2 = vld1q_u16(ref16_ptr2 + i * ref_stride); + + sad8_neon(s, r0, &sum[0]); + sad8_neon(s, r1, &sum[1]); + sad8_neon(s, r2, &sum[2]); + + } while (++i < h); + + res[0] = horizontal_add_u32x4(sum[0]); + res[1] = horizontal_add_u32x4(sum[1]); + res[2] = horizontal_add_u32x4(sum[2]); +} +#endif // !CONFIG_REALTIME_ONLY + +static INLINE void highbd_sad16xhx3d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + + uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = 0; + do { + uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride); + sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride), &sum_lo[0]); + sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride), &sum_lo[1]); + sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride), &sum_lo[2]); + + uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + 8); + sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + 8), &sum_hi[0]); + sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + 8), &sum_hi[1]); + sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + 8), &sum_hi[2]); + + } while (++i < h); + + res[0] = horizontal_add_u32x4(vaddq_u32(sum_lo[0], sum_hi[0])); + res[1] = horizontal_add_u32x4(vaddq_u32(sum_lo[1], sum_hi[1])); + res[2] = horizontal_add_u32x4(vaddq_u32(sum_lo[2], sum_hi[2])); +} + +static INLINE void highbd_sadwxhx3d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int w, int h) { + const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr); + const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]); + const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]); + const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]); + + uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + uint32x4_t sum[3]; + + int i = 0; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride + j); + sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride + j), &sum_lo[0]); + sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride + j), &sum_lo[1]); + sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride + j), &sum_lo[2]); + + uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + j + 8); + sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 8), &sum_hi[0]); + sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 8), &sum_hi[1]); + sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 8), &sum_hi[2]); + + uint16x8_t s2 = vld1q_u16(src16_ptr + i * src_stride + j + 16); + sad8_neon(s2, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 16), + &sum_lo[0]); + sad8_neon(s2, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 16), + &sum_lo[1]); + sad8_neon(s2, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 16), + &sum_lo[2]); + + uint16x8_t s3 = vld1q_u16(src16_ptr + i * src_stride + j + 24); + sad8_neon(s3, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 24), + &sum_hi[0]); + sad8_neon(s3, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 24), + &sum_hi[1]); + sad8_neon(s3, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 24), + &sum_hi[2]); + + j += 32; + } while (j < w); + + } while (++i < h); + + sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]); + sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]); + sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]); + + res[0] = horizontal_add_u32x4(sum[0]); + res[1] = horizontal_add_u32x4(sum[1]); + res[2] = horizontal_add_u32x4(sum[2]); +} + +static INLINE void highbd_sad128xhx3d_large_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], int h) { + highbd_sadwxhx3d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, + 128, h); +} + +static INLINE void highbd_sad64xhx3d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + highbd_sadwxhx3d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 64, + h); +} + +static INLINE void highbd_sad32xhx3d_large_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *const ref_ptr[4], + int ref_stride, uint32_t res[4], + int h) { + highbd_sadwxhx3d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 32, + h); +} + +#define HBD_SAD_WXH_3D_SMALL_NEON(w, h) \ + void aom_highbd_sad##w##x##h##x3d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \ + int ref_stride, uint32_t sad_array[4]) { \ + highbd_sad##w##xhx3d_small_neon(src, src_stride, ref_array, ref_stride, \ + sad_array, (h)); \ + } + +#define HBD_SAD_WXH_3D_LARGE_NEON(w, h) \ + void aom_highbd_sad##w##x##h##x3d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \ + int ref_stride, uint32_t sad_array[4]) { \ + highbd_sad##w##xhx3d_large_neon(src, src_stride, ref_array, ref_stride, \ + sad_array, (h)); \ + } + +HBD_SAD_WXH_3D_SMALL_NEON(4, 4) +HBD_SAD_WXH_3D_SMALL_NEON(4, 8) + +HBD_SAD_WXH_3D_SMALL_NEON(8, 4) +HBD_SAD_WXH_3D_SMALL_NEON(8, 8) +HBD_SAD_WXH_3D_SMALL_NEON(8, 16) + +HBD_SAD_WXH_3D_LARGE_NEON(16, 8) +HBD_SAD_WXH_3D_LARGE_NEON(16, 16) +HBD_SAD_WXH_3D_LARGE_NEON(16, 32) + +HBD_SAD_WXH_3D_LARGE_NEON(32, 16) +HBD_SAD_WXH_3D_LARGE_NEON(32, 32) +HBD_SAD_WXH_3D_LARGE_NEON(32, 64) + +HBD_SAD_WXH_3D_LARGE_NEON(64, 32) +HBD_SAD_WXH_3D_LARGE_NEON(64, 64) +HBD_SAD_WXH_3D_LARGE_NEON(64, 128) + +HBD_SAD_WXH_3D_LARGE_NEON(128, 64) +HBD_SAD_WXH_3D_LARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SAD_WXH_3D_SMALL_NEON(4, 16) + +HBD_SAD_WXH_3D_LARGE_NEON(8, 32) + +HBD_SAD_WXH_3D_LARGE_NEON(16, 4) +HBD_SAD_WXH_3D_LARGE_NEON(16, 64) + +HBD_SAD_WXH_3D_LARGE_NEON(32, 8) + +HBD_SAD_WXH_3D_LARGE_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY diff --git a/third_party/aom/aom_dsp/arm/highbd_sse_neon.c b/third_party/aom/aom_dsp/arm/highbd_sse_neon.c new file mode 100644 index 0000000000..184e9f9bef --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_sse_neon.c @@ -0,0 +1,284 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void highbd_sse_8x1_init_neon(const uint16_t *src, + const uint16_t *ref, + uint32x4_t *sse_acc0, + uint32x4_t *sse_acc1) { + uint16x8_t s = vld1q_u16(src); + uint16x8_t r = vld1q_u16(ref); + + uint16x8_t abs_diff = vabdq_u16(s, r); + uint16x4_t abs_diff_lo = vget_low_u16(abs_diff); + uint16x4_t abs_diff_hi = vget_high_u16(abs_diff); + + *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo); + *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi); +} + +static INLINE void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref, + uint32x4_t *sse_acc0, + uint32x4_t *sse_acc1) { + uint16x8_t s = vld1q_u16(src); + uint16x8_t r = vld1q_u16(ref); + + uint16x8_t abs_diff = vabdq_u16(s, r); + uint16x4_t abs_diff_lo = vget_low_u16(abs_diff); + uint16x4_t abs_diff_hi = vget_high_u16(abs_diff); + + *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo); + *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi); +} + +static INLINE int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint32x4_t sse[16]; + highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); + highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]); + highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]); + highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]); + highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]); + highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]); + highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]); + highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]); + highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]); + highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]); + + src += src_stride; + ref += ref_stride; + + while (--height != 0) { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); + highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]); + highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]); + highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]); + highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]); + highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]); + highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]); + highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]); + highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]); + highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]); + + src += src_stride; + ref += ref_stride; + } + + return horizontal_long_add_u32x4_x16(sse); +} + +static INLINE int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint32x4_t sse[8]; + highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); + highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]); + + src += src_stride; + ref += ref_stride; + + while (--height != 0) { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); + highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]); + + src += src_stride; + ref += ref_stride; + } + + return horizontal_long_add_u32x4_x8(sse); +} + +static INLINE int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint32x4_t sse[8]; + highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); + + src += src_stride; + ref += ref_stride; + + while (--height != 0) { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); + highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); + + src += src_stride; + ref += ref_stride; + } + + return horizontal_long_add_u32x4_x8(sse); +} + +static INLINE int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint32x4_t sse[4]; + highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + + src += src_stride; + ref += ref_stride; + + while (--height != 0) { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); + + src += src_stride; + ref += ref_stride; + } + + return horizontal_long_add_u32x4_x4(sse); +} + +static INLINE int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2]; + highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]); + + src += src_stride; + ref += ref_stride; + + while (--height != 0) { + highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]); + + src += src_stride; + ref += ref_stride; + } + + return horizontal_long_add_u32x4_x2(sse); +} + +static INLINE int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + // Peel the first loop iteration. + uint16x4_t s = vld1_u16(src); + uint16x4_t r = vld1_u16(ref); + + uint16x4_t abs_diff = vabd_u16(s, r); + uint32x4_t sse = vmull_u16(abs_diff, abs_diff); + + src += src_stride; + ref += ref_stride; + + while (--height != 0) { + s = vld1_u16(src); + r = vld1_u16(ref); + + abs_diff = vabd_u16(s, r); + sse = vmlal_u16(sse, abs_diff, abs_diff); + + src += src_stride; + ref += ref_stride; + } + + return horizontal_long_add_u32x4(sse); +} + +static INLINE int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int width, int height) { + // { 0, 1, 2, 3, 4, 5, 6, 7 } + uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100)); + uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7)); + uint64_t sse = 0; + + do { + int w = width; + int offset = 0; + + do { + uint16x8_t s = vld1q_u16(src + offset); + uint16x8_t r = vld1q_u16(ref + offset); + + if (w < 8) { + // Mask out-of-range elements. + s = vandq_u16(s, remainder_mask); + r = vandq_u16(r, remainder_mask); + } + + uint16x8_t abs_diff = vabdq_u16(s, r); + uint16x4_t abs_diff_lo = vget_low_u16(abs_diff); + uint16x4_t abs_diff_hi = vget_high_u16(abs_diff); + + uint32x4_t sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo); + sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi); + + sse += horizontal_long_add_u32x4(sse_u32); + + offset += 8; + w -= 8; + } while (w > 0); + + src += src_stride; + ref += ref_stride; + } while (--height != 0); + + return sse; +} + +int64_t aom_highbd_sse_neon(const uint8_t *src8, int src_stride, + const uint8_t *ref8, int ref_stride, int width, + int height) { + uint16_t *src = CONVERT_TO_SHORTPTR(src8); + uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); + + switch (width) { + case 4: + return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height); + case 8: + return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height); + case 16: + return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height); + case 32: + return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height); + case 64: + return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height); + case 128: + return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height); + default: + return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width, + height); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_sse_sve.c b/third_party/aom/aom_dsp/arm/highbd_sse_sve.c new file mode 100644 index 0000000000..b267da5cfb --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_sse_sve.c @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "aom_dsp/arm/dot_sve.h" +#include "aom_dsp/arm/mem_neon.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref, + uint64x2_t *sse) { + uint16x8_t s = vld1q_u16(src); + uint16x8_t r = vld1q_u16(ref); + + uint16x8_t abs_diff = vabdq_u16(s, r); + + *sse = aom_udotq_u16(*sse, abs_diff, abs_diff); +} + +static INLINE int64_t highbd_sse_128xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), + vdupq_n_u64(0) }; + + do { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]); + highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]); + highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[3]); + highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[3]); + + src += src_stride; + ref += ref_stride; + } while (--height != 0); + + sse[0] = vaddq_u64(sse[0], sse[1]); + sse[2] = vaddq_u64(sse[2], sse[3]); + sse[0] = vaddq_u64(sse[0], sse[2]); + return vaddvq_u64(sse[0]); +} + +static INLINE int64_t highbd_sse_64xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), + vdupq_n_u64(0) }; + + do { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]); + highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]); + + src += src_stride; + ref += ref_stride; + } while (--height != 0); + + sse[0] = vaddq_u64(sse[0], sse[1]); + sse[2] = vaddq_u64(sse[2], sse[3]); + sse[0] = vaddq_u64(sse[0], sse[2]); + return vaddvq_u64(sse[0]); +} + +static INLINE int64_t highbd_sse_32xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), + vdupq_n_u64(0) }; + + do { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); + highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]); + highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]); + + src += src_stride; + ref += ref_stride; + } while (--height != 0); + + sse[0] = vaddq_u64(sse[0], sse[1]); + sse[2] = vaddq_u64(sse[2], sse[3]); + sse[0] = vaddq_u64(sse[0], sse[2]); + return vaddvq_u64(sse[0]); +} + +static INLINE int64_t highbd_sse_16xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; + + do { + highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]); + highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--height != 0); + + return vaddvq_u64(vaddq_u64(sse[0], sse[1])); +} + +static INLINE int64_t highbd_sse_8xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; + + do { + highbd_sse_8x1_neon(src + 0 * src_stride, ref + 0 * ref_stride, &sse[0]); + highbd_sse_8x1_neon(src + 1 * src_stride, ref + 1 * ref_stride, &sse[1]); + + src += 2 * src_stride; + ref += 2 * ref_stride; + height -= 2; + } while (height != 0); + + return vaddvq_u64(vaddq_u64(sse[0], sse[1])); +} + +static INLINE int64_t highbd_sse_4xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int height) { + uint64x2_t sse = vdupq_n_u64(0); + + do { + uint16x8_t s = load_unaligned_u16_4x2(src, src_stride); + uint16x8_t r = load_unaligned_u16_4x2(ref, ref_stride); + + uint16x8_t abs_diff = vabdq_u16(s, r); + sse = aom_udotq_u16(sse, abs_diff, abs_diff); + + src += 2 * src_stride; + ref += 2 * ref_stride; + height -= 2; + } while (height != 0); + + return vaddvq_u64(sse); +} + +static INLINE int64_t highbd_sse_wxh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int width, int height) { + svuint64_t sse = svdup_n_u64(0); + uint64_t step = svcnth(); + + do { + int w = 0; + const uint16_t *src_ptr = src; + const uint16_t *ref_ptr = ref; + + do { + svbool_t pred = svwhilelt_b16_u32(w, width); + svuint16_t s = svld1_u16(pred, src_ptr); + svuint16_t r = svld1_u16(pred, ref_ptr); + + svuint16_t abs_diff = svabd_u16_z(pred, s, r); + + sse = svdot_u64(sse, abs_diff, abs_diff); + + src_ptr += step; + ref_ptr += step; + w += step; + } while (w < width); + + src += src_stride; + ref += ref_stride; + } while (--height != 0); + + return svaddv_u64(svptrue_b64(), sse); +} + +int64_t aom_highbd_sse_sve(const uint8_t *src8, int src_stride, + const uint8_t *ref8, int ref_stride, int width, + int height) { + uint16_t *src = CONVERT_TO_SHORTPTR(src8); + uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); + + switch (width) { + case 4: return highbd_sse_4xh_sve(src, src_stride, ref, ref_stride, height); + case 8: return highbd_sse_8xh_sve(src, src_stride, ref, ref_stride, height); + case 16: + return highbd_sse_16xh_sve(src, src_stride, ref, ref_stride, height); + case 32: + return highbd_sse_32xh_sve(src, src_stride, ref, ref_stride, height); + case 64: + return highbd_sse_64xh_sve(src, src_stride, ref, ref_stride, height); + case 128: + return highbd_sse_128xh_sve(src, src_stride, ref, ref_stride, height); + default: + return highbd_sse_wxh_sve(src, src_stride, ref, ref_stride, width, + height); + } +} diff --git a/third_party/aom/aom_dsp/arm/highbd_subpel_variance_neon.c b/third_party/aom/aom_dsp/arm/highbd_subpel_variance_neon.c new file mode 100644 index 0000000000..686fa5f226 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_subpel_variance_neon.c @@ -0,0 +1,1497 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/dist_wtd_avg_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/variance.h" + +// The bilinear filters look like this: +// +// {{ 128, 0 }, { 112, 16 }, { 96, 32 }, { 80, 48 }, +// { 64, 64 }, { 48, 80 }, { 32, 96 }, { 16, 112 }} +// +// We can factor out the highest common multiple, such that the sum of both +// weights will be 8 instead of 128. The benefits of this are two-fold: +// +// 1) We can infer the filter values from the filter_offset parameter in the +// bilinear filter functions below - we don't have to actually load the values +// from memory: +// f0 = 8 - filter_offset +// f1 = filter_offset +// +// 2) Scaling the pixel values by 8, instead of 128 enables us to operate on +// 16-bit data types at all times, rather than widening out to 32-bit and +// requiring double the number of data processing instructions. (12-bit * 8 = +// 15-bit.) + +// Process a block exactly 4 wide and any height. +static void highbd_var_filter_block2d_bil_w4(const uint16_t *src_ptr, + uint16_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset) { + const uint16x4_t f0 = vdup_n_u16(8 - filter_offset); + const uint16x4_t f1 = vdup_n_u16(filter_offset); + + int i = dst_height; + do { + uint16x4_t s0 = load_unaligned_u16_4x1(src_ptr); + uint16x4_t s1 = load_unaligned_u16_4x1(src_ptr + pixel_step); + + uint16x4_t blend = vmul_u16(s0, f0); + blend = vmla_u16(blend, s1, f1); + blend = vrshr_n_u16(blend, 3); + + vst1_u16(dst_ptr, blend); + + src_ptr += src_stride; + dst_ptr += 4; + } while (--i != 0); +} + +// Process a block which is a multiple of 8 and any height. +static void highbd_var_filter_block2d_bil_large(const uint16_t *src_ptr, + uint16_t *dst_ptr, + int src_stride, int pixel_step, + int dst_width, int dst_height, + int filter_offset) { + const uint16x8_t f0 = vdupq_n_u16(8 - filter_offset); + const uint16x8_t f1 = vdupq_n_u16(filter_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src_ptr + j); + uint16x8_t s1 = vld1q_u16(src_ptr + j + pixel_step); + + uint16x8_t blend = vmulq_u16(s0, f0); + blend = vmlaq_u16(blend, s1, f1); + blend = vrshrq_n_u16(blend, 3); + + vst1q_u16(dst_ptr + j, blend); + + j += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +static void highbd_var_filter_block2d_bil_w8(const uint16_t *src_ptr, + uint16_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset) { + highbd_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, + 8, dst_height, filter_offset); +} + +static void highbd_var_filter_block2d_bil_w16(const uint16_t *src_ptr, + uint16_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset) { + highbd_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, + 16, dst_height, filter_offset); +} + +static void highbd_var_filter_block2d_bil_w32(const uint16_t *src_ptr, + uint16_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset) { + highbd_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, + 32, dst_height, filter_offset); +} + +static void highbd_var_filter_block2d_bil_w64(const uint16_t *src_ptr, + uint16_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset) { + highbd_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, + 64, dst_height, filter_offset); +} + +static void highbd_var_filter_block2d_bil_w128(const uint16_t *src_ptr, + uint16_t *dst_ptr, + int src_stride, int pixel_step, + int dst_height, + int filter_offset) { + highbd_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, + 128, dst_height, filter_offset); +} + +static void highbd_var_filter_block2d_avg(const uint16_t *src_ptr, + uint16_t *dst_ptr, int src_stride, + int pixel_step, int dst_width, + int dst_height) { + int i = dst_height; + + // We only specialize on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src_ptr + j); + uint16x8_t s1 = vld1q_u16(src_ptr + j + pixel_step); + uint16x8_t avg = vrhaddq_u16(s0, s1); + vst1q_u16(dst_ptr + j, avg); + + j += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +#define HBD_SUBPEL_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int aom_highbd_##bitdepth##_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, uint32_t *sse) { \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * h]; \ + uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src); \ + \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, (h + 1), \ + xoffset); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + \ + return aom_highbd_##bitdepth##_variance##w##x##h(CONVERT_TO_BYTEPTR(tmp1), \ + w, ref, ref_stride, sse); \ + } + +#define HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int aom_highbd_##bitdepth##_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, unsigned int *sse) { \ + uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src); \ + \ + if (xoffset == 0) { \ + if (yoffset == 0) { \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(src_ptr), src_stride, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp[w * h]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp, src_stride, src_stride, w, \ + h); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp, src_stride, \ + src_stride, h, yoffset); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, 1, w, h); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, 1, w, \ + (h + 1)); \ + highbd_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, 1, w, \ + (h + 1)); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } \ + } else { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, h, \ + xoffset); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, \ + (h + 1), xoffset); \ + highbd_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, \ + (h + 1), xoffset); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } \ + } \ + } + +// 8-bit +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 4, 4) +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 4, 8) + +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 8, 4) +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 8, 8) +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 8, 16) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 8) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 16) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 32) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 16) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 32) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 64) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 32) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 64) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 128) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 128, 64) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 4, 16) + +HBD_SUBPEL_VARIANCE_WXH_NEON(8, 8, 32) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 4) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 64) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 8) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 10-bit +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 4, 4) +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 4, 8) + +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 8, 4) +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 8, 8) +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 8, 16) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 8) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 16) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 32) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 16) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 32) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 64) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 32) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 64) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 128) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 128, 64) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 4, 16) + +HBD_SUBPEL_VARIANCE_WXH_NEON(10, 8, 32) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 4) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 64) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 8) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 12-bit +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 4, 4) +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 4, 8) + +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 8, 4) +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 8, 8) +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 8, 16) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 8) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 16) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 32) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 16) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 32) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 64) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 32) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 64) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 128) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 128, 64) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 4, 16) + +HBD_SUBPEL_VARIANCE_WXH_NEON(12, 8, 32) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 4) +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 64) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 8) + +HBD_SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// Combine bilinear filter with aom_highbd_comp_avg_pred for blocks having +// width 4. +static void highbd_avg_pred_var_filter_block2d_bil_w4( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred) { + const uint16x4_t f0 = vdup_n_u16(8 - filter_offset); + const uint16x4_t f1 = vdup_n_u16(filter_offset); + + int i = dst_height; + do { + uint16x4_t s0 = load_unaligned_u16_4x1(src_ptr); + uint16x4_t s1 = load_unaligned_u16_4x1(src_ptr + pixel_step); + uint16x4_t p = vld1_u16(second_pred); + + uint16x4_t blend = vmul_u16(s0, f0); + blend = vmla_u16(blend, s1, f1); + blend = vrshr_n_u16(blend, 3); + + vst1_u16(dst_ptr, vrhadd_u16(blend, p)); + + src_ptr += src_stride; + dst_ptr += 4; + second_pred += 4; + } while (--i != 0); +} + +// Combine bilinear filter with aom_highbd_comp_avg_pred for large blocks. +static void highbd_avg_pred_var_filter_block2d_bil_large( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, int filter_offset, + const uint16_t *second_pred) { + const uint16x8_t f0 = vdupq_n_u16(8 - filter_offset); + const uint16x8_t f1 = vdupq_n_u16(filter_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src_ptr + j); + uint16x8_t s1 = vld1q_u16(src_ptr + j + pixel_step); + uint16x8_t p = vld1q_u16(second_pred); + + uint16x8_t blend = vmulq_u16(s0, f0); + blend = vmlaq_u16(blend, s1, f1); + blend = vrshrq_n_u16(blend, 3); + + vst1q_u16(dst_ptr + j, vrhaddq_u16(blend, p)); + + j += 8; + second_pred += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +static void highbd_avg_pred_var_filter_block2d_bil_w8( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred) { + highbd_avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 8, dst_height, + filter_offset, second_pred); +} + +static void highbd_avg_pred_var_filter_block2d_bil_w16( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred) { + highbd_avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 16, dst_height, + filter_offset, second_pred); +} + +static void highbd_avg_pred_var_filter_block2d_bil_w32( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred) { + highbd_avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 32, dst_height, + filter_offset, second_pred); +} + +static void highbd_avg_pred_var_filter_block2d_bil_w64( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred) { + highbd_avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 64, dst_height, + filter_offset, second_pred); +} + +static void highbd_avg_pred_var_filter_block2d_bil_w128( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred) { + highbd_avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 128, dst_height, + filter_offset, second_pred); +} + +// Combine averaging subpel filter with aom_highbd_comp_avg_pred. +static void highbd_avg_pred_var_filter_block2d_avg( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, const uint16_t *second_pred) { + int i = dst_height; + + // We only specialize on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src_ptr + j); + uint16x8_t s1 = vld1q_u16(src_ptr + j + pixel_step); + uint16x8_t avg = vrhaddq_u16(s0, s1); + + uint16x8_t p = vld1q_u16(second_pred); + avg = vrhaddq_u16(avg, p); + + vst1q_u16(dst_ptr + j, avg); + + j += 8; + second_pred += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +// Implementation of aom_highbd_comp_avg_pred for blocks having width >= 16. +static void highbd_avg_pred(const uint16_t *src_ptr, uint16_t *dst_ptr, + int src_stride, int dst_width, int dst_height, + const uint16_t *second_pred) { + int i = dst_height; + + // We only specialize on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + + do { + int j = 0; + do { + uint16x8_t s = vld1q_u16(src_ptr + j); + uint16x8_t p = vld1q_u16(second_pred); + + uint16x8_t avg = vrhaddq_u16(s, p); + + vst1q_u16(dst_ptr + j, avg); + + j += 8; + second_pred += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +#define HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(bitdepth, w, h) \ + uint32_t aom_highbd_##bitdepth##_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, uint32_t *sse, \ + const uint8_t *second_pred) { \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * h]; \ + uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src); \ + \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, (h + 1), \ + xoffset); \ + highbd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, CONVERT_TO_SHORTPTR(second_pred)); \ + \ + return aom_highbd_##bitdepth##_variance##w##x##h(CONVERT_TO_BYTEPTR(tmp1), \ + w, ref, ref_stride, sse); \ + } + +#define HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int aom_highbd_##bitdepth##_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, uint32_t *sse, \ + const uint8_t *second_pred) { \ + uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src); \ + \ + if (xoffset == 0) { \ + uint16_t tmp[w * h]; \ + if (yoffset == 0) { \ + highbd_avg_pred(src_ptr, tmp, source_stride, w, h, \ + CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + highbd_avg_pred_var_filter_block2d_avg( \ + src_ptr, tmp, source_stride, source_stride, w, h, \ + CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref, ref_stride, sse); \ + } else { \ + highbd_avg_pred_var_filter_block2d_bil_w##w( \ + src_ptr, tmp, source_stride, source_stride, h, yoffset, \ + CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_avg_pred_var_filter_block2d_avg( \ + src_ptr, tmp0, source_stride, 1, w, h, \ + CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, source_stride, 1, w, \ + (h + 1)); \ + highbd_avg_pred_var_filter_block2d_avg( \ + tmp0, tmp1, w, w, w, h, CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, source_stride, 1, w, \ + (h + 1)); \ + highbd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } \ + } else { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_avg_pred_var_filter_block2d_bil_w##w( \ + src_ptr, tmp0, source_stride, 1, h, xoffset, \ + CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, source_stride, 1, \ + (h + 1), xoffset); \ + highbd_avg_pred_var_filter_block2d_avg( \ + tmp0, tmp1, w, w, w, h, CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, source_stride, 1, \ + (h + 1), xoffset); \ + highbd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, CONVERT_TO_SHORTPTR(second_pred)); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } \ + } \ + } + +// 8-bit +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 4) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 8) + +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 4) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 8) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 16) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 8) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 16) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 32) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 16) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 32) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 64) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 32) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 64) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 128) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 128, 64) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 16) + +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 32) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 4) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 64) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 8) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 10-bit +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 4, 4) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 4, 8) + +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 4) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 8) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 16) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 8) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 16) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 32) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 16) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 32) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 64) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 32) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 64) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 128) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 128, 64) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 4, 16) + +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 32) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 4) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 64) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 8) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 12-bit +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 4, 4) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 4, 8) + +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 4) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 8) +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 16) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 8) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 16) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 32) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 16) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 32) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 64) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 32) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 64) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 128) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 128, 64) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 4, 16) + +HBD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 32) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 4) +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 64) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 8) + +HBD_SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#define HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int \ + aom_highbd_##bitdepth##_masked_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \ + const uint8_t *msk, int msk_stride, int invert_mask, \ + unsigned int *sse) { \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * (h + 1)]; \ + uint16_t tmp2[w * h]; \ + uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src); \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, (h + 1), \ + xoffset); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp2), second_pred, w, \ + h, CONVERT_TO_BYTEPTR(tmp1), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h(CONVERT_TO_BYTEPTR(tmp2), \ + w, ref, ref_stride, sse); \ + } + +#define HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int \ + aom_highbd_##bitdepth##_masked_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \ + const uint8_t *msk, int msk_stride, int invert_mask, \ + unsigned int *sse) { \ + uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src); \ + if (xoffset == 0) { \ + uint16_t tmp0[w * h]; \ + if (yoffset == 0) { \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp0), second_pred, \ + w, h, src, src_stride, msk, msk_stride, \ + invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, src_stride, \ + w, h); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp1), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp0), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, \ + src_stride, h, yoffset); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp1), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp0), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, 1, w, h); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp1), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp0), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * h]; \ + uint16_t tmp2[w * h]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, 1, w, \ + (h + 1)); \ + highbd_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp2), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp1), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp2), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * h]; \ + uint16_t tmp2[w * h]; \ + highbd_var_filter_block2d_avg(src_ptr, tmp0, src_stride, 1, w, \ + (h + 1)); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp2), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp1), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp2), w, ref, ref_stride, sse); \ + } \ + } else { \ + if (yoffset == 0) { \ + uint16_t tmp0[w * h]; \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, h, \ + xoffset); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp1), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp0), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * h]; \ + uint16_t tmp2[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, \ + (h + 1), xoffset); \ + highbd_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp2), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp1), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp2), w, ref, ref_stride, sse); \ + } else { \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * (h + 1)]; \ + uint16_t tmp2[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src_ptr, tmp0, src_stride, 1, \ + (h + 1), xoffset); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + aom_highbd_comp_mask_pred_neon(CONVERT_TO_BYTEPTR(tmp2), second_pred, \ + w, h, CONVERT_TO_BYTEPTR(tmp1), w, msk, \ + msk_stride, invert_mask); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp2), w, ref, ref_stride, sse); \ + } \ + } \ + } + +// 8-bit +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 4, 4) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 4, 8) + +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 8, 4) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 8, 8) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 8, 16) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 8) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 16) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 32) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 16) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 32) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 64) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 32) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 64) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 128) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 128, 64) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 4, 16) + +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 8, 32) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 4) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 64) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 8) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 10-bit +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 4, 4) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 4, 8) + +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 8, 4) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 8, 8) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 8, 16) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 8) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 16) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 32) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 16) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 32) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 64) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 32) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 64) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 128) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 128, 64) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 4, 16) + +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 8, 32) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 4) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 16, 64) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 32, 8) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(10, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 12-bit +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 4, 4) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 4, 8) + +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 8, 4) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 8, 8) +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 8, 16) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 8) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 16) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 32) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 16) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 32) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 64) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 32) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 64) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 128) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 128, 64) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 4, 16) + +HBD_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 8, 32) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 4) +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 16, 64) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 32, 8) + +HBD_SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(12, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#if !CONFIG_REALTIME_ONLY +#define HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int \ + aom_highbd_##bitdepth##_obmc_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *pre, int pre_stride, int xoffset, int yoffset, \ + const int32_t *wsrc, const int32_t *mask, unsigned int *sse) { \ + uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre); \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(pre_ptr, tmp0, pre_stride, 1, h + 1, \ + xoffset); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp1), w, wsrc, mask, sse); \ + } + +#define SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int \ + aom_highbd_##bitdepth##_obmc_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *pre, int pre_stride, int xoffset, int yoffset, \ + const int32_t *wsrc, const int32_t *mask, unsigned int *sse) { \ + uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre); \ + if (xoffset == 0) { \ + if (yoffset == 0) { \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + pre, pre_stride, wsrc, mask, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp[w * h]; \ + highbd_var_filter_block2d_avg(pre_ptr, tmp, pre_stride, pre_stride, w, \ + h); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp), w, wsrc, mask, sse); \ + } else { \ + uint16_t tmp[w * h]; \ + highbd_var_filter_block2d_bil_w##w(pre_ptr, tmp, pre_stride, \ + pre_stride, h, yoffset); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp), w, wsrc, mask, sse); \ + } \ + } else if (xoffset == 4) { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_var_filter_block2d_avg(pre_ptr, tmp0, pre_stride, 1, w, h); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp0), w, wsrc, mask, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(pre_ptr, tmp0, pre_stride, 1, w, h + 1); \ + highbd_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp1), w, wsrc, mask, sse); \ + } else { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(pre_ptr, tmp0, pre_stride, 1, w, h + 1); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp1), w, wsrc, mask, sse); \ + } \ + } else { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_var_filter_block2d_bil_w##w(pre_ptr, tmp0, pre_stride, 1, h, \ + xoffset); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp0), w, wsrc, mask, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(pre_ptr, tmp0, pre_stride, 1, \ + h + 1, xoffset); \ + highbd_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp1), w, wsrc, mask, sse); \ + } else { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(pre_ptr, tmp0, pre_stride, 1, \ + h + 1, xoffset); \ + highbd_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ + CONVERT_TO_BYTEPTR(tmp1), w, wsrc, mask, sse); \ + } \ + } \ + } + +// 8-bit +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 4, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 4, 8) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 4, 16) + +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 8, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 8, 8) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 8, 16) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 8, 32) + +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 8) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 64) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 32, 8) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 32, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 32, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 32, 64) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 64, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 64, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 64, 64) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 64, 128) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 128, 64) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 128, 128) + +// 10-bit +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 4, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 4, 8) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 4, 16) + +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 8, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 8, 8) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 8, 16) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 8, 32) + +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 16, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 16, 8) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 16, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 16, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 16, 64) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 32, 8) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 32, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 32, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 32, 64) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 64, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 64, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 64, 64) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 64, 128) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 128, 64) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(10, 128, 128) + +// 12-bit +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 4, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 4, 8) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 4, 16) + +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 8, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 8, 8) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 8, 16) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 8, 32) + +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 16, 4) +HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 16, 8) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 16, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 16, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 16, 64) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 32, 8) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 32, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 32, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 32, 64) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 64, 16) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 64, 32) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 64, 64) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 64, 128) + +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 128, 64) +SPECIALIZED_HIGHBD_OBMC_SUBPEL_VARIANCE_WXH_NEON(12, 128, 128) +#endif // !CONFIG_REALTIME_ONLY + +static void highbd_dist_wtd_avg_pred(const uint16_t *src_ptr, uint16_t *dst_ptr, + int src_stride, int dst_width, + int dst_height, + const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + const uint16x8_t fwd_offset = vdupq_n_u16(jcp_param->fwd_offset); + const uint16x8_t bck_offset = vdupq_n_u16(jcp_param->bck_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint16x8_t s = vld1q_u16(src_ptr + j); + uint16x8_t p = vld1q_u16(second_pred); + + uint16x8_t avg = dist_wtd_avg_u16x8(s, p, fwd_offset, bck_offset); + + vst1q_u16(dst_ptr + j, avg); + + second_pred += 8; + j += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +static void highbd_dist_wtd_avg_pred_var_filter_block2d_avg( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + const uint16x8_t fwd_offset = vdupq_n_u16(jcp_param->fwd_offset); + const uint16x8_t bck_offset = vdupq_n_u16(jcp_param->bck_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src_ptr + j); + uint16x8_t s1 = vld1q_u16(src_ptr + j + pixel_step); + uint16x8_t p = vld1q_u16(second_pred); + uint16x8_t avg = vrhaddq_u16(s0, s1); + avg = dist_wtd_avg_u16x8(avg, p, fwd_offset, bck_offset); + + vst1q_u16(dst_ptr + j, avg); + + second_pred += 8; + j += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w4( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint16x4_t fwd_offset = vdup_n_u16(jcp_param->fwd_offset); + const uint16x4_t bck_offset = vdup_n_u16(jcp_param->bck_offset); + const uint16x4_t f0 = vdup_n_u16(8 - filter_offset); + const uint16x4_t f1 = vdup_n_u16(filter_offset); + + int i = dst_height; + do { + uint16x4_t s0 = load_unaligned_u16_4x1(src_ptr); + uint16x4_t s1 = load_unaligned_u16_4x1(src_ptr + pixel_step); + uint16x4_t p = vld1_u16(second_pred); + + uint16x4_t blend = vmul_u16(s0, f0); + blend = vmla_u16(blend, s1, f1); + blend = vrshr_n_u16(blend, 3); + + uint16x4_t avg = dist_wtd_avg_u16x4(blend, p, fwd_offset, bck_offset); + + vst1_u16(dst_ptr, avg); + + src_ptr += src_stride; + dst_ptr += 4; + second_pred += 4; + } while (--i != 0); +} + +// Combine bilinear filter with aom_dist_wtd_comp_avg_pred for large blocks. +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_large( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, int filter_offset, + const uint16_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint16x8_t fwd_offset = vdupq_n_u16(jcp_param->fwd_offset); + const uint16x8_t bck_offset = vdupq_n_u16(jcp_param->bck_offset); + const uint16x8_t f0 = vdupq_n_u16(8 - filter_offset); + const uint16x8_t f1 = vdupq_n_u16(filter_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint16x8_t s0 = vld1q_u16(src_ptr + j); + uint16x8_t s1 = vld1q_u16(src_ptr + j + pixel_step); + uint16x8_t p = vld1q_u16(second_pred); + + uint16x8_t blend = vmulq_u16(s0, f0); + blend = vmlaq_u16(blend, s1, f1); + blend = vrshrq_n_u16(blend, 3); + + uint16x8_t avg = dist_wtd_avg_u16x8(blend, p, fwd_offset, bck_offset); + + vst1q_u16(dst_ptr + j, avg); + + second_pred += 8; + j += 8; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w8( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 8, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 16. +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w16( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 16, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 32. +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w32( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 32, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 64. +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w64( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 64, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 128. +static void highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w128( + const uint16_t *src_ptr, uint16_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint16_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 128, dst_height, filter_offset, + second_pred, jcp_param); +} + +#define HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int \ + aom_highbd_##bitdepth##_dist_wtd_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src_ptr, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref_ptr, int ref_stride, uint32_t *sse, \ + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *second = CONVERT_TO_SHORTPTR(second_pred); \ + uint16_t tmp0[w * (h + 1)]; \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, h + 1, \ + xoffset); \ + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref_ptr, ref_stride, sse); \ + } + +#define SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(bitdepth, w, h) \ + unsigned int \ + aom_highbd_##bitdepth##_dist_wtd_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src_ptr, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref_ptr, int ref_stride, uint32_t *sse, \ + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *second = CONVERT_TO_SHORTPTR(second_pred); \ + if (xoffset == 0) { \ + uint16_t tmp[w * h]; \ + if (yoffset == 0) { \ + highbd_dist_wtd_avg_pred(src, tmp, source_stride, w, h, second, \ + jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref_ptr, ref_stride, sse); \ + } else if (yoffset == 4) { \ + highbd_dist_wtd_avg_pred_var_filter_block2d_avg( \ + src, tmp, source_stride, source_stride, w, h, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref_ptr, ref_stride, sse); \ + } else { \ + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + src, tmp, source_stride, source_stride, h, yoffset, second, \ + jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp), w, ref_ptr, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_dist_wtd_avg_pred_var_filter_block2d_avg( \ + src, tmp0, source_stride, 1, w, h, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref_ptr, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(src, tmp0, source_stride, 1, w, h + 1); \ + highbd_dist_wtd_avg_pred_var_filter_block2d_avg(tmp0, tmp1, w, w, w, \ + h, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref_ptr, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * (h + 1)]; \ + highbd_var_filter_block2d_avg(src, tmp0, source_stride, 1, w, h + 1); \ + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref_ptr, ref_stride, sse); \ + } \ + } else { \ + uint16_t tmp0[w * (h + 1)]; \ + if (yoffset == 0) { \ + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + src, tmp0, source_stride, 1, h, xoffset, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp0), w, ref_ptr, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, h + 1, \ + xoffset); \ + highbd_dist_wtd_avg_pred_var_filter_block2d_avg(tmp0, tmp1, w, w, w, \ + h, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref_ptr, ref_stride, sse); \ + } else { \ + uint16_t tmp1[w * h]; \ + highbd_var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, h + 1, \ + xoffset); \ + highbd_dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, second, jcp_param); \ + return aom_highbd_##bitdepth##_variance##w##x##h( \ + CONVERT_TO_BYTEPTR(tmp1), w, ref_ptr, ref_stride, sse); \ + } \ + } \ + } + +// 8-bit +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 4) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 8) + +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 4) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 8) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 16) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 8) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 16) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 32) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 16) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 32) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 64) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 32) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 64) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 128) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 128, 64) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 16) + +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 32) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 4) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 64) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 8) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 10-bit +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 4, 4) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 4, 8) + +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 4) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 8) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 16) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 8) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 16) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 32) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 16) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 32) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 64) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 32) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 64) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 128) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 128, 64) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 4, 16) + +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 8, 32) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 4) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 16, 64) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 32, 8) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(10, 64, 16) +#endif // !CONFIG_REALTIME_ONLY + +// 12-bit +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 4, 4) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 4, 8) + +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 4) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 8) +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 16) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 8) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 16) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 32) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 16) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 32) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 64) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 32) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 64) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 128) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 128, 64) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 128, 128) + +#if !CONFIG_REALTIME_ONLY +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 4, 16) + +HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 8, 32) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 4) +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 16, 64) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 32, 8) + +SPECIALIZED_HBD_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(12, 64, 16) +#endif // !CONFIG_REALTIME_ONLY diff --git a/third_party/aom/aom_dsp/arm/highbd_variance_neon.c b/third_party/aom/aom_dsp/arm/highbd_variance_neon.c new file mode 100644 index 0000000000..18b8efff4c --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_variance_neon.c @@ -0,0 +1,502 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/variance.h" + +// Process a block of width 4 two rows at a time. +static INLINE void highbd_variance_4xh_neon(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + int16x8_t sum_s16 = vdupq_n_s16(0); + int32x4_t sse_s32 = vdupq_n_s32(0); + + int i = h; + do { + const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride); + const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride); + + int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); + sum_s16 = vaddq_s16(sum_s16, diff); + + sse_s32 = vmlal_s16(sse_s32, vget_low_s16(diff), vget_low_s16(diff)); + sse_s32 = vmlal_s16(sse_s32, vget_high_s16(diff), vget_high_s16(diff)); + + src_ptr += 2 * src_stride; + ref_ptr += 2 * ref_stride; + i -= 2; + } while (i != 0); + + *sum = horizontal_add_s16x8(sum_s16); + *sse = horizontal_add_s32x4(sse_s32); +} + +// For 8-bit and 10-bit data, since we're using two int32x4 accumulators, all +// block sizes can be processed in 32-bit elements (1023*1023*128*32 = +// 4286582784 for a 128x128 block). +static INLINE void highbd_variance_large_neon(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int w, int h, + uint64_t *sse, int64_t *sum) { + int32x4_t sum_s32 = vdupq_n_s32(0); + int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + int i = h; + do { + int j = 0; + do { + const uint16x8_t s = vld1q_u16(src_ptr + j); + const uint16x8_t r = vld1q_u16(ref_ptr + j); + + const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); + sum_s32 = vpadalq_s16(sum_s32, diff); + + sse_s32[0] = + vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff)); + + j += 8; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + *sum = horizontal_add_s32x4(sum_s32); + *sse = horizontal_long_add_u32x4(vaddq_u32( + vreinterpretq_u32_s32(sse_s32[0]), vreinterpretq_u32_s32(sse_s32[1]))); +} + +static INLINE void highbd_variance_8xh_neon(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int h, uint64_t *sse, + int64_t *sum) { + highbd_variance_large_neon(src, src_stride, ref, ref_stride, 8, h, sse, sum); +} + +static INLINE void highbd_variance_16xh_neon(const uint16_t *src, + int src_stride, + const uint16_t *ref, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + highbd_variance_large_neon(src, src_stride, ref, ref_stride, 16, h, sse, sum); +} + +static INLINE void highbd_variance_32xh_neon(const uint16_t *src, + int src_stride, + const uint16_t *ref, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + highbd_variance_large_neon(src, src_stride, ref, ref_stride, 32, h, sse, sum); +} + +static INLINE void highbd_variance_64xh_neon(const uint16_t *src, + int src_stride, + const uint16_t *ref, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + highbd_variance_large_neon(src, src_stride, ref, ref_stride, 64, h, sse, sum); +} + +static INLINE void highbd_variance_128xh_neon(const uint16_t *src, + int src_stride, + const uint16_t *ref, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + highbd_variance_large_neon(src, src_stride, ref, ref_stride, 128, h, sse, + sum); +} + +// For 12-bit data, we can only accumulate up to 128 elements in the sum of +// squares (4095*4095*128 = 2146435200), and because we're using two int32x4 +// accumulators, we can only process up to 32 32-element rows (32*32/8 = 128) +// or 16 64-element rows before we have to accumulate into 64-bit elements. +// Therefore blocks of size 32x64, 64x32, 64x64, 64x128, 128x64, 128x128 are +// processed in a different helper function. + +// Process a block of any size where the width is divisible by 8, with +// accumulation into 64-bit elements. +static INLINE void highbd_variance_xlarge_neon( + const uint16_t *src_ptr, int src_stride, const uint16_t *ref_ptr, + int ref_stride, int w, int h, int h_limit, uint64_t *sse, int64_t *sum) { + int32x4_t sum_s32 = vdupq_n_s32(0); + int64x2_t sse_s64 = vdupq_n_s64(0); + + // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit + // accumulator overflows. After hitting this limit we accumulate into 64-bit + // elements. + int h_tmp = h > h_limit ? h_limit : h; + + int i = 0; + do { + int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + do { + int j = 0; + do { + const uint16x8_t s0 = vld1q_u16(src_ptr + j); + const uint16x8_t r0 = vld1q_u16(ref_ptr + j); + + const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s0, r0)); + sum_s32 = vpadalq_s16(sum_s32, diff); + + sse_s32[0] = + vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff)); + + j += 8; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + i++; + } while (i < h_tmp); + + sse_s64 = vpadalq_s32(sse_s64, sse_s32[0]); + sse_s64 = vpadalq_s32(sse_s64, sse_s32[1]); + h_tmp += h_limit; + } while (i < h); + + *sum = horizontal_add_s32x4(sum_s32); + *sse = (uint64_t)horizontal_add_s64x2(sse_s64); +} + +static INLINE void highbd_variance_32xh_xlarge_neon( + const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride, + int h, uint64_t *sse, int64_t *sum) { + highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 32, h, 32, sse, + sum); +} + +static INLINE void highbd_variance_64xh_xlarge_neon( + const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride, + int h, uint64_t *sse, int64_t *sum) { + highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 64, h, 16, sse, + sum); +} + +static INLINE void highbd_variance_128xh_xlarge_neon( + const uint16_t *src, int src_stride, const uint16_t *ref, int ref_stride, + int h, uint64_t *sse, int64_t *sum) { + highbd_variance_xlarge_neon(src, src_stride, ref, ref_stride, 128, h, 8, sse, + sum); +} + +#define HBD_VARIANCE_WXH_8_NEON(w, h) \ + uint32_t aom_highbd_8_variance##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)sse_long; \ + sum = (int)sum_long; \ + return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h)); \ + } + +#define HBD_VARIANCE_WXH_10_NEON(w, h) \ + uint32_t aom_highbd_10_variance##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + int64_t var; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4); \ + sum = (int)ROUND_POWER_OF_TWO(sum_long, 2); \ + var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ + return (var >= 0) ? (uint32_t)var : 0; \ + } + +#define HBD_VARIANCE_WXH_12_NEON(w, h) \ + uint32_t aom_highbd_12_variance##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + int64_t var; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \ + sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \ + var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ + return (var >= 0) ? (uint32_t)var : 0; \ + } + +#define HBD_VARIANCE_WXH_12_XLARGE_NEON(w, h) \ + uint32_t aom_highbd_12_variance##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + int64_t var; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_xlarge_neon(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \ + sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \ + var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ + return (var >= 0) ? (uint32_t)var : 0; \ + } + +// 8-bit +HBD_VARIANCE_WXH_8_NEON(4, 4) +HBD_VARIANCE_WXH_8_NEON(4, 8) + +HBD_VARIANCE_WXH_8_NEON(8, 4) +HBD_VARIANCE_WXH_8_NEON(8, 8) +HBD_VARIANCE_WXH_8_NEON(8, 16) + +HBD_VARIANCE_WXH_8_NEON(16, 8) +HBD_VARIANCE_WXH_8_NEON(16, 16) +HBD_VARIANCE_WXH_8_NEON(16, 32) + +HBD_VARIANCE_WXH_8_NEON(32, 16) +HBD_VARIANCE_WXH_8_NEON(32, 32) +HBD_VARIANCE_WXH_8_NEON(32, 64) + +HBD_VARIANCE_WXH_8_NEON(64, 32) +HBD_VARIANCE_WXH_8_NEON(64, 64) +HBD_VARIANCE_WXH_8_NEON(64, 128) + +HBD_VARIANCE_WXH_8_NEON(128, 64) +HBD_VARIANCE_WXH_8_NEON(128, 128) + +// 10-bit +HBD_VARIANCE_WXH_10_NEON(4, 4) +HBD_VARIANCE_WXH_10_NEON(4, 8) + +HBD_VARIANCE_WXH_10_NEON(8, 4) +HBD_VARIANCE_WXH_10_NEON(8, 8) +HBD_VARIANCE_WXH_10_NEON(8, 16) + +HBD_VARIANCE_WXH_10_NEON(16, 8) +HBD_VARIANCE_WXH_10_NEON(16, 16) +HBD_VARIANCE_WXH_10_NEON(16, 32) + +HBD_VARIANCE_WXH_10_NEON(32, 16) +HBD_VARIANCE_WXH_10_NEON(32, 32) +HBD_VARIANCE_WXH_10_NEON(32, 64) + +HBD_VARIANCE_WXH_10_NEON(64, 32) +HBD_VARIANCE_WXH_10_NEON(64, 64) +HBD_VARIANCE_WXH_10_NEON(64, 128) + +HBD_VARIANCE_WXH_10_NEON(128, 64) +HBD_VARIANCE_WXH_10_NEON(128, 128) + +// 12-bit +HBD_VARIANCE_WXH_12_NEON(4, 4) +HBD_VARIANCE_WXH_12_NEON(4, 8) + +HBD_VARIANCE_WXH_12_NEON(8, 4) +HBD_VARIANCE_WXH_12_NEON(8, 8) +HBD_VARIANCE_WXH_12_NEON(8, 16) + +HBD_VARIANCE_WXH_12_NEON(16, 8) +HBD_VARIANCE_WXH_12_NEON(16, 16) +HBD_VARIANCE_WXH_12_NEON(16, 32) + +HBD_VARIANCE_WXH_12_NEON(32, 16) +HBD_VARIANCE_WXH_12_NEON(32, 32) +HBD_VARIANCE_WXH_12_XLARGE_NEON(32, 64) + +HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 32) +HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 64) +HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 128) + +HBD_VARIANCE_WXH_12_XLARGE_NEON(128, 64) +HBD_VARIANCE_WXH_12_XLARGE_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +// 8-bit +HBD_VARIANCE_WXH_8_NEON(4, 16) + +HBD_VARIANCE_WXH_8_NEON(8, 32) + +HBD_VARIANCE_WXH_8_NEON(16, 4) +HBD_VARIANCE_WXH_8_NEON(16, 64) + +HBD_VARIANCE_WXH_8_NEON(32, 8) + +HBD_VARIANCE_WXH_8_NEON(64, 16) + +// 10-bit +HBD_VARIANCE_WXH_10_NEON(4, 16) + +HBD_VARIANCE_WXH_10_NEON(8, 32) + +HBD_VARIANCE_WXH_10_NEON(16, 4) +HBD_VARIANCE_WXH_10_NEON(16, 64) + +HBD_VARIANCE_WXH_10_NEON(32, 8) + +HBD_VARIANCE_WXH_10_NEON(64, 16) + +// 12-bit +HBD_VARIANCE_WXH_12_NEON(4, 16) + +HBD_VARIANCE_WXH_12_NEON(8, 32) + +HBD_VARIANCE_WXH_12_NEON(16, 4) +HBD_VARIANCE_WXH_12_NEON(16, 64) + +HBD_VARIANCE_WXH_12_NEON(32, 8) + +HBD_VARIANCE_WXH_12_NEON(64, 16) + +#endif // !CONFIG_REALTIME_ONLY + +static INLINE uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int w, int h, + unsigned int *sse) { + uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h; + do { + int j = 0; + do { + uint16x8_t s = vld1q_u16(src_ptr + j); + uint16x8_t r = vld1q_u16(ref_ptr + j); + + uint16x8_t diff = vabdq_u16(s, r); + + sse_u32[0] = + vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff)); + sse_u32[1] = + vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff)); + + j += 8; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); + return *sse; +} + +#define HIGHBD_MSE_WXH_NEON(w, h) \ + uint32_t aom_highbd_8_mse##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \ + return *sse; \ + } \ + \ + uint32_t aom_highbd_10_mse##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \ + *sse = ROUND_POWER_OF_TWO(*sse, 4); \ + return *sse; \ + } \ + \ + uint32_t aom_highbd_12_mse##w##x##h##_neon( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse); \ + *sse = ROUND_POWER_OF_TWO(*sse, 8); \ + return *sse; \ + } + +HIGHBD_MSE_WXH_NEON(16, 16) +HIGHBD_MSE_WXH_NEON(16, 8) +HIGHBD_MSE_WXH_NEON(8, 16) +HIGHBD_MSE_WXH_NEON(8, 8) + +#undef HIGHBD_MSE_WXH_NEON + +static INLINE uint64x2_t mse_accumulate_u16_8x2(uint64x2_t sum, uint16x8_t s0, + uint16x8_t s1, uint16x8_t d0, + uint16x8_t d1) { + uint16x8_t e0 = vabdq_u16(s0, d0); + uint16x8_t e1 = vabdq_u16(s1, d1); + + uint32x4_t mse = vmull_u16(vget_low_u16(e0), vget_low_u16(e0)); + mse = vmlal_u16(mse, vget_high_u16(e0), vget_high_u16(e0)); + mse = vmlal_u16(mse, vget_low_u16(e1), vget_low_u16(e1)); + mse = vmlal_u16(mse, vget_high_u16(e1), vget_high_u16(e1)); + + return vpadalq_u32(sum, mse); +} + +uint64_t aom_mse_wxh_16bit_highbd_neon(uint16_t *dst, int dstride, + uint16_t *src, int sstride, int w, + int h) { + assert((w == 8 || w == 4) && (h == 8 || h == 4)); + + uint64x2_t sum = vdupq_n_u64(0); + + if (w == 8) { + do { + uint16x8_t d0 = vld1q_u16(dst + 0 * dstride); + uint16x8_t d1 = vld1q_u16(dst + 1 * dstride); + uint16x8_t s0 = vld1q_u16(src + 0 * sstride); + uint16x8_t s1 = vld1q_u16(src + 1 * sstride); + + sum = mse_accumulate_u16_8x2(sum, s0, s1, d0, d1); + + dst += 2 * dstride; + src += 2 * sstride; + h -= 2; + } while (h != 0); + } else { // w == 4 + do { + uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride); + uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride); + uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride); + uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride); + + sum = mse_accumulate_u16_8x2(sum, s0, s1, d0, d1); + + dst += 4 * dstride; + src += 4 * sstride; + h -= 4; + } while (h != 0); + } + + return horizontal_add_u64x2(sum); +} diff --git a/third_party/aom/aom_dsp/arm/highbd_variance_neon_dotprod.c b/third_party/aom/aom_dsp/arm/highbd_variance_neon_dotprod.c new file mode 100644 index 0000000000..d56ae97571 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_variance_neon_dotprod.c @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2023 The WebM project authors. All Rights Reserved. + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "aom_dsp/arm/sum_neon.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE uint32_t highbd_mse8_8xh_neon_dotprod(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int h, + unsigned int *sse) { + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h / 2; + do { + uint16x8_t s0 = vld1q_u16(src_ptr); + src_ptr += src_stride; + uint16x8_t s1 = vld1q_u16(src_ptr); + src_ptr += src_stride; + uint16x8_t r0 = vld1q_u16(ref_ptr); + ref_ptr += ref_stride; + uint16x8_t r1 = vld1q_u16(ref_ptr); + ref_ptr += ref_stride; + + uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1)); + uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1)); + + uint8x16_t diff = vabdq_u8(s, r); + sse_u32 = vdotq_u32(sse_u32, diff, diff); + } while (--i != 0); + + *sse = horizontal_add_u32x4(sse_u32); + return *sse; +} + +static INLINE uint32_t highbd_mse8_16xh_neon_dotprod(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int h, + unsigned int *sse) { + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h; + do { + uint16x8_t s0 = vld1q_u16(src_ptr); + uint16x8_t s1 = vld1q_u16(src_ptr + 8); + uint16x8_t r0 = vld1q_u16(ref_ptr); + uint16x8_t r1 = vld1q_u16(ref_ptr + 8); + + uint8x16_t s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1)); + uint8x16_t r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1)); + + uint8x16_t diff = vabdq_u8(s, r); + sse_u32 = vdotq_u32(sse_u32, diff, diff); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + *sse = horizontal_add_u32x4(sse_u32); + return *sse; +} + +#define HIGHBD_MSE_WXH_NEON_DOTPROD(w, h) \ + uint32_t aom_highbd_8_mse##w##x##h##_neon_dotprod( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse8_##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, h, \ + sse); \ + return *sse; \ + } + +HIGHBD_MSE_WXH_NEON_DOTPROD(16, 16) +HIGHBD_MSE_WXH_NEON_DOTPROD(16, 8) +HIGHBD_MSE_WXH_NEON_DOTPROD(8, 16) +HIGHBD_MSE_WXH_NEON_DOTPROD(8, 8) + +#undef HIGHBD_MSE_WXH_NEON_DOTPROD diff --git a/third_party/aom/aom_dsp/arm/highbd_variance_sve.c b/third_party/aom/aom_dsp/arm/highbd_variance_sve.c new file mode 100644 index 0000000000..d0058bfa90 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/highbd_variance_sve.c @@ -0,0 +1,430 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/dot_sve.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/variance.h" + +// Process a block of width 4 two rows at a time. +static INLINE void highbd_variance_4xh_sve(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int h, uint64_t *sse, + int64_t *sum) { + int16x8_t sum_s16 = vdupq_n_s16(0); + int64x2_t sse_s64 = vdupq_n_s64(0); + + do { + const uint16x8_t s = load_unaligned_u16_4x2(src_ptr, src_stride); + const uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride); + + int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); + sum_s16 = vaddq_s16(sum_s16, diff); + + sse_s64 = aom_sdotq_s16(sse_s64, diff, diff); + + src_ptr += 2 * src_stride; + ref_ptr += 2 * ref_stride; + h -= 2; + } while (h != 0); + + *sum = vaddlvq_s16(sum_s16); + *sse = vaddvq_s64(sse_s64); +} + +static INLINE void variance_8x1_sve(const uint16_t *src, const uint16_t *ref, + int32x4_t *sum, int64x2_t *sse) { + const uint16x8_t s = vld1q_u16(src); + const uint16x8_t r = vld1q_u16(ref); + + const int16x8_t diff = vreinterpretq_s16_u16(vsubq_u16(s, r)); + *sum = vpadalq_s16(*sum, diff); + + *sse = aom_sdotq_s16(*sse, diff, diff); +} + +static INLINE void highbd_variance_8xh_sve(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int h, uint64_t *sse, + int64_t *sum) { + int32x4_t sum_s32 = vdupq_n_s32(0); + int64x2_t sse_s64 = vdupq_n_s64(0); + + do { + variance_8x1_sve(src_ptr, ref_ptr, &sum_s32, &sse_s64); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--h != 0); + + *sum = vaddlvq_s32(sum_s32); + *sse = vaddvq_s64(sse_s64); +} + +static INLINE void highbd_variance_16xh_sve(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int64x2_t sse_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + do { + variance_8x1_sve(src_ptr, ref_ptr, &sum_s32[0], &sse_s64[0]); + variance_8x1_sve(src_ptr + 8, ref_ptr + 8, &sum_s32[1], &sse_s64[1]); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--h != 0); + + *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[1])); + *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[1])); +} + +static INLINE void highbd_variance_large_sve(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int w, int h, + uint64_t *sse, int64_t *sum) { + int32x4_t sum_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), + vdupq_n_s32(0) }; + int64x2_t sse_s64[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0), + vdupq_n_s64(0) }; + + do { + int j = 0; + do { + variance_8x1_sve(src_ptr + j, ref_ptr + j, &sum_s32[0], &sse_s64[0]); + variance_8x1_sve(src_ptr + j + 8, ref_ptr + j + 8, &sum_s32[1], + &sse_s64[1]); + variance_8x1_sve(src_ptr + j + 16, ref_ptr + j + 16, &sum_s32[2], + &sse_s64[2]); + variance_8x1_sve(src_ptr + j + 24, ref_ptr + j + 24, &sum_s32[3], + &sse_s64[3]); + + j += 32; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--h != 0); + + sum_s32[0] = vaddq_s32(sum_s32[0], sum_s32[1]); + sum_s32[2] = vaddq_s32(sum_s32[2], sum_s32[3]); + *sum = vaddlvq_s32(vaddq_s32(sum_s32[0], sum_s32[2])); + sse_s64[0] = vaddq_s64(sse_s64[0], sse_s64[1]); + sse_s64[2] = vaddq_s64(sse_s64[2], sse_s64[3]); + *sse = vaddvq_s64(vaddq_s64(sse_s64[0], sse_s64[2])); +} + +static INLINE void highbd_variance_32xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int h, uint64_t *sse, + int64_t *sum) { + highbd_variance_large_sve(src, src_stride, ref, ref_stride, 32, h, sse, sum); +} + +static INLINE void highbd_variance_64xh_sve(const uint16_t *src, int src_stride, + const uint16_t *ref, int ref_stride, + int h, uint64_t *sse, + int64_t *sum) { + highbd_variance_large_sve(src, src_stride, ref, ref_stride, 64, h, sse, sum); +} + +static INLINE void highbd_variance_128xh_sve(const uint16_t *src, + int src_stride, + const uint16_t *ref, + int ref_stride, int h, + uint64_t *sse, int64_t *sum) { + highbd_variance_large_sve(src, src_stride, ref, ref_stride, 128, h, sse, sum); +} + +#define HBD_VARIANCE_WXH_8_SVE(w, h) \ + uint32_t aom_highbd_8_variance##w##x##h##_sve( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)sse_long; \ + sum = (int)sum_long; \ + return *sse - (uint32_t)(((int64_t)sum * sum) / (w * h)); \ + } + +#define HBD_VARIANCE_WXH_10_SVE(w, h) \ + uint32_t aom_highbd_10_variance##w##x##h##_sve( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + int64_t var; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4); \ + sum = (int)ROUND_POWER_OF_TWO(sum_long, 2); \ + var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ + return (var >= 0) ? (uint32_t)var : 0; \ + } + +#define HBD_VARIANCE_WXH_12_SVE(w, h) \ + uint32_t aom_highbd_12_variance##w##x##h##_sve( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + int sum; \ + int64_t var; \ + uint64_t sse_long = 0; \ + int64_t sum_long = 0; \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_variance_##w##xh_sve(src, src_stride, ref, ref_stride, h, \ + &sse_long, &sum_long); \ + *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8); \ + sum = (int)ROUND_POWER_OF_TWO(sum_long, 4); \ + var = (int64_t)(*sse) - (((int64_t)sum * sum) / (w * h)); \ + return (var >= 0) ? (uint32_t)var : 0; \ + } + +// 8-bit +HBD_VARIANCE_WXH_8_SVE(4, 4) +HBD_VARIANCE_WXH_8_SVE(4, 8) + +HBD_VARIANCE_WXH_8_SVE(8, 4) +HBD_VARIANCE_WXH_8_SVE(8, 8) +HBD_VARIANCE_WXH_8_SVE(8, 16) + +HBD_VARIANCE_WXH_8_SVE(16, 8) +HBD_VARIANCE_WXH_8_SVE(16, 16) +HBD_VARIANCE_WXH_8_SVE(16, 32) + +HBD_VARIANCE_WXH_8_SVE(32, 16) +HBD_VARIANCE_WXH_8_SVE(32, 32) +HBD_VARIANCE_WXH_8_SVE(32, 64) + +HBD_VARIANCE_WXH_8_SVE(64, 32) +HBD_VARIANCE_WXH_8_SVE(64, 64) +HBD_VARIANCE_WXH_8_SVE(64, 128) + +HBD_VARIANCE_WXH_8_SVE(128, 64) +HBD_VARIANCE_WXH_8_SVE(128, 128) + +// 10-bit +HBD_VARIANCE_WXH_10_SVE(4, 4) +HBD_VARIANCE_WXH_10_SVE(4, 8) + +HBD_VARIANCE_WXH_10_SVE(8, 4) +HBD_VARIANCE_WXH_10_SVE(8, 8) +HBD_VARIANCE_WXH_10_SVE(8, 16) + +HBD_VARIANCE_WXH_10_SVE(16, 8) +HBD_VARIANCE_WXH_10_SVE(16, 16) +HBD_VARIANCE_WXH_10_SVE(16, 32) + +HBD_VARIANCE_WXH_10_SVE(32, 16) +HBD_VARIANCE_WXH_10_SVE(32, 32) +HBD_VARIANCE_WXH_10_SVE(32, 64) + +HBD_VARIANCE_WXH_10_SVE(64, 32) +HBD_VARIANCE_WXH_10_SVE(64, 64) +HBD_VARIANCE_WXH_10_SVE(64, 128) + +HBD_VARIANCE_WXH_10_SVE(128, 64) +HBD_VARIANCE_WXH_10_SVE(128, 128) + +// 12-bit +HBD_VARIANCE_WXH_12_SVE(4, 4) +HBD_VARIANCE_WXH_12_SVE(4, 8) + +HBD_VARIANCE_WXH_12_SVE(8, 4) +HBD_VARIANCE_WXH_12_SVE(8, 8) +HBD_VARIANCE_WXH_12_SVE(8, 16) + +HBD_VARIANCE_WXH_12_SVE(16, 8) +HBD_VARIANCE_WXH_12_SVE(16, 16) +HBD_VARIANCE_WXH_12_SVE(16, 32) + +HBD_VARIANCE_WXH_12_SVE(32, 16) +HBD_VARIANCE_WXH_12_SVE(32, 32) +HBD_VARIANCE_WXH_12_SVE(32, 64) + +HBD_VARIANCE_WXH_12_SVE(64, 32) +HBD_VARIANCE_WXH_12_SVE(64, 64) +HBD_VARIANCE_WXH_12_SVE(64, 128) + +HBD_VARIANCE_WXH_12_SVE(128, 64) +HBD_VARIANCE_WXH_12_SVE(128, 128) + +#if !CONFIG_REALTIME_ONLY +// 8-bit +HBD_VARIANCE_WXH_8_SVE(4, 16) + +HBD_VARIANCE_WXH_8_SVE(8, 32) + +HBD_VARIANCE_WXH_8_SVE(16, 4) +HBD_VARIANCE_WXH_8_SVE(16, 64) + +HBD_VARIANCE_WXH_8_SVE(32, 8) + +HBD_VARIANCE_WXH_8_SVE(64, 16) + +// 10-bit +HBD_VARIANCE_WXH_10_SVE(4, 16) + +HBD_VARIANCE_WXH_10_SVE(8, 32) + +HBD_VARIANCE_WXH_10_SVE(16, 4) +HBD_VARIANCE_WXH_10_SVE(16, 64) + +HBD_VARIANCE_WXH_10_SVE(32, 8) + +HBD_VARIANCE_WXH_10_SVE(64, 16) + +// 12-bit +HBD_VARIANCE_WXH_12_SVE(4, 16) + +HBD_VARIANCE_WXH_12_SVE(8, 32) + +HBD_VARIANCE_WXH_12_SVE(16, 4) +HBD_VARIANCE_WXH_12_SVE(16, 64) + +HBD_VARIANCE_WXH_12_SVE(32, 8) + +HBD_VARIANCE_WXH_12_SVE(64, 16) + +#endif // !CONFIG_REALTIME_ONLY + +#undef HBD_VARIANCE_WXH_8_SVE +#undef HBD_VARIANCE_WXH_10_SVE +#undef HBD_VARIANCE_WXH_12_SVE + +static INLINE uint32_t highbd_mse_wxh_sve(const uint16_t *src_ptr, + int src_stride, + const uint16_t *ref_ptr, + int ref_stride, int w, int h, + unsigned int *sse) { + uint64x2_t sse_u64 = vdupq_n_u64(0); + + do { + int j = 0; + do { + uint16x8_t s = vld1q_u16(src_ptr + j); + uint16x8_t r = vld1q_u16(ref_ptr + j); + + uint16x8_t diff = vabdq_u16(s, r); + + sse_u64 = aom_udotq_u16(sse_u64, diff, diff); + + j += 8; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--h != 0); + + *sse = (uint32_t)vaddvq_u64(sse_u64); + return *sse; +} + +#define HIGHBD_MSE_WXH_SVE(w, h) \ + uint32_t aom_highbd_8_mse##w##x##h##_sve( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \ + return *sse; \ + } \ + \ + uint32_t aom_highbd_10_mse##w##x##h##_sve( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \ + *sse = ROUND_POWER_OF_TWO(*sse, 4); \ + return *sse; \ + } \ + \ + uint32_t aom_highbd_12_mse##w##x##h##_sve( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, uint32_t *sse) { \ + uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + highbd_mse_wxh_sve(src, src_stride, ref, ref_stride, w, h, sse); \ + *sse = ROUND_POWER_OF_TWO(*sse, 8); \ + return *sse; \ + } + +HIGHBD_MSE_WXH_SVE(16, 16) +HIGHBD_MSE_WXH_SVE(16, 8) +HIGHBD_MSE_WXH_SVE(8, 16) +HIGHBD_MSE_WXH_SVE(8, 8) + +#undef HIGHBD_MSE_WXH_SVE + +uint64_t aom_mse_wxh_16bit_highbd_sve(uint16_t *dst, int dstride, uint16_t *src, + int sstride, int w, int h) { + assert((w == 8 || w == 4) && (h == 8 || h == 4)); + + uint64x2_t sum = vdupq_n_u64(0); + + if (w == 8) { + do { + uint16x8_t d0 = vld1q_u16(dst + 0 * dstride); + uint16x8_t d1 = vld1q_u16(dst + 1 * dstride); + uint16x8_t s0 = vld1q_u16(src + 0 * sstride); + uint16x8_t s1 = vld1q_u16(src + 1 * sstride); + + uint16x8_t abs_diff0 = vabdq_u16(s0, d0); + uint16x8_t abs_diff1 = vabdq_u16(s1, d1); + + sum = aom_udotq_u16(sum, abs_diff0, abs_diff0); + sum = aom_udotq_u16(sum, abs_diff1, abs_diff1); + + dst += 2 * dstride; + src += 2 * sstride; + h -= 2; + } while (h != 0); + } else { // w == 4 + do { + uint16x8_t d0 = load_unaligned_u16_4x2(dst + 0 * dstride, dstride); + uint16x8_t d1 = load_unaligned_u16_4x2(dst + 2 * dstride, dstride); + uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride); + uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride); + + uint16x8_t abs_diff0 = vabdq_u16(s0, d0); + uint16x8_t abs_diff1 = vabdq_u16(s1, d1); + + sum = aom_udotq_u16(sum, abs_diff0, abs_diff0); + sum = aom_udotq_u16(sum, abs_diff1, abs_diff1); + + dst += 4 * dstride; + src += 4 * sstride; + h -= 4; + } while (h != 0); + } + + return vaddvq_u64(sum); +} diff --git a/third_party/aom/aom_dsp/arm/intrapred_neon.c b/third_party/aom/aom_dsp/arm/intrapred_neon.c new file mode 100644 index 0000000000..d8dc60c1fe --- /dev/null +++ b/third_party/aom/aom_dsp/arm/intrapred_neon.c @@ -0,0 +1,3110 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/reinterpret_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_dsp/intrapred_common.h" + +//------------------------------------------------------------------------------ +// DC 4x4 + +static INLINE uint16x8_t dc_load_sum_4(const uint8_t *in) { + const uint8x8_t a = load_u8_4x1(in); + const uint16x4_t p0 = vpaddl_u8(a); + const uint16x4_t p1 = vpadd_u16(p0, p0); + return vcombine_u16(p1, vdup_n_u16(0)); +} + +static INLINE void dc_store_4xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x8_t dc) { + for (int i = 0; i < h; ++i) { + store_u8_4x1(dst + i * stride, dc); + } +} + +void aom_dc_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_top = dc_load_sum_4(above); + const uint16x8_t sum_left = dc_load_sum_4(left); + const uint16x8_t sum = vaddq_u16(sum_left, sum_top); + const uint8x8_t dc0 = vrshrn_n_u16(sum, 3); + dc_store_4xh(dst, stride, 4, vdup_lane_u8(dc0, 0)); +} + +void aom_dc_left_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_left = dc_load_sum_4(left); + const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 2); + (void)above; + dc_store_4xh(dst, stride, 4, vdup_lane_u8(dc0, 0)); +} + +void aom_dc_top_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_top = dc_load_sum_4(above); + const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 2); + (void)left; + dc_store_4xh(dst, stride, 4, vdup_lane_u8(dc0, 0)); +} + +void aom_dc_128_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t dc0 = vdup_n_u8(0x80); + (void)above; + (void)left; + dc_store_4xh(dst, stride, 4, dc0); +} + +//------------------------------------------------------------------------------ +// DC 8x8 + +static INLINE uint16x8_t dc_load_sum_8(const uint8_t *in) { + // This isn't used in the case where we want to load both above and left + // vectors, since we want to avoid performing the reduction twice. + const uint8x8_t a = vld1_u8(in); + const uint16x4_t p0 = vpaddl_u8(a); + const uint16x4_t p1 = vpadd_u16(p0, p0); + const uint16x4_t p2 = vpadd_u16(p1, p1); + return vcombine_u16(p2, vdup_n_u16(0)); +} + +static INLINE uint16x8_t horizontal_add_and_broadcast_u16x8(uint16x8_t a) { +#if AOM_ARCH_AARCH64 + // On AArch64 we could also use vdupq_n_u16(vaddvq_u16(a)) here to save an + // instruction, however the addv instruction is usually slightly more + // expensive than a pairwise addition, so the need for immediately + // broadcasting the result again seems to negate any benefit. + const uint16x8_t b = vpaddq_u16(a, a); + const uint16x8_t c = vpaddq_u16(b, b); + return vpaddq_u16(c, c); +#else + const uint16x4_t b = vadd_u16(vget_low_u16(a), vget_high_u16(a)); + const uint16x4_t c = vpadd_u16(b, b); + const uint16x4_t d = vpadd_u16(c, c); + return vcombine_u16(d, d); +#endif +} + +static INLINE void dc_store_8xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x8_t dc) { + for (int i = 0; i < h; ++i) { + vst1_u8(dst + i * stride, dc); + } +} + +void aom_dc_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t sum_top = vld1_u8(above); + const uint8x8_t sum_left = vld1_u8(left); + uint16x8_t sum = vaddl_u8(sum_left, sum_top); + sum = horizontal_add_and_broadcast_u16x8(sum); + const uint8x8_t dc0 = vrshrn_n_u16(sum, 4); + dc_store_8xh(dst, stride, 8, vdup_lane_u8(dc0, 0)); +} + +void aom_dc_left_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_left = dc_load_sum_8(left); + const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 3); + (void)above; + dc_store_8xh(dst, stride, 8, vdup_lane_u8(dc0, 0)); +} + +void aom_dc_top_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_top = dc_load_sum_8(above); + const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 3); + (void)left; + dc_store_8xh(dst, stride, 8, vdup_lane_u8(dc0, 0)); +} + +void aom_dc_128_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t dc0 = vdup_n_u8(0x80); + (void)above; + (void)left; + dc_store_8xh(dst, stride, 8, dc0); +} + +//------------------------------------------------------------------------------ +// DC 16x16 + +static INLINE uint16x8_t dc_load_partial_sum_16(const uint8_t *in) { + const uint8x16_t a = vld1q_u8(in); + // delay the remainder of the reduction until + // horizontal_add_and_broadcast_u16x8, since we want to do it once rather + // than twice in the case we are loading both above and left. + return vpaddlq_u8(a); +} + +static INLINE uint16x8_t dc_load_sum_16(const uint8_t *in) { + return horizontal_add_and_broadcast_u16x8(dc_load_partial_sum_16(in)); +} + +static INLINE void dc_store_16xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x16_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u8(dst + i * stride, dc); + } +} + +void aom_dc_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_top = dc_load_partial_sum_16(above); + const uint16x8_t sum_left = dc_load_partial_sum_16(left); + uint16x8_t sum = vaddq_u16(sum_left, sum_top); + sum = horizontal_add_and_broadcast_u16x8(sum); + const uint8x8_t dc0 = vrshrn_n_u16(sum, 5); + dc_store_16xh(dst, stride, 16, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_left_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint16x8_t sum_left = dc_load_sum_16(left); + const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 4); + (void)above; + dc_store_16xh(dst, stride, 16, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_top_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint16x8_t sum_top = dc_load_sum_16(above); + const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 4); + (void)left; + dc_store_16xh(dst, stride, 16, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_128_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint8x16_t dc0 = vdupq_n_u8(0x80); + (void)above; + (void)left; + dc_store_16xh(dst, stride, 16, dc0); +} + +//------------------------------------------------------------------------------ +// DC 32x32 + +static INLINE uint16x8_t dc_load_partial_sum_32(const uint8_t *in) { + const uint8x16_t a0 = vld1q_u8(in); + const uint8x16_t a1 = vld1q_u8(in + 16); + // delay the remainder of the reduction until + // horizontal_add_and_broadcast_u16x8, since we want to do it once rather + // than twice in the case we are loading both above and left. + return vpadalq_u8(vpaddlq_u8(a0), a1); +} + +static INLINE uint16x8_t dc_load_sum_32(const uint8_t *in) { + return horizontal_add_and_broadcast_u16x8(dc_load_partial_sum_32(in)); +} + +static INLINE void dc_store_32xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x16_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u8(dst + i * stride, dc); + vst1q_u8(dst + i * stride + 16, dc); + } +} + +void aom_dc_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_top = dc_load_partial_sum_32(above); + const uint16x8_t sum_left = dc_load_partial_sum_32(left); + uint16x8_t sum = vaddq_u16(sum_left, sum_top); + sum = horizontal_add_and_broadcast_u16x8(sum); + const uint8x8_t dc0 = vrshrn_n_u16(sum, 6); + dc_store_32xh(dst, stride, 32, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_left_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint16x8_t sum_left = dc_load_sum_32(left); + const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 5); + (void)above; + dc_store_32xh(dst, stride, 32, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_top_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint16x8_t sum_top = dc_load_sum_32(above); + const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 5); + (void)left; + dc_store_32xh(dst, stride, 32, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_128_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint8x16_t dc0 = vdupq_n_u8(0x80); + (void)above; + (void)left; + dc_store_32xh(dst, stride, 32, dc0); +} + +//------------------------------------------------------------------------------ +// DC 64x64 + +static INLINE uint16x8_t dc_load_partial_sum_64(const uint8_t *in) { + const uint8x16_t a0 = vld1q_u8(in); + const uint8x16_t a1 = vld1q_u8(in + 16); + const uint8x16_t a2 = vld1q_u8(in + 32); + const uint8x16_t a3 = vld1q_u8(in + 48); + const uint16x8_t p01 = vpadalq_u8(vpaddlq_u8(a0), a1); + const uint16x8_t p23 = vpadalq_u8(vpaddlq_u8(a2), a3); + // delay the remainder of the reduction until + // horizontal_add_and_broadcast_u16x8, since we want to do it once rather + // than twice in the case we are loading both above and left. + return vaddq_u16(p01, p23); +} + +static INLINE uint16x8_t dc_load_sum_64(const uint8_t *in) { + return horizontal_add_and_broadcast_u16x8(dc_load_partial_sum_64(in)); +} + +static INLINE void dc_store_64xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x16_t dc) { + for (int i = 0; i < h; ++i) { + vst1q_u8(dst + i * stride, dc); + vst1q_u8(dst + i * stride + 16, dc); + vst1q_u8(dst + i * stride + 32, dc); + vst1q_u8(dst + i * stride + 48, dc); + } +} + +void aom_dc_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint16x8_t sum_top = dc_load_partial_sum_64(above); + const uint16x8_t sum_left = dc_load_partial_sum_64(left); + uint16x8_t sum = vaddq_u16(sum_left, sum_top); + sum = horizontal_add_and_broadcast_u16x8(sum); + const uint8x8_t dc0 = vrshrn_n_u16(sum, 7); + dc_store_64xh(dst, stride, 64, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_left_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint16x8_t sum_left = dc_load_sum_64(left); + const uint8x8_t dc0 = vrshrn_n_u16(sum_left, 6); + (void)above; + dc_store_64xh(dst, stride, 64, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_top_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint16x8_t sum_top = dc_load_sum_64(above); + const uint8x8_t dc0 = vrshrn_n_u16(sum_top, 6); + (void)left; + dc_store_64xh(dst, stride, 64, vdupq_lane_u8(dc0, 0)); +} + +void aom_dc_128_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, + const uint8_t *left) { + const uint8x16_t dc0 = vdupq_n_u8(0x80); + (void)above; + (void)left; + dc_store_64xh(dst, stride, 64, dc0); +} + +//------------------------------------------------------------------------------ +// DC rectangular cases + +#define DC_MULTIPLIER_1X2 0x5556 +#define DC_MULTIPLIER_1X4 0x3334 + +#define DC_SHIFT2 16 + +static INLINE int divide_using_multiply_shift(int num, int shift1, + int multiplier, int shift2) { + const int interm = num >> shift1; + return interm * multiplier >> shift2; +} + +static INLINE int calculate_dc_from_sum(int bw, int bh, uint32_t sum, + int shift1, int multiplier) { + const int expected_dc = divide_using_multiply_shift( + sum + ((bw + bh) >> 1), shift1, multiplier, DC_SHIFT2); + assert(expected_dc < (1 << 8)); + return expected_dc; +} + +#undef DC_SHIFT2 + +void aom_dc_predictor_4x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x8_t a = load_u8_4x1(above); + uint8x8_t l = vld1_u8(left); + uint32_t sum = horizontal_add_u16x8(vaddl_u8(a, l)); + uint32_t dc = calculate_dc_from_sum(4, 8, sum, 2, DC_MULTIPLIER_1X2); + dc_store_4xh(dst, stride, 8, vdup_n_u8(dc)); +} + +void aom_dc_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x8_t a = vld1_u8(above); + uint8x8_t l = load_u8_4x1(left); + uint32_t sum = horizontal_add_u16x8(vaddl_u8(a, l)); + uint32_t dc = calculate_dc_from_sum(8, 4, sum, 2, DC_MULTIPLIER_1X2); + dc_store_8xh(dst, stride, 4, vdup_n_u8(dc)); +} + +void aom_dc_predictor_4x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x8_t a = load_u8_4x1(above); + uint8x16_t l = vld1q_u8(left); + uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(l), a); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(4, 16, sum, 2, DC_MULTIPLIER_1X4); + dc_store_4xh(dst, stride, 16, vdup_n_u8(dc)); +} + +void aom_dc_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x16_t a = vld1q_u8(above); + uint8x8_t l = load_u8_4x1(left); + uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(a), l); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(16, 4, sum, 2, DC_MULTIPLIER_1X4); + dc_store_16xh(dst, stride, 4, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x8_t a = vld1_u8(above); + uint8x16_t l = vld1q_u8(left); + uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(l), a); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(8, 16, sum, 3, DC_MULTIPLIER_1X2); + dc_store_8xh(dst, stride, 16, vdup_n_u8(dc)); +} + +void aom_dc_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x16_t a = vld1q_u8(above); + uint8x8_t l = vld1_u8(left); + uint16x8_t sum_al = vaddw_u8(vpaddlq_u8(a), l); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(16, 8, sum, 3, DC_MULTIPLIER_1X2); + dc_store_16xh(dst, stride, 8, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint8x8_t a = vld1_u8(above); + uint16x8_t sum_left = dc_load_partial_sum_32(left); + uint16x8_t sum_al = vaddw_u8(sum_left, a); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(8, 32, sum, 3, DC_MULTIPLIER_1X4); + dc_store_8xh(dst, stride, 32, vdup_n_u8(dc)); +} + +void aom_dc_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_top = dc_load_partial_sum_32(above); + uint8x8_t l = vld1_u8(left); + uint16x8_t sum_al = vaddw_u8(sum_top, l); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(32, 8, sum, 3, DC_MULTIPLIER_1X4); + dc_store_32xh(dst, stride, 8, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_above = dc_load_partial_sum_16(above); + uint16x8_t sum_left = dc_load_partial_sum_32(left); + uint16x8_t sum_al = vaddq_u16(sum_left, sum_above); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(16, 32, sum, 4, DC_MULTIPLIER_1X2); + dc_store_16xh(dst, stride, 32, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_above = dc_load_partial_sum_32(above); + uint16x8_t sum_left = dc_load_partial_sum_16(left); + uint16x8_t sum_al = vaddq_u16(sum_left, sum_above); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(32, 16, sum, 4, DC_MULTIPLIER_1X2); + dc_store_32xh(dst, stride, 16, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_above = dc_load_partial_sum_16(above); + uint16x8_t sum_left = dc_load_partial_sum_64(left); + uint16x8_t sum_al = vaddq_u16(sum_left, sum_above); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(16, 64, sum, 4, DC_MULTIPLIER_1X4); + dc_store_16xh(dst, stride, 64, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_above = dc_load_partial_sum_64(above); + uint16x8_t sum_left = dc_load_partial_sum_16(left); + uint16x8_t sum_al = vaddq_u16(sum_above, sum_left); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(64, 16, sum, 4, DC_MULTIPLIER_1X4); + dc_store_64xh(dst, stride, 16, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_above = dc_load_partial_sum_32(above); + uint16x8_t sum_left = dc_load_partial_sum_64(left); + uint16x8_t sum_al = vaddq_u16(sum_above, sum_left); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(32, 64, sum, 5, DC_MULTIPLIER_1X2); + dc_store_32xh(dst, stride, 64, vdupq_n_u8(dc)); +} + +void aom_dc_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + uint16x8_t sum_above = dc_load_partial_sum_64(above); + uint16x8_t sum_left = dc_load_partial_sum_32(left); + uint16x8_t sum_al = vaddq_u16(sum_above, sum_left); + uint32_t sum = horizontal_add_u16x8(sum_al); + uint32_t dc = calculate_dc_from_sum(64, 32, sum, 5, DC_MULTIPLIER_1X2); + dc_store_64xh(dst, stride, 32, vdupq_n_u8(dc)); +} + +#undef DC_MULTIPLIER_1X2 +#undef DC_MULTIPLIER_1X4 + +#define DC_PREDICTOR_128(w, h, q) \ + void aom_dc_128_predictor_##w##x##h##_neon(uint8_t *dst, ptrdiff_t stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + (void)above; \ + (void)left; \ + dc_store_##w##xh(dst, stride, (h), vdup##q##_n_u8(0x80)); \ + } + +DC_PREDICTOR_128(4, 8, ) +DC_PREDICTOR_128(4, 16, ) +DC_PREDICTOR_128(8, 4, ) +DC_PREDICTOR_128(8, 16, ) +DC_PREDICTOR_128(8, 32, ) +DC_PREDICTOR_128(16, 4, q) +DC_PREDICTOR_128(16, 8, q) +DC_PREDICTOR_128(16, 32, q) +DC_PREDICTOR_128(16, 64, q) +DC_PREDICTOR_128(32, 8, q) +DC_PREDICTOR_128(32, 16, q) +DC_PREDICTOR_128(32, 64, q) +DC_PREDICTOR_128(64, 32, q) +DC_PREDICTOR_128(64, 16, q) + +#undef DC_PREDICTOR_128 + +#define DC_PREDICTOR_LEFT(w, h, shift, q) \ + void aom_dc_left_predictor_##w##x##h##_neon(uint8_t *dst, ptrdiff_t stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + (void)above; \ + const uint16x8_t sum = dc_load_sum_##h(left); \ + const uint8x8_t dc0 = vrshrn_n_u16(sum, (shift)); \ + dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u8(dc0, 0)); \ + } + +DC_PREDICTOR_LEFT(4, 8, 3, ) +DC_PREDICTOR_LEFT(8, 4, 2, ) +DC_PREDICTOR_LEFT(8, 16, 4, ) +DC_PREDICTOR_LEFT(16, 8, 3, q) +DC_PREDICTOR_LEFT(16, 32, 5, q) +DC_PREDICTOR_LEFT(32, 16, 4, q) +DC_PREDICTOR_LEFT(32, 64, 6, q) +DC_PREDICTOR_LEFT(64, 32, 5, q) +DC_PREDICTOR_LEFT(4, 16, 4, ) +DC_PREDICTOR_LEFT(16, 4, 2, q) +DC_PREDICTOR_LEFT(8, 32, 5, ) +DC_PREDICTOR_LEFT(32, 8, 3, q) +DC_PREDICTOR_LEFT(16, 64, 6, q) +DC_PREDICTOR_LEFT(64, 16, 4, q) + +#undef DC_PREDICTOR_LEFT + +#define DC_PREDICTOR_TOP(w, h, shift, q) \ + void aom_dc_top_predictor_##w##x##h##_neon(uint8_t *dst, ptrdiff_t stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + (void)left; \ + const uint16x8_t sum = dc_load_sum_##w(above); \ + const uint8x8_t dc0 = vrshrn_n_u16(sum, (shift)); \ + dc_store_##w##xh(dst, stride, (h), vdup##q##_lane_u8(dc0, 0)); \ + } + +DC_PREDICTOR_TOP(4, 8, 2, ) +DC_PREDICTOR_TOP(4, 16, 2, ) +DC_PREDICTOR_TOP(8, 4, 3, ) +DC_PREDICTOR_TOP(8, 16, 3, ) +DC_PREDICTOR_TOP(8, 32, 3, ) +DC_PREDICTOR_TOP(16, 4, 4, q) +DC_PREDICTOR_TOP(16, 8, 4, q) +DC_PREDICTOR_TOP(16, 32, 4, q) +DC_PREDICTOR_TOP(16, 64, 4, q) +DC_PREDICTOR_TOP(32, 8, 5, q) +DC_PREDICTOR_TOP(32, 16, 5, q) +DC_PREDICTOR_TOP(32, 64, 5, q) +DC_PREDICTOR_TOP(64, 16, 6, q) +DC_PREDICTOR_TOP(64, 32, 6, q) + +#undef DC_PREDICTOR_TOP + +// ----------------------------------------------------------------------------- + +static INLINE void v_store_4xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x8_t d0) { + for (int i = 0; i < h; ++i) { + store_u8_4x1(dst + i * stride, d0); + } +} + +static INLINE void v_store_8xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x8_t d0) { + for (int i = 0; i < h; ++i) { + vst1_u8(dst + i * stride, d0); + } +} + +static INLINE void v_store_16xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x16_t d0) { + for (int i = 0; i < h; ++i) { + vst1q_u8(dst + i * stride, d0); + } +} + +static INLINE void v_store_32xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x16_t d0, uint8x16_t d1) { + for (int i = 0; i < h; ++i) { + vst1q_u8(dst + 0, d0); + vst1q_u8(dst + 16, d1); + dst += stride; + } +} + +static INLINE void v_store_64xh(uint8_t *dst, ptrdiff_t stride, int h, + uint8x16_t d0, uint8x16_t d1, uint8x16_t d2, + uint8x16_t d3) { + for (int i = 0; i < h; ++i) { + vst1q_u8(dst + 0, d0); + vst1q_u8(dst + 16, d1); + vst1q_u8(dst + 32, d2); + vst1q_u8(dst + 48, d3); + dst += stride; + } +} + +void aom_v_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_4xh(dst, stride, 4, load_u8_4x1(above)); +} + +void aom_v_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_8xh(dst, stride, 8, vld1_u8(above)); +} + +void aom_v_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_16xh(dst, stride, 16, vld1q_u8(above)); +} + +void aom_v_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + (void)left; + v_store_32xh(dst, stride, 32, d0, d1); +} + +void aom_v_predictor_4x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_4xh(dst, stride, 8, load_u8_4x1(above)); +} + +void aom_v_predictor_4x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_4xh(dst, stride, 16, load_u8_4x1(above)); +} + +void aom_v_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_8xh(dst, stride, 4, vld1_u8(above)); +} + +void aom_v_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_8xh(dst, stride, 16, vld1_u8(above)); +} + +void aom_v_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_8xh(dst, stride, 32, vld1_u8(above)); +} + +void aom_v_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_16xh(dst, stride, 4, vld1q_u8(above)); +} + +void aom_v_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_16xh(dst, stride, 8, vld1q_u8(above)); +} + +void aom_v_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_16xh(dst, stride, 32, vld1q_u8(above)); +} + +void aom_v_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)left; + v_store_16xh(dst, stride, 64, vld1q_u8(above)); +} + +void aom_v_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + (void)left; + v_store_32xh(dst, stride, 8, d0, d1); +} + +void aom_v_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + (void)left; + v_store_32xh(dst, stride, 16, d0, d1); +} + +void aom_v_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + (void)left; + v_store_32xh(dst, stride, 64, d0, d1); +} + +void aom_v_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + const uint8x16_t d2 = vld1q_u8(above + 32); + const uint8x16_t d3 = vld1q_u8(above + 48); + (void)left; + v_store_64xh(dst, stride, 16, d0, d1, d2, d3); +} + +void aom_v_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + const uint8x16_t d2 = vld1q_u8(above + 32); + const uint8x16_t d3 = vld1q_u8(above + 48); + (void)left; + v_store_64xh(dst, stride, 32, d0, d1, d2, d3); +} + +void aom_v_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(above); + const uint8x16_t d1 = vld1q_u8(above + 16); + const uint8x16_t d2 = vld1q_u8(above + 32); + const uint8x16_t d3 = vld1q_u8(above + 48); + (void)left; + v_store_64xh(dst, stride, 64, d0, d1, d2, d3); +} + +// ----------------------------------------------------------------------------- + +static INLINE void h_store_4x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) { + store_u8_4x1(dst + 0 * stride, vdup_lane_u8(d0, 0)); + store_u8_4x1(dst + 1 * stride, vdup_lane_u8(d0, 1)); + store_u8_4x1(dst + 2 * stride, vdup_lane_u8(d0, 2)); + store_u8_4x1(dst + 3 * stride, vdup_lane_u8(d0, 3)); + store_u8_4x1(dst + 4 * stride, vdup_lane_u8(d0, 4)); + store_u8_4x1(dst + 5 * stride, vdup_lane_u8(d0, 5)); + store_u8_4x1(dst + 6 * stride, vdup_lane_u8(d0, 6)); + store_u8_4x1(dst + 7 * stride, vdup_lane_u8(d0, 7)); +} + +static INLINE void h_store_8x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) { + vst1_u8(dst + 0 * stride, vdup_lane_u8(d0, 0)); + vst1_u8(dst + 1 * stride, vdup_lane_u8(d0, 1)); + vst1_u8(dst + 2 * stride, vdup_lane_u8(d0, 2)); + vst1_u8(dst + 3 * stride, vdup_lane_u8(d0, 3)); + vst1_u8(dst + 4 * stride, vdup_lane_u8(d0, 4)); + vst1_u8(dst + 5 * stride, vdup_lane_u8(d0, 5)); + vst1_u8(dst + 6 * stride, vdup_lane_u8(d0, 6)); + vst1_u8(dst + 7 * stride, vdup_lane_u8(d0, 7)); +} + +static INLINE void h_store_16x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) { + vst1q_u8(dst + 0 * stride, vdupq_lane_u8(d0, 0)); + vst1q_u8(dst + 1 * stride, vdupq_lane_u8(d0, 1)); + vst1q_u8(dst + 2 * stride, vdupq_lane_u8(d0, 2)); + vst1q_u8(dst + 3 * stride, vdupq_lane_u8(d0, 3)); + vst1q_u8(dst + 4 * stride, vdupq_lane_u8(d0, 4)); + vst1q_u8(dst + 5 * stride, vdupq_lane_u8(d0, 5)); + vst1q_u8(dst + 6 * stride, vdupq_lane_u8(d0, 6)); + vst1q_u8(dst + 7 * stride, vdupq_lane_u8(d0, 7)); +} + +static INLINE void h_store_32x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) { + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 0)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 0)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 1)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 1)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 2)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 2)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 3)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 3)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 4)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 4)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 5)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 5)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 6)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 6)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 7)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 7)); +} + +static INLINE void h_store_64x8(uint8_t *dst, ptrdiff_t stride, uint8x8_t d0) { + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 0)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 0)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 0)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 0)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 1)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 1)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 1)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 1)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 2)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 2)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 2)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 2)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 3)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 3)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 3)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 3)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 4)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 4)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 4)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 4)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 5)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 5)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 5)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 5)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 6)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 6)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 6)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 6)); + dst += stride; + vst1q_u8(dst + 0, vdupq_lane_u8(d0, 7)); + vst1q_u8(dst + 16, vdupq_lane_u8(d0, 7)); + vst1q_u8(dst + 32, vdupq_lane_u8(d0, 7)); + vst1q_u8(dst + 48, vdupq_lane_u8(d0, 7)); +} + +void aom_h_predictor_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = load_u8_4x1(left); + (void)above; + store_u8_4x1(dst + 0 * stride, vdup_lane_u8(d0, 0)); + store_u8_4x1(dst + 1 * stride, vdup_lane_u8(d0, 1)); + store_u8_4x1(dst + 2 * stride, vdup_lane_u8(d0, 2)); + store_u8_4x1(dst + 3 * stride, vdup_lane_u8(d0, 3)); +} + +void aom_h_predictor_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = vld1_u8(left); + (void)above; + h_store_8x8(dst, stride, d0); +} + +void aom_h_predictor_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + (void)above; + h_store_16x8(dst, stride, vget_low_u8(d0)); + h_store_16x8(dst + 8 * stride, stride, vget_high_u8(d0)); +} + +void aom_h_predictor_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + const uint8x16_t d1 = vld1q_u8(left + 16); + (void)above; + h_store_32x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_32x8(dst + 8 * stride, stride, vget_high_u8(d0)); + h_store_32x8(dst + 16 * stride, stride, vget_low_u8(d1)); + h_store_32x8(dst + 24 * stride, stride, vget_high_u8(d1)); +} + +void aom_h_predictor_4x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = vld1_u8(left); + (void)above; + h_store_4x8(dst, stride, d0); +} + +void aom_h_predictor_4x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + (void)above; + h_store_4x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_4x8(dst + 8 * stride, stride, vget_high_u8(d0)); +} + +void aom_h_predictor_8x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = load_u8_4x1(left); + (void)above; + vst1_u8(dst + 0 * stride, vdup_lane_u8(d0, 0)); + vst1_u8(dst + 1 * stride, vdup_lane_u8(d0, 1)); + vst1_u8(dst + 2 * stride, vdup_lane_u8(d0, 2)); + vst1_u8(dst + 3 * stride, vdup_lane_u8(d0, 3)); +} + +void aom_h_predictor_8x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + (void)above; + h_store_8x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_8x8(dst + 8 * stride, stride, vget_high_u8(d0)); +} + +void aom_h_predictor_8x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + const uint8x16_t d1 = vld1q_u8(left + 16); + (void)above; + h_store_8x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_8x8(dst + 8 * stride, stride, vget_high_u8(d0)); + h_store_8x8(dst + 16 * stride, stride, vget_low_u8(d1)); + h_store_8x8(dst + 24 * stride, stride, vget_high_u8(d1)); +} + +void aom_h_predictor_16x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = load_u8_4x1(left); + (void)above; + vst1q_u8(dst + 0 * stride, vdupq_lane_u8(d0, 0)); + vst1q_u8(dst + 1 * stride, vdupq_lane_u8(d0, 1)); + vst1q_u8(dst + 2 * stride, vdupq_lane_u8(d0, 2)); + vst1q_u8(dst + 3 * stride, vdupq_lane_u8(d0, 3)); +} + +void aom_h_predictor_16x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = vld1_u8(left); + (void)above; + h_store_16x8(dst, stride, d0); +} + +void aom_h_predictor_16x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + const uint8x16_t d1 = vld1q_u8(left + 16); + (void)above; + h_store_16x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_16x8(dst + 8 * stride, stride, vget_high_u8(d0)); + h_store_16x8(dst + 16 * stride, stride, vget_low_u8(d1)); + h_store_16x8(dst + 24 * stride, stride, vget_high_u8(d1)); +} + +void aom_h_predictor_16x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + const uint8x16_t d1 = vld1q_u8(left + 16); + const uint8x16_t d2 = vld1q_u8(left + 32); + const uint8x16_t d3 = vld1q_u8(left + 48); + (void)above; + h_store_16x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_16x8(dst + 8 * stride, stride, vget_high_u8(d0)); + h_store_16x8(dst + 16 * stride, stride, vget_low_u8(d1)); + h_store_16x8(dst + 24 * stride, stride, vget_high_u8(d1)); + h_store_16x8(dst + 32 * stride, stride, vget_low_u8(d2)); + h_store_16x8(dst + 40 * stride, stride, vget_high_u8(d2)); + h_store_16x8(dst + 48 * stride, stride, vget_low_u8(d3)); + h_store_16x8(dst + 56 * stride, stride, vget_high_u8(d3)); +} + +void aom_h_predictor_32x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x8_t d0 = vld1_u8(left); + (void)above; + h_store_32x8(dst, stride, d0); +} + +void aom_h_predictor_32x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + (void)above; + h_store_32x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_32x8(dst + 8 * stride, stride, vget_high_u8(d0)); +} + +void aom_h_predictor_32x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left + 0); + const uint8x16_t d1 = vld1q_u8(left + 16); + const uint8x16_t d2 = vld1q_u8(left + 32); + const uint8x16_t d3 = vld1q_u8(left + 48); + (void)above; + h_store_32x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_32x8(dst + 8 * stride, stride, vget_high_u8(d0)); + h_store_32x8(dst + 16 * stride, stride, vget_low_u8(d1)); + h_store_32x8(dst + 24 * stride, stride, vget_high_u8(d1)); + h_store_32x8(dst + 32 * stride, stride, vget_low_u8(d2)); + h_store_32x8(dst + 40 * stride, stride, vget_high_u8(d2)); + h_store_32x8(dst + 48 * stride, stride, vget_low_u8(d3)); + h_store_32x8(dst + 56 * stride, stride, vget_high_u8(d3)); +} + +void aom_h_predictor_64x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + const uint8x16_t d0 = vld1q_u8(left); + (void)above; + h_store_64x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_64x8(dst + 8 * stride, stride, vget_high_u8(d0)); +} + +void aom_h_predictor_64x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)above; + for (int i = 0; i < 2; ++i) { + const uint8x16_t d0 = vld1q_u8(left); + h_store_64x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_64x8(dst + 8 * stride, stride, vget_high_u8(d0)); + left += 16; + dst += 16 * stride; + } +} + +void aom_h_predictor_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left) { + (void)above; + for (int i = 0; i < 4; ++i) { + const uint8x16_t d0 = vld1q_u8(left); + h_store_64x8(dst + 0 * stride, stride, vget_low_u8(d0)); + h_store_64x8(dst + 8 * stride, stride, vget_high_u8(d0)); + left += 16; + dst += 16 * stride; + } +} + +/* ---------------------P R E D I C T I O N Z 1--------------------------- */ + +// Low bit depth functions +static DECLARE_ALIGNED(32, uint8_t, BaseMask[33][32]) = { + { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }, +}; + +static AOM_FORCE_INLINE void dr_prediction_z1_HxW_internal_neon_64( + int H, int W, uint8x8_t *dst, const uint8_t *above, int upsample_above, + int dx) { + const int frac_bits = 6 - upsample_above; + const int max_base_x = ((W + H) - 1) << upsample_above; + + assert(dx > 0); + // pre-filter above pixels + // store in temp buffers: + // above[x] * 32 + 16 + // above[x+1] - above[x] + // final pixels will be calculated as: + // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + + const uint8x8_t a_mbase_x = vdup_n_u8(above[max_base_x]); + + int x = dx; + for (int r = 0; r < W; r++) { + int base = x >> frac_bits; + int base_max_diff = (max_base_x - base) >> upsample_above; + if (base_max_diff <= 0) { + for (int i = r; i < W; ++i) { + dst[i] = a_mbase_x; // save 4 values + } + return; + } + + if (base_max_diff > H) base_max_diff = H; + + uint8x8x2_t a01_128; + uint16x8_t shift; + if (upsample_above) { + a01_128 = vld2_u8(above + base); + shift = vdupq_n_u16(((x << upsample_above) & 0x3f) >> 1); + } else { + a01_128.val[0] = vld1_u8(above + base); + a01_128.val[1] = vld1_u8(above + base + 1); + shift = vdupq_n_u16((x & 0x3f) >> 1); + } + uint16x8_t diff = vsubl_u8(a01_128.val[1], a01_128.val[0]); + uint16x8_t a32 = vmlal_u8(vdupq_n_u16(16), a01_128.val[0], vdup_n_u8(32)); + uint16x8_t res = vmlaq_u16(a32, diff, shift); + + uint8x8_t mask = vld1_u8(BaseMask[base_max_diff]); + dst[r] = vbsl_u8(mask, vshrn_n_u16(res, 5), a_mbase_x); + + x += dx; + } +} + +static void dr_prediction_z1_4xN_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, int upsample_above, + int dx) { + uint8x8_t dstvec[16]; + + dr_prediction_z1_HxW_internal_neon_64(4, N, dstvec, above, upsample_above, + dx); + for (int i = 0; i < N; i++) { + vst1_lane_u32((uint32_t *)(dst + stride * i), + vreinterpret_u32_u8(dstvec[i]), 0); + } +} + +static void dr_prediction_z1_8xN_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, int upsample_above, + int dx) { + uint8x8_t dstvec[32]; + + dr_prediction_z1_HxW_internal_neon_64(8, N, dstvec, above, upsample_above, + dx); + for (int i = 0; i < N; i++) { + vst1_u8(dst + stride * i, dstvec[i]); + } +} + +static AOM_FORCE_INLINE void dr_prediction_z1_HxW_internal_neon( + int H, int W, uint8x16_t *dst, const uint8_t *above, int upsample_above, + int dx) { + const int frac_bits = 6 - upsample_above; + const int max_base_x = ((W + H) - 1) << upsample_above; + + assert(dx > 0); + // pre-filter above pixels + // store in temp buffers: + // above[x] * 32 + 16 + // above[x+1] - above[x] + // final pixels will be calculated as: + // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + + const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]); + + int x = dx; + for (int r = 0; r < W; r++) { + int base = x >> frac_bits; + int base_max_diff = (max_base_x - base) >> upsample_above; + if (base_max_diff <= 0) { + for (int i = r; i < W; ++i) { + dst[i] = a_mbase_x; // save 4 values + } + return; + } + + if (base_max_diff > H) base_max_diff = H; + + uint16x8_t shift; + uint8x16_t a0_128, a1_128; + if (upsample_above) { + uint8x8x2_t v_tmp_a0_128 = vld2_u8(above + base); + a0_128 = vcombine_u8(v_tmp_a0_128.val[0], v_tmp_a0_128.val[1]); + a1_128 = vextq_u8(a0_128, vdupq_n_u8(0), 8); + shift = vdupq_n_u16(x & 0x1f); + } else { + a0_128 = vld1q_u8(above + base); + a1_128 = vld1q_u8(above + base + 1); + shift = vdupq_n_u16((x & 0x3f) >> 1); + } + uint16x8_t diff_lo = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128)); + uint16x8_t diff_hi = vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128)); + uint16x8_t a32_lo = + vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_128), vdup_n_u8(32)); + uint16x8_t a32_hi = + vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_128), vdup_n_u8(32)); + uint16x8_t res_lo = vmlaq_u16(a32_lo, diff_lo, shift); + uint16x8_t res_hi = vmlaq_u16(a32_hi, diff_hi, shift); + uint8x16_t v_temp = + vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5)); + + uint8x16_t mask = vld1q_u8(BaseMask[base_max_diff]); + dst[r] = vbslq_u8(mask, v_temp, a_mbase_x); + + x += dx; + } +} + +static void dr_prediction_z1_16xN_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, int upsample_above, + int dx) { + uint8x16_t dstvec[64]; + + dr_prediction_z1_HxW_internal_neon(16, N, dstvec, above, upsample_above, dx); + for (int i = 0; i < N; i++) { + vst1q_u8(dst + stride * i, dstvec[i]); + } +} + +static AOM_FORCE_INLINE void dr_prediction_z1_32xN_internal_neon( + int N, uint8x16x2_t *dstvec, const uint8_t *above, int dx) { + const int frac_bits = 6; + const int max_base_x = ((32 + N) - 1); + + // pre-filter above pixels + // store in temp buffers: + // above[x] * 32 + 16 + // above[x+1] - above[x] + // final pixels will be calculated as: + // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + + const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]); + + int x = dx; + for (int r = 0; r < N; r++) { + int base = x >> frac_bits; + int base_max_diff = (max_base_x - base); + if (base_max_diff <= 0) { + for (int i = r; i < N; ++i) { + dstvec[i].val[0] = a_mbase_x; // save 32 values + dstvec[i].val[1] = a_mbase_x; + } + return; + } + if (base_max_diff > 32) base_max_diff = 32; + + uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1); + + uint8x16_t res16[2]; + for (int j = 0, jj = 0; j < 32; j += 16, jj++) { + int mdiff = base_max_diff - j; + if (mdiff <= 0) { + res16[jj] = a_mbase_x; + } else { + uint8x16_t a0_128 = vld1q_u8(above + base + j); + uint8x16_t a1_128 = vld1q_u8(above + base + j + 1); + uint16x8_t diff_lo = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128)); + uint16x8_t diff_hi = + vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128)); + uint16x8_t a32_lo = + vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_128), vdup_n_u8(32)); + uint16x8_t a32_hi = + vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_128), vdup_n_u8(32)); + uint16x8_t res_lo = vmlaq_u16(a32_lo, diff_lo, shift); + uint16x8_t res_hi = vmlaq_u16(a32_hi, diff_hi, shift); + + res16[jj] = vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5)); + } + } + + uint8x16_t mask_lo = vld1q_u8(BaseMask[base_max_diff]); + uint8x16_t mask_hi = vld1q_u8(BaseMask[base_max_diff] + 16); + dstvec[r].val[0] = vbslq_u8(mask_lo, res16[0], a_mbase_x); + dstvec[r].val[1] = vbslq_u8(mask_hi, res16[1], a_mbase_x); + x += dx; + } +} + +static void dr_prediction_z1_32xN_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, int dx) { + uint8x16x2_t dstvec[64]; + + dr_prediction_z1_32xN_internal_neon(N, dstvec, above, dx); + for (int i = 0; i < N; i++) { + vst1q_u8(dst + stride * i, dstvec[i].val[0]); + vst1q_u8(dst + stride * i + 16, dstvec[i].val[1]); + } +} + +static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, int dx) { + const int frac_bits = 6; + const int max_base_x = ((64 + N) - 1); + + // pre-filter above pixels + // store in temp buffers: + // above[x] * 32 + 16 + // above[x+1] - above[x] + // final pixels will be calculated as: + // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + + const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]); + const uint8x16_t max_base_x128 = vdupq_n_u8(max_base_x); + + int x = dx; + for (int r = 0; r < N; r++, dst += stride) { + int base = x >> frac_bits; + if (base >= max_base_x) { + for (int i = r; i < N; ++i) { + vst1q_u8(dst, a_mbase_x); + vst1q_u8(dst + 16, a_mbase_x); + vst1q_u8(dst + 32, a_mbase_x); + vst1q_u8(dst + 48, a_mbase_x); + dst += stride; + } + return; + } + + uint16x8_t shift = vdupq_n_u16((x & 0x3f) >> 1); + uint8x16_t base_inc128 = + vaddq_u8(vdupq_n_u8(base), vcombine_u8(vcreate_u8(0x0706050403020100), + vcreate_u8(0x0F0E0D0C0B0A0908))); + + for (int j = 0; j < 64; j += 16) { + int mdif = max_base_x - (base + j); + if (mdif <= 0) { + vst1q_u8(dst + j, a_mbase_x); + } else { + uint8x16_t a0_128 = vld1q_u8(above + base + j); + uint8x16_t a1_128 = vld1q_u8(above + base + 1 + j); + uint16x8_t diff_lo = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128)); + uint16x8_t diff_hi = + vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128)); + uint16x8_t a32_lo = + vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_128), vdup_n_u8(32)); + uint16x8_t a32_hi = + vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_128), vdup_n_u8(32)); + uint16x8_t res_lo = vmlaq_u16(a32_lo, diff_lo, shift); + uint16x8_t res_hi = vmlaq_u16(a32_hi, diff_hi, shift); + uint8x16_t v_temp = + vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5)); + + uint8x16_t mask128 = + vcgtq_u8(vqsubq_u8(max_base_x128, base_inc128), vdupq_n_u8(0)); + uint8x16_t res128 = vbslq_u8(mask128, v_temp, a_mbase_x); + vst1q_u8(dst + j, res128); + + base_inc128 = vaddq_u8(base_inc128, vdupq_n_u8(16)); + } + } + x += dx; + } +} + +// Directional prediction, zone 1: 0 < angle < 90 +void av1_dr_prediction_z1_neon(uint8_t *dst, ptrdiff_t stride, int bw, int bh, + const uint8_t *above, const uint8_t *left, + int upsample_above, int dx, int dy) { + (void)left; + (void)dy; + + switch (bw) { + case 4: + dr_prediction_z1_4xN_neon(bh, dst, stride, above, upsample_above, dx); + break; + case 8: + dr_prediction_z1_8xN_neon(bh, dst, stride, above, upsample_above, dx); + break; + case 16: + dr_prediction_z1_16xN_neon(bh, dst, stride, above, upsample_above, dx); + break; + case 32: dr_prediction_z1_32xN_neon(bh, dst, stride, above, dx); break; + case 64: dr_prediction_z1_64xN_neon(bh, dst, stride, above, dx); break; + default: break; + } +} + +/* ---------------------P R E D I C T I O N Z 2--------------------------- */ + +#if !AOM_ARCH_AARCH64 +static DECLARE_ALIGNED(16, uint8_t, LoadMaskz2[4][16]) = { + { 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, + 0, 0, 0 }, + { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff } +}; +#endif // !AOM_ARCH_AARCH64 + +static AOM_FORCE_INLINE void dr_prediction_z2_Nx4_above_neon( + const uint8_t *above, int upsample_above, int dx, int base_x, int y, + uint8x8_t *a0_x, uint8x8_t *a1_x, uint16x4_t *shift0) { + uint16x4_t r6 = vcreate_u16(0x00C0008000400000); + uint16x4_t ydx = vdup_n_u16(y * dx); + if (upsample_above) { + // Cannot use LD2 here since we only want to load eight bytes, but LD2 can + // only load either 16 or 32. + uint8x8_t v_tmp = vld1_u8(above + base_x); + *a0_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[0]; + *a1_x = vuzp_u8(v_tmp, vdup_n_u8(0)).val[1]; + *shift0 = vand_u16(vsub_u16(r6, ydx), vdup_n_u16(0x1f)); + } else { + *a0_x = load_u8_4x1(above + base_x); + *a1_x = load_u8_4x1(above + base_x + 1); + *shift0 = vand_u16(vhsub_u16(r6, ydx), vdup_n_u16(0x1f)); + } +} + +static AOM_FORCE_INLINE void dr_prediction_z2_Nx4_left_neon( +#if AOM_ARCH_AARCH64 + uint8x16x2_t left_vals, +#else + const uint8_t *left, +#endif + int upsample_left, int dy, int r, int min_base_y, int frac_bits_y, + uint16x4_t *a0_y, uint16x4_t *a1_y, uint16x4_t *shift1) { + int16x4_t dy64 = vdup_n_s16(dy); + int16x4_t v_1234 = vcreate_s16(0x0004000300020001); + int16x4_t v_frac_bits_y = vdup_n_s16(-frac_bits_y); + int16x4_t min_base_y64 = vdup_n_s16(min_base_y); + int16x4_t v_r6 = vdup_n_s16(r << 6); + int16x4_t y_c64 = vmls_s16(v_r6, v_1234, dy64); + int16x4_t base_y_c64 = vshl_s16(y_c64, v_frac_bits_y); + + // Values in base_y_c64 range from -2 through 14 inclusive. + base_y_c64 = vmax_s16(base_y_c64, min_base_y64); + +#if AOM_ARCH_AARCH64 + uint8x8_t left_idx0 = + vreinterpret_u8_s16(vadd_s16(base_y_c64, vdup_n_s16(2))); // [0, 16] + uint8x8_t left_idx1 = + vreinterpret_u8_s16(vadd_s16(base_y_c64, vdup_n_s16(3))); // [1, 17] + + *a0_y = vreinterpret_u16_u8(vqtbl2_u8(left_vals, left_idx0)); + *a1_y = vreinterpret_u16_u8(vqtbl2_u8(left_vals, left_idx1)); +#else // !AOM_ARCH_AARCH64 + DECLARE_ALIGNED(32, int16_t, base_y_c[4]); + + vst1_s16(base_y_c, base_y_c64); + uint8x8_t a0_y_u8 = vdup_n_u8(0); + a0_y_u8 = vld1_lane_u8(left + base_y_c[0], a0_y_u8, 0); + a0_y_u8 = vld1_lane_u8(left + base_y_c[1], a0_y_u8, 2); + a0_y_u8 = vld1_lane_u8(left + base_y_c[2], a0_y_u8, 4); + a0_y_u8 = vld1_lane_u8(left + base_y_c[3], a0_y_u8, 6); + + base_y_c64 = vadd_s16(base_y_c64, vdup_n_s16(1)); + vst1_s16(base_y_c, base_y_c64); + uint8x8_t a1_y_u8 = vdup_n_u8(0); + a1_y_u8 = vld1_lane_u8(left + base_y_c[0], a1_y_u8, 0); + a1_y_u8 = vld1_lane_u8(left + base_y_c[1], a1_y_u8, 2); + a1_y_u8 = vld1_lane_u8(left + base_y_c[2], a1_y_u8, 4); + a1_y_u8 = vld1_lane_u8(left + base_y_c[3], a1_y_u8, 6); + + *a0_y = vreinterpret_u16_u8(a0_y_u8); + *a1_y = vreinterpret_u16_u8(a1_y_u8); +#endif // AOM_ARCH_AARCH64 + + if (upsample_left) { + *shift1 = vand_u16(vreinterpret_u16_s16(y_c64), vdup_n_u16(0x1f)); + } else { + *shift1 = + vand_u16(vshr_n_u16(vreinterpret_u16_s16(y_c64), 1), vdup_n_u16(0x1f)); + } +} + +static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_Nx8_above_neon( + const uint8_t *above, int upsample_above, int dx, int base_x, int y) { + uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001), + vcreate_u16(0x0008000700060005)); + uint16x8_t ydx = vdupq_n_u16(y * dx); + uint16x8_t r6 = vshlq_n_u16(vextq_u16(c1234, vdupq_n_u16(0), 2), 6); + + uint16x8_t shift0; + uint8x8_t a0_x0; + uint8x8_t a1_x0; + if (upsample_above) { + uint8x8x2_t v_tmp = vld2_u8(above + base_x); + a0_x0 = v_tmp.val[0]; + a1_x0 = v_tmp.val[1]; + shift0 = vandq_u16(vsubq_u16(r6, ydx), vdupq_n_u16(0x1f)); + } else { + a0_x0 = vld1_u8(above + base_x); + a1_x0 = vld1_u8(above + base_x + 1); + shift0 = vandq_u16(vhsubq_u16(r6, ydx), vdupq_n_u16(0x1f)); + } + + uint16x8_t diff0 = vsubl_u8(a1_x0, a0_x0); // a[x+1] - a[x] + uint16x8_t a32 = + vmlal_u8(vdupq_n_u16(16), a0_x0, vdup_n_u8(32)); // a[x] * 32 + 16 + uint16x8_t res = vmlaq_u16(a32, diff0, shift0); + return vshrn_n_u16(res, 5); +} + +static AOM_FORCE_INLINE uint8x8_t dr_prediction_z2_Nx8_left_neon( +#if AOM_ARCH_AARCH64 + uint8x16x3_t left_vals, +#else + const uint8_t *left, +#endif + int upsample_left, int dy, int r, int min_base_y, int frac_bits_y) { + int16x8_t v_r6 = vdupq_n_s16(r << 6); + int16x8_t dy128 = vdupq_n_s16(dy); + int16x8_t v_frac_bits_y = vdupq_n_s16(-frac_bits_y); + int16x8_t min_base_y128 = vdupq_n_s16(min_base_y); + + uint16x8_t c1234 = vcombine_u16(vcreate_u16(0x0004000300020001), + vcreate_u16(0x0008000700060005)); + int16x8_t y_c128 = vmlsq_s16(v_r6, vreinterpretq_s16_u16(c1234), dy128); + int16x8_t base_y_c128 = vshlq_s16(y_c128, v_frac_bits_y); + + // Values in base_y_c128 range from -2 through 31 inclusive. + base_y_c128 = vmaxq_s16(base_y_c128, min_base_y128); + +#if AOM_ARCH_AARCH64 + uint8x16_t left_idx0 = + vreinterpretq_u8_s16(vaddq_s16(base_y_c128, vdupq_n_s16(2))); // [0, 33] + uint8x16_t left_idx1 = + vreinterpretq_u8_s16(vaddq_s16(base_y_c128, vdupq_n_s16(3))); // [1, 34] + uint8x16_t left_idx01 = vuzp1q_u8(left_idx0, left_idx1); + + uint8x16_t a01_x = vqtbl3q_u8(left_vals, left_idx01); + uint8x8_t a0_x1 = vget_low_u8(a01_x); + uint8x8_t a1_x1 = vget_high_u8(a01_x); +#else // !AOM_ARCH_AARCH64 + uint8x8_t a0_x1 = load_u8_gather_s16_x8(left, base_y_c128); + uint8x8_t a1_x1 = load_u8_gather_s16_x8(left + 1, base_y_c128); +#endif // AOM_ARCH_AARCH64 + + uint16x8_t shift1; + if (upsample_left) { + shift1 = vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x1f)); + } else { + shift1 = vshrq_n_u16( + vandq_u16(vreinterpretq_u16_s16(y_c128), vdupq_n_u16(0x3f)), 1); + } + + uint16x8_t diff1 = vsubl_u8(a1_x1, a0_x1); + uint16x8_t a32 = vmlal_u8(vdupq_n_u16(16), a0_x1, vdup_n_u8(32)); + uint16x8_t res = vmlaq_u16(a32, diff1, shift1); + return vshrn_n_u16(res, 5); +} + +static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_NxW_above_neon( + const uint8_t *above, int dx, int base_x, int y, int j) { + uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000), + vcreate_u16(0x0007000600050004)), + vcombine_u16(vcreate_u16(0x000B000A00090008), + vcreate_u16(0x000F000E000D000C)) } }; + uint16x8_t j256 = vdupq_n_u16(j); + uint16x8_t ydx = vdupq_n_u16((uint16_t)(y * dx)); + + const uint8x16_t a0_x128 = vld1q_u8(above + base_x + j); + const uint8x16_t a1_x128 = vld1q_u8(above + base_x + j + 1); + uint16x8_t res6_0 = vshlq_n_u16(vaddq_u16(c0123.val[0], j256), 6); + uint16x8_t res6_1 = vshlq_n_u16(vaddq_u16(c0123.val[1], j256), 6); + uint16x8_t shift0 = + vshrq_n_u16(vandq_u16(vsubq_u16(res6_0, ydx), vdupq_n_u16(0x3f)), 1); + uint16x8_t shift1 = + vshrq_n_u16(vandq_u16(vsubq_u16(res6_1, ydx), vdupq_n_u16(0x3f)), 1); + // a[x+1] - a[x] + uint16x8_t diff0 = vsubl_u8(vget_low_u8(a1_x128), vget_low_u8(a0_x128)); + uint16x8_t diff1 = vsubl_u8(vget_high_u8(a1_x128), vget_high_u8(a0_x128)); + // a[x] * 32 + 16 + uint16x8_t a32_0 = + vmlal_u8(vdupq_n_u16(16), vget_low_u8(a0_x128), vdup_n_u8(32)); + uint16x8_t a32_1 = + vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_x128), vdup_n_u8(32)); + uint16x8_t res0 = vmlaq_u16(a32_0, diff0, shift0); + uint16x8_t res1 = vmlaq_u16(a32_1, diff1, shift1); + return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5)); +} + +static AOM_FORCE_INLINE uint8x16_t dr_prediction_z2_NxW_left_neon( +#if AOM_ARCH_AARCH64 + uint8x16x4_t left_vals0, uint8x16x4_t left_vals1, +#else + const uint8_t *left, +#endif + int dy, int r, int j) { + // here upsample_above and upsample_left are 0 by design of + // av1_use_intra_edge_upsample + const int min_base_y = -1; + + int16x8_t min_base_y256 = vdupq_n_s16(min_base_y); + int16x8_t half_min_base_y256 = vdupq_n_s16(min_base_y >> 1); + int16x8_t dy256 = vdupq_n_s16(dy); + uint16x8_t j256 = vdupq_n_u16(j); + + uint16x8x2_t c0123 = { { vcombine_u16(vcreate_u16(0x0003000200010000), + vcreate_u16(0x0007000600050004)), + vcombine_u16(vcreate_u16(0x000B000A00090008), + vcreate_u16(0x000F000E000D000C)) } }; + uint16x8x2_t c1234 = { { vaddq_u16(c0123.val[0], vdupq_n_u16(1)), + vaddq_u16(c0123.val[1], vdupq_n_u16(1)) } }; + + int16x8_t v_r6 = vdupq_n_s16(r << 6); + + int16x8_t c256_0 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[0])); + int16x8_t c256_1 = vreinterpretq_s16_u16(vaddq_u16(j256, c1234.val[1])); + int16x8_t mul16_lo = vreinterpretq_s16_u16( + vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_0, dy256)), + vreinterpretq_u16_s16(half_min_base_y256))); + int16x8_t mul16_hi = vreinterpretq_s16_u16( + vminq_u16(vreinterpretq_u16_s16(vmulq_s16(c256_1, dy256)), + vreinterpretq_u16_s16(half_min_base_y256))); + int16x8_t y_c256_lo = vsubq_s16(v_r6, mul16_lo); + int16x8_t y_c256_hi = vsubq_s16(v_r6, mul16_hi); + + int16x8_t base_y_c256_lo = vshrq_n_s16(y_c256_lo, 6); + int16x8_t base_y_c256_hi = vshrq_n_s16(y_c256_hi, 6); + + base_y_c256_lo = vmaxq_s16(min_base_y256, base_y_c256_lo); + base_y_c256_hi = vmaxq_s16(min_base_y256, base_y_c256_hi); + +#if !AOM_ARCH_AARCH64 + int16_t min_y = vgetq_lane_s16(base_y_c256_hi, 7); + int16_t max_y = vgetq_lane_s16(base_y_c256_lo, 0); + int16_t offset_diff = max_y - min_y; + + uint8x8_t a0_y0; + uint8x8_t a0_y1; + uint8x8_t a1_y0; + uint8x8_t a1_y1; + if (offset_diff < 16) { + // Avoid gathers where the data we want is close together in memory. + // We don't need this for AArch64 since we can already use TBL to cover the + // full range of possible values. + assert(offset_diff >= 0); + int16x8_t min_y256 = vdupq_lane_s16(vget_high_s16(base_y_c256_hi), 3); + + int16x8x2_t base_y_offset; + base_y_offset.val[0] = vsubq_s16(base_y_c256_lo, min_y256); + base_y_offset.val[1] = vsubq_s16(base_y_c256_hi, min_y256); + + int8x16_t base_y_offset128 = vcombine_s8(vqmovn_s16(base_y_offset.val[0]), + vqmovn_s16(base_y_offset.val[1])); + + uint8x16_t v_loadmaskz2 = vld1q_u8(LoadMaskz2[offset_diff / 4]); + uint8x16_t a0_y128 = vld1q_u8(left + min_y); + uint8x16_t a1_y128 = vld1q_u8(left + min_y + 1); + a0_y128 = vandq_u8(a0_y128, v_loadmaskz2); + a1_y128 = vandq_u8(a1_y128, v_loadmaskz2); + + uint8x8_t v_index_low = vget_low_u8(vreinterpretq_u8_s8(base_y_offset128)); + uint8x8_t v_index_high = + vget_high_u8(vreinterpretq_u8_s8(base_y_offset128)); + uint8x8x2_t v_tmp, v_res; + v_tmp.val[0] = vget_low_u8(a0_y128); + v_tmp.val[1] = vget_high_u8(a0_y128); + v_res.val[0] = vtbl2_u8(v_tmp, v_index_low); + v_res.val[1] = vtbl2_u8(v_tmp, v_index_high); + a0_y128 = vcombine_u8(v_res.val[0], v_res.val[1]); + v_tmp.val[0] = vget_low_u8(a1_y128); + v_tmp.val[1] = vget_high_u8(a1_y128); + v_res.val[0] = vtbl2_u8(v_tmp, v_index_low); + v_res.val[1] = vtbl2_u8(v_tmp, v_index_high); + a1_y128 = vcombine_u8(v_res.val[0], v_res.val[1]); + + a0_y0 = vget_low_u8(a0_y128); + a0_y1 = vget_high_u8(a0_y128); + a1_y0 = vget_low_u8(a1_y128); + a1_y1 = vget_high_u8(a1_y128); + } else { + a0_y0 = load_u8_gather_s16_x8(left, base_y_c256_lo); + a0_y1 = load_u8_gather_s16_x8(left, base_y_c256_hi); + a1_y0 = load_u8_gather_s16_x8(left + 1, base_y_c256_lo); + a1_y1 = load_u8_gather_s16_x8(left + 1, base_y_c256_hi); + } +#else + // Values in left_idx{0,1} range from 0 through 63 inclusive. + uint8x16_t left_idx0 = + vreinterpretq_u8_s16(vaddq_s16(base_y_c256_lo, vdupq_n_s16(1))); + uint8x16_t left_idx1 = + vreinterpretq_u8_s16(vaddq_s16(base_y_c256_hi, vdupq_n_s16(1))); + uint8x16_t left_idx01 = vuzp1q_u8(left_idx0, left_idx1); + + uint8x16_t a0_y01 = vqtbl4q_u8(left_vals0, left_idx01); + uint8x16_t a1_y01 = vqtbl4q_u8(left_vals1, left_idx01); + + uint8x8_t a0_y0 = vget_low_u8(a0_y01); + uint8x8_t a0_y1 = vget_high_u8(a0_y01); + uint8x8_t a1_y0 = vget_low_u8(a1_y01); + uint8x8_t a1_y1 = vget_high_u8(a1_y01); +#endif // !AOM_ARCH_AARCH64 + + uint16x8_t shifty_lo = vshrq_n_u16( + vandq_u16(vreinterpretq_u16_s16(y_c256_lo), vdupq_n_u16(0x3f)), 1); + uint16x8_t shifty_hi = vshrq_n_u16( + vandq_u16(vreinterpretq_u16_s16(y_c256_hi), vdupq_n_u16(0x3f)), 1); + + // a[x+1] - a[x] + uint16x8_t diff_lo = vsubl_u8(a1_y0, a0_y0); + uint16x8_t diff_hi = vsubl_u8(a1_y1, a0_y1); + // a[x] * 32 + 16 + uint16x8_t a32_lo = vmlal_u8(vdupq_n_u16(16), a0_y0, vdup_n_u8(32)); + uint16x8_t a32_hi = vmlal_u8(vdupq_n_u16(16), a0_y1, vdup_n_u8(32)); + + uint16x8_t res0 = vmlaq_u16(a32_lo, diff_lo, shifty_lo); + uint16x8_t res1 = vmlaq_u16(a32_hi, diff_hi, shifty_hi); + + return vcombine_u8(vshrn_n_u16(res0, 5), vshrn_n_u16(res1, 5)); +} + +static void dr_prediction_z2_Nx4_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left, + int upsample_above, int upsample_left, + int dx, int dy) { + const int min_base_x = -(1 << upsample_above); + const int min_base_y = -(1 << upsample_left); + const int frac_bits_x = 6 - upsample_above; + const int frac_bits_y = 6 - upsample_left; + + assert(dx > 0); + // pre-filter above pixels + // store in temp buffers: + // above[x] * 32 + 16 + // above[x+1] - above[x] + // final pixels will be calculated as: + // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + +#if AOM_ARCH_AARCH64 + // Use ext rather than loading left + 14 directly to avoid over-read. + const uint8x16_t left_m2 = vld1q_u8(left - 2); + const uint8x16_t left_0 = vld1q_u8(left); + const uint8x16_t left_14 = vextq_u8(left_0, left_0, 14); + const uint8x16x2_t left_vals = { { left_m2, left_14 } }; +#define LEFT left_vals +#else // !AOM_ARCH_AARCH64 +#define LEFT left +#endif // AOM_ARCH_AARCH64 + + for (int r = 0; r < N; r++) { + int y = r + 1; + int base_x = (-y * dx) >> frac_bits_x; + const int base_min_diff = + (min_base_x - ((-y * dx) >> frac_bits_x) + upsample_above) >> + upsample_above; + + if (base_min_diff <= 0) { + uint8x8_t a0_x_u8, a1_x_u8; + uint16x4_t shift0; + dr_prediction_z2_Nx4_above_neon(above, upsample_above, dx, base_x, y, + &a0_x_u8, &a1_x_u8, &shift0); + uint8x8_t a0_x = a0_x_u8; + uint8x8_t a1_x = a1_x_u8; + + uint16x8_t diff = vsubl_u8(a1_x, a0_x); // a[x+1] - a[x] + uint16x8_t a32 = + vmlal_u8(vdupq_n_u16(16), a0_x, vdup_n_u8(32)); // a[x] * 32 + 16 + uint16x8_t res = + vmlaq_u16(a32, diff, vcombine_u16(shift0, vdup_n_u16(0))); + uint8x8_t resx = vshrn_n_u16(res, 5); + vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resx), 0); + } else if (base_min_diff < 4) { + uint8x8_t a0_x_u8, a1_x_u8; + uint16x4_t shift0; + dr_prediction_z2_Nx4_above_neon(above, upsample_above, dx, base_x, y, + &a0_x_u8, &a1_x_u8, &shift0); + uint16x8_t a0_x = vmovl_u8(a0_x_u8); + uint16x8_t a1_x = vmovl_u8(a1_x_u8); + + uint16x4_t a0_y; + uint16x4_t a1_y; + uint16x4_t shift1; + dr_prediction_z2_Nx4_left_neon(LEFT, upsample_left, dy, r, min_base_y, + frac_bits_y, &a0_y, &a1_y, &shift1); + a0_x = vcombine_u16(vget_low_u16(a0_x), a0_y); + a1_x = vcombine_u16(vget_low_u16(a1_x), a1_y); + + uint16x8_t shift = vcombine_u16(shift0, shift1); + uint16x8_t diff = vsubq_u16(a1_x, a0_x); // a[x+1] - a[x] + uint16x8_t a32 = + vmlaq_n_u16(vdupq_n_u16(16), a0_x, 32); // a[x] * 32 + 16 + uint16x8_t res = vmlaq_u16(a32, diff, shift); + uint8x8_t resx = vshrn_n_u16(res, 5); + uint8x8_t resy = vext_u8(resx, vdup_n_u8(0), 4); + + uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]); + uint8x8_t v_resxy = vbsl_u8(mask, resy, resx); + vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(v_resxy), 0); + } else { + uint16x4_t a0_y, a1_y; + uint16x4_t shift1; + dr_prediction_z2_Nx4_left_neon(LEFT, upsample_left, dy, r, min_base_y, + frac_bits_y, &a0_y, &a1_y, &shift1); + uint16x4_t diff = vsub_u16(a1_y, a0_y); // a[x+1] - a[x] + uint16x4_t a32 = vmla_n_u16(vdup_n_u16(16), a0_y, 32); // a[x] * 32 + 16 + uint16x4_t res = vmla_u16(a32, diff, shift1); + uint8x8_t resy = vshrn_n_u16(vcombine_u16(res, vdup_n_u16(0)), 5); + + vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(resy), 0); + } + + dst += stride; + } +#undef LEFT +} + +static void dr_prediction_z2_Nx8_neon(int N, uint8_t *dst, ptrdiff_t stride, + const uint8_t *above, const uint8_t *left, + int upsample_above, int upsample_left, + int dx, int dy) { + const int min_base_x = -(1 << upsample_above); + const int min_base_y = -(1 << upsample_left); + const int frac_bits_x = 6 - upsample_above; + const int frac_bits_y = 6 - upsample_left; + + // pre-filter above pixels + // store in temp buffers: + // above[x] * 32 + 16 + // above[x+1] - above[x] + // final pixels will be calculated as: + // (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5 + +#if AOM_ARCH_AARCH64 + // Use ext rather than loading left + 30 directly to avoid over-read. + const uint8x16_t left_m2 = vld1q_u8(left - 2); + const uint8x16_t left_0 = vld1q_u8(left + 0); + const uint8x16_t left_16 = vld1q_u8(left + 16); + const uint8x16_t left_14 = vextq_u8(left_0, left_16, 14); + const uint8x16_t left_30 = vextq_u8(left_16, left_16, 14); + const uint8x16x3_t left_vals = { { left_m2, left_14, left_30 } }; +#define LEFT left_vals +#else // !AOM_ARCH_AARCH64 +#define LEFT left +#endif // AOM_ARCH_AARCH64 + + for (int r = 0; r < N; r++) { + int y = r + 1; + int base_x = (-y * dx) >> frac_bits_x; + int base_min_diff = + (min_base_x - base_x + upsample_above) >> upsample_above; + + if (base_min_diff <= 0) { + uint8x8_t resx = + dr_prediction_z2_Nx8_above_neon(above, upsample_above, dx, base_x, y); + vst1_u8(dst, resx); + } else if (base_min_diff < 8) { + uint8x8_t resx = + dr_prediction_z2_Nx8_above_neon(above, upsample_above, dx, base_x, y); + uint8x8_t resy = dr_prediction_z2_Nx8_left_neon( + LEFT, upsample_left, dy, r, min_base_y, frac_bits_y); + uint8x8_t mask = vld1_u8(BaseMask[base_min_diff]); + uint8x8_t resxy = vbsl_u8(mask, resy, resx); + vst1_u8(dst, resxy); + } else { + uint8x8_t resy = dr_prediction_z2_Nx8_left_neon( + LEFT, upsample_left, dy, r, min_base_y, frac_bits_y); + vst1_u8(dst, resy); + } + + dst += stride; + } +#undef LEFT +} + +static void dr_prediction_z2_HxW_neon(int H, int W, uint8_t *dst, + ptrdiff_t stride, const uint8_t *above, + const uint8_t *left, int dx, int dy) { + // here upsample_above and upsample_left are 0 by design of + // av1_use_intra_edge_upsample + const int min_base_x = -1; + +#if AOM_ARCH_AARCH64 + const uint8x16_t left_m1 = vld1q_u8(left - 1); + const uint8x16_t left_0 = vld1q_u8(left + 0); + const uint8x16_t left_16 = vld1q_u8(left + 16); + const uint8x16_t left_32 = vld1q_u8(left + 32); + const uint8x16_t left_48 = vld1q_u8(left + 48); + const uint8x16_t left_15 = vextq_u8(left_0, left_16, 15); + const uint8x16_t left_31 = vextq_u8(left_16, left_32, 15); + const uint8x16_t left_47 = vextq_u8(left_32, left_48, 15); + const uint8x16x4_t left_vals0 = { { left_m1, left_15, left_31, left_47 } }; + const uint8x16x4_t left_vals1 = { { left_0, left_16, left_32, left_48 } }; +#define LEFT left_vals0, left_vals1 +#else // !AOM_ARCH_AARCH64 +#define LEFT left +#endif // AOM_ARCH_AARCH64 + + for (int r = 0; r < H; r++) { + int y = r + 1; + int base_x = (-y * dx) >> 6; + for (int j = 0; j < W; j += 16) { + const int base_min_diff = min_base_x - base_x - j; + + if (base_min_diff <= 0) { + uint8x16_t resx = + dr_prediction_z2_NxW_above_neon(above, dx, base_x, y, j); + vst1q_u8(dst + j, resx); + } else if (base_min_diff < 16) { + uint8x16_t resx = + dr_prediction_z2_NxW_above_neon(above, dx, base_x, y, j); + uint8x16_t resy = dr_prediction_z2_NxW_left_neon(LEFT, dy, r, j); + uint8x16_t mask = vld1q_u8(BaseMask[base_min_diff]); + uint8x16_t resxy = vbslq_u8(mask, resy, resx); + vst1q_u8(dst + j, resxy); + } else { + uint8x16_t resy = dr_prediction_z2_NxW_left_neon(LEFT, dy, r, j); + vst1q_u8(dst + j, resy); + } + } // for j + dst += stride; + } +#undef LEFT +} + +// Directional prediction, zone 2: 90 < angle < 180 +void av1_dr_prediction_z2_neon(uint8_t *dst, ptrdiff_t stride, int bw, int bh, + const uint8_t *above, const uint8_t *left, + int upsample_above, int upsample_left, int dx, + int dy) { + assert(dx > 0); + assert(dy > 0); + + switch (bw) { + case 4: + dr_prediction_z2_Nx4_neon(bh, dst, stride, above, left, upsample_above, + upsample_left, dx, dy); + break; + case 8: + dr_prediction_z2_Nx8_neon(bh, dst, stride, above, left, upsample_above, + upsample_left, dx, dy); + break; + default: + dr_prediction_z2_HxW_neon(bh, bw, dst, stride, above, left, dx, dy); + break; + } +} + +/* ---------------------P R E D I C T I O N Z 3--------------------------- */ + +static AOM_FORCE_INLINE void z3_transpose_arrays_u8_16x4(const uint8x16_t *x, + uint8x16x2_t *d) { + uint8x16x2_t w0 = vzipq_u8(x[0], x[1]); + uint8x16x2_t w1 = vzipq_u8(x[2], x[3]); + + d[0] = aom_reinterpretq_u8_u16_x2(vzipq_u16(vreinterpretq_u16_u8(w0.val[0]), + vreinterpretq_u16_u8(w1.val[0]))); + d[1] = aom_reinterpretq_u8_u16_x2(vzipq_u16(vreinterpretq_u16_u8(w0.val[1]), + vreinterpretq_u16_u8(w1.val[1]))); +} + +static AOM_FORCE_INLINE void z3_transpose_arrays_u8_4x4(const uint8x8_t *x, + uint8x8x2_t *d) { + uint8x8x2_t w0 = vzip_u8(x[0], x[1]); + uint8x8x2_t w1 = vzip_u8(x[2], x[3]); + + *d = aom_reinterpret_u8_u16_x2( + vzip_u16(vreinterpret_u16_u8(w0.val[0]), vreinterpret_u16_u8(w1.val[0]))); +} + +static AOM_FORCE_INLINE void z3_transpose_arrays_u8_8x4(const uint8x8_t *x, + uint8x8x2_t *d) { + uint8x8x2_t w0 = vzip_u8(x[0], x[1]); + uint8x8x2_t w1 = vzip_u8(x[2], x[3]); + + d[0] = aom_reinterpret_u8_u16_x2( + vzip_u16(vreinterpret_u16_u8(w0.val[0]), vreinterpret_u16_u8(w1.val[0]))); + d[1] = aom_reinterpret_u8_u16_x2( + vzip_u16(vreinterpret_u16_u8(w0.val[1]), vreinterpret_u16_u8(w1.val[1]))); +} + +static void z3_transpose_arrays_u8_16x16(const uint8_t *src, ptrdiff_t pitchSrc, + uint8_t *dst, ptrdiff_t pitchDst) { + // The same as the normal transposes in transpose_neon.h, but with a stride + // between consecutive vectors of elements. + uint8x16_t r[16]; + uint8x16_t d[16]; + for (int i = 0; i < 16; i++) { + r[i] = vld1q_u8(src + i * pitchSrc); + } + transpose_arrays_u8_16x16(r, d); + for (int i = 0; i < 16; i++) { + vst1q_u8(dst + i * pitchDst, d[i]); + } +} + +static void z3_transpose_arrays_u8_16nx16n(const uint8_t *src, + ptrdiff_t pitchSrc, uint8_t *dst, + ptrdiff_t pitchDst, int width, + int height) { + for (int j = 0; j < height; j += 16) { + for (int i = 0; i < width; i += 16) { + z3_transpose_arrays_u8_16x16(src + i * pitchSrc + j, pitchSrc, + dst + j * pitchDst + i, pitchDst); + } + } +} + +static void dr_prediction_z3_4x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[4]; + uint8x8x2_t dest; + + dr_prediction_z1_HxW_internal_neon_64(4, 4, dstvec, left, upsample_left, dy); + z3_transpose_arrays_u8_4x4(dstvec, &dest); + store_u8x4_strided_x2(dst + stride * 0, stride, dest.val[0]); + store_u8x4_strided_x2(dst + stride * 2, stride, dest.val[1]); +} + +static void dr_prediction_z3_8x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[8]; + uint8x8_t d[8]; + + dr_prediction_z1_HxW_internal_neon_64(8, 8, dstvec, left, upsample_left, dy); + transpose_arrays_u8_8x8(dstvec, d); + store_u8_8x8(dst, stride, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7]); +} + +static void dr_prediction_z3_4x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[4]; + uint8x8x2_t d[2]; + + dr_prediction_z1_HxW_internal_neon_64(8, 4, dstvec, left, upsample_left, dy); + z3_transpose_arrays_u8_8x4(dstvec, d); + store_u8x4_strided_x2(dst + stride * 0, stride, d[0].val[0]); + store_u8x4_strided_x2(dst + stride * 2, stride, d[0].val[1]); + store_u8x4_strided_x2(dst + stride * 4, stride, d[1].val[0]); + store_u8x4_strided_x2(dst + stride * 6, stride, d[1].val[1]); +} + +static void dr_prediction_z3_8x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[8]; + uint8x8_t d[8]; + + dr_prediction_z1_HxW_internal_neon_64(4, 8, dstvec, left, upsample_left, dy); + transpose_arrays_u8_8x8(dstvec, d); + store_u8_8x4(dst, stride, d[0], d[1], d[2], d[3]); +} + +static void dr_prediction_z3_8x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x16_t dstvec[8]; + uint8x8_t d[16]; + + dr_prediction_z1_HxW_internal_neon(16, 8, dstvec, left, upsample_left, dy); + transpose_arrays_u8_16x8(dstvec, d); + for (int i = 0; i < 16; i++) { + vst1_u8(dst + i * stride, d[i]); + } +} + +static void dr_prediction_z3_16x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[16]; + uint8x16_t d[8]; + + dr_prediction_z1_HxW_internal_neon_64(8, 16, dstvec, left, upsample_left, dy); + transpose_arrays_u8_8x16(dstvec, d); + for (int i = 0; i < 8; i++) { + vst1q_u8(dst + i * stride, d[i]); + } +} + +static void dr_prediction_z3_4x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x16_t dstvec[4]; + uint8x16x2_t d[2]; + + dr_prediction_z1_HxW_internal_neon(16, 4, dstvec, left, upsample_left, dy); + z3_transpose_arrays_u8_16x4(dstvec, d); + store_u8x4_strided_x4(dst + stride * 0, stride, d[0].val[0]); + store_u8x4_strided_x4(dst + stride * 4, stride, d[0].val[1]); + store_u8x4_strided_x4(dst + stride * 8, stride, d[1].val[0]); + store_u8x4_strided_x4(dst + stride * 12, stride, d[1].val[1]); +} + +static void dr_prediction_z3_16x4_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[16]; + uint8x16_t d[8]; + + dr_prediction_z1_HxW_internal_neon_64(4, 16, dstvec, left, upsample_left, dy); + transpose_arrays_u8_8x16(dstvec, d); + for (int i = 0; i < 4; i++) { + vst1q_u8(dst + i * stride, d[i]); + } +} + +static void dr_prediction_z3_8x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + uint8x16x2_t dstvec[16]; + uint8x16_t d[32]; + uint8x16_t v_zero = vdupq_n_u8(0); + + dr_prediction_z1_32xN_internal_neon(8, dstvec, left, dy); + for (int i = 8; i < 16; i++) { + dstvec[i].val[0] = v_zero; + dstvec[i].val[1] = v_zero; + } + transpose_arrays_u8_32x16(dstvec, d); + for (int i = 0; i < 32; i++) { + vst1_u8(dst + i * stride, vget_low_u8(d[i])); + } +} + +static void dr_prediction_z3_32x8_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x8_t dstvec[32]; + uint8x16_t d[16]; + + dr_prediction_z1_HxW_internal_neon_64(8, 32, dstvec, left, upsample_left, dy); + transpose_arrays_u8_8x16(dstvec, d); + transpose_arrays_u8_8x16(dstvec + 16, d + 8); + for (int i = 0; i < 8; i++) { + vst1q_u8(dst + i * stride, d[i]); + vst1q_u8(dst + i * stride + 16, d[i + 8]); + } +} + +static void dr_prediction_z3_16x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x16_t dstvec[16]; + uint8x16_t d[16]; + + dr_prediction_z1_HxW_internal_neon(16, 16, dstvec, left, upsample_left, dy); + transpose_arrays_u8_16x16(dstvec, d); + for (int i = 0; i < 16; i++) { + vst1q_u8(dst + i * stride, d[i]); + } +} + +static void dr_prediction_z3_32x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + uint8x16x2_t dstvec[32]; + uint8x16_t d[64]; + + dr_prediction_z1_32xN_internal_neon(32, dstvec, left, dy); + transpose_arrays_u8_32x16(dstvec, d); + transpose_arrays_u8_32x16(dstvec + 16, d + 32); + for (int i = 0; i < 32; i++) { + vst1q_u8(dst + i * stride, d[i]); + vst1q_u8(dst + i * stride + 16, d[i + 32]); + } +} + +static void dr_prediction_z3_64x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + DECLARE_ALIGNED(16, uint8_t, dstT[64 * 64]); + + dr_prediction_z1_64xN_neon(64, dstT, 64, left, dy); + z3_transpose_arrays_u8_16nx16n(dstT, 64, dst, stride, 64, 64); +} + +static void dr_prediction_z3_16x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + uint8x16x2_t dstvec[16]; + uint8x16_t d[32]; + + dr_prediction_z1_32xN_internal_neon(16, dstvec, left, dy); + transpose_arrays_u8_32x16(dstvec, d); + for (int i = 0; i < 16; i++) { + vst1q_u8(dst + 2 * i * stride, d[2 * i + 0]); + vst1q_u8(dst + (2 * i + 1) * stride, d[2 * i + 1]); + } +} + +static void dr_prediction_z3_32x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x16_t dstvec[32]; + + dr_prediction_z1_HxW_internal_neon(16, 32, dstvec, left, upsample_left, dy); + for (int i = 0; i < 32; i += 16) { + uint8x16_t d[16]; + transpose_arrays_u8_16x16(dstvec + i, d); + for (int j = 0; j < 16; j++) { + vst1q_u8(dst + j * stride + i, d[j]); + } + } +} + +static void dr_prediction_z3_32x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + uint8_t dstT[64 * 32]; + + dr_prediction_z1_64xN_neon(32, dstT, 64, left, dy); + z3_transpose_arrays_u8_16nx16n(dstT, 64, dst, stride, 32, 64); +} + +static void dr_prediction_z3_64x32_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + uint8_t dstT[32 * 64]; + + dr_prediction_z1_32xN_neon(64, dstT, 32, left, dy); + z3_transpose_arrays_u8_16nx16n(dstT, 32, dst, stride, 64, 32); +} + +static void dr_prediction_z3_16x64_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + (void)upsample_left; + uint8_t dstT[64 * 16]; + + dr_prediction_z1_64xN_neon(16, dstT, 64, left, dy); + z3_transpose_arrays_u8_16nx16n(dstT, 64, dst, stride, 16, 64); +} + +static void dr_prediction_z3_64x16_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy) { + uint8x16_t dstvec[64]; + + dr_prediction_z1_HxW_internal_neon(16, 64, dstvec, left, upsample_left, dy); + for (int i = 0; i < 64; i += 16) { + uint8x16_t d[16]; + transpose_arrays_u8_16x16(dstvec + i, d); + for (int j = 0; j < 16; ++j) { + vst1q_u8(dst + j * stride + i, d[j]); + } + } +} + +typedef void (*dr_prediction_z3_fn)(uint8_t *dst, ptrdiff_t stride, + const uint8_t *left, int upsample_left, + int dy); + +static dr_prediction_z3_fn dr_prediction_z3_arr[7][7] = { + { NULL, NULL, NULL, NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL, NULL, NULL, NULL }, + { NULL, NULL, dr_prediction_z3_4x4_neon, dr_prediction_z3_4x8_neon, + dr_prediction_z3_4x16_neon, NULL, NULL }, + { NULL, NULL, dr_prediction_z3_8x4_neon, dr_prediction_z3_8x8_neon, + dr_prediction_z3_8x16_neon, dr_prediction_z3_8x32_neon, NULL }, + { NULL, NULL, dr_prediction_z3_16x4_neon, dr_prediction_z3_16x8_neon, + dr_prediction_z3_16x16_neon, dr_prediction_z3_16x32_neon, + dr_prediction_z3_16x64_neon }, + { NULL, NULL, NULL, dr_prediction_z3_32x8_neon, dr_prediction_z3_32x16_neon, + dr_prediction_z3_32x32_neon, dr_prediction_z3_32x64_neon }, + { NULL, NULL, NULL, NULL, dr_prediction_z3_64x16_neon, + dr_prediction_z3_64x32_neon, dr_prediction_z3_64x64_neon }, +}; + +void av1_dr_prediction_z3_neon(uint8_t *dst, ptrdiff_t stride, int bw, int bh, + const uint8_t *above, const uint8_t *left, + int upsample_left, int dx, int dy) { + (void)above; + (void)dx; + assert(dx == 1); + assert(dy > 0); + + dr_prediction_z3_fn f = dr_prediction_z3_arr[get_msb(bw)][get_msb(bh)]; + assert(f != NULL); + f(dst, stride, left, upsample_left, dy); +} + +// ----------------------------------------------------------------------------- +// SMOOTH_PRED + +// 256 - v = vneg_s8(v) +static INLINE uint8x8_t negate_s8(const uint8x8_t v) { + return vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(v))); +} + +static void smooth_4xh_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *const top_row, + const uint8_t *const left_column, + const int height) { + const uint8_t top_right = top_row[3]; + const uint8_t bottom_left = left_column[height - 1]; + const uint8_t *const weights_y = smooth_weights + height - 4; + + uint8x8_t top_v = load_u8_4x1(top_row); + const uint8x8_t top_right_v = vdup_n_u8(top_right); + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + uint8x8_t weights_x_v = load_u8_4x1(smooth_weights); + const uint8x8_t scaled_weights_x = negate_s8(weights_x_v); + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); + + assert(height > 0); + int y = 0; + do { + const uint8x8_t left_v = vdup_n_u8(left_column[y]); + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = negate_s8(weights_y_v); + const uint16x8_t weighted_bl = vmull_u8(scaled_weights_y, bottom_left_v); + const uint16x8_t weighted_top_bl = + vmlal_u8(weighted_bl, weights_y_v, top_v); + const uint16x8_t weighted_left_tr = + vmlal_u8(weighted_tr, weights_x_v, left_v); + // Maximum value of each parameter: 0xFF00 + const uint16x8_t avg = vhaddq_u16(weighted_top_bl, weighted_left_tr); + const uint8x8_t result = vrshrn_n_u16(avg, SMOOTH_WEIGHT_LOG2_SCALE); + + vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(result), 0); + dst += stride; + } while (++y != height); +} + +static INLINE uint8x8_t calculate_pred(const uint16x8_t weighted_top_bl, + const uint16x8_t weighted_left_tr) { + // Maximum value of each parameter: 0xFF00 + const uint16x8_t avg = vhaddq_u16(weighted_top_bl, weighted_left_tr); + return vrshrn_n_u16(avg, SMOOTH_WEIGHT_LOG2_SCALE); +} + +static INLINE uint8x8_t calculate_weights_and_pred( + const uint8x8_t top, const uint8x8_t left, const uint16x8_t weighted_tr, + const uint8x8_t bottom_left, const uint8x8_t weights_x, + const uint8x8_t scaled_weights_y, const uint8x8_t weights_y) { + const uint16x8_t weighted_top = vmull_u8(weights_y, top); + const uint16x8_t weighted_top_bl = + vmlal_u8(weighted_top, scaled_weights_y, bottom_left); + const uint16x8_t weighted_left_tr = vmlal_u8(weighted_tr, weights_x, left); + return calculate_pred(weighted_top_bl, weighted_left_tr); +} + +static void smooth_8xh_neon(uint8_t *dst, ptrdiff_t stride, + const uint8_t *const top_row, + const uint8_t *const left_column, + const int height) { + const uint8_t top_right = top_row[7]; + const uint8_t bottom_left = left_column[height - 1]; + const uint8_t *const weights_y = smooth_weights + height - 4; + + const uint8x8_t top_v = vld1_u8(top_row); + const uint8x8_t top_right_v = vdup_n_u8(top_right); + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); + const uint8x8_t weights_x_v = vld1_u8(smooth_weights + 4); + const uint8x8_t scaled_weights_x = negate_s8(weights_x_v); + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); + + assert(height > 0); + int y = 0; + do { + const uint8x8_t left_v = vdup_n_u8(left_column[y]); + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); + const uint8x8_t scaled_weights_y = negate_s8(weights_y_v); + const uint8x8_t result = + calculate_weights_and_pred(top_v, left_v, weighted_tr, bottom_left_v, + weights_x_v, scaled_weights_y, weights_y_v); + + vst1_u8(dst, result); + dst += stride; + } while (++y != height); +} + +#define SMOOTH_NXM(W, H) \ + void aom_smooth_predictor_##W##x##H##_neon(uint8_t *dst, ptrdiff_t y_stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + smooth_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +SMOOTH_NXM(4, 4) +SMOOTH_NXM(4, 8) +SMOOTH_NXM(8, 4) +SMOOTH_NXM(8, 8) +SMOOTH_NXM(4, 16) +SMOOTH_NXM(8, 16) +SMOOTH_NXM(8, 32) + +#undef SMOOTH_NXM + +static INLINE uint8x16_t calculate_weights_and_predq( + const uint8x16_t top, const uint8x8_t left, const uint8x8_t top_right, + const uint8x8_t weights_y, const uint8x16_t weights_x, + const uint8x16_t scaled_weights_x, const uint16x8_t weighted_bl) { + const uint16x8_t weighted_top_bl_low = + vmlal_u8(weighted_bl, weights_y, vget_low_u8(top)); + const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left); + const uint16x8_t weighted_left_tr_low = + vmlal_u8(weighted_left_low, vget_low_u8(scaled_weights_x), top_right); + const uint8x8_t result_low = + calculate_pred(weighted_top_bl_low, weighted_left_tr_low); + + const uint16x8_t weighted_top_bl_high = + vmlal_u8(weighted_bl, weights_y, vget_high_u8(top)); + const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left); + const uint16x8_t weighted_left_tr_high = + vmlal_u8(weighted_left_high, vget_high_u8(scaled_weights_x), top_right); + const uint8x8_t result_high = + calculate_pred(weighted_top_bl_high, weighted_left_tr_high); + + return vcombine_u8(result_low, result_high); +} + +// 256 - v = vneg_s8(v) +static INLINE uint8x16_t negate_s8q(const uint8x16_t v) { + return vreinterpretq_u8_s8(vnegq_s8(vreinterpretq_s8_u8(v))); +} + +// For width 16 and above. +#define SMOOTH_PREDICTOR(W) \ + static void smooth_##W##xh_neon( \ + uint8_t *dst, ptrdiff_t stride, const uint8_t *const top_row, \ + const uint8_t *const left_column, const int height) { \ + const uint8_t top_right = top_row[(W)-1]; \ + const uint8_t bottom_left = left_column[height - 1]; \ + const uint8_t *const weights_y = smooth_weights + height - 4; \ + \ + uint8x16_t top_v[4]; \ + top_v[0] = vld1q_u8(top_row); \ + if ((W) > 16) { \ + top_v[1] = vld1q_u8(top_row + 16); \ + if ((W) == 64) { \ + top_v[2] = vld1q_u8(top_row + 32); \ + top_v[3] = vld1q_u8(top_row + 48); \ + } \ + } \ + \ + const uint8x8_t top_right_v = vdup_n_u8(top_right); \ + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); \ + \ + uint8x16_t weights_x_v[4]; \ + weights_x_v[0] = vld1q_u8(smooth_weights + (W)-4); \ + if ((W) > 16) { \ + weights_x_v[1] = vld1q_u8(smooth_weights + (W) + 16 - 4); \ + if ((W) == 64) { \ + weights_x_v[2] = vld1q_u8(smooth_weights + (W) + 32 - 4); \ + weights_x_v[3] = vld1q_u8(smooth_weights + (W) + 48 - 4); \ + } \ + } \ + \ + uint8x16_t scaled_weights_x[4]; \ + scaled_weights_x[0] = negate_s8q(weights_x_v[0]); \ + if ((W) > 16) { \ + scaled_weights_x[1] = negate_s8q(weights_x_v[1]); \ + if ((W) == 64) { \ + scaled_weights_x[2] = negate_s8q(weights_x_v[2]); \ + scaled_weights_x[3] = negate_s8q(weights_x_v[3]); \ + } \ + } \ + \ + for (int y = 0; y < height; ++y) { \ + const uint8x8_t left_v = vdup_n_u8(left_column[y]); \ + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); \ + const uint8x8_t scaled_weights_y = negate_s8(weights_y_v); \ + const uint16x8_t weighted_bl = \ + vmull_u8(scaled_weights_y, bottom_left_v); \ + \ + vst1q_u8(dst, calculate_weights_and_predq( \ + top_v[0], left_v, top_right_v, weights_y_v, \ + weights_x_v[0], scaled_weights_x[0], weighted_bl)); \ + \ + if ((W) > 16) { \ + vst1q_u8(dst + 16, \ + calculate_weights_and_predq( \ + top_v[1], left_v, top_right_v, weights_y_v, \ + weights_x_v[1], scaled_weights_x[1], weighted_bl)); \ + if ((W) == 64) { \ + vst1q_u8(dst + 32, \ + calculate_weights_and_predq( \ + top_v[2], left_v, top_right_v, weights_y_v, \ + weights_x_v[2], scaled_weights_x[2], weighted_bl)); \ + vst1q_u8(dst + 48, \ + calculate_weights_and_predq( \ + top_v[3], left_v, top_right_v, weights_y_v, \ + weights_x_v[3], scaled_weights_x[3], weighted_bl)); \ + } \ + } \ + \ + dst += stride; \ + } \ + } + +SMOOTH_PREDICTOR(16) +SMOOTH_PREDICTOR(32) +SMOOTH_PREDICTOR(64) + +#undef SMOOTH_PREDICTOR + +#define SMOOTH_NXM_WIDE(W, H) \ + void aom_smooth_predictor_##W##x##H##_neon(uint8_t *dst, ptrdiff_t y_stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + smooth_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +SMOOTH_NXM_WIDE(16, 4) +SMOOTH_NXM_WIDE(16, 8) +SMOOTH_NXM_WIDE(16, 16) +SMOOTH_NXM_WIDE(16, 32) +SMOOTH_NXM_WIDE(16, 64) +SMOOTH_NXM_WIDE(32, 8) +SMOOTH_NXM_WIDE(32, 16) +SMOOTH_NXM_WIDE(32, 32) +SMOOTH_NXM_WIDE(32, 64) +SMOOTH_NXM_WIDE(64, 16) +SMOOTH_NXM_WIDE(64, 32) +SMOOTH_NXM_WIDE(64, 64) + +#undef SMOOTH_NXM_WIDE + +// ----------------------------------------------------------------------------- +// SMOOTH_V_PRED + +// For widths 4 and 8. +#define SMOOTH_V_PREDICTOR(W) \ + static void smooth_v_##W##xh_neon( \ + uint8_t *dst, ptrdiff_t stride, const uint8_t *const top_row, \ + const uint8_t *const left_column, const int height) { \ + const uint8_t bottom_left = left_column[height - 1]; \ + const uint8_t *const weights_y = smooth_weights + height - 4; \ + \ + uint8x8_t top_v; \ + if ((W) == 4) { \ + top_v = load_u8_4x1(top_row); \ + } else { /* width == 8 */ \ + top_v = vld1_u8(top_row); \ + } \ + \ + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); \ + \ + assert(height > 0); \ + int y = 0; \ + do { \ + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); \ + const uint8x8_t scaled_weights_y = negate_s8(weights_y_v); \ + \ + const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v); \ + const uint16x8_t weighted_top_bl = \ + vmlal_u8(weighted_top, scaled_weights_y, bottom_left_v); \ + const uint8x8_t pred = \ + vrshrn_n_u16(weighted_top_bl, SMOOTH_WEIGHT_LOG2_SCALE); \ + \ + if ((W) == 4) { \ + vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(pred), 0); \ + } else { /* width == 8 */ \ + vst1_u8(dst, pred); \ + } \ + dst += stride; \ + } while (++y != height); \ + } + +SMOOTH_V_PREDICTOR(4) +SMOOTH_V_PREDICTOR(8) + +#undef SMOOTH_V_PREDICTOR + +#define SMOOTH_V_NXM(W, H) \ + void aom_smooth_v_predictor_##W##x##H##_neon( \ + uint8_t *dst, ptrdiff_t y_stride, const uint8_t *above, \ + const uint8_t *left) { \ + smooth_v_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +SMOOTH_V_NXM(4, 4) +SMOOTH_V_NXM(4, 8) +SMOOTH_V_NXM(4, 16) +SMOOTH_V_NXM(8, 4) +SMOOTH_V_NXM(8, 8) +SMOOTH_V_NXM(8, 16) +SMOOTH_V_NXM(8, 32) + +#undef SMOOTH_V_NXM + +static INLINE uint8x16_t calculate_vertical_weights_and_pred( + const uint8x16_t top, const uint8x8_t weights_y, + const uint16x8_t weighted_bl) { + const uint16x8_t pred_low = + vmlal_u8(weighted_bl, weights_y, vget_low_u8(top)); + const uint16x8_t pred_high = + vmlal_u8(weighted_bl, weights_y, vget_high_u8(top)); + const uint8x8_t pred_scaled_low = + vrshrn_n_u16(pred_low, SMOOTH_WEIGHT_LOG2_SCALE); + const uint8x8_t pred_scaled_high = + vrshrn_n_u16(pred_high, SMOOTH_WEIGHT_LOG2_SCALE); + return vcombine_u8(pred_scaled_low, pred_scaled_high); +} + +// For width 16 and above. +#define SMOOTH_V_PREDICTOR(W) \ + static void smooth_v_##W##xh_neon( \ + uint8_t *dst, ptrdiff_t stride, const uint8_t *const top_row, \ + const uint8_t *const left_column, const int height) { \ + const uint8_t bottom_left = left_column[height - 1]; \ + const uint8_t *const weights_y = smooth_weights + height - 4; \ + \ + uint8x16_t top_v[4]; \ + top_v[0] = vld1q_u8(top_row); \ + if ((W) > 16) { \ + top_v[1] = vld1q_u8(top_row + 16); \ + if ((W) == 64) { \ + top_v[2] = vld1q_u8(top_row + 32); \ + top_v[3] = vld1q_u8(top_row + 48); \ + } \ + } \ + \ + const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left); \ + \ + assert(height > 0); \ + int y = 0; \ + do { \ + const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]); \ + const uint8x8_t scaled_weights_y = negate_s8(weights_y_v); \ + const uint16x8_t weighted_bl = \ + vmull_u8(scaled_weights_y, bottom_left_v); \ + \ + const uint8x16_t pred_0 = calculate_vertical_weights_and_pred( \ + top_v[0], weights_y_v, weighted_bl); \ + vst1q_u8(dst, pred_0); \ + \ + if ((W) > 16) { \ + const uint8x16_t pred_1 = calculate_vertical_weights_and_pred( \ + top_v[1], weights_y_v, weighted_bl); \ + vst1q_u8(dst + 16, pred_1); \ + \ + if ((W) == 64) { \ + const uint8x16_t pred_2 = calculate_vertical_weights_and_pred( \ + top_v[2], weights_y_v, weighted_bl); \ + vst1q_u8(dst + 32, pred_2); \ + \ + const uint8x16_t pred_3 = calculate_vertical_weights_and_pred( \ + top_v[3], weights_y_v, weighted_bl); \ + vst1q_u8(dst + 48, pred_3); \ + } \ + } \ + \ + dst += stride; \ + } while (++y != height); \ + } + +SMOOTH_V_PREDICTOR(16) +SMOOTH_V_PREDICTOR(32) +SMOOTH_V_PREDICTOR(64) + +#undef SMOOTH_V_PREDICTOR + +#define SMOOTH_V_NXM_WIDE(W, H) \ + void aom_smooth_v_predictor_##W##x##H##_neon( \ + uint8_t *dst, ptrdiff_t y_stride, const uint8_t *above, \ + const uint8_t *left) { \ + smooth_v_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +SMOOTH_V_NXM_WIDE(16, 4) +SMOOTH_V_NXM_WIDE(16, 8) +SMOOTH_V_NXM_WIDE(16, 16) +SMOOTH_V_NXM_WIDE(16, 32) +SMOOTH_V_NXM_WIDE(16, 64) +SMOOTH_V_NXM_WIDE(32, 8) +SMOOTH_V_NXM_WIDE(32, 16) +SMOOTH_V_NXM_WIDE(32, 32) +SMOOTH_V_NXM_WIDE(32, 64) +SMOOTH_V_NXM_WIDE(64, 16) +SMOOTH_V_NXM_WIDE(64, 32) +SMOOTH_V_NXM_WIDE(64, 64) + +#undef SMOOTH_V_NXM_WIDE + +// ----------------------------------------------------------------------------- +// SMOOTH_H_PRED + +// For widths 4 and 8. +#define SMOOTH_H_PREDICTOR(W) \ + static void smooth_h_##W##xh_neon( \ + uint8_t *dst, ptrdiff_t stride, const uint8_t *const top_row, \ + const uint8_t *const left_column, const int height) { \ + const uint8_t top_right = top_row[(W)-1]; \ + \ + const uint8x8_t top_right_v = vdup_n_u8(top_right); \ + /* Over-reads for 4xN but still within the array. */ \ + const uint8x8_t weights_x = vld1_u8(smooth_weights + (W)-4); \ + const uint8x8_t scaled_weights_x = negate_s8(weights_x); \ + const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v); \ + \ + assert(height > 0); \ + int y = 0; \ + do { \ + const uint8x8_t left_v = vdup_n_u8(left_column[y]); \ + const uint16x8_t weighted_left_tr = \ + vmlal_u8(weighted_tr, weights_x, left_v); \ + const uint8x8_t pred = \ + vrshrn_n_u16(weighted_left_tr, SMOOTH_WEIGHT_LOG2_SCALE); \ + \ + if ((W) == 4) { \ + vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(pred), 0); \ + } else { /* width == 8 */ \ + vst1_u8(dst, pred); \ + } \ + dst += stride; \ + } while (++y != height); \ + } + +SMOOTH_H_PREDICTOR(4) +SMOOTH_H_PREDICTOR(8) + +#undef SMOOTH_H_PREDICTOR + +#define SMOOTH_H_NXM(W, H) \ + void aom_smooth_h_predictor_##W##x##H##_neon( \ + uint8_t *dst, ptrdiff_t y_stride, const uint8_t *above, \ + const uint8_t *left) { \ + smooth_h_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +SMOOTH_H_NXM(4, 4) +SMOOTH_H_NXM(4, 8) +SMOOTH_H_NXM(4, 16) +SMOOTH_H_NXM(8, 4) +SMOOTH_H_NXM(8, 8) +SMOOTH_H_NXM(8, 16) +SMOOTH_H_NXM(8, 32) + +#undef SMOOTH_H_NXM + +static INLINE uint8x16_t calculate_horizontal_weights_and_pred( + const uint8x8_t left, const uint8x8_t top_right, const uint8x16_t weights_x, + const uint8x16_t scaled_weights_x) { + const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left); + const uint16x8_t weighted_left_tr_low = + vmlal_u8(weighted_left_low, vget_low_u8(scaled_weights_x), top_right); + const uint8x8_t pred_scaled_low = + vrshrn_n_u16(weighted_left_tr_low, SMOOTH_WEIGHT_LOG2_SCALE); + + const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left); + const uint16x8_t weighted_left_tr_high = + vmlal_u8(weighted_left_high, vget_high_u8(scaled_weights_x), top_right); + const uint8x8_t pred_scaled_high = + vrshrn_n_u16(weighted_left_tr_high, SMOOTH_WEIGHT_LOG2_SCALE); + + return vcombine_u8(pred_scaled_low, pred_scaled_high); +} + +// For width 16 and above. +#define SMOOTH_H_PREDICTOR(W) \ + static void smooth_h_##W##xh_neon( \ + uint8_t *dst, ptrdiff_t stride, const uint8_t *const top_row, \ + const uint8_t *const left_column, const int height) { \ + const uint8_t top_right = top_row[(W)-1]; \ + \ + const uint8x8_t top_right_v = vdup_n_u8(top_right); \ + \ + uint8x16_t weights_x[4]; \ + weights_x[0] = vld1q_u8(smooth_weights + (W)-4); \ + if ((W) > 16) { \ + weights_x[1] = vld1q_u8(smooth_weights + (W) + 16 - 4); \ + if ((W) == 64) { \ + weights_x[2] = vld1q_u8(smooth_weights + (W) + 32 - 4); \ + weights_x[3] = vld1q_u8(smooth_weights + (W) + 48 - 4); \ + } \ + } \ + \ + uint8x16_t scaled_weights_x[4]; \ + scaled_weights_x[0] = negate_s8q(weights_x[0]); \ + if ((W) > 16) { \ + scaled_weights_x[1] = negate_s8q(weights_x[1]); \ + if ((W) == 64) { \ + scaled_weights_x[2] = negate_s8q(weights_x[2]); \ + scaled_weights_x[3] = negate_s8q(weights_x[3]); \ + } \ + } \ + \ + assert(height > 0); \ + int y = 0; \ + do { \ + const uint8x8_t left_v = vdup_n_u8(left_column[y]); \ + \ + const uint8x16_t pred_0 = calculate_horizontal_weights_and_pred( \ + left_v, top_right_v, weights_x[0], scaled_weights_x[0]); \ + vst1q_u8(dst, pred_0); \ + \ + if ((W) > 16) { \ + const uint8x16_t pred_1 = calculate_horizontal_weights_and_pred( \ + left_v, top_right_v, weights_x[1], scaled_weights_x[1]); \ + vst1q_u8(dst + 16, pred_1); \ + \ + if ((W) == 64) { \ + const uint8x16_t pred_2 = calculate_horizontal_weights_and_pred( \ + left_v, top_right_v, weights_x[2], scaled_weights_x[2]); \ + vst1q_u8(dst + 32, pred_2); \ + \ + const uint8x16_t pred_3 = calculate_horizontal_weights_and_pred( \ + left_v, top_right_v, weights_x[3], scaled_weights_x[3]); \ + vst1q_u8(dst + 48, pred_3); \ + } \ + } \ + dst += stride; \ + } while (++y != height); \ + } + +SMOOTH_H_PREDICTOR(16) +SMOOTH_H_PREDICTOR(32) +SMOOTH_H_PREDICTOR(64) + +#undef SMOOTH_H_PREDICTOR + +#define SMOOTH_H_NXM_WIDE(W, H) \ + void aom_smooth_h_predictor_##W##x##H##_neon( \ + uint8_t *dst, ptrdiff_t y_stride, const uint8_t *above, \ + const uint8_t *left) { \ + smooth_h_##W##xh_neon(dst, y_stride, above, left, H); \ + } + +SMOOTH_H_NXM_WIDE(16, 4) +SMOOTH_H_NXM_WIDE(16, 8) +SMOOTH_H_NXM_WIDE(16, 16) +SMOOTH_H_NXM_WIDE(16, 32) +SMOOTH_H_NXM_WIDE(16, 64) +SMOOTH_H_NXM_WIDE(32, 8) +SMOOTH_H_NXM_WIDE(32, 16) +SMOOTH_H_NXM_WIDE(32, 32) +SMOOTH_H_NXM_WIDE(32, 64) +SMOOTH_H_NXM_WIDE(64, 16) +SMOOTH_H_NXM_WIDE(64, 32) +SMOOTH_H_NXM_WIDE(64, 64) + +#undef SMOOTH_H_NXM_WIDE + +// ----------------------------------------------------------------------------- +// PAETH + +static INLINE void paeth_4or8_x_h_neon(uint8_t *dest, ptrdiff_t stride, + const uint8_t *const top_row, + const uint8_t *const left_column, + int width, int height) { + const uint8x8_t top_left = vdup_n_u8(top_row[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]); + uint8x8_t top; + if (width == 4) { + top = load_u8_4x1(top_row); + } else { // width == 8 + top = vld1_u8(top_row); + } + + assert(height > 0); + int y = 0; + do { + const uint8x8_t left = vdup_n_u8(left_column[y]); + + const uint8x8_t left_dist = vabd_u8(top, top_left); + const uint8x8_t top_dist = vabd_u8(left, top_left); + const uint16x8_t top_left_dist = + vabdq_u16(vaddl_u8(top, left), top_left_x2); + + const uint8x8_t left_le_top = vcle_u8(left_dist, top_dist); + const uint8x8_t left_le_top_left = + vmovn_u16(vcleq_u16(vmovl_u8(left_dist), top_left_dist)); + const uint8x8_t top_le_top_left = + vmovn_u16(vcleq_u16(vmovl_u8(top_dist), top_left_dist)); + + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint8x8_t left_mask = vand_u8(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint8x8_t result = vbsl_u8(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint8x8_t left_or_top_mask = vorr_u8(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + result = vbsl_u8(left_or_top_mask, result, top_left); + + if (width == 4) { + store_u8_4x1(dest, result); + } else { // width == 8 + vst1_u8(dest, result); + } + dest += stride; + } while (++y != height); +} + +#define PAETH_NXM(W, H) \ + void aom_paeth_predictor_##W##x##H##_neon(uint8_t *dst, ptrdiff_t stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + paeth_4or8_x_h_neon(dst, stride, above, left, W, H); \ + } + +PAETH_NXM(4, 4) +PAETH_NXM(4, 8) +PAETH_NXM(8, 4) +PAETH_NXM(8, 8) +PAETH_NXM(8, 16) + +PAETH_NXM(4, 16) +PAETH_NXM(8, 32) + +// Calculate X distance <= TopLeft distance and pack the resulting mask into +// uint8x8_t. +static INLINE uint8x16_t x_le_top_left(const uint8x16_t x_dist, + const uint16x8_t top_left_dist_low, + const uint16x8_t top_left_dist_high) { + const uint8x16_t top_left_dist = vcombine_u8(vqmovn_u16(top_left_dist_low), + vqmovn_u16(top_left_dist_high)); + return vcleq_u8(x_dist, top_left_dist); +} + +// Select the closest values and collect them. +static INLINE uint8x16_t select_paeth(const uint8x16_t top, + const uint8x16_t left, + const uint8x16_t top_left, + const uint8x16_t left_le_top, + const uint8x16_t left_le_top_left, + const uint8x16_t top_le_top_left) { + // if (left_dist <= top_dist && left_dist <= top_left_dist) + const uint8x16_t left_mask = vandq_u8(left_le_top, left_le_top_left); + // dest[x] = left_column[y]; + // Fill all the unused spaces with 'top'. They will be overwritten when + // the positions for top_left are known. + uint8x16_t result = vbslq_u8(left_mask, left, top); + // else if (top_dist <= top_left_dist) + // dest[x] = top_row[x]; + // Add these values to the mask. They were already set. + const uint8x16_t left_or_top_mask = vorrq_u8(left_mask, top_le_top_left); + // else + // dest[x] = top_left; + return vbslq_u8(left_or_top_mask, result, top_left); +} + +// Generate numbered and high/low versions of top_left_dist. +#define TOP_LEFT_DIST(num) \ + const uint16x8_t top_left_##num##_dist_low = vabdq_u16( \ + vaddl_u8(vget_low_u8(top[num]), vget_low_u8(left)), top_left_x2); \ + const uint16x8_t top_left_##num##_dist_high = vabdq_u16( \ + vaddl_u8(vget_high_u8(top[num]), vget_low_u8(left)), top_left_x2) + +// Generate numbered versions of XLeTopLeft with x = left. +#define LEFT_LE_TOP_LEFT(num) \ + const uint8x16_t left_le_top_left_##num = \ + x_le_top_left(left_##num##_dist, top_left_##num##_dist_low, \ + top_left_##num##_dist_high) + +// Generate numbered versions of XLeTopLeft with x = top. +#define TOP_LE_TOP_LEFT(num) \ + const uint8x16_t top_le_top_left_##num = x_le_top_left( \ + top_dist, top_left_##num##_dist_low, top_left_##num##_dist_high) + +static INLINE void paeth16_plus_x_h_neon(uint8_t *dest, ptrdiff_t stride, + const uint8_t *const top_row, + const uint8_t *const left_column, + int width, int height) { + const uint8x16_t top_left = vdupq_n_u8(top_row[-1]); + const uint16x8_t top_left_x2 = vdupq_n_u16(top_row[-1] + top_row[-1]); + uint8x16_t top[4]; + top[0] = vld1q_u8(top_row); + if (width > 16) { + top[1] = vld1q_u8(top_row + 16); + if (width == 64) { + top[2] = vld1q_u8(top_row + 32); + top[3] = vld1q_u8(top_row + 48); + } + } + + assert(height > 0); + int y = 0; + do { + const uint8x16_t left = vdupq_n_u8(left_column[y]); + + const uint8x16_t top_dist = vabdq_u8(left, top_left); + + const uint8x16_t left_0_dist = vabdq_u8(top[0], top_left); + TOP_LEFT_DIST(0); + const uint8x16_t left_0_le_top = vcleq_u8(left_0_dist, top_dist); + LEFT_LE_TOP_LEFT(0); + TOP_LE_TOP_LEFT(0); + + const uint8x16_t result_0 = + select_paeth(top[0], left, top_left, left_0_le_top, left_le_top_left_0, + top_le_top_left_0); + vst1q_u8(dest, result_0); + + if (width > 16) { + const uint8x16_t left_1_dist = vabdq_u8(top[1], top_left); + TOP_LEFT_DIST(1); + const uint8x16_t left_1_le_top = vcleq_u8(left_1_dist, top_dist); + LEFT_LE_TOP_LEFT(1); + TOP_LE_TOP_LEFT(1); + + const uint8x16_t result_1 = + select_paeth(top[1], left, top_left, left_1_le_top, + left_le_top_left_1, top_le_top_left_1); + vst1q_u8(dest + 16, result_1); + + if (width == 64) { + const uint8x16_t left_2_dist = vabdq_u8(top[2], top_left); + TOP_LEFT_DIST(2); + const uint8x16_t left_2_le_top = vcleq_u8(left_2_dist, top_dist); + LEFT_LE_TOP_LEFT(2); + TOP_LE_TOP_LEFT(2); + + const uint8x16_t result_2 = + select_paeth(top[2], left, top_left, left_2_le_top, + left_le_top_left_2, top_le_top_left_2); + vst1q_u8(dest + 32, result_2); + + const uint8x16_t left_3_dist = vabdq_u8(top[3], top_left); + TOP_LEFT_DIST(3); + const uint8x16_t left_3_le_top = vcleq_u8(left_3_dist, top_dist); + LEFT_LE_TOP_LEFT(3); + TOP_LE_TOP_LEFT(3); + + const uint8x16_t result_3 = + select_paeth(top[3], left, top_left, left_3_le_top, + left_le_top_left_3, top_le_top_left_3); + vst1q_u8(dest + 48, result_3); + } + } + + dest += stride; + } while (++y != height); +} + +#define PAETH_NXM_WIDE(W, H) \ + void aom_paeth_predictor_##W##x##H##_neon(uint8_t *dst, ptrdiff_t stride, \ + const uint8_t *above, \ + const uint8_t *left) { \ + paeth16_plus_x_h_neon(dst, stride, above, left, W, H); \ + } + +PAETH_NXM_WIDE(16, 8) +PAETH_NXM_WIDE(16, 16) +PAETH_NXM_WIDE(16, 32) +PAETH_NXM_WIDE(32, 16) +PAETH_NXM_WIDE(32, 32) +PAETH_NXM_WIDE(32, 64) +PAETH_NXM_WIDE(64, 32) +PAETH_NXM_WIDE(64, 64) + +PAETH_NXM_WIDE(16, 4) +PAETH_NXM_WIDE(16, 64) +PAETH_NXM_WIDE(32, 8) +PAETH_NXM_WIDE(64, 16) diff --git a/third_party/aom/aom_dsp/arm/loopfilter_neon.c b/third_party/aom/aom_dsp/arm/loopfilter_neon.c new file mode 100644 index 0000000000..7c64be1253 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/loopfilter_neon.c @@ -0,0 +1,1045 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" + +static INLINE uint8x8_t lpf_mask(uint8x8_t p3q3, uint8x8_t p2q2, uint8x8_t p1q1, + uint8x8_t p0q0, const uint8_t blimit, + const uint8_t limit) { + // Calculate mask values for four samples + uint32x2x2_t p0q0_p1q1; + uint16x8_t temp_16x8; + uint16x4_t temp0_16x4, temp1_16x4; + uint8x8_t mask_8x8, temp_8x8; + const uint8x8_t limit_8x8 = vdup_n_u8(limit); + const uint16x4_t blimit_16x4 = vdup_n_u16((uint16_t)blimit); + + mask_8x8 = vabd_u8(p3q3, p2q2); + mask_8x8 = vmax_u8(mask_8x8, vabd_u8(p2q2, p1q1)); + mask_8x8 = vmax_u8(mask_8x8, vabd_u8(p1q1, p0q0)); + mask_8x8 = vcle_u8(mask_8x8, limit_8x8); + + temp_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(mask_8x8))); + mask_8x8 = vand_u8(mask_8x8, temp_8x8); + + p0q0_p1q1 = vtrn_u32(vreinterpret_u32_u8(p0q0), vreinterpret_u32_u8(p1q1)); + temp_8x8 = vabd_u8(vreinterpret_u8_u32(p0q0_p1q1.val[0]), + vreinterpret_u8_u32(p0q0_p1q1.val[1])); + temp_16x8 = vmovl_u8(temp_8x8); + temp0_16x4 = vshl_n_u16(vget_low_u16(temp_16x8), 1); + temp1_16x4 = vshr_n_u16(vget_high_u16(temp_16x8), 1); + temp0_16x4 = vadd_u16(temp0_16x4, temp1_16x4); + temp0_16x4 = vcle_u16(temp0_16x4, blimit_16x4); + temp_8x8 = vmovn_u16(vcombine_u16(temp0_16x4, temp0_16x4)); + + mask_8x8 = vand_u8(mask_8x8, temp_8x8); + + return mask_8x8; +} + +static INLINE uint8x8_t lpf_mask2(uint8x8_t p1q1, uint8x8_t p0q0, + const uint8_t blimit, const uint8_t limit) { + uint32x2x2_t p0q0_p1q1; + uint16x8_t temp_16x8; + uint16x4_t temp0_16x4, temp1_16x4; + const uint16x4_t blimit_16x4 = vdup_n_u16(blimit); + const uint8x8_t limit_8x8 = vdup_n_u8(limit); + uint8x8_t mask_8x8, temp_8x8; + + mask_8x8 = vabd_u8(p1q1, p0q0); + mask_8x8 = vcle_u8(mask_8x8, limit_8x8); + + temp_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(mask_8x8))); + mask_8x8 = vand_u8(mask_8x8, temp_8x8); + + p0q0_p1q1 = vtrn_u32(vreinterpret_u32_u8(p0q0), vreinterpret_u32_u8(p1q1)); + temp_8x8 = vabd_u8(vreinterpret_u8_u32(p0q0_p1q1.val[0]), + vreinterpret_u8_u32(p0q0_p1q1.val[1])); + temp_16x8 = vmovl_u8(temp_8x8); + temp0_16x4 = vshl_n_u16(vget_low_u16(temp_16x8), 1); + temp1_16x4 = vshr_n_u16(vget_high_u16(temp_16x8), 1); + temp0_16x4 = vadd_u16(temp0_16x4, temp1_16x4); + temp0_16x4 = vcle_u16(temp0_16x4, blimit_16x4); + temp_8x8 = vmovn_u16(vcombine_u16(temp0_16x4, temp0_16x4)); + + mask_8x8 = vand_u8(mask_8x8, temp_8x8); + + return mask_8x8; +} + +static INLINE uint8x8_t lpf_flat_mask4(uint8x8_t p3q3, uint8x8_t p2q2, + uint8x8_t p1q1, uint8x8_t p0q0) { + const uint8x8_t thresh_8x8 = vdup_n_u8(1); // for bd==8 threshold is always 1 + uint8x8_t flat_8x8, temp_8x8; + + flat_8x8 = vabd_u8(p1q1, p0q0); + flat_8x8 = vmax_u8(flat_8x8, vabd_u8(p2q2, p0q0)); + flat_8x8 = vmax_u8(flat_8x8, vabd_u8(p3q3, p0q0)); + flat_8x8 = vcle_u8(flat_8x8, thresh_8x8); + + temp_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(flat_8x8))); + flat_8x8 = vand_u8(flat_8x8, temp_8x8); + + return flat_8x8; +} + +static INLINE uint8x8_t lpf_flat_mask3(uint8x8_t p2q2, uint8x8_t p1q1, + uint8x8_t p0q0) { + const uint8x8_t thresh_8x8 = vdup_n_u8(1); // for bd==8 threshold is always 1 + uint8x8_t flat_8x8, temp_8x8; + + flat_8x8 = vabd_u8(p1q1, p0q0); + flat_8x8 = vmax_u8(flat_8x8, vabd_u8(p2q2, p0q0)); + flat_8x8 = vcle_u8(flat_8x8, thresh_8x8); + + temp_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(flat_8x8))); + flat_8x8 = vand_u8(flat_8x8, temp_8x8); + + return flat_8x8; +} + +static INLINE uint8x8_t lpf_mask3_chroma(uint8x8_t p2q2, uint8x8_t p1q1, + uint8x8_t p0q0, const uint8_t blimit, + const uint8_t limit) { + // Calculate mask3 values for four samples + uint32x2x2_t p0q0_p1q1; + uint16x8_t temp_16x8; + uint16x4_t temp0_16x4, temp1_16x4; + uint8x8_t mask_8x8, temp_8x8; + const uint8x8_t limit_8x8 = vdup_n_u8(limit); + const uint16x4_t blimit_16x4 = vdup_n_u16((uint16_t)blimit); + + mask_8x8 = vabd_u8(p2q2, p1q1); + mask_8x8 = vmax_u8(mask_8x8, vabd_u8(p1q1, p0q0)); + mask_8x8 = vcle_u8(mask_8x8, limit_8x8); + + temp_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(mask_8x8))); + mask_8x8 = vand_u8(mask_8x8, temp_8x8); + + p0q0_p1q1 = vtrn_u32(vreinterpret_u32_u8(p0q0), vreinterpret_u32_u8(p1q1)); + temp_8x8 = vabd_u8(vreinterpret_u8_u32(p0q0_p1q1.val[0]), + vreinterpret_u8_u32(p0q0_p1q1.val[1])); + temp_16x8 = vmovl_u8(temp_8x8); + temp0_16x4 = vshl_n_u16(vget_low_u16(temp_16x8), 1); + temp1_16x4 = vshr_n_u16(vget_high_u16(temp_16x8), 1); + temp0_16x4 = vadd_u16(temp0_16x4, temp1_16x4); + temp0_16x4 = vcle_u16(temp0_16x4, blimit_16x4); + temp_8x8 = vmovn_u16(vcombine_u16(temp0_16x4, temp0_16x4)); + + mask_8x8 = vand_u8(mask_8x8, temp_8x8); + + return mask_8x8; +} + +static void lpf_14_neon(uint8x8_t *p6q6, uint8x8_t *p5q5, uint8x8_t *p4q4, + uint8x8_t *p3q3, uint8x8_t *p2q2, uint8x8_t *p1q1, + uint8x8_t *p0q0, const uint8_t blimit, + const uint8_t limit, const uint8_t thresh) { + uint16x8_t out; + uint8x8_t out_f14_pq0, out_f14_pq1, out_f14_pq2, out_f14_pq3, out_f14_pq4, + out_f14_pq5; + uint8x8_t out_f7_pq0, out_f7_pq1, out_f7_pq2; + uint8x8_t out_f4_pq0, out_f4_pq1; + uint8x8_t mask_8x8, flat_8x8, flat2_8x8; + uint8x8_t q0p0, q1p1, q2p2; + + // Calculate filter masks + mask_8x8 = lpf_mask(*p3q3, *p2q2, *p1q1, *p0q0, blimit, limit); + flat_8x8 = lpf_flat_mask4(*p3q3, *p2q2, *p1q1, *p0q0); + flat2_8x8 = lpf_flat_mask4(*p6q6, *p5q5, *p4q4, *p0q0); + { + // filter 4 + int32x2x2_t ps0_qs0, ps1_qs1; + int16x8_t filter_s16; + const uint8x8_t thresh_f4 = vdup_n_u8(thresh); + uint8x8_t temp0_8x8, temp1_8x8; + int8x8_t ps0_s8, ps1_s8, qs0_s8, qs1_s8, temp_s8; + int8x8_t op0, oq0, op1, oq1; + int8x8_t pq_s0, pq_s1; + int8x8_t filter_s8, filter1_s8, filter2_s8; + int8x8_t hev_8x8; + const int8x8_t sign_mask = vdup_n_s8(0x80); + const int8x8_t val_4 = vdup_n_s8(4); + const int8x8_t val_3 = vdup_n_s8(3); + + pq_s0 = veor_s8(vreinterpret_s8_u8(*p0q0), sign_mask); + pq_s1 = veor_s8(vreinterpret_s8_u8(*p1q1), sign_mask); + + ps0_qs0 = vtrn_s32(vreinterpret_s32_s8(pq_s0), vreinterpret_s32_s8(pq_s0)); + ps1_qs1 = vtrn_s32(vreinterpret_s32_s8(pq_s1), vreinterpret_s32_s8(pq_s1)); + ps0_s8 = vreinterpret_s8_s32(ps0_qs0.val[0]); + qs0_s8 = vreinterpret_s8_s32(ps0_qs0.val[1]); + ps1_s8 = vreinterpret_s8_s32(ps1_qs1.val[0]); + qs1_s8 = vreinterpret_s8_s32(ps1_qs1.val[1]); + + // hev_mask + temp0_8x8 = vcgt_u8(vabd_u8(*p0q0, *p1q1), thresh_f4); + temp1_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(temp0_8x8))); + hev_8x8 = vreinterpret_s8_u8(vorr_u8(temp0_8x8, temp1_8x8)); + + // add outer taps if we have high edge variance + filter_s8 = vqsub_s8(ps1_s8, qs1_s8); + filter_s8 = vand_s8(filter_s8, hev_8x8); + + // inner taps + temp_s8 = vqsub_s8(qs0_s8, ps0_s8); + filter_s16 = vmovl_s8(filter_s8); + filter_s16 = vmlal_s8(filter_s16, temp_s8, val_3); + filter_s8 = vqmovn_s16(filter_s16); + filter_s8 = vand_s8(filter_s8, vreinterpret_s8_u8(mask_8x8)); + + filter1_s8 = vqadd_s8(filter_s8, val_4); + filter2_s8 = vqadd_s8(filter_s8, val_3); + filter1_s8 = vshr_n_s8(filter1_s8, 3); + filter2_s8 = vshr_n_s8(filter2_s8, 3); + + oq0 = veor_s8(vqsub_s8(qs0_s8, filter1_s8), sign_mask); + op0 = veor_s8(vqadd_s8(ps0_s8, filter2_s8), sign_mask); + + hev_8x8 = vmvn_s8(hev_8x8); + filter_s8 = vrshr_n_s8(filter1_s8, 1); + filter_s8 = vand_s8(filter_s8, hev_8x8); + + oq1 = veor_s8(vqsub_s8(qs1_s8, filter_s8), sign_mask); + op1 = veor_s8(vqadd_s8(ps1_s8, filter_s8), sign_mask); + + out_f4_pq0 = vreinterpret_u8_s8(vext_s8(op0, oq0, 4)); + out_f4_pq1 = vreinterpret_u8_s8(vext_s8(op1, oq1, 4)); + } + // reverse p and q + q0p0 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p0q0))); + q1p1 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p1q1))); + q2p2 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p2q2))); + { + // filter 8 + uint16x8_t out_pq0, out_pq1, out_pq2; + out = vaddl_u8(*p3q3, *p2q2); + out = vaddw_u8(out, *p1q1); + out = vaddw_u8(out, *p0q0); + + out = vaddw_u8(out, q0p0); + out_pq1 = vaddw_u8(out, *p3q3); + out_pq2 = vaddw_u8(out_pq1, *p3q3); + out_pq2 = vaddw_u8(out_pq2, *p2q2); + out_pq1 = vaddw_u8(out_pq1, *p1q1); + out_pq1 = vaddw_u8(out_pq1, q1p1); + + out_pq0 = vaddw_u8(out, *p0q0); + out_pq0 = vaddw_u8(out_pq0, q1p1); + out_pq0 = vaddw_u8(out_pq0, q2p2); + + out_f7_pq0 = vrshrn_n_u16(out_pq0, 3); + out_f7_pq1 = vrshrn_n_u16(out_pq1, 3); + out_f7_pq2 = vrshrn_n_u16(out_pq2, 3); + } + { + // filter 14 + uint16x8_t out_pq0, out_pq1, out_pq2, out_pq3, out_pq4, out_pq5; + uint16x8_t p6q6_2, p6q6_temp, qp_sum; + uint8x8_t qp_rev; + + out = vaddw_u8(out, *p4q4); + out = vaddw_u8(out, *p5q5); + out = vaddw_u8(out, *p6q6); + + out_pq5 = vaddw_u8(out, *p4q4); + out_pq4 = vaddw_u8(out_pq5, *p3q3); + out_pq3 = vaddw_u8(out_pq4, *p2q2); + + out_pq5 = vaddw_u8(out_pq5, *p5q5); + out_pq4 = vaddw_u8(out_pq4, *p5q5); + + out_pq0 = vaddw_u8(out, *p1q1); + out_pq1 = vaddw_u8(out_pq0, *p2q2); + out_pq2 = vaddw_u8(out_pq1, *p3q3); + + out_pq0 = vaddw_u8(out_pq0, *p0q0); + out_pq1 = vaddw_u8(out_pq1, *p0q0); + + out_pq1 = vaddw_u8(out_pq1, *p6q6); + p6q6_2 = vaddl_u8(*p6q6, *p6q6); + out_pq2 = vaddq_u16(out_pq2, p6q6_2); + p6q6_temp = vaddw_u8(p6q6_2, *p6q6); + out_pq3 = vaddq_u16(out_pq3, p6q6_temp); + p6q6_temp = vaddw_u8(p6q6_temp, *p6q6); + out_pq4 = vaddq_u16(out_pq4, p6q6_temp); + p6q6_temp = vaddq_u16(p6q6_temp, p6q6_2); + out_pq5 = vaddq_u16(out_pq5, p6q6_temp); + + out_pq4 = vaddw_u8(out_pq4, q1p1); + + qp_sum = vaddl_u8(q2p2, q1p1); + out_pq3 = vaddq_u16(out_pq3, qp_sum); + + qp_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p3q3))); + qp_sum = vaddw_u8(qp_sum, qp_rev); + out_pq2 = vaddq_u16(out_pq2, qp_sum); + + qp_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p4q4))); + qp_sum = vaddw_u8(qp_sum, qp_rev); + out_pq1 = vaddq_u16(out_pq1, qp_sum); + + qp_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p5q5))); + qp_sum = vaddw_u8(qp_sum, qp_rev); + out_pq0 = vaddq_u16(out_pq0, qp_sum); + + out_pq0 = vaddw_u8(out_pq0, q0p0); + + out_f14_pq0 = vrshrn_n_u16(out_pq0, 4); + out_f14_pq1 = vrshrn_n_u16(out_pq1, 4); + out_f14_pq2 = vrshrn_n_u16(out_pq2, 4); + out_f14_pq3 = vrshrn_n_u16(out_pq3, 4); + out_f14_pq4 = vrshrn_n_u16(out_pq4, 4); + out_f14_pq5 = vrshrn_n_u16(out_pq5, 4); + } + { + uint8x8_t filter4_cond, filter8_cond, filter14_cond; + filter8_cond = vand_u8(flat_8x8, mask_8x8); + filter4_cond = vmvn_u8(filter8_cond); + filter14_cond = vand_u8(filter8_cond, flat2_8x8); + + // filter4 outputs + *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0); + *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1); + + // filter8 outputs + *p0q0 = vbsl_u8(filter8_cond, out_f7_pq0, *p0q0); + *p1q1 = vbsl_u8(filter8_cond, out_f7_pq1, *p1q1); + *p2q2 = vbsl_u8(filter8_cond, out_f7_pq2, *p2q2); + + // filter14 outputs + *p0q0 = vbsl_u8(filter14_cond, out_f14_pq0, *p0q0); + *p1q1 = vbsl_u8(filter14_cond, out_f14_pq1, *p1q1); + *p2q2 = vbsl_u8(filter14_cond, out_f14_pq2, *p2q2); + *p3q3 = vbsl_u8(filter14_cond, out_f14_pq3, *p3q3); + *p4q4 = vbsl_u8(filter14_cond, out_f14_pq4, *p4q4); + *p5q5 = vbsl_u8(filter14_cond, out_f14_pq5, *p5q5); + } +} + +static void lpf_8_neon(uint8x8_t *p3q3, uint8x8_t *p2q2, uint8x8_t *p1q1, + uint8x8_t *p0q0, const uint8_t blimit, + const uint8_t limit, const uint8_t thresh) { + uint16x8_t out; + uint8x8_t out_f7_pq0, out_f7_pq1, out_f7_pq2; + uint8x8_t out_f4_pq0, out_f4_pq1; + uint8x8_t mask_8x8, flat_8x8; + + // Calculate filter masks + mask_8x8 = lpf_mask(*p3q3, *p2q2, *p1q1, *p0q0, blimit, limit); + flat_8x8 = lpf_flat_mask4(*p3q3, *p2q2, *p1q1, *p0q0); + { + // filter 4 + int32x2x2_t ps0_qs0, ps1_qs1; + int16x8_t filter_s16; + const uint8x8_t thresh_f4 = vdup_n_u8(thresh); + uint8x8_t temp0_8x8, temp1_8x8; + int8x8_t ps0_s8, ps1_s8, qs0_s8, qs1_s8, temp_s8; + int8x8_t op0, oq0, op1, oq1; + int8x8_t pq_s0, pq_s1; + int8x8_t filter_s8, filter1_s8, filter2_s8; + int8x8_t hev_8x8; + const int8x8_t sign_mask = vdup_n_s8(0x80); + const int8x8_t val_4 = vdup_n_s8(4); + const int8x8_t val_3 = vdup_n_s8(3); + + pq_s0 = veor_s8(vreinterpret_s8_u8(*p0q0), sign_mask); + pq_s1 = veor_s8(vreinterpret_s8_u8(*p1q1), sign_mask); + + ps0_qs0 = vtrn_s32(vreinterpret_s32_s8(pq_s0), vreinterpret_s32_s8(pq_s0)); + ps1_qs1 = vtrn_s32(vreinterpret_s32_s8(pq_s1), vreinterpret_s32_s8(pq_s1)); + ps0_s8 = vreinterpret_s8_s32(ps0_qs0.val[0]); + qs0_s8 = vreinterpret_s8_s32(ps0_qs0.val[1]); + ps1_s8 = vreinterpret_s8_s32(ps1_qs1.val[0]); + qs1_s8 = vreinterpret_s8_s32(ps1_qs1.val[1]); + + // hev_mask + temp0_8x8 = vcgt_u8(vabd_u8(*p0q0, *p1q1), thresh_f4); + temp1_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(temp0_8x8))); + hev_8x8 = vreinterpret_s8_u8(vorr_u8(temp0_8x8, temp1_8x8)); + + // add outer taps if we have high edge variance + filter_s8 = vqsub_s8(ps1_s8, qs1_s8); + filter_s8 = vand_s8(filter_s8, hev_8x8); + + // inner taps + temp_s8 = vqsub_s8(qs0_s8, ps0_s8); + filter_s16 = vmovl_s8(filter_s8); + filter_s16 = vmlal_s8(filter_s16, temp_s8, val_3); + filter_s8 = vqmovn_s16(filter_s16); + filter_s8 = vand_s8(filter_s8, vreinterpret_s8_u8(mask_8x8)); + + filter1_s8 = vqadd_s8(filter_s8, val_4); + filter2_s8 = vqadd_s8(filter_s8, val_3); + filter1_s8 = vshr_n_s8(filter1_s8, 3); + filter2_s8 = vshr_n_s8(filter2_s8, 3); + + oq0 = veor_s8(vqsub_s8(qs0_s8, filter1_s8), sign_mask); + op0 = veor_s8(vqadd_s8(ps0_s8, filter2_s8), sign_mask); + + hev_8x8 = vmvn_s8(hev_8x8); + filter_s8 = vrshr_n_s8(filter1_s8, 1); + filter_s8 = vand_s8(filter_s8, hev_8x8); + + oq1 = veor_s8(vqsub_s8(qs1_s8, filter_s8), sign_mask); + op1 = veor_s8(vqadd_s8(ps1_s8, filter_s8), sign_mask); + + out_f4_pq0 = vreinterpret_u8_s8(vext_s8(op0, oq0, 4)); + out_f4_pq1 = vreinterpret_u8_s8(vext_s8(op1, oq1, 4)); + } + { + // filter 8 + uint16x8_t out_pq0, out_pq1, out_pq2; + uint8x8_t q0p0, q1p1, q2p2; + + out = vaddl_u8(*p3q3, *p2q2); + out = vaddw_u8(out, *p1q1); + out = vaddw_u8(out, *p0q0); + + // reverse p and q + q0p0 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p0q0))); + q1p1 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p1q1))); + q2p2 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p2q2))); + + out = vaddw_u8(out, q0p0); + out_pq1 = vaddw_u8(out, *p3q3); + out_pq2 = vaddw_u8(out_pq1, *p3q3); + out_pq2 = vaddw_u8(out_pq2, *p2q2); + out_pq1 = vaddw_u8(out_pq1, *p1q1); + out_pq1 = vaddw_u8(out_pq1, q1p1); + + out_pq0 = vaddw_u8(out, *p0q0); + out_pq0 = vaddw_u8(out_pq0, q1p1); + out_pq0 = vaddw_u8(out_pq0, q2p2); + + out_f7_pq0 = vrshrn_n_u16(out_pq0, 3); + out_f7_pq1 = vrshrn_n_u16(out_pq1, 3); + out_f7_pq2 = vrshrn_n_u16(out_pq2, 3); + } + { + uint8x8_t filter4_cond, filter8_cond; + filter8_cond = vand_u8(flat_8x8, mask_8x8); + filter4_cond = vmvn_u8(filter8_cond); + + // filter4 outputs + *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0); + *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1); + + // filter8 outputs + *p0q0 = vbsl_u8(filter8_cond, out_f7_pq0, *p0q0); + *p1q1 = vbsl_u8(filter8_cond, out_f7_pq1, *p1q1); + *p2q2 = vbsl_u8(filter8_cond, out_f7_pq2, *p2q2); + } +} + +static void lpf_6_neon(uint8x8_t *p2q2, uint8x8_t *p1q1, uint8x8_t *p0q0, + const uint8_t blimit, const uint8_t limit, + const uint8_t thresh) { + uint16x8_t out; + uint8x8_t out_f6_pq0, out_f6_pq1; + uint8x8_t out_f4_pq0, out_f4_pq1; + uint8x8_t mask_8x8, flat_8x8; + + // Calculate filter masks + mask_8x8 = lpf_mask3_chroma(*p2q2, *p1q1, *p0q0, blimit, limit); + flat_8x8 = lpf_flat_mask3(*p2q2, *p1q1, *p0q0); + { + // filter 4 + int32x2x2_t ps0_qs0, ps1_qs1; + int16x8_t filter_s16; + const uint8x8_t thresh_f4 = vdup_n_u8(thresh); + uint8x8_t temp0_8x8, temp1_8x8; + int8x8_t ps0_s8, ps1_s8, qs0_s8, qs1_s8, temp_s8; + int8x8_t op0, oq0, op1, oq1; + int8x8_t pq_s0, pq_s1; + int8x8_t filter_s8, filter1_s8, filter2_s8; + int8x8_t hev_8x8; + const int8x8_t sign_mask = vdup_n_s8(0x80); + const int8x8_t val_4 = vdup_n_s8(4); + const int8x8_t val_3 = vdup_n_s8(3); + + pq_s0 = veor_s8(vreinterpret_s8_u8(*p0q0), sign_mask); + pq_s1 = veor_s8(vreinterpret_s8_u8(*p1q1), sign_mask); + + ps0_qs0 = vtrn_s32(vreinterpret_s32_s8(pq_s0), vreinterpret_s32_s8(pq_s0)); + ps1_qs1 = vtrn_s32(vreinterpret_s32_s8(pq_s1), vreinterpret_s32_s8(pq_s1)); + ps0_s8 = vreinterpret_s8_s32(ps0_qs0.val[0]); + qs0_s8 = vreinterpret_s8_s32(ps0_qs0.val[1]); + ps1_s8 = vreinterpret_s8_s32(ps1_qs1.val[0]); + qs1_s8 = vreinterpret_s8_s32(ps1_qs1.val[1]); + + // hev_mask + temp0_8x8 = vcgt_u8(vabd_u8(*p0q0, *p1q1), thresh_f4); + temp1_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(temp0_8x8))); + hev_8x8 = vreinterpret_s8_u8(vorr_u8(temp0_8x8, temp1_8x8)); + + // add outer taps if we have high edge variance + filter_s8 = vqsub_s8(ps1_s8, qs1_s8); + filter_s8 = vand_s8(filter_s8, hev_8x8); + + // inner taps + temp_s8 = vqsub_s8(qs0_s8, ps0_s8); + filter_s16 = vmovl_s8(filter_s8); + filter_s16 = vmlal_s8(filter_s16, temp_s8, val_3); + filter_s8 = vqmovn_s16(filter_s16); + filter_s8 = vand_s8(filter_s8, vreinterpret_s8_u8(mask_8x8)); + + filter1_s8 = vqadd_s8(filter_s8, val_4); + filter2_s8 = vqadd_s8(filter_s8, val_3); + filter1_s8 = vshr_n_s8(filter1_s8, 3); + filter2_s8 = vshr_n_s8(filter2_s8, 3); + + oq0 = veor_s8(vqsub_s8(qs0_s8, filter1_s8), sign_mask); + op0 = veor_s8(vqadd_s8(ps0_s8, filter2_s8), sign_mask); + + filter_s8 = vrshr_n_s8(filter1_s8, 1); + filter_s8 = vbic_s8(filter_s8, hev_8x8); + + oq1 = veor_s8(vqsub_s8(qs1_s8, filter_s8), sign_mask); + op1 = veor_s8(vqadd_s8(ps1_s8, filter_s8), sign_mask); + + out_f4_pq0 = vreinterpret_u8_s8(vext_s8(op0, oq0, 4)); + out_f4_pq1 = vreinterpret_u8_s8(vext_s8(op1, oq1, 4)); + } + { + // filter 6 + uint16x8_t out_pq0, out_pq1; + uint8x8_t pq_rev; + + out = vaddl_u8(*p0q0, *p1q1); + out = vaddq_u16(out, out); + out = vaddw_u8(out, *p2q2); + + pq_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p0q0))); + out = vaddw_u8(out, pq_rev); + + out_pq0 = vaddw_u8(out, pq_rev); + pq_rev = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(*p1q1))); + out_pq0 = vaddw_u8(out_pq0, pq_rev); + + out_pq1 = vaddw_u8(out, *p2q2); + out_pq1 = vaddw_u8(out_pq1, *p2q2); + + out_f6_pq0 = vrshrn_n_u16(out_pq0, 3); + out_f6_pq1 = vrshrn_n_u16(out_pq1, 3); + } + { + uint8x8_t filter4_cond, filter6_cond; + filter6_cond = vand_u8(flat_8x8, mask_8x8); + filter4_cond = vmvn_u8(filter6_cond); + + // filter4 outputs + *p0q0 = vbsl_u8(filter4_cond, out_f4_pq0, *p0q0); + *p1q1 = vbsl_u8(filter4_cond, out_f4_pq1, *p1q1); + + // filter6 outputs + *p0q0 = vbsl_u8(filter6_cond, out_f6_pq0, *p0q0); + *p1q1 = vbsl_u8(filter6_cond, out_f6_pq1, *p1q1); + } +} + +static void lpf_4_neon(uint8x8_t *p1q1, uint8x8_t *p0q0, const uint8_t blimit, + const uint8_t limit, const uint8_t thresh) { + int32x2x2_t ps0_qs0, ps1_qs1; + int16x8_t filter_s16; + const uint8x8_t thresh_f4 = vdup_n_u8(thresh); + uint8x8_t mask_8x8, temp0_8x8, temp1_8x8; + int8x8_t ps0_s8, ps1_s8, qs0_s8, qs1_s8, temp_s8; + int8x8_t op0, oq0, op1, oq1; + int8x8_t pq_s0, pq_s1; + int8x8_t filter_s8, filter1_s8, filter2_s8; + int8x8_t hev_8x8; + const int8x8_t sign_mask = vdup_n_s8(0x80); + const int8x8_t val_4 = vdup_n_s8(4); + const int8x8_t val_3 = vdup_n_s8(3); + + // Calculate filter mask + mask_8x8 = lpf_mask2(*p1q1, *p0q0, blimit, limit); + + pq_s0 = veor_s8(vreinterpret_s8_u8(*p0q0), sign_mask); + pq_s1 = veor_s8(vreinterpret_s8_u8(*p1q1), sign_mask); + + ps0_qs0 = vtrn_s32(vreinterpret_s32_s8(pq_s0), vreinterpret_s32_s8(pq_s0)); + ps1_qs1 = vtrn_s32(vreinterpret_s32_s8(pq_s1), vreinterpret_s32_s8(pq_s1)); + ps0_s8 = vreinterpret_s8_s32(ps0_qs0.val[0]); + qs0_s8 = vreinterpret_s8_s32(ps0_qs0.val[1]); + ps1_s8 = vreinterpret_s8_s32(ps1_qs1.val[0]); + qs1_s8 = vreinterpret_s8_s32(ps1_qs1.val[1]); + + // hev_mask + temp0_8x8 = vcgt_u8(vabd_u8(*p0q0, *p1q1), thresh_f4); + temp1_8x8 = vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(temp0_8x8))); + hev_8x8 = vreinterpret_s8_u8(vorr_u8(temp0_8x8, temp1_8x8)); + + // add outer taps if we have high edge variance + filter_s8 = vqsub_s8(ps1_s8, qs1_s8); + filter_s8 = vand_s8(filter_s8, hev_8x8); + + // inner taps + temp_s8 = vqsub_s8(qs0_s8, ps0_s8); + filter_s16 = vmovl_s8(filter_s8); + filter_s16 = vmlal_s8(filter_s16, temp_s8, val_3); + filter_s8 = vqmovn_s16(filter_s16); + filter_s8 = vand_s8(filter_s8, vreinterpret_s8_u8(mask_8x8)); + + filter1_s8 = vqadd_s8(filter_s8, val_4); + filter2_s8 = vqadd_s8(filter_s8, val_3); + filter1_s8 = vshr_n_s8(filter1_s8, 3); + filter2_s8 = vshr_n_s8(filter2_s8, 3); + + oq0 = veor_s8(vqsub_s8(qs0_s8, filter1_s8), sign_mask); + op0 = veor_s8(vqadd_s8(ps0_s8, filter2_s8), sign_mask); + + filter_s8 = vrshr_n_s8(filter1_s8, 1); + filter_s8 = vbic_s8(filter_s8, hev_8x8); + + oq1 = veor_s8(vqsub_s8(qs1_s8, filter_s8), sign_mask); + op1 = veor_s8(vqadd_s8(ps1_s8, filter_s8), sign_mask); + + *p0q0 = vreinterpret_u8_s8(vext_s8(op0, oq0, 4)); + *p1q1 = vreinterpret_u8_s8(vext_s8(op1, oq1, 4)); +} + +void aom_lpf_vertical_14_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint8x16_t row0, row1, row2, row3; + uint8x8_t pxp3, p6p2, p5p1, p4p0; + uint8x8_t q0q4, q1q5, q2q6, q3qy; + uint32x2x2_t p6q6_p2q2, p5q5_p1q1, p4q4_p0q0, pxqx_p3q3; + uint32x2_t pq_rev; + uint8x8_t p0q0, p1q1, p2q2, p3q3, p4q4, p5q5, p6q6; + + // row0: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y + // row1: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y + // row2: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y + // row3: x p6 p5 p4 p3 p2 p1 p0 | q0 q1 q2 q3 q4 q5 q6 y + load_u8_16x4(src - 8, stride, &row0, &row1, &row2, &row3); + + pxp3 = vget_low_u8(row0); + p6p2 = vget_low_u8(row1); + p5p1 = vget_low_u8(row2); + p4p0 = vget_low_u8(row3); + transpose_elems_inplace_u8_8x4(&pxp3, &p6p2, &p5p1, &p4p0); + + q0q4 = vget_high_u8(row0); + q1q5 = vget_high_u8(row1); + q2q6 = vget_high_u8(row2); + q3qy = vget_high_u8(row3); + transpose_elems_inplace_u8_8x4(&q0q4, &q1q5, &q2q6, &q3qy); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(q3qy)); + pxqx_p3q3 = vtrn_u32(vreinterpret_u32_u8(pxp3), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(q1q5)); + p5q5_p1q1 = vtrn_u32(vreinterpret_u32_u8(p5p1), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(q0q4)); + p4q4_p0q0 = vtrn_u32(vreinterpret_u32_u8(p4p0), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(q2q6)); + p6q6_p2q2 = vtrn_u32(vreinterpret_u32_u8(p6p2), pq_rev); + + p0q0 = vreinterpret_u8_u32(p4q4_p0q0.val[1]); + p1q1 = vreinterpret_u8_u32(p5q5_p1q1.val[1]); + p2q2 = vreinterpret_u8_u32(p6q6_p2q2.val[1]); + p3q3 = vreinterpret_u8_u32(pxqx_p3q3.val[1]); + p4q4 = vreinterpret_u8_u32(p4q4_p0q0.val[0]); + p5q5 = vreinterpret_u8_u32(p5q5_p1q1.val[0]); + p6q6 = vreinterpret_u8_u32(p6q6_p2q2.val[0]); + + lpf_14_neon(&p6q6, &p5q5, &p4q4, &p3q3, &p2q2, &p1q1, &p0q0, *blimit, *limit, + *thresh); + + pxqx_p3q3 = vtrn_u32(pxqx_p3q3.val[0], vreinterpret_u32_u8(p3q3)); + p5q5_p1q1 = vtrn_u32(vreinterpret_u32_u8(p5q5), vreinterpret_u32_u8(p1q1)); + p4q4_p0q0 = vtrn_u32(vreinterpret_u32_u8(p4q4), vreinterpret_u32_u8(p0q0)); + p6q6_p2q2 = vtrn_u32(vreinterpret_u32_u8(p6q6), vreinterpret_u32_u8(p2q2)); + + pxqx_p3q3.val[1] = vrev64_u32(pxqx_p3q3.val[1]); + p5q5_p1q1.val[1] = vrev64_u32(p5q5_p1q1.val[1]); + p4q4_p0q0.val[1] = vrev64_u32(p4q4_p0q0.val[1]); + p6q6_p2q2.val[1] = vrev64_u32(p6q6_p2q2.val[1]); + + q0q4 = vreinterpret_u8_u32(p4q4_p0q0.val[1]); + q1q5 = vreinterpret_u8_u32(p5q5_p1q1.val[1]); + q2q6 = vreinterpret_u8_u32(p6q6_p2q2.val[1]); + q3qy = vreinterpret_u8_u32(pxqx_p3q3.val[1]); + transpose_elems_inplace_u8_8x4(&q0q4, &q1q5, &q2q6, &q3qy); + + pxp3 = vreinterpret_u8_u32(pxqx_p3q3.val[0]); + p6p2 = vreinterpret_u8_u32(p6q6_p2q2.val[0]); + p5p1 = vreinterpret_u8_u32(p5q5_p1q1.val[0]); + p4p0 = vreinterpret_u8_u32(p4q4_p0q0.val[0]); + transpose_elems_inplace_u8_8x4(&pxp3, &p6p2, &p5p1, &p4p0); + + row0 = vcombine_u8(pxp3, q0q4); + row1 = vcombine_u8(p6p2, q1q5); + row2 = vcombine_u8(p5p1, q2q6); + row3 = vcombine_u8(p4p0, q3qy); + + store_u8_16x4(src - 8, stride, row0, row1, row2, row3); +} + +void aom_lpf_vertical_14_dual_neon( + uint8_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_vertical_14_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_vertical_14_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1); +} + +void aom_lpf_vertical_14_quad_neon(uint8_t *s, int pitch, const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh) { + aom_lpf_vertical_14_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_vertical_14_dual_neon(s + 2 * MI_SIZE * pitch, pitch, blimit, limit, + thresh, blimit, limit, thresh); +} + +void aom_lpf_vertical_8_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint32x2x2_t p2q2_p1q1, p3q3_p0q0; + uint32x2_t pq_rev; + uint8x8_t p3q0, p2q1, p1q2, p0q3; + uint8x8_t p0q0, p1q1, p2q2, p3q3; + + // row0: p3 p2 p1 p0 | q0 q1 q2 q3 + // row1: p3 p2 p1 p0 | q0 q1 q2 q3 + // row2: p3 p2 p1 p0 | q0 q1 q2 q3 + // row3: p3 p2 p1 p0 | q0 q1 q2 q3 + load_u8_8x4(src - 4, stride, &p3q0, &p2q1, &p1q2, &p0q3); + + transpose_elems_inplace_u8_8x4(&p3q0, &p2q1, &p1q2, &p0q3); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p0q3)); + p3q3_p0q0 = vtrn_u32(vreinterpret_u32_u8(p3q0), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p1q2)); + p2q2_p1q1 = vtrn_u32(vreinterpret_u32_u8(p2q1), pq_rev); + + p0q0 = vreinterpret_u8_u32(vrev64_u32(p3q3_p0q0.val[1])); + p1q1 = vreinterpret_u8_u32(vrev64_u32(p2q2_p1q1.val[1])); + p2q2 = vreinterpret_u8_u32(p2q2_p1q1.val[0]); + p3q3 = vreinterpret_u8_u32(p3q3_p0q0.val[0]); + + lpf_8_neon(&p3q3, &p2q2, &p1q1, &p0q0, *blimit, *limit, *thresh); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p0q0)); + p3q3_p0q0 = vtrn_u32(vreinterpret_u32_u8(p3q3), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p1q1)); + p2q2_p1q1 = vtrn_u32(vreinterpret_u32_u8(p2q2), pq_rev); + + p0q3 = vreinterpret_u8_u32(vrev64_u32(p3q3_p0q0.val[1])); + p1q2 = vreinterpret_u8_u32(vrev64_u32(p2q2_p1q1.val[1])); + p2q1 = vreinterpret_u8_u32(p2q2_p1q1.val[0]); + p3q0 = vreinterpret_u8_u32(p3q3_p0q0.val[0]); + transpose_elems_inplace_u8_8x4(&p3q0, &p2q1, &p1q2, &p0q3); + + store_u8_8x4(src - 4, stride, p3q0, p2q1, p1q2, p0q3); +} + +void aom_lpf_vertical_8_dual_neon(uint8_t *s, int pitch, const uint8_t *blimit0, + const uint8_t *limit0, const uint8_t *thresh0, + const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_vertical_8_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_vertical_8_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1); +} + +void aom_lpf_vertical_8_quad_neon(uint8_t *s, int pitch, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + aom_lpf_vertical_8_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_vertical_8_dual_neon(s + 2 * MI_SIZE * pitch, pitch, blimit, limit, + thresh, blimit, limit, thresh); +} + +void aom_lpf_vertical_6_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint32x2x2_t p2q2_p1q1, pxqy_p0q0; + uint32x2_t pq_rev; + uint8x8_t pxq0, p2q1, p1q2, p0qy; + uint8x8_t p0q0, p1q1, p2q2, pxqy; + + // row0: px p2 p1 p0 | q0 q1 q2 qy + // row1: px p2 p1 p0 | q0 q1 q2 qy + // row2: px p2 p1 p0 | q0 q1 q2 qy + // row3: px p2 p1 p0 | q0 q1 q2 qy + load_u8_8x4(src - 4, stride, &pxq0, &p2q1, &p1q2, &p0qy); + + transpose_elems_inplace_u8_8x4(&pxq0, &p2q1, &p1q2, &p0qy); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p0qy)); + pxqy_p0q0 = vtrn_u32(vreinterpret_u32_u8(pxq0), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p1q2)); + p2q2_p1q1 = vtrn_u32(vreinterpret_u32_u8(p2q1), pq_rev); + + p0q0 = vreinterpret_u8_u32(vrev64_u32(pxqy_p0q0.val[1])); + p1q1 = vreinterpret_u8_u32(vrev64_u32(p2q2_p1q1.val[1])); + p2q2 = vreinterpret_u8_u32(p2q2_p1q1.val[0]); + pxqy = vreinterpret_u8_u32(pxqy_p0q0.val[0]); + + lpf_6_neon(&p2q2, &p1q1, &p0q0, *blimit, *limit, *thresh); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p0q0)); + pxqy_p0q0 = vtrn_u32(vreinterpret_u32_u8(pxqy), pq_rev); + + pq_rev = vrev64_u32(vreinterpret_u32_u8(p1q1)); + p2q2_p1q1 = vtrn_u32(vreinterpret_u32_u8(p2q2), pq_rev); + + p0qy = vreinterpret_u8_u32(vrev64_u32(pxqy_p0q0.val[1])); + p1q2 = vreinterpret_u8_u32(vrev64_u32(p2q2_p1q1.val[1])); + p2q1 = vreinterpret_u8_u32(p2q2_p1q1.val[0]); + pxq0 = vreinterpret_u8_u32(pxqy_p0q0.val[0]); + transpose_elems_inplace_u8_8x4(&pxq0, &p2q1, &p1q2, &p0qy); + + store_u8_8x4(src - 4, stride, pxq0, p2q1, p1q2, p0qy); +} + +void aom_lpf_vertical_6_dual_neon(uint8_t *s, int pitch, const uint8_t *blimit0, + const uint8_t *limit0, const uint8_t *thresh0, + const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_vertical_6_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_vertical_6_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1); +} + +void aom_lpf_vertical_6_quad_neon(uint8_t *s, int pitch, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + aom_lpf_vertical_6_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_vertical_6_dual_neon(s + 2 * MI_SIZE * pitch, pitch, blimit, limit, + thresh, blimit, limit, thresh); +} + +void aom_lpf_vertical_4_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint32x2x2_t p1q0_p0q1, p1q1_p0q0, p1p0_q1q0; + uint32x2_t pq_rev; + uint8x8_t p1p0, q0q1; + uint8x8_t p0q0, p1q1; + + // row0: p1 p0 | q0 q1 + // row1: p1 p0 | q0 q1 + // row2: p1 p0 | q0 q1 + // row3: p1 p0 | q0 q1 + load_unaligned_u8_4x4(src - 2, stride, &p1p0, &q0q1); + + transpose_elems_inplace_u8_4x4(&p1p0, &q0q1); + + p1q0_p0q1 = vtrn_u32(vreinterpret_u32_u8(p1p0), vreinterpret_u32_u8(q0q1)); + + pq_rev = vrev64_u32(p1q0_p0q1.val[1]); + p1q1_p0q0 = vtrn_u32(p1q0_p0q1.val[0], pq_rev); + + p1q1 = vreinterpret_u8_u32(p1q1_p0q0.val[0]); + p0q0 = vreinterpret_u8_u32(p1q1_p0q0.val[1]); + + lpf_4_neon(&p1q1, &p0q0, *blimit, *limit, *thresh); + + p1p0_q1q0 = vtrn_u32(vreinterpret_u32_u8(p1q1), vreinterpret_u32_u8(p0q0)); + + p1p0 = vreinterpret_u8_u32(p1p0_q1q0.val[0]); + q0q1 = vreinterpret_u8_u32(vrev64_u32(p1p0_q1q0.val[1])); + + transpose_elems_inplace_u8_4x4(&p1p0, &q0q1); + + store_u8x4_strided_x2(src - 2, 2 * stride, p1p0); + store_u8x4_strided_x2(src + stride - 2, 2 * stride, q0q1); +} + +void aom_lpf_vertical_4_dual_neon(uint8_t *s, int pitch, const uint8_t *blimit0, + const uint8_t *limit0, const uint8_t *thresh0, + const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_vertical_4_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_vertical_4_neon(s + 4 * pitch, pitch, blimit1, limit1, thresh1); +} + +void aom_lpf_vertical_4_quad_neon(uint8_t *s, int pitch, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + aom_lpf_vertical_4_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_vertical_4_dual_neon(s + 2 * MI_SIZE * pitch, pitch, blimit, limit, + thresh, blimit, limit, thresh); +} + +void aom_lpf_horizontal_14_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint8x8_t p6q6 = load_u8_4x2(src - 7 * stride, 13 * stride); + uint8x8_t p5q5 = load_u8_4x2(src - 6 * stride, 11 * stride); + uint8x8_t p4q4 = load_u8_4x2(src - 5 * stride, 9 * stride); + uint8x8_t p3q3 = load_u8_4x2(src - 4 * stride, 7 * stride); + uint8x8_t p2q2 = load_u8_4x2(src - 3 * stride, 5 * stride); + uint8x8_t p1q1 = load_u8_4x2(src - 2 * stride, 3 * stride); + uint8x8_t p0q0 = load_u8_4x2(src - 1 * stride, 1 * stride); + + lpf_14_neon(&p6q6, &p5q5, &p4q4, &p3q3, &p2q2, &p1q1, &p0q0, *blimit, *limit, + *thresh); + + store_u8x4_strided_x2(src - 1 * stride, 1 * stride, p0q0); + store_u8x4_strided_x2(src - 2 * stride, 3 * stride, p1q1); + store_u8x4_strided_x2(src - 3 * stride, 5 * stride, p2q2); + store_u8x4_strided_x2(src - 4 * stride, 7 * stride, p3q3); + store_u8x4_strided_x2(src - 5 * stride, 9 * stride, p4q4); + store_u8x4_strided_x2(src - 6 * stride, 11 * stride, p5q5); +} + +void aom_lpf_horizontal_14_dual_neon( + uint8_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_horizontal_14_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_horizontal_14_neon(s + 4, pitch, blimit1, limit1, thresh1); +} + +// TODO(any): Rewrite in NEON (similar to quad SSE2 functions) for better speed +// up. +void aom_lpf_horizontal_14_quad_neon(uint8_t *s, int pitch, + const uint8_t *blimit, + const uint8_t *limit, + const uint8_t *thresh) { + aom_lpf_horizontal_14_dual_neon(s, pitch, blimit, limit, thresh, blimit, + limit, thresh); + aom_lpf_horizontal_14_dual_neon(s + 2 * MI_SIZE, pitch, blimit, limit, thresh, + blimit, limit, thresh); +} + +void aom_lpf_horizontal_8_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint8x8_t p0q0, p1q1, p2q2, p3q3; + + p3q3 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 4 * stride))); + p2q2 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 3 * stride))); + p1q1 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 2 * stride))); + p0q0 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 1 * stride))); + p0q0 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 0 * stride), + vreinterpret_u32_u8(p0q0), 1)); + p1q1 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 1 * stride), + vreinterpret_u32_u8(p1q1), 1)); + p2q2 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 2 * stride), + vreinterpret_u32_u8(p2q2), 1)); + p3q3 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 3 * stride), + vreinterpret_u32_u8(p3q3), 1)); + + lpf_8_neon(&p3q3, &p2q2, &p1q1, &p0q0, *blimit, *limit, *thresh); + + vst1_lane_u32((uint32_t *)(src - 4 * stride), vreinterpret_u32_u8(p3q3), 0); + vst1_lane_u32((uint32_t *)(src - 3 * stride), vreinterpret_u32_u8(p2q2), 0); + vst1_lane_u32((uint32_t *)(src - 2 * stride), vreinterpret_u32_u8(p1q1), 0); + vst1_lane_u32((uint32_t *)(src - 1 * stride), vreinterpret_u32_u8(p0q0), 0); + vst1_lane_u32((uint32_t *)(src + 0 * stride), vreinterpret_u32_u8(p0q0), 1); + vst1_lane_u32((uint32_t *)(src + 1 * stride), vreinterpret_u32_u8(p1q1), 1); + vst1_lane_u32((uint32_t *)(src + 2 * stride), vreinterpret_u32_u8(p2q2), 1); + vst1_lane_u32((uint32_t *)(src + 3 * stride), vreinterpret_u32_u8(p3q3), 1); +} + +void aom_lpf_horizontal_8_dual_neon( + uint8_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_horizontal_8_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_horizontal_8_neon(s + 4, pitch, blimit1, limit1, thresh1); +} + +// TODO(any): Rewrite in NEON (similar to quad SSE2 functions) for better speed +// up. +void aom_lpf_horizontal_8_quad_neon(uint8_t *s, int pitch, + const uint8_t *blimit, const uint8_t *limit, + const uint8_t *thresh) { + aom_lpf_horizontal_8_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_horizontal_8_dual_neon(s + 2 * MI_SIZE, pitch, blimit, limit, thresh, + blimit, limit, thresh); +} + +void aom_lpf_horizontal_6_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint8x8_t p0q0, p1q1, p2q2; + + p2q2 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 3 * stride))); + p1q1 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 2 * stride))); + p0q0 = vreinterpret_u8_u32(vld1_dup_u32((uint32_t *)(src - 1 * stride))); + p0q0 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 0 * stride), + vreinterpret_u32_u8(p0q0), 1)); + p1q1 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 1 * stride), + vreinterpret_u32_u8(p1q1), 1)); + p2q2 = vreinterpret_u8_u32(vld1_lane_u32((uint32_t *)(src + 2 * stride), + vreinterpret_u32_u8(p2q2), 1)); + + lpf_6_neon(&p2q2, &p1q1, &p0q0, *blimit, *limit, *thresh); + + vst1_lane_u32((uint32_t *)(src - 3 * stride), vreinterpret_u32_u8(p2q2), 0); + vst1_lane_u32((uint32_t *)(src - 2 * stride), vreinterpret_u32_u8(p1q1), 0); + vst1_lane_u32((uint32_t *)(src - 1 * stride), vreinterpret_u32_u8(p0q0), 0); + vst1_lane_u32((uint32_t *)(src + 0 * stride), vreinterpret_u32_u8(p0q0), 1); + vst1_lane_u32((uint32_t *)(src + 1 * stride), vreinterpret_u32_u8(p1q1), 1); + vst1_lane_u32((uint32_t *)(src + 2 * stride), vreinterpret_u32_u8(p2q2), 1); +} + +void aom_lpf_horizontal_6_dual_neon( + uint8_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_horizontal_6_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_horizontal_6_neon(s + 4, pitch, blimit1, limit1, thresh1); +} + +// TODO(any): Rewrite in NEON (similar to quad SSE2 functions) for better speed +// up. +void aom_lpf_horizontal_6_quad_neon(uint8_t *s, int pitch, + const uint8_t *blimit, const uint8_t *limit, + const uint8_t *thresh) { + aom_lpf_horizontal_6_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_horizontal_6_dual_neon(s + 2 * MI_SIZE, pitch, blimit, limit, thresh, + blimit, limit, thresh); +} + +void aom_lpf_horizontal_4_neon(uint8_t *src, int stride, const uint8_t *blimit, + const uint8_t *limit, const uint8_t *thresh) { + uint8x8_t p1q1 = load_u8_4x2(src - 2 * stride, 3 * stride); + uint8x8_t p0q0 = load_u8_4x2(src - 1 * stride, 1 * stride); + + lpf_4_neon(&p1q1, &p0q0, *blimit, *limit, *thresh); + + store_u8x4_strided_x2(src - 1 * stride, 1 * stride, p0q0); + store_u8x4_strided_x2(src - 2 * stride, 3 * stride, p1q1); +} + +void aom_lpf_horizontal_4_dual_neon( + uint8_t *s, int pitch, const uint8_t *blimit0, const uint8_t *limit0, + const uint8_t *thresh0, const uint8_t *blimit1, const uint8_t *limit1, + const uint8_t *thresh1) { + aom_lpf_horizontal_4_neon(s, pitch, blimit0, limit0, thresh0); + aom_lpf_horizontal_4_neon(s + 4, pitch, blimit1, limit1, thresh1); +} + +// TODO(any): Rewrite in NEON (similar to quad SSE2 functions) for better speed +// up. +void aom_lpf_horizontal_4_quad_neon(uint8_t *s, int pitch, + const uint8_t *blimit, const uint8_t *limit, + const uint8_t *thresh) { + aom_lpf_horizontal_4_dual_neon(s, pitch, blimit, limit, thresh, blimit, limit, + thresh); + aom_lpf_horizontal_4_dual_neon(s + 2 * MI_SIZE, pitch, blimit, limit, thresh, + blimit, limit, thresh); +} diff --git a/third_party/aom/aom_dsp/arm/masked_sad4d_neon.c b/third_party/aom/aom_dsp/arm/masked_sad4d_neon.c new file mode 100644 index 0000000000..8f65b805ec --- /dev/null +++ b/third_party/aom/aom_dsp/arm/masked_sad4d_neon.c @@ -0,0 +1,562 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "aom/aom_integer.h" +#include "aom_dsp/blend.h" +#include "mem_neon.h" +#include "sum_neon.h" + +static INLINE uint16x8_t masked_sad_16x1_neon(uint16x8_t sad, + const uint8x16_t s0, + const uint8x16_t a0, + const uint8x16_t b0, + const uint8x16_t m0) { + uint8x16_t m0_inv = vsubq_u8(vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0); + uint16x8_t blend_u16_lo = vmull_u8(vget_low_u8(m0), vget_low_u8(a0)); + uint16x8_t blend_u16_hi = vmull_u8(vget_high_u8(m0), vget_high_u8(a0)); + blend_u16_lo = vmlal_u8(blend_u16_lo, vget_low_u8(m0_inv), vget_low_u8(b0)); + blend_u16_hi = vmlal_u8(blend_u16_hi, vget_high_u8(m0_inv), vget_high_u8(b0)); + + uint8x8_t blend_u8_lo = vrshrn_n_u16(blend_u16_lo, AOM_BLEND_A64_ROUND_BITS); + uint8x8_t blend_u8_hi = vrshrn_n_u16(blend_u16_hi, AOM_BLEND_A64_ROUND_BITS); + uint8x16_t blend_u8 = vcombine_u8(blend_u8_lo, blend_u8_hi); + return vpadalq_u8(sad, vabdq_u8(blend_u8, s0)); +} + +static INLINE void masked_inv_sadwxhx4d_large_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int width, int height, int h_overflow) { + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + int h_limit = height > h_overflow ? h_overflow : height; + + int ref_offset = 0; + int i = 0; + do { + uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src + j); + uint8x16_t p0 = vld1q_u8(second_pred + j); + uint8x16_t m0 = vld1q_u8(mask + j); + sum_lo[0] = masked_sad_16x1_neon(sum_lo[0], s0, p0, + vld1q_u8(ref[0] + ref_offset + j), m0); + sum_lo[1] = masked_sad_16x1_neon(sum_lo[1], s0, p0, + vld1q_u8(ref[1] + ref_offset + j), m0); + sum_lo[2] = masked_sad_16x1_neon(sum_lo[2], s0, p0, + vld1q_u8(ref[2] + ref_offset + j), m0); + sum_lo[3] = masked_sad_16x1_neon(sum_lo[3], s0, p0, + vld1q_u8(ref[3] + ref_offset + j), m0); + + uint8x16_t s1 = vld1q_u8(src + j + 16); + uint8x16_t p1 = vld1q_u8(second_pred + j + 16); + uint8x16_t m1 = vld1q_u8(mask + j + 16); + sum_hi[0] = masked_sad_16x1_neon( + sum_hi[0], s1, p1, vld1q_u8(ref[0] + ref_offset + j + 16), m1); + sum_hi[1] = masked_sad_16x1_neon( + sum_hi[1], s1, p1, vld1q_u8(ref[1] + ref_offset + j + 16), m1); + sum_hi[2] = masked_sad_16x1_neon( + sum_hi[2], s1, p1, vld1q_u8(ref[2] + ref_offset + j + 16), m1); + sum_hi[3] = masked_sad_16x1_neon( + sum_hi[3], s1, p1, vld1q_u8(ref[3] + ref_offset + j + 16), m1); + + j += 32; + } while (j < width); + + src += src_stride; + ref_offset += ref_stride; + second_pred += width; + mask += mask_stride; + } while (++i < h_limit); + + sum[0] = vpadalq_u16(sum[0], sum_lo[0]); + sum[0] = vpadalq_u16(sum[0], sum_hi[0]); + sum[1] = vpadalq_u16(sum[1], sum_lo[1]); + sum[1] = vpadalq_u16(sum[1], sum_hi[1]); + sum[2] = vpadalq_u16(sum[2], sum_lo[2]); + sum[2] = vpadalq_u16(sum[2], sum_hi[2]); + sum[3] = vpadalq_u16(sum[3], sum_lo[3]); + sum[3] = vpadalq_u16(sum[3], sum_hi[3]); + + h_limit += h_overflow; + } while (i < height); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void masked_inv_sad128xhx4d_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int h) { + masked_inv_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred, + mask, mask_stride, res, 128, h, 32); +} + +static INLINE void masked_inv_sad64xhx4d_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int h) { + masked_inv_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred, + mask, mask_stride, res, 64, h, 64); +} + +static INLINE void masked_sadwxhx4d_large_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int width, int height, int h_overflow) { + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + int h_limit = height > h_overflow ? h_overflow : height; + + int ref_offset = 0; + int i = 0; + do { + uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src + j); + uint8x16_t p0 = vld1q_u8(second_pred + j); + uint8x16_t m0 = vld1q_u8(mask + j); + sum_lo[0] = masked_sad_16x1_neon( + sum_lo[0], s0, vld1q_u8(ref[0] + ref_offset + j), p0, m0); + sum_lo[1] = masked_sad_16x1_neon( + sum_lo[1], s0, vld1q_u8(ref[1] + ref_offset + j), p0, m0); + sum_lo[2] = masked_sad_16x1_neon( + sum_lo[2], s0, vld1q_u8(ref[2] + ref_offset + j), p0, m0); + sum_lo[3] = masked_sad_16x1_neon( + sum_lo[3], s0, vld1q_u8(ref[3] + ref_offset + j), p0, m0); + + uint8x16_t s1 = vld1q_u8(src + j + 16); + uint8x16_t p1 = vld1q_u8(second_pred + j + 16); + uint8x16_t m1 = vld1q_u8(mask + j + 16); + sum_hi[0] = masked_sad_16x1_neon( + sum_hi[0], s1, vld1q_u8(ref[0] + ref_offset + j + 16), p1, m1); + sum_hi[1] = masked_sad_16x1_neon( + sum_hi[1], s1, vld1q_u8(ref[1] + ref_offset + j + 16), p1, m1); + sum_hi[2] = masked_sad_16x1_neon( + sum_hi[2], s1, vld1q_u8(ref[2] + ref_offset + j + 16), p1, m1); + sum_hi[3] = masked_sad_16x1_neon( + sum_hi[3], s1, vld1q_u8(ref[3] + ref_offset + j + 16), p1, m1); + + j += 32; + } while (j < width); + + src += src_stride; + ref_offset += ref_stride; + second_pred += width; + mask += mask_stride; + } while (++i < h_limit); + + sum[0] = vpadalq_u16(sum[0], sum_lo[0]); + sum[0] = vpadalq_u16(sum[0], sum_hi[0]); + sum[1] = vpadalq_u16(sum[1], sum_lo[1]); + sum[1] = vpadalq_u16(sum[1], sum_hi[1]); + sum[2] = vpadalq_u16(sum[2], sum_lo[2]); + sum[2] = vpadalq_u16(sum[2], sum_hi[2]); + sum[3] = vpadalq_u16(sum[3], sum_lo[3]); + sum[3] = vpadalq_u16(sum[3], sum_hi[3]); + + h_limit += h_overflow; + } while (i < height); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void masked_sad128xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, + const uint8_t *second_pred, + const uint8_t *mask, int mask_stride, + uint32_t res[4], int h) { + masked_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred, + mask, mask_stride, res, 128, h, 32); +} + +static INLINE void masked_sad64xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, + const uint8_t *second_pred, + const uint8_t *mask, int mask_stride, + uint32_t res[4], int h) { + masked_sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, second_pred, + mask, mask_stride, res, 64, h, 64); +} + +static INLINE void masked_inv_sad32xhx4d_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int h) { + uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t m0 = vld1q_u8(mask); + sum_lo[0] = masked_sad_16x1_neon(sum_lo[0], s0, p0, + vld1q_u8(ref[0] + ref_offset), m0); + sum_lo[1] = masked_sad_16x1_neon(sum_lo[1], s0, p0, + vld1q_u8(ref[1] + ref_offset), m0); + sum_lo[2] = masked_sad_16x1_neon(sum_lo[2], s0, p0, + vld1q_u8(ref[2] + ref_offset), m0); + sum_lo[3] = masked_sad_16x1_neon(sum_lo[3], s0, p0, + vld1q_u8(ref[3] + ref_offset), m0); + + uint8x16_t s1 = vld1q_u8(src + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t m1 = vld1q_u8(mask + 16); + sum_hi[0] = masked_sad_16x1_neon(sum_hi[0], s1, p1, + vld1q_u8(ref[0] + ref_offset + 16), m1); + sum_hi[1] = masked_sad_16x1_neon(sum_hi[1], s1, p1, + vld1q_u8(ref[1] + ref_offset + 16), m1); + sum_hi[2] = masked_sad_16x1_neon(sum_hi[2], s1, p1, + vld1q_u8(ref[2] + ref_offset + 16), m1); + sum_hi[3] = masked_sad_16x1_neon(sum_hi[3], s1, p1, + vld1q_u8(ref[3] + ref_offset + 16), m1); + + src += src_stride; + ref_offset += ref_stride; + second_pred += 32; + mask += mask_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi)); +} + +static INLINE void masked_sad32xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, + const uint8_t *second_pred, + const uint8_t *mask, int mask_stride, + uint32_t res[4], int h) { + uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t m0 = vld1q_u8(mask); + sum_lo[0] = masked_sad_16x1_neon(sum_lo[0], s0, + vld1q_u8(ref[0] + ref_offset), p0, m0); + sum_lo[1] = masked_sad_16x1_neon(sum_lo[1], s0, + vld1q_u8(ref[1] + ref_offset), p0, m0); + sum_lo[2] = masked_sad_16x1_neon(sum_lo[2], s0, + vld1q_u8(ref[2] + ref_offset), p0, m0); + sum_lo[3] = masked_sad_16x1_neon(sum_lo[3], s0, + vld1q_u8(ref[3] + ref_offset), p0, m0); + + uint8x16_t s1 = vld1q_u8(src + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t m1 = vld1q_u8(mask + 16); + sum_hi[0] = masked_sad_16x1_neon( + sum_hi[0], s1, vld1q_u8(ref[0] + ref_offset + 16), p1, m1); + sum_hi[1] = masked_sad_16x1_neon( + sum_hi[1], s1, vld1q_u8(ref[1] + ref_offset + 16), p1, m1); + sum_hi[2] = masked_sad_16x1_neon( + sum_hi[2], s1, vld1q_u8(ref[2] + ref_offset + 16), p1, m1); + sum_hi[3] = masked_sad_16x1_neon( + sum_hi[3], s1, vld1q_u8(ref[3] + ref_offset + 16), p1, m1); + + src += src_stride; + ref_offset += ref_stride; + second_pred += 32; + mask += mask_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi)); +} + +static INLINE void masked_inv_sad16xhx4d_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int h) { + uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint32x4_t sum_u32[4]; + + int ref_offset = 0; + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t m0 = vld1q_u8(mask); + sum_u16[0] = masked_sad_16x1_neon(sum_u16[0], s0, p0, + vld1q_u8(ref[0] + ref_offset), m0); + sum_u16[1] = masked_sad_16x1_neon(sum_u16[1], s0, p0, + vld1q_u8(ref[1] + ref_offset), m0); + sum_u16[2] = masked_sad_16x1_neon(sum_u16[2], s0, p0, + vld1q_u8(ref[2] + ref_offset), m0); + sum_u16[3] = masked_sad_16x1_neon(sum_u16[3], s0, p0, + vld1q_u8(ref[3] + ref_offset), m0); + + src += src_stride; + ref_offset += ref_stride; + second_pred += 16; + mask += mask_stride; + } while (--i != 0); + + sum_u32[0] = vpaddlq_u16(sum_u16[0]); + sum_u32[1] = vpaddlq_u16(sum_u16[1]); + sum_u32[2] = vpaddlq_u16(sum_u16[2]); + sum_u32[3] = vpaddlq_u16(sum_u16[3]); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32)); +} + +static INLINE void masked_sad16xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, + const uint8_t *second_pred, + const uint8_t *mask, int mask_stride, + uint32_t res[4], int h) { + uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint32x4_t sum_u32[4]; + + int ref_offset = 0; + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t m0 = vld1q_u8(mask); + sum_u16[0] = masked_sad_16x1_neon(sum_u16[0], s0, + vld1q_u8(ref[0] + ref_offset), p0, m0); + sum_u16[1] = masked_sad_16x1_neon(sum_u16[1], s0, + vld1q_u8(ref[1] + ref_offset), p0, m0); + sum_u16[2] = masked_sad_16x1_neon(sum_u16[2], s0, + vld1q_u8(ref[2] + ref_offset), p0, m0); + sum_u16[3] = masked_sad_16x1_neon(sum_u16[3], s0, + vld1q_u8(ref[3] + ref_offset), p0, m0); + + src += src_stride; + ref_offset += ref_stride; + second_pred += 16; + mask += mask_stride; + } while (--i != 0); + + sum_u32[0] = vpaddlq_u16(sum_u16[0]); + sum_u32[1] = vpaddlq_u16(sum_u16[1]); + sum_u32[2] = vpaddlq_u16(sum_u16[2]); + sum_u32[3] = vpaddlq_u16(sum_u16[3]); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32)); +} + +static INLINE uint16x8_t masked_sad_8x1_neon(uint16x8_t sad, const uint8x8_t s0, + const uint8x8_t a0, + const uint8x8_t b0, + const uint8x8_t m0) { + uint8x8_t m0_inv = vsub_u8(vdup_n_u8(AOM_BLEND_A64_MAX_ALPHA), m0); + uint16x8_t blend_u16 = vmull_u8(m0, a0); + blend_u16 = vmlal_u8(blend_u16, m0_inv, b0); + + uint8x8_t blend_u8 = vrshrn_n_u16(blend_u16, AOM_BLEND_A64_ROUND_BITS); + return vabal_u8(sad, blend_u8, s0); +} + +static INLINE void masked_inv_sad8xhx4d_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int h) { + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + uint8x8_t s0 = vld1_u8(src); + uint8x8_t p0 = vld1_u8(second_pred); + uint8x8_t m0 = vld1_u8(mask); + sum[0] = + masked_sad_8x1_neon(sum[0], s0, p0, vld1_u8(ref[0] + ref_offset), m0); + sum[1] = + masked_sad_8x1_neon(sum[1], s0, p0, vld1_u8(ref[1] + ref_offset), m0); + sum[2] = + masked_sad_8x1_neon(sum[2], s0, p0, vld1_u8(ref[2] + ref_offset), m0); + sum[3] = + masked_sad_8x1_neon(sum[3], s0, p0, vld1_u8(ref[3] + ref_offset), m0); + + src += src_stride; + ref_offset += ref_stride; + second_pred += 8; + mask += mask_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_add_4d_u16x8(sum)); +} + +static INLINE void masked_sad8xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, + const uint8_t *second_pred, + const uint8_t *mask, int mask_stride, + uint32_t res[4], int h) { + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + uint8x8_t s0 = vld1_u8(src); + uint8x8_t p0 = vld1_u8(second_pred); + uint8x8_t m0 = vld1_u8(mask); + + sum[0] = + masked_sad_8x1_neon(sum[0], s0, vld1_u8(ref[0] + ref_offset), p0, m0); + sum[1] = + masked_sad_8x1_neon(sum[1], s0, vld1_u8(ref[1] + ref_offset), p0, m0); + sum[2] = + masked_sad_8x1_neon(sum[2], s0, vld1_u8(ref[2] + ref_offset), p0, m0); + sum[3] = + masked_sad_8x1_neon(sum[3], s0, vld1_u8(ref[3] + ref_offset), p0, m0); + + src += src_stride; + ref_offset += ref_stride; + second_pred += 8; + mask += mask_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_add_4d_u16x8(sum)); +} + +static INLINE void masked_inv_sad4xhx4d_neon( + const uint8_t *src, int src_stride, const uint8_t *const ref[4], + int ref_stride, const uint8_t *second_pred, const uint8_t *mask, + int mask_stride, uint32_t res[4], int h) { + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h / 2; + do { + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride); + uint8x8_t r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride); + uint8x8_t r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride); + uint8x8_t r3 = load_unaligned_u8(ref[3] + ref_offset, ref_stride); + uint8x8_t p0 = vld1_u8(second_pred); + uint8x8_t m0 = load_unaligned_u8(mask, mask_stride); + + sum[0] = masked_sad_8x1_neon(sum[0], s, p0, r0, m0); + sum[1] = masked_sad_8x1_neon(sum[1], s, p0, r1, m0); + sum[2] = masked_sad_8x1_neon(sum[2], s, p0, r2, m0); + sum[3] = masked_sad_8x1_neon(sum[3], s, p0, r3, m0); + + src += 2 * src_stride; + ref_offset += 2 * ref_stride; + second_pred += 2 * 4; + mask += 2 * mask_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_add_4d_u16x8(sum)); +} + +static INLINE void masked_sad4xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, + const uint8_t *second_pred, + const uint8_t *mask, int mask_stride, + uint32_t res[4], int h) { + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h / 2; + do { + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride); + uint8x8_t r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride); + uint8x8_t r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride); + uint8x8_t r3 = load_unaligned_u8(ref[3] + ref_offset, ref_stride); + uint8x8_t p0 = vld1_u8(second_pred); + uint8x8_t m0 = load_unaligned_u8(mask, mask_stride); + + sum[0] = masked_sad_8x1_neon(sum[0], s, r0, p0, m0); + sum[1] = masked_sad_8x1_neon(sum[1], s, r1, p0, m0); + sum[2] = masked_sad_8x1_neon(sum[2], s, r2, p0, m0); + sum[3] = masked_sad_8x1_neon(sum[3], s, r3, p0, m0); + + src += 2 * src_stride; + ref_offset += 2 * ref_stride; + second_pred += 2 * 4; + mask += 2 * mask_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_add_4d_u16x8(sum)); +} + +#define MASKED_SAD4D_WXH_NEON(w, h) \ + void aom_masked_sad##w##x##h##x4d_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref[4], \ + int ref_stride, const uint8_t *second_pred, const uint8_t *msk, \ + int msk_stride, int invert_mask, uint32_t res[4]) { \ + if (invert_mask) { \ + masked_inv_sad##w##xhx4d_neon(src, src_stride, ref, ref_stride, \ + second_pred, msk, msk_stride, res, h); \ + } else { \ + masked_sad##w##xhx4d_neon(src, src_stride, ref, ref_stride, second_pred, \ + msk, msk_stride, res, h); \ + } \ + } + +MASKED_SAD4D_WXH_NEON(4, 8) +MASKED_SAD4D_WXH_NEON(4, 4) + +MASKED_SAD4D_WXH_NEON(8, 16) +MASKED_SAD4D_WXH_NEON(8, 8) +MASKED_SAD4D_WXH_NEON(8, 4) + +MASKED_SAD4D_WXH_NEON(16, 32) +MASKED_SAD4D_WXH_NEON(16, 16) +MASKED_SAD4D_WXH_NEON(16, 8) + +MASKED_SAD4D_WXH_NEON(32, 64) +MASKED_SAD4D_WXH_NEON(32, 32) +MASKED_SAD4D_WXH_NEON(32, 16) + +MASKED_SAD4D_WXH_NEON(64, 128) +MASKED_SAD4D_WXH_NEON(64, 64) +MASKED_SAD4D_WXH_NEON(64, 32) + +MASKED_SAD4D_WXH_NEON(128, 128) +MASKED_SAD4D_WXH_NEON(128, 64) + +#if !CONFIG_REALTIME_ONLY +MASKED_SAD4D_WXH_NEON(4, 16) +MASKED_SAD4D_WXH_NEON(16, 4) +MASKED_SAD4D_WXH_NEON(8, 32) +MASKED_SAD4D_WXH_NEON(32, 8) +MASKED_SAD4D_WXH_NEON(16, 64) +MASKED_SAD4D_WXH_NEON(64, 16) +#endif diff --git a/third_party/aom/aom_dsp/arm/masked_sad_neon.c b/third_party/aom/aom_dsp/arm/masked_sad_neon.c new file mode 100644 index 0000000000..9d263105e3 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/masked_sad_neon.c @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/blend_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/blend.h" + +static INLINE uint16x8_t masked_sad_16x1_neon(uint16x8_t sad, + const uint8_t *src, + const uint8_t *a, + const uint8_t *b, + const uint8_t *m) { + uint8x16_t m0 = vld1q_u8(m); + uint8x16_t a0 = vld1q_u8(a); + uint8x16_t b0 = vld1q_u8(b); + uint8x16_t s0 = vld1q_u8(src); + + uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, a0, b0); + + return vpadalq_u8(sad, vabdq_u8(blend_u8, s0)); +} + +static INLINE unsigned masked_sad_128xh_neon(const uint8_t *src, int src_stride, + const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + const uint8_t *m, int m_stride, + int height) { + // Eight accumulator vectors are required to avoid overflow in the 128x128 + // case. + assert(height <= 128); + uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0) }; + + do { + sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]); + sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]); + sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]); + sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]); + sad[4] = masked_sad_16x1_neon(sad[4], &src[64], &a[64], &b[64], &m[64]); + sad[5] = masked_sad_16x1_neon(sad[5], &src[80], &a[80], &b[80], &m[80]); + sad[6] = masked_sad_16x1_neon(sad[6], &src[96], &a[96], &b[96], &m[96]); + sad[7] = masked_sad_16x1_neon(sad[7], &src[112], &a[112], &b[112], &m[112]); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + height--; + } while (height != 0); + + return horizontal_long_add_u16x8(sad[0], sad[1]) + + horizontal_long_add_u16x8(sad[2], sad[3]) + + horizontal_long_add_u16x8(sad[4], sad[5]) + + horizontal_long_add_u16x8(sad[6], sad[7]); +} + +static INLINE unsigned masked_sad_64xh_neon(const uint8_t *src, int src_stride, + const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + const uint8_t *m, int m_stride, + int height) { + // Four accumulator vectors are required to avoid overflow in the 64x128 case. + assert(height <= 128); + uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + do { + sad[0] = masked_sad_16x1_neon(sad[0], &src[0], &a[0], &b[0], &m[0]); + sad[1] = masked_sad_16x1_neon(sad[1], &src[16], &a[16], &b[16], &m[16]); + sad[2] = masked_sad_16x1_neon(sad[2], &src[32], &a[32], &b[32], &m[32]); + sad[3] = masked_sad_16x1_neon(sad[3], &src[48], &a[48], &b[48], &m[48]); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + height--; + } while (height != 0); + + return horizontal_long_add_u16x8(sad[0], sad[1]) + + horizontal_long_add_u16x8(sad[2], sad[3]); +} + +static INLINE unsigned masked_sad_32xh_neon(const uint8_t *src, int src_stride, + const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + const uint8_t *m, int m_stride, + int height) { + // We could use a single accumulator up to height=64 without overflow. + assert(height <= 64); + uint16x8_t sad = vdupq_n_u16(0); + + do { + sad = masked_sad_16x1_neon(sad, &src[0], &a[0], &b[0], &m[0]); + sad = masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + height--; + } while (height != 0); + + return horizontal_add_u16x8(sad); +} + +static INLINE unsigned masked_sad_16xh_neon(const uint8_t *src, int src_stride, + const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + const uint8_t *m, int m_stride, + int height) { + // We could use a single accumulator up to height=128 without overflow. + assert(height <= 128); + uint16x8_t sad = vdupq_n_u16(0); + + do { + sad = masked_sad_16x1_neon(sad, src, a, b, m); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + height--; + } while (height != 0); + + return horizontal_add_u16x8(sad); +} + +static INLINE unsigned masked_sad_8xh_neon(const uint8_t *src, int src_stride, + const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + const uint8_t *m, int m_stride, + int height) { + // We could use a single accumulator up to height=128 without overflow. + assert(height <= 128); + uint16x4_t sad = vdup_n_u16(0); + + do { + uint8x8_t m0 = vld1_u8(m); + uint8x8_t a0 = vld1_u8(a); + uint8x8_t b0 = vld1_u8(b); + uint8x8_t s0 = vld1_u8(src); + + uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, a0, b0); + + sad = vpadal_u8(sad, vabd_u8(blend_u8, s0)); + + src += src_stride; + a += a_stride; + b += b_stride; + m += m_stride; + height--; + } while (height != 0); + + return horizontal_add_u16x4(sad); +} + +static INLINE unsigned masked_sad_4xh_neon(const uint8_t *src, int src_stride, + const uint8_t *a, int a_stride, + const uint8_t *b, int b_stride, + const uint8_t *m, int m_stride, + int height) { + // Process two rows per loop iteration. + assert(height % 2 == 0); + + // We could use a single accumulator up to height=256 without overflow. + assert(height <= 256); + uint16x4_t sad = vdup_n_u16(0); + + do { + uint8x8_t m0 = load_unaligned_u8(m, m_stride); + uint8x8_t a0 = load_unaligned_u8(a, a_stride); + uint8x8_t b0 = load_unaligned_u8(b, b_stride); + uint8x8_t s0 = load_unaligned_u8(src, src_stride); + + uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, a0, b0); + + sad = vpadal_u8(sad, vabd_u8(blend_u8, s0)); + + src += 2 * src_stride; + a += 2 * a_stride; + b += 2 * b_stride; + m += 2 * m_stride; + height -= 2; + } while (height != 0); + + return horizontal_add_u16x4(sad); +} + +#define MASKED_SAD_WXH_NEON(width, height) \ + unsigned aom_masked_sad##width##x##height##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \ + int invert_mask) { \ + if (!invert_mask) \ + return masked_sad_##width##xh_neon(src, src_stride, ref, ref_stride, \ + second_pred, width, msk, msk_stride, \ + height); \ + else \ + return masked_sad_##width##xh_neon(src, src_stride, second_pred, width, \ + ref, ref_stride, msk, msk_stride, \ + height); \ + } + +MASKED_SAD_WXH_NEON(4, 4) +MASKED_SAD_WXH_NEON(4, 8) +MASKED_SAD_WXH_NEON(8, 4) +MASKED_SAD_WXH_NEON(8, 8) +MASKED_SAD_WXH_NEON(8, 16) +MASKED_SAD_WXH_NEON(16, 8) +MASKED_SAD_WXH_NEON(16, 16) +MASKED_SAD_WXH_NEON(16, 32) +MASKED_SAD_WXH_NEON(32, 16) +MASKED_SAD_WXH_NEON(32, 32) +MASKED_SAD_WXH_NEON(32, 64) +MASKED_SAD_WXH_NEON(64, 32) +MASKED_SAD_WXH_NEON(64, 64) +MASKED_SAD_WXH_NEON(64, 128) +MASKED_SAD_WXH_NEON(128, 64) +MASKED_SAD_WXH_NEON(128, 128) +#if !CONFIG_REALTIME_ONLY +MASKED_SAD_WXH_NEON(4, 16) +MASKED_SAD_WXH_NEON(16, 4) +MASKED_SAD_WXH_NEON(8, 32) +MASKED_SAD_WXH_NEON(32, 8) +MASKED_SAD_WXH_NEON(16, 64) +MASKED_SAD_WXH_NEON(64, 16) +#endif diff --git a/third_party/aom/aom_dsp/arm/mem_neon.h b/third_party/aom/aom_dsp/arm/mem_neon.h new file mode 100644 index 0000000000..52c7a34e3e --- /dev/null +++ b/third_party/aom/aom_dsp/arm/mem_neon.h @@ -0,0 +1,1253 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef AOM_AOM_DSP_ARM_MEM_NEON_H_ +#define AOM_AOM_DSP_ARM_MEM_NEON_H_ + +#include +#include +#include "aom_dsp/aom_dsp_common.h" + +// Support for xN Neon intrinsics is lacking in some compilers. +#if defined(__arm__) || defined(_M_ARM) +#define ARM_32_BIT +#endif + +// DEFICIENT_CLANG_32_BIT includes clang-cl. +#if defined(__clang__) && defined(ARM_32_BIT) && \ + (__clang_major__ <= 6 || (defined(__ANDROID__) && __clang_major__ <= 7)) +#define DEFICIENT_CLANG_32_BIT // This includes clang-cl. +#endif + +#if defined(__GNUC__) && !defined(__clang__) && defined(ARM_32_BIT) +#define GCC_32_BIT +#endif + +#if defined(DEFICIENT_CLANG_32_BIT) || defined(GCC_32_BIT) + +static INLINE uint8x16x3_t vld1q_u8_x3(const uint8_t *ptr) { + uint8x16x3_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16), + vld1q_u8(ptr + 2 * 16) } }; + return res; +} + +static INLINE uint8x16x2_t vld1q_u8_x2(const uint8_t *ptr) { + uint8x16x2_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16) } }; + return res; +} + +static INLINE uint16x8x2_t vld1q_u16_x2(const uint16_t *ptr) { + uint16x8x2_t res = { { vld1q_u16(ptr + 0), vld1q_u16(ptr + 8) } }; + return res; +} + +static INLINE uint16x8x4_t vld1q_u16_x4(const uint16_t *ptr) { + uint16x8x4_t res = { { vld1q_u16(ptr + 0 * 8), vld1q_u16(ptr + 1 * 8), + vld1q_u16(ptr + 2 * 8), vld1q_u16(ptr + 3 * 8) } }; + return res; +} + +#elif defined(__GNUC__) && !defined(__clang__) // GCC 64-bit. +#if __GNUC__ < 8 + +static INLINE uint8x16x2_t vld1q_u8_x2(const uint8_t *ptr) { + uint8x16x2_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16) } }; + return res; +} + +static INLINE uint16x8x4_t vld1q_u16_x4(const uint16_t *ptr) { + uint16x8x4_t res = { { vld1q_u16(ptr + 0 * 8), vld1q_u16(ptr + 1 * 8), + vld1q_u16(ptr + 2 * 8), vld1q_u16(ptr + 3 * 8) } }; + return res; +} +#endif // __GNUC__ < 8 + +#if __GNUC__ < 9 +static INLINE uint8x16x3_t vld1q_u8_x3(const uint8_t *ptr) { + uint8x16x3_t res = { { vld1q_u8(ptr + 0 * 16), vld1q_u8(ptr + 1 * 16), + vld1q_u8(ptr + 2 * 16) } }; + return res; +} +#endif // __GNUC__ < 9 +#endif // defined(__GNUC__) && !defined(__clang__) + +static INLINE void store_u8_8x2(uint8_t *s, ptrdiff_t p, const uint8x8_t s0, + const uint8x8_t s1) { + vst1_u8(s, s0); + s += p; + vst1_u8(s, s1); + s += p; +} + +static INLINE uint8x16_t load_u8_8x2(const uint8_t *s, ptrdiff_t p) { + return vcombine_u8(vld1_u8(s), vld1_u8(s + p)); +} + +// Load four bytes into the low half of a uint8x8_t, zero the upper half. +static INLINE uint8x8_t load_u8_4x1(const uint8_t *p) { + uint8x8_t ret = vdup_n_u8(0); + ret = vreinterpret_u8_u32( + vld1_lane_u32((const uint32_t *)p, vreinterpret_u32_u8(ret), 0)); + return ret; +} + +static INLINE uint8x8_t load_u8_4x2(const uint8_t *p, int stride) { + uint8x8_t ret = vdup_n_u8(0); + ret = vreinterpret_u8_u32( + vld1_lane_u32((const uint32_t *)p, vreinterpret_u32_u8(ret), 0)); + p += stride; + ret = vreinterpret_u8_u32( + vld1_lane_u32((const uint32_t *)p, vreinterpret_u32_u8(ret), 1)); + return ret; +} + +static INLINE uint16x4_t load_u16_2x2(const uint16_t *p, int stride) { + uint16x4_t ret = vdup_n_u16(0); + ret = vreinterpret_u16_u32( + vld1_lane_u32((const uint32_t *)p, vreinterpret_u32_u16(ret), 0)); + p += stride; + ret = vreinterpret_u16_u32( + vld1_lane_u32((const uint32_t *)p, vreinterpret_u32_u16(ret), 1)); + return ret; +} + +static INLINE void load_u8_8x8(const uint8_t *s, ptrdiff_t p, + uint8x8_t *const s0, uint8x8_t *const s1, + uint8x8_t *const s2, uint8x8_t *const s3, + uint8x8_t *const s4, uint8x8_t *const s5, + uint8x8_t *const s6, uint8x8_t *const s7) { + *s0 = vld1_u8(s); + s += p; + *s1 = vld1_u8(s); + s += p; + *s2 = vld1_u8(s); + s += p; + *s3 = vld1_u8(s); + s += p; + *s4 = vld1_u8(s); + s += p; + *s5 = vld1_u8(s); + s += p; + *s6 = vld1_u8(s); + s += p; + *s7 = vld1_u8(s); +} + +static INLINE void load_u8_8x7(const uint8_t *s, ptrdiff_t p, + uint8x8_t *const s0, uint8x8_t *const s1, + uint8x8_t *const s2, uint8x8_t *const s3, + uint8x8_t *const s4, uint8x8_t *const s5, + uint8x8_t *const s6) { + *s0 = vld1_u8(s); + s += p; + *s1 = vld1_u8(s); + s += p; + *s2 = vld1_u8(s); + s += p; + *s3 = vld1_u8(s); + s += p; + *s4 = vld1_u8(s); + s += p; + *s5 = vld1_u8(s); + s += p; + *s6 = vld1_u8(s); +} + +static INLINE void load_u8_8x4(const uint8_t *s, const ptrdiff_t p, + uint8x8_t *const s0, uint8x8_t *const s1, + uint8x8_t *const s2, uint8x8_t *const s3) { + *s0 = vld1_u8(s); + s += p; + *s1 = vld1_u8(s); + s += p; + *s2 = vld1_u8(s); + s += p; + *s3 = vld1_u8(s); +} + +static INLINE void load_u16_4x4(const uint16_t *s, const ptrdiff_t p, + uint16x4_t *const s0, uint16x4_t *const s1, + uint16x4_t *const s2, uint16x4_t *const s3) { + *s0 = vld1_u16(s); + s += p; + *s1 = vld1_u16(s); + s += p; + *s2 = vld1_u16(s); + s += p; + *s3 = vld1_u16(s); + s += p; +} + +static INLINE void load_u16_4x7(const uint16_t *s, ptrdiff_t p, + uint16x4_t *const s0, uint16x4_t *const s1, + uint16x4_t *const s2, uint16x4_t *const s3, + uint16x4_t *const s4, uint16x4_t *const s5, + uint16x4_t *const s6) { + *s0 = vld1_u16(s); + s += p; + *s1 = vld1_u16(s); + s += p; + *s2 = vld1_u16(s); + s += p; + *s3 = vld1_u16(s); + s += p; + *s4 = vld1_u16(s); + s += p; + *s5 = vld1_u16(s); + s += p; + *s6 = vld1_u16(s); +} + +static INLINE void load_s16_8x2(const int16_t *s, const ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); +} + +static INLINE void load_u16_8x2(const uint16_t *s, const ptrdiff_t p, + uint16x8_t *const s0, uint16x8_t *const s1) { + *s0 = vld1q_u16(s); + s += p; + *s1 = vld1q_u16(s); +} + +static INLINE void load_u16_8x4(const uint16_t *s, const ptrdiff_t p, + uint16x8_t *const s0, uint16x8_t *const s1, + uint16x8_t *const s2, uint16x8_t *const s3) { + *s0 = vld1q_u16(s); + s += p; + *s1 = vld1q_u16(s); + s += p; + *s2 = vld1q_u16(s); + s += p; + *s3 = vld1q_u16(s); + s += p; +} + +static INLINE void load_s16_4x12(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3, + int16x4_t *const s4, int16x4_t *const s5, + int16x4_t *const s6, int16x4_t *const s7, + int16x4_t *const s8, int16x4_t *const s9, + int16x4_t *const s10, int16x4_t *const s11) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); + s += p; + *s4 = vld1_s16(s); + s += p; + *s5 = vld1_s16(s); + s += p; + *s6 = vld1_s16(s); + s += p; + *s7 = vld1_s16(s); + s += p; + *s8 = vld1_s16(s); + s += p; + *s9 = vld1_s16(s); + s += p; + *s10 = vld1_s16(s); + s += p; + *s11 = vld1_s16(s); +} + +static INLINE void load_s16_4x11(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3, + int16x4_t *const s4, int16x4_t *const s5, + int16x4_t *const s6, int16x4_t *const s7, + int16x4_t *const s8, int16x4_t *const s9, + int16x4_t *const s10) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); + s += p; + *s4 = vld1_s16(s); + s += p; + *s5 = vld1_s16(s); + s += p; + *s6 = vld1_s16(s); + s += p; + *s7 = vld1_s16(s); + s += p; + *s8 = vld1_s16(s); + s += p; + *s9 = vld1_s16(s); + s += p; + *s10 = vld1_s16(s); +} + +static INLINE void load_u16_4x11(const uint16_t *s, ptrdiff_t p, + uint16x4_t *const s0, uint16x4_t *const s1, + uint16x4_t *const s2, uint16x4_t *const s3, + uint16x4_t *const s4, uint16x4_t *const s5, + uint16x4_t *const s6, uint16x4_t *const s7, + uint16x4_t *const s8, uint16x4_t *const s9, + uint16x4_t *const s10) { + *s0 = vld1_u16(s); + s += p; + *s1 = vld1_u16(s); + s += p; + *s2 = vld1_u16(s); + s += p; + *s3 = vld1_u16(s); + s += p; + *s4 = vld1_u16(s); + s += p; + *s5 = vld1_u16(s); + s += p; + *s6 = vld1_u16(s); + s += p; + *s7 = vld1_u16(s); + s += p; + *s8 = vld1_u16(s); + s += p; + *s9 = vld1_u16(s); + s += p; + *s10 = vld1_u16(s); +} + +static INLINE void load_s16_4x8(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3, + int16x4_t *const s4, int16x4_t *const s5, + int16x4_t *const s6, int16x4_t *const s7) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); + s += p; + *s4 = vld1_s16(s); + s += p; + *s5 = vld1_s16(s); + s += p; + *s6 = vld1_s16(s); + s += p; + *s7 = vld1_s16(s); +} + +static INLINE void load_s16_4x7(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3, + int16x4_t *const s4, int16x4_t *const s5, + int16x4_t *const s6) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); + s += p; + *s4 = vld1_s16(s); + s += p; + *s5 = vld1_s16(s); + s += p; + *s6 = vld1_s16(s); +} + +static INLINE void load_s16_4x6(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3, + int16x4_t *const s4, int16x4_t *const s5) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); + s += p; + *s4 = vld1_s16(s); + s += p; + *s5 = vld1_s16(s); +} + +static INLINE void load_s16_4x5(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3, + int16x4_t *const s4) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); + s += p; + *s4 = vld1_s16(s); +} + +static INLINE void load_u16_4x5(const uint16_t *s, const ptrdiff_t p, + uint16x4_t *const s0, uint16x4_t *const s1, + uint16x4_t *const s2, uint16x4_t *const s3, + uint16x4_t *const s4) { + *s0 = vld1_u16(s); + s += p; + *s1 = vld1_u16(s); + s += p; + *s2 = vld1_u16(s); + s += p; + *s3 = vld1_u16(s); + s += p; + *s4 = vld1_u16(s); + s += p; +} + +static INLINE void load_u8_8x5(const uint8_t *s, ptrdiff_t p, + uint8x8_t *const s0, uint8x8_t *const s1, + uint8x8_t *const s2, uint8x8_t *const s3, + uint8x8_t *const s4) { + *s0 = vld1_u8(s); + s += p; + *s1 = vld1_u8(s); + s += p; + *s2 = vld1_u8(s); + s += p; + *s3 = vld1_u8(s); + s += p; + *s4 = vld1_u8(s); +} + +static INLINE void load_u16_8x5(const uint16_t *s, const ptrdiff_t p, + uint16x8_t *const s0, uint16x8_t *const s1, + uint16x8_t *const s2, uint16x8_t *const s3, + uint16x8_t *const s4) { + *s0 = vld1q_u16(s); + s += p; + *s1 = vld1q_u16(s); + s += p; + *s2 = vld1q_u16(s); + s += p; + *s3 = vld1q_u16(s); + s += p; + *s4 = vld1q_u16(s); + s += p; +} + +static INLINE void load_s16_4x4(const int16_t *s, ptrdiff_t p, + int16x4_t *const s0, int16x4_t *const s1, + int16x4_t *const s2, int16x4_t *const s3) { + *s0 = vld1_s16(s); + s += p; + *s1 = vld1_s16(s); + s += p; + *s2 = vld1_s16(s); + s += p; + *s3 = vld1_s16(s); +} + +static INLINE void store_u8_8x8(uint8_t *s, ptrdiff_t p, const uint8x8_t s0, + const uint8x8_t s1, const uint8x8_t s2, + const uint8x8_t s3, const uint8x8_t s4, + const uint8x8_t s5, const uint8x8_t s6, + const uint8x8_t s7) { + vst1_u8(s, s0); + s += p; + vst1_u8(s, s1); + s += p; + vst1_u8(s, s2); + s += p; + vst1_u8(s, s3); + s += p; + vst1_u8(s, s4); + s += p; + vst1_u8(s, s5); + s += p; + vst1_u8(s, s6); + s += p; + vst1_u8(s, s7); +} + +static INLINE void store_u8_8x4(uint8_t *s, ptrdiff_t p, const uint8x8_t s0, + const uint8x8_t s1, const uint8x8_t s2, + const uint8x8_t s3) { + vst1_u8(s, s0); + s += p; + vst1_u8(s, s1); + s += p; + vst1_u8(s, s2); + s += p; + vst1_u8(s, s3); +} + +static INLINE void store_u8_16x4(uint8_t *s, ptrdiff_t p, const uint8x16_t s0, + const uint8x16_t s1, const uint8x16_t s2, + const uint8x16_t s3) { + vst1q_u8(s, s0); + s += p; + vst1q_u8(s, s1); + s += p; + vst1q_u8(s, s2); + s += p; + vst1q_u8(s, s3); +} + +static INLINE void store_u16_8x8(uint16_t *s, ptrdiff_t dst_stride, + const uint16x8_t s0, const uint16x8_t s1, + const uint16x8_t s2, const uint16x8_t s3, + const uint16x8_t s4, const uint16x8_t s5, + const uint16x8_t s6, const uint16x8_t s7) { + vst1q_u16(s, s0); + s += dst_stride; + vst1q_u16(s, s1); + s += dst_stride; + vst1q_u16(s, s2); + s += dst_stride; + vst1q_u16(s, s3); + s += dst_stride; + vst1q_u16(s, s4); + s += dst_stride; + vst1q_u16(s, s5); + s += dst_stride; + vst1q_u16(s, s6); + s += dst_stride; + vst1q_u16(s, s7); +} + +static INLINE void store_u16_4x4(uint16_t *s, ptrdiff_t dst_stride, + const uint16x4_t s0, const uint16x4_t s1, + const uint16x4_t s2, const uint16x4_t s3) { + vst1_u16(s, s0); + s += dst_stride; + vst1_u16(s, s1); + s += dst_stride; + vst1_u16(s, s2); + s += dst_stride; + vst1_u16(s, s3); +} + +static INLINE void store_u16_8x2(uint16_t *s, ptrdiff_t dst_stride, + const uint16x8_t s0, const uint16x8_t s1) { + vst1q_u16(s, s0); + s += dst_stride; + vst1q_u16(s, s1); +} + +static INLINE void store_u16_8x4(uint16_t *s, ptrdiff_t dst_stride, + const uint16x8_t s0, const uint16x8_t s1, + const uint16x8_t s2, const uint16x8_t s3) { + vst1q_u16(s, s0); + s += dst_stride; + vst1q_u16(s, s1); + s += dst_stride; + vst1q_u16(s, s2); + s += dst_stride; + vst1q_u16(s, s3); +} + +static INLINE void store_s16_8x8(int16_t *s, ptrdiff_t dst_stride, + const int16x8_t s0, const int16x8_t s1, + const int16x8_t s2, const int16x8_t s3, + const int16x8_t s4, const int16x8_t s5, + const int16x8_t s6, const int16x8_t s7) { + vst1q_s16(s, s0); + s += dst_stride; + vst1q_s16(s, s1); + s += dst_stride; + vst1q_s16(s, s2); + s += dst_stride; + vst1q_s16(s, s3); + s += dst_stride; + vst1q_s16(s, s4); + s += dst_stride; + vst1q_s16(s, s5); + s += dst_stride; + vst1q_s16(s, s6); + s += dst_stride; + vst1q_s16(s, s7); +} + +static INLINE void store_s16_4x4(int16_t *s, ptrdiff_t dst_stride, + const int16x4_t s0, const int16x4_t s1, + const int16x4_t s2, const int16x4_t s3) { + vst1_s16(s, s0); + s += dst_stride; + vst1_s16(s, s1); + s += dst_stride; + vst1_s16(s, s2); + s += dst_stride; + vst1_s16(s, s3); +} + +static INLINE void store_s16_8x4(int16_t *s, ptrdiff_t dst_stride, + const int16x8_t s0, const int16x8_t s1, + const int16x8_t s2, const int16x8_t s3) { + vst1q_s16(s, s0); + s += dst_stride; + vst1q_s16(s, s1); + s += dst_stride; + vst1q_s16(s, s2); + s += dst_stride; + vst1q_s16(s, s3); +} + +static INLINE void load_u8_8x11(const uint8_t *s, ptrdiff_t p, + uint8x8_t *const s0, uint8x8_t *const s1, + uint8x8_t *const s2, uint8x8_t *const s3, + uint8x8_t *const s4, uint8x8_t *const s5, + uint8x8_t *const s6, uint8x8_t *const s7, + uint8x8_t *const s8, uint8x8_t *const s9, + uint8x8_t *const s10) { + *s0 = vld1_u8(s); + s += p; + *s1 = vld1_u8(s); + s += p; + *s2 = vld1_u8(s); + s += p; + *s3 = vld1_u8(s); + s += p; + *s4 = vld1_u8(s); + s += p; + *s5 = vld1_u8(s); + s += p; + *s6 = vld1_u8(s); + s += p; + *s7 = vld1_u8(s); + s += p; + *s8 = vld1_u8(s); + s += p; + *s9 = vld1_u8(s); + s += p; + *s10 = vld1_u8(s); +} + +static INLINE void load_s16_8x10(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4, int16x8_t *const s5, + int16x8_t *const s6, int16x8_t *const s7, + int16x8_t *const s8, int16x8_t *const s9) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); + s += p; + *s5 = vld1q_s16(s); + s += p; + *s6 = vld1q_s16(s); + s += p; + *s7 = vld1q_s16(s); + s += p; + *s8 = vld1q_s16(s); + s += p; + *s9 = vld1q_s16(s); +} + +static INLINE void load_s16_8x11(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4, int16x8_t *const s5, + int16x8_t *const s6, int16x8_t *const s7, + int16x8_t *const s8, int16x8_t *const s9, + int16x8_t *const s10) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); + s += p; + *s5 = vld1q_s16(s); + s += p; + *s6 = vld1q_s16(s); + s += p; + *s7 = vld1q_s16(s); + s += p; + *s8 = vld1q_s16(s); + s += p; + *s9 = vld1q_s16(s); + s += p; + *s10 = vld1q_s16(s); +} + +static INLINE void load_s16_8x12(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4, int16x8_t *const s5, + int16x8_t *const s6, int16x8_t *const s7, + int16x8_t *const s8, int16x8_t *const s9, + int16x8_t *const s10, int16x8_t *const s11) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); + s += p; + *s5 = vld1q_s16(s); + s += p; + *s6 = vld1q_s16(s); + s += p; + *s7 = vld1q_s16(s); + s += p; + *s8 = vld1q_s16(s); + s += p; + *s9 = vld1q_s16(s); + s += p; + *s10 = vld1q_s16(s); + s += p; + *s11 = vld1q_s16(s); +} + +static INLINE void load_u16_8x11(const uint16_t *s, ptrdiff_t p, + uint16x8_t *const s0, uint16x8_t *const s1, + uint16x8_t *const s2, uint16x8_t *const s3, + uint16x8_t *const s4, uint16x8_t *const s5, + uint16x8_t *const s6, uint16x8_t *const s7, + uint16x8_t *const s8, uint16x8_t *const s9, + uint16x8_t *const s10) { + *s0 = vld1q_u16(s); + s += p; + *s1 = vld1q_u16(s); + s += p; + *s2 = vld1q_u16(s); + s += p; + *s3 = vld1q_u16(s); + s += p; + *s4 = vld1q_u16(s); + s += p; + *s5 = vld1q_u16(s); + s += p; + *s6 = vld1q_u16(s); + s += p; + *s7 = vld1q_u16(s); + s += p; + *s8 = vld1q_u16(s); + s += p; + *s9 = vld1q_u16(s); + s += p; + *s10 = vld1q_u16(s); +} + +static INLINE void load_s16_8x8(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4, int16x8_t *const s5, + int16x8_t *const s6, int16x8_t *const s7) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); + s += p; + *s5 = vld1q_s16(s); + s += p; + *s6 = vld1q_s16(s); + s += p; + *s7 = vld1q_s16(s); +} + +static INLINE void load_u16_8x7(const uint16_t *s, ptrdiff_t p, + uint16x8_t *const s0, uint16x8_t *const s1, + uint16x8_t *const s2, uint16x8_t *const s3, + uint16x8_t *const s4, uint16x8_t *const s5, + uint16x8_t *const s6) { + *s0 = vld1q_u16(s); + s += p; + *s1 = vld1q_u16(s); + s += p; + *s2 = vld1q_u16(s); + s += p; + *s3 = vld1q_u16(s); + s += p; + *s4 = vld1q_u16(s); + s += p; + *s5 = vld1q_u16(s); + s += p; + *s6 = vld1q_u16(s); +} + +static INLINE void load_s16_8x7(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4, int16x8_t *const s5, + int16x8_t *const s6) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); + s += p; + *s5 = vld1q_s16(s); + s += p; + *s6 = vld1q_s16(s); +} + +static INLINE void load_s16_8x6(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4, int16x8_t *const s5) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); + s += p; + *s5 = vld1q_s16(s); +} + +static INLINE void load_s16_8x5(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3, + int16x8_t *const s4) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); + s += p; + *s4 = vld1q_s16(s); +} + +static INLINE void load_s16_8x4(const int16_t *s, ptrdiff_t p, + int16x8_t *const s0, int16x8_t *const s1, + int16x8_t *const s2, int16x8_t *const s3) { + *s0 = vld1q_s16(s); + s += p; + *s1 = vld1q_s16(s); + s += p; + *s2 = vld1q_s16(s); + s += p; + *s3 = vld1q_s16(s); +} + +// Load 2 sets of 4 bytes when alignment is not guaranteed. +static INLINE uint8x8_t load_unaligned_u8(const uint8_t *buf, int stride) { + uint32_t a; + memcpy(&a, buf, 4); + buf += stride; + uint32x2_t a_u32 = vdup_n_u32(a); + memcpy(&a, buf, 4); + a_u32 = vset_lane_u32(a, a_u32, 1); + return vreinterpret_u8_u32(a_u32); +} + +// Load 4 sets of 4 bytes when alignment is not guaranteed. +static INLINE uint8x16_t load_unaligned_u8q(const uint8_t *buf, int stride) { + uint32_t a; + uint32x4_t a_u32; + if (stride == 4) return vld1q_u8(buf); + memcpy(&a, buf, 4); + buf += stride; + a_u32 = vdupq_n_u32(a); + memcpy(&a, buf, 4); + buf += stride; + a_u32 = vsetq_lane_u32(a, a_u32, 1); + memcpy(&a, buf, 4); + buf += stride; + a_u32 = vsetq_lane_u32(a, a_u32, 2); + memcpy(&a, buf, 4); + a_u32 = vsetq_lane_u32(a, a_u32, 3); + return vreinterpretq_u8_u32(a_u32); +} + +static INLINE uint8x8_t load_unaligned_u8_2x2(const uint8_t *buf, int stride) { + uint16_t a; + uint16x4_t a_u16; + + memcpy(&a, buf, 2); + buf += stride; + a_u16 = vdup_n_u16(a); + memcpy(&a, buf, 2); + a_u16 = vset_lane_u16(a, a_u16, 1); + return vreinterpret_u8_u16(a_u16); +} + +static INLINE uint8x8_t load_unaligned_u8_4x1(const uint8_t *buf) { + uint32_t a; + uint32x2_t a_u32; + + memcpy(&a, buf, 4); + a_u32 = vdup_n_u32(0); + a_u32 = vset_lane_u32(a, a_u32, 0); + return vreinterpret_u8_u32(a_u32); +} + +static INLINE uint8x8_t load_unaligned_dup_u8_4x2(const uint8_t *buf) { + uint32_t a; + uint32x2_t a_u32; + + memcpy(&a, buf, 4); + a_u32 = vdup_n_u32(a); + return vreinterpret_u8_u32(a_u32); +} + +static INLINE uint8x8_t load_unaligned_dup_u8_2x4(const uint8_t *buf) { + uint16_t a; + uint16x4_t a_u32; + + memcpy(&a, buf, 2); + a_u32 = vdup_n_u16(a); + return vreinterpret_u8_u16(a_u32); +} + +static INLINE uint8x8_t load_unaligned_u8_4x2(const uint8_t *buf, int stride) { + uint32_t a; + uint32x2_t a_u32; + + memcpy(&a, buf, 4); + buf += stride; + a_u32 = vdup_n_u32(a); + memcpy(&a, buf, 4); + a_u32 = vset_lane_u32(a, a_u32, 1); + return vreinterpret_u8_u32(a_u32); +} + +static INLINE void load_unaligned_u8_4x4(const uint8_t *buf, int stride, + uint8x8_t *tu0, uint8x8_t *tu1) { + *tu0 = load_unaligned_u8_4x2(buf, stride); + buf += 2 * stride; + *tu1 = load_unaligned_u8_4x2(buf, stride); +} + +static INLINE void load_unaligned_u8_3x8(const uint8_t *buf, int stride, + uint8x8_t *tu0, uint8x8_t *tu1, + uint8x8_t *tu2) { + load_unaligned_u8_4x4(buf, stride, tu0, tu1); + buf += 4 * stride; + *tu2 = load_unaligned_u8_4x2(buf, stride); +} + +static INLINE void load_unaligned_u8_4x8(const uint8_t *buf, int stride, + uint8x8_t *tu0, uint8x8_t *tu1, + uint8x8_t *tu2, uint8x8_t *tu3) { + load_unaligned_u8_4x4(buf, stride, tu0, tu1); + buf += 4 * stride; + load_unaligned_u8_4x4(buf, stride, tu2, tu3); +} + +static INLINE void load_u8_16x8(const uint8_t *s, ptrdiff_t p, + uint8x16_t *const s0, uint8x16_t *const s1, + uint8x16_t *const s2, uint8x16_t *const s3, + uint8x16_t *const s4, uint8x16_t *const s5, + uint8x16_t *const s6, uint8x16_t *const s7) { + *s0 = vld1q_u8(s); + s += p; + *s1 = vld1q_u8(s); + s += p; + *s2 = vld1q_u8(s); + s += p; + *s3 = vld1q_u8(s); + s += p; + *s4 = vld1q_u8(s); + s += p; + *s5 = vld1q_u8(s); + s += p; + *s6 = vld1q_u8(s); + s += p; + *s7 = vld1q_u8(s); +} + +static INLINE void load_u8_16x4(const uint8_t *s, ptrdiff_t p, + uint8x16_t *const s0, uint8x16_t *const s1, + uint8x16_t *const s2, uint8x16_t *const s3) { + *s0 = vld1q_u8(s); + s += p; + *s1 = vld1q_u8(s); + s += p; + *s2 = vld1q_u8(s); + s += p; + *s3 = vld1q_u8(s); +} + +static INLINE void load_u16_8x8(const uint16_t *s, const ptrdiff_t p, + uint16x8_t *s0, uint16x8_t *s1, uint16x8_t *s2, + uint16x8_t *s3, uint16x8_t *s4, uint16x8_t *s5, + uint16x8_t *s6, uint16x8_t *s7) { + *s0 = vld1q_u16(s); + s += p; + *s1 = vld1q_u16(s); + s += p; + *s2 = vld1q_u16(s); + s += p; + *s3 = vld1q_u16(s); + s += p; + *s4 = vld1q_u16(s); + s += p; + *s5 = vld1q_u16(s); + s += p; + *s6 = vld1q_u16(s); + s += p; + *s7 = vld1q_u16(s); +} + +static INLINE void load_u16_16x4(const uint16_t *s, ptrdiff_t p, + uint16x8_t *const s0, uint16x8_t *const s1, + uint16x8_t *const s2, uint16x8_t *const s3, + uint16x8_t *const s4, uint16x8_t *const s5, + uint16x8_t *const s6, uint16x8_t *const s7) { + *s0 = vld1q_u16(s); + *s1 = vld1q_u16(s + 8); + s += p; + *s2 = vld1q_u16(s); + *s3 = vld1q_u16(s + 8); + s += p; + *s4 = vld1q_u16(s); + *s5 = vld1q_u16(s + 8); + s += p; + *s6 = vld1q_u16(s); + *s7 = vld1q_u16(s + 8); +} + +static INLINE uint16x4_t load_unaligned_u16_2x2(const uint16_t *buf, + int stride) { + uint32_t a; + uint32x2_t a_u32; + + memcpy(&a, buf, 4); + buf += stride; + a_u32 = vdup_n_u32(a); + memcpy(&a, buf, 4); + a_u32 = vset_lane_u32(a, a_u32, 1); + return vreinterpret_u16_u32(a_u32); +} + +static INLINE uint16x4_t load_unaligned_u16_4x1(const uint16_t *buf) { + uint64_t a; + uint64x1_t a_u64 = vdup_n_u64(0); + memcpy(&a, buf, 8); + a_u64 = vset_lane_u64(a, a_u64, 0); + return vreinterpret_u16_u64(a_u64); +} + +static INLINE uint16x8_t load_unaligned_u16_4x2(const uint16_t *buf, + uint32_t stride) { + uint64_t a; + uint64x2_t a_u64; + + memcpy(&a, buf, 8); + buf += stride; + a_u64 = vdupq_n_u64(0); + a_u64 = vsetq_lane_u64(a, a_u64, 0); + memcpy(&a, buf, 8); + buf += stride; + a_u64 = vsetq_lane_u64(a, a_u64, 1); + return vreinterpretq_u16_u64(a_u64); +} + +static INLINE void load_unaligned_u16_4x4(const uint16_t *buf, uint32_t stride, + uint16x8_t *tu0, uint16x8_t *tu1) { + *tu0 = load_unaligned_u16_4x2(buf, stride); + buf += 2 * stride; + *tu1 = load_unaligned_u16_4x2(buf, stride); +} + +static INLINE void load_s32_4x4(int32_t *s, int32_t p, int32x4_t *s1, + int32x4_t *s2, int32x4_t *s3, int32x4_t *s4) { + *s1 = vld1q_s32(s); + s += p; + *s2 = vld1q_s32(s); + s += p; + *s3 = vld1q_s32(s); + s += p; + *s4 = vld1q_s32(s); +} + +static INLINE void store_s32_4x4(int32_t *s, int32_t p, int32x4_t s1, + int32x4_t s2, int32x4_t s3, int32x4_t s4) { + vst1q_s32(s, s1); + s += p; + vst1q_s32(s, s2); + s += p; + vst1q_s32(s, s3); + s += p; + vst1q_s32(s, s4); +} + +static INLINE void load_u32_4x4(uint32_t *s, int32_t p, uint32x4_t *s1, + uint32x4_t *s2, uint32x4_t *s3, + uint32x4_t *s4) { + *s1 = vld1q_u32(s); + s += p; + *s2 = vld1q_u32(s); + s += p; + *s3 = vld1q_u32(s); + s += p; + *s4 = vld1q_u32(s); +} + +static INLINE void store_u32_4x4(uint32_t *s, int32_t p, uint32x4_t s1, + uint32x4_t s2, uint32x4_t s3, uint32x4_t s4) { + vst1q_u32(s, s1); + s += p; + vst1q_u32(s, s2); + s += p; + vst1q_u32(s, s3); + s += p; + vst1q_u32(s, s4); +} + +static INLINE int16x8_t load_tran_low_to_s16q(const tran_low_t *buf) { + const int32x4_t v0 = vld1q_s32(buf); + const int32x4_t v1 = vld1q_s32(buf + 4); + const int16x4_t s0 = vmovn_s32(v0); + const int16x4_t s1 = vmovn_s32(v1); + return vcombine_s16(s0, s1); +} + +static INLINE void store_s16q_to_tran_low(tran_low_t *buf, const int16x8_t a) { + const int32x4_t v0 = vmovl_s16(vget_low_s16(a)); + const int32x4_t v1 = vmovl_s16(vget_high_s16(a)); + vst1q_s32(buf, v0); + vst1q_s32(buf + 4, v1); +} + +static INLINE void store_s16_to_tran_low(tran_low_t *buf, const int16x4_t a) { + const int32x4_t v0 = vmovl_s16(a); + vst1q_s32(buf, v0); +} + +static INLINE uint8x8_t load_u8_gather_s16_x8(const uint8_t *src, + int16x8_t indices) { + // Recent Clang and GCC versions correctly identify that this zero-broadcast + // is redundant. Alternatively we could load and broadcast the zeroth element + // and then replace the other lanes, however this is slower than loading a + // single element without broadcast on some micro-architectures. + uint8x8_t ret = vdup_n_u8(0); + ret = vld1_lane_u8(src + vget_lane_s16(vget_low_s16(indices), 0), ret, 0); + ret = vld1_lane_u8(src + vget_lane_s16(vget_low_s16(indices), 1), ret, 1); + ret = vld1_lane_u8(src + vget_lane_s16(vget_low_s16(indices), 2), ret, 2); + ret = vld1_lane_u8(src + vget_lane_s16(vget_low_s16(indices), 3), ret, 3); + ret = vld1_lane_u8(src + vget_lane_s16(vget_high_s16(indices), 0), ret, 4); + ret = vld1_lane_u8(src + vget_lane_s16(vget_high_s16(indices), 1), ret, 5); + ret = vld1_lane_u8(src + vget_lane_s16(vget_high_s16(indices), 2), ret, 6); + ret = vld1_lane_u8(src + vget_lane_s16(vget_high_s16(indices), 3), ret, 7); + return ret; +} + +// The `lane` parameter here must be an immediate. +#define store_u8_2x1_lane(dst, src, lane) \ + do { \ + uint16_t a = vget_lane_u16(vreinterpret_u16_u8(src), lane); \ + memcpy(dst, &a, 2); \ + } while (0) + +#define store_u8_4x1_lane(dst, src, lane) \ + do { \ + uint32_t a = vget_lane_u32(vreinterpret_u32_u8(src), lane); \ + memcpy(dst, &a, 4); \ + } while (0) + +#define store_u16_2x1_lane(dst, src, lane) \ + do { \ + uint32_t a = vget_lane_u32(vreinterpret_u32_u16(src), lane); \ + memcpy(dst, &a, 4); \ + } while (0) + +#define store_u16_4x1_lane(dst, src, lane) \ + do { \ + uint64_t a = vgetq_lane_u64(vreinterpretq_u64_u16(src), lane); \ + memcpy(dst, &a, 8); \ + } while (0) + +// Store the low 16-bits from a single vector. +static INLINE void store_u8_2x1(uint8_t *dst, const uint8x8_t src) { + store_u8_2x1_lane(dst, src, 0); +} + +// Store the low 32-bits from a single vector. +static INLINE void store_u8_4x1(uint8_t *dst, const uint8x8_t src) { + store_u8_4x1_lane(dst, src, 0); +} + +// Store two blocks of 16-bits from a single vector. +static INLINE void store_u8x2_strided_x2(uint8_t *dst, uint32_t dst_stride, + uint8x8_t src) { + store_u8_2x1_lane(dst, src, 0); + dst += dst_stride; + store_u8_2x1_lane(dst, src, 1); +} + +// Store two blocks of 32-bits from a single vector. +static INLINE void store_u8x4_strided_x2(uint8_t *dst, ptrdiff_t stride, + uint8x8_t src) { + store_u8_4x1_lane(dst, src, 0); + dst += stride; + store_u8_4x1_lane(dst, src, 1); +} + +// Store four blocks of 32-bits from a single vector. +static INLINE void store_u8x4_strided_x4(uint8_t *dst, ptrdiff_t stride, + uint8x16_t src) { + store_u8_4x1_lane(dst, vget_low_u8(src), 0); + dst += stride; + store_u8_4x1_lane(dst, vget_low_u8(src), 1); + dst += stride; + store_u8_4x1_lane(dst, vget_high_u8(src), 0); + dst += stride; + store_u8_4x1_lane(dst, vget_high_u8(src), 1); +} + +// Store the low 32-bits from a single vector. +static INLINE void store_u16_2x1(uint16_t *dst, const uint16x4_t src) { + store_u16_2x1_lane(dst, src, 0); +} + +// Store two blocks of 32-bits from a single vector. +static INLINE void store_u16x2_strided_x2(uint16_t *dst, uint32_t dst_stride, + uint16x4_t src) { + store_u16_2x1_lane(dst, src, 0); + dst += dst_stride; + store_u16_2x1_lane(dst, src, 1); +} + +// Store two blocks of 64-bits from a single vector. +static INLINE void store_u16x4_strided_x2(uint16_t *dst, uint32_t dst_stride, + uint16x8_t src) { + store_u16_4x1_lane(dst, src, 0); + dst += dst_stride; + store_u16_4x1_lane(dst, src, 1); +} + +#undef store_u8_2x1_lane +#undef store_u8_4x1_lane +#undef store_u16_2x1_lane +#undef store_u16_4x1_lane + +#endif // AOM_AOM_DSP_ARM_MEM_NEON_H_ diff --git a/third_party/aom/aom_dsp/arm/obmc_sad_neon.c b/third_party/aom/aom_dsp/arm/obmc_sad_neon.c new file mode 100644 index 0000000000..a692cbb388 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/obmc_sad_neon.c @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "mem_neon.h" +#include "sum_neon.h" + +static INLINE void obmc_sad_8x1_s16_neon(int16x8_t ref_s16, const int32_t *mask, + const int32_t *wsrc, uint32x4_t *sum) { + int32x4_t wsrc_lo = vld1q_s32(wsrc); + int32x4_t wsrc_hi = vld1q_s32(wsrc + 4); + + int32x4_t mask_lo = vld1q_s32(mask); + int32x4_t mask_hi = vld1q_s32(mask + 4); + + int16x8_t mask_s16 = + vuzpq_s16(vreinterpretq_s16_s32(mask_lo), vreinterpretq_s16_s32(mask_hi)) + .val[0]; + + int32x4_t pre_lo = vmull_s16(vget_low_s16(ref_s16), vget_low_s16(mask_s16)); + int32x4_t pre_hi = vmull_s16(vget_high_s16(ref_s16), vget_high_s16(mask_s16)); + + uint32x4_t abs_lo = vreinterpretq_u32_s32(vabdq_s32(wsrc_lo, pre_lo)); + uint32x4_t abs_hi = vreinterpretq_u32_s32(vabdq_s32(wsrc_hi, pre_hi)); + + *sum = vrsraq_n_u32(*sum, abs_lo, 12); + *sum = vrsraq_n_u32(*sum, abs_hi, 12); +} + +#if AOM_ARCH_AARCH64 + +// Use tbl for doing a double-width zero extension from 8->32 bits since we can +// do this in one instruction rather than two (indices out of range (255 here) +// are set to zero by tbl). +DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = { + 0, 255, 255, 255, 1, 255, 255, 255, 2, 255, 255, 255, 3, 255, 255, 255, + 4, 255, 255, 255, 5, 255, 255, 255, 6, 255, 255, 255, 7, 255, 255, 255, + 8, 255, 255, 255, 9, 255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255, + 12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255 +}; + +static INLINE void obmc_sad_8x1_s32_neon(uint32x4_t ref_u32_lo, + uint32x4_t ref_u32_hi, + const int32_t *mask, + const int32_t *wsrc, + uint32x4_t sum[2]) { + int32x4_t wsrc_lo = vld1q_s32(wsrc); + int32x4_t wsrc_hi = vld1q_s32(wsrc + 4); + int32x4_t mask_lo = vld1q_s32(mask); + int32x4_t mask_hi = vld1q_s32(mask + 4); + + int32x4_t pre_lo = vmulq_s32(vreinterpretq_s32_u32(ref_u32_lo), mask_lo); + int32x4_t pre_hi = vmulq_s32(vreinterpretq_s32_u32(ref_u32_hi), mask_hi); + + uint32x4_t abs_lo = vreinterpretq_u32_s32(vabdq_s32(wsrc_lo, pre_lo)); + uint32x4_t abs_hi = vreinterpretq_u32_s32(vabdq_s32(wsrc_hi, pre_hi)); + + sum[0] = vrsraq_n_u32(sum[0], abs_lo, 12); + sum[1] = vrsraq_n_u32(sum[1], abs_hi, 12); +} + +static INLINE unsigned int obmc_sad_large_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int width, + int height) { + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + // Use tbl for doing a double-width zero extension from 8->32 bits since we + // can do this in one instruction rather than two. + uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]); + uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]); + uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]); + uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]); + + int h = height; + do { + int w = width; + const uint8_t *ref_ptr = ref; + do { + uint8x16_t r = vld1q_u8(ref_ptr); + + uint32x4_t ref_u32_lo = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx0)); + uint32x4_t ref_u32_hi = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx1)); + obmc_sad_8x1_s32_neon(ref_u32_lo, ref_u32_hi, mask, wsrc, sum); + + ref_u32_lo = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx2)); + ref_u32_hi = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx3)); + obmc_sad_8x1_s32_neon(ref_u32_lo, ref_u32_hi, mask + 8, wsrc + 8, sum); + + ref_ptr += 16; + wsrc += 16; + mask += 16; + w -= 16; + } while (w != 0); + + ref += ref_stride; + } while (--h != 0); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1])); +} + +#else // !AOM_ARCH_AARCH64 + +static INLINE unsigned int obmc_sad_large_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int width, + int height) { + uint32x4_t sum = vdupq_n_u32(0); + + int h = height; + do { + int w = width; + const uint8_t *ref_ptr = ref; + do { + uint8x16_t r = vld1q_u8(ref_ptr); + + int16x8_t ref_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(r))); + obmc_sad_8x1_s16_neon(ref_s16, mask, wsrc, &sum); + + ref_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(r))); + obmc_sad_8x1_s16_neon(ref_s16, mask + 8, wsrc + 8, &sum); + + ref_ptr += 16; + wsrc += 16; + mask += 16; + w -= 16; + } while (w != 0); + + ref += ref_stride; + } while (--h != 0); + + return horizontal_add_u32x4(sum); +} + +#endif // AOM_ARCH_AARCH64 + +static INLINE unsigned int obmc_sad_128xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int h) { + return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 128, h); +} + +static INLINE unsigned int obmc_sad_64xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int h) { + return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 64, h); +} + +static INLINE unsigned int obmc_sad_32xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int h) { + return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 32, h); +} + +static INLINE unsigned int obmc_sad_16xh_neon(const uint8_t *ref, + int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int h) { + return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 16, h); +} + +static INLINE unsigned int obmc_sad_8xh_neon(const uint8_t *ref, int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int height) { + uint32x4_t sum = vdupq_n_u32(0); + + int h = height; + do { + uint8x8_t r = vld1_u8(ref); + + int16x8_t ref_s16 = vreinterpretq_s16_u16(vmovl_u8(r)); + obmc_sad_8x1_s16_neon(ref_s16, mask, wsrc, &sum); + + ref += ref_stride; + wsrc += 8; + mask += 8; + } while (--h != 0); + + return horizontal_add_u32x4(sum); +} + +static INLINE unsigned int obmc_sad_4xh_neon(const uint8_t *ref, int ref_stride, + const int32_t *wsrc, + const int32_t *mask, int height) { + uint32x4_t sum = vdupq_n_u32(0); + + int h = height / 2; + do { + uint8x8_t r = load_unaligned_u8(ref, ref_stride); + + int16x8_t ref_s16 = vreinterpretq_s16_u16(vmovl_u8(r)); + obmc_sad_8x1_s16_neon(ref_s16, mask, wsrc, &sum); + + ref += 2 * ref_stride; + wsrc += 8; + mask += 8; + } while (--h != 0); + + return horizontal_add_u32x4(sum); +} + +#define OBMC_SAD_WXH_NEON(w, h) \ + unsigned int aom_obmc_sad##w##x##h##_neon( \ + const uint8_t *ref, int ref_stride, const int32_t *wsrc, \ + const int32_t *mask) { \ + return obmc_sad_##w##xh_neon(ref, ref_stride, wsrc, mask, h); \ + } + +OBMC_SAD_WXH_NEON(4, 4) +OBMC_SAD_WXH_NEON(4, 8) +OBMC_SAD_WXH_NEON(4, 16) + +OBMC_SAD_WXH_NEON(8, 4) +OBMC_SAD_WXH_NEON(8, 8) +OBMC_SAD_WXH_NEON(8, 16) +OBMC_SAD_WXH_NEON(8, 32) + +OBMC_SAD_WXH_NEON(16, 4) +OBMC_SAD_WXH_NEON(16, 8) +OBMC_SAD_WXH_NEON(16, 16) +OBMC_SAD_WXH_NEON(16, 32) +OBMC_SAD_WXH_NEON(16, 64) + +OBMC_SAD_WXH_NEON(32, 8) +OBMC_SAD_WXH_NEON(32, 16) +OBMC_SAD_WXH_NEON(32, 32) +OBMC_SAD_WXH_NEON(32, 64) + +OBMC_SAD_WXH_NEON(64, 16) +OBMC_SAD_WXH_NEON(64, 32) +OBMC_SAD_WXH_NEON(64, 64) +OBMC_SAD_WXH_NEON(64, 128) + +OBMC_SAD_WXH_NEON(128, 64) +OBMC_SAD_WXH_NEON(128, 128) diff --git a/third_party/aom/aom_dsp/arm/obmc_variance_neon.c b/third_party/aom/aom_dsp/arm/obmc_variance_neon.c new file mode 100644 index 0000000000..50cd5f3b6a --- /dev/null +++ b/third_party/aom/aom_dsp/arm/obmc_variance_neon.c @@ -0,0 +1,290 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "mem_neon.h" +#include "sum_neon.h" + +static INLINE void obmc_variance_8x1_s16_neon(int16x8_t pre_s16, + const int32_t *wsrc, + const int32_t *mask, + int32x4_t *ssev, + int32x4_t *sumv) { + // For 4xh and 8xh we observe it is faster to avoid the double-widening of + // pre. Instead we do a single widening step and narrow the mask to 16-bits + // to allow us to perform a widening multiply. Widening multiply + // instructions have better throughput on some micro-architectures but for + // the larger block sizes this benefit is outweighed by the additional + // instruction needed to first narrow the mask vectors. + + int32x4_t wsrc_s32_lo = vld1q_s32(&wsrc[0]); + int32x4_t wsrc_s32_hi = vld1q_s32(&wsrc[4]); + int16x8_t mask_s16 = vuzpq_s16(vreinterpretq_s16_s32(vld1q_s32(&mask[0])), + vreinterpretq_s16_s32(vld1q_s32(&mask[4]))) + .val[0]; + + int32x4_t diff_s32_lo = + vmlsl_s16(wsrc_s32_lo, vget_low_s16(pre_s16), vget_low_s16(mask_s16)); + int32x4_t diff_s32_hi = + vmlsl_s16(wsrc_s32_hi, vget_high_s16(pre_s16), vget_high_s16(mask_s16)); + + // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away + // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. + // This difference only affects the bit patterns at the rounding breakpoints + // exactly, so we can add -1 to all negative numbers to move the breakpoint + // one value across and into the correct rounding region. + diff_s32_lo = vsraq_n_s32(diff_s32_lo, diff_s32_lo, 31); + diff_s32_hi = vsraq_n_s32(diff_s32_hi, diff_s32_hi, 31); + int32x4_t round_s32_lo = vrshrq_n_s32(diff_s32_lo, 12); + int32x4_t round_s32_hi = vrshrq_n_s32(diff_s32_hi, 12); + + *sumv = vrsraq_n_s32(*sumv, diff_s32_lo, 12); + *sumv = vrsraq_n_s32(*sumv, diff_s32_hi, 12); + *ssev = vmlaq_s32(*ssev, round_s32_lo, round_s32_lo); + *ssev = vmlaq_s32(*ssev, round_s32_hi, round_s32_hi); +} + +#if AOM_ARCH_AARCH64 + +// Use tbl for doing a double-width zero extension from 8->32 bits since we can +// do this in one instruction rather than two (indices out of range (255 here) +// are set to zero by tbl). +DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = { + 0, 255, 255, 255, 1, 255, 255, 255, 2, 255, 255, 255, 3, 255, 255, 255, + 4, 255, 255, 255, 5, 255, 255, 255, 6, 255, 255, 255, 7, 255, 255, 255, + 8, 255, 255, 255, 9, 255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255, + 12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255 +}; + +static INLINE void obmc_variance_8x1_s32_neon( + int32x4_t pre_lo, int32x4_t pre_hi, const int32_t *wsrc, + const int32_t *mask, int32x4_t *ssev, int32x4_t *sumv) { + int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]); + int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]); + int32x4_t mask_lo = vld1q_s32(&mask[0]); + int32x4_t mask_hi = vld1q_s32(&mask[4]); + + int32x4_t diff_lo = vmlsq_s32(wsrc_lo, pre_lo, mask_lo); + int32x4_t diff_hi = vmlsq_s32(wsrc_hi, pre_hi, mask_hi); + + // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away from + // zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. This + // difference only affects the bit patterns at the rounding breakpoints + // exactly, so we can add -1 to all negative numbers to move the breakpoint + // one value across and into the correct rounding region. + diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31); + diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31); + int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12); + int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12); + + *sumv = vrsraq_n_s32(*sumv, diff_lo, 12); + *sumv = vrsraq_n_s32(*sumv, diff_hi, 12); + *ssev = vmlaq_s32(*ssev, round_lo, round_lo); + *ssev = vmlaq_s32(*ssev, round_hi, round_hi); +} + +static INLINE void obmc_variance_large_neon(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int width, + int height, unsigned *sse, + int *sum) { + assert(width % 16 == 0); + + // Use tbl for doing a double-width zero extension from 8->32 bits since we + // can do this in one instruction rather than two. + uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]); + uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]); + uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]); + uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]); + + int32x4_t ssev = vdupq_n_s32(0); + int32x4_t sumv = vdupq_n_s32(0); + + int h = height; + do { + int w = width; + do { + uint8x16_t pre_u8 = vld1q_u8(pre); + + int32x4_t pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx0)); + int32x4_t pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx1)); + obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[0], &mask[0], + &ssev, &sumv); + + pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx2)); + pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx3)); + obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[8], &mask[8], + &ssev, &sumv); + + wsrc += 16; + mask += 16; + pre += 16; + w -= 16; + } while (w != 0); + + pre += pre_stride - width; + } while (--h != 0); + + *sse = horizontal_add_s32x4(ssev); + *sum = horizontal_add_s32x4(sumv); +} + +#else // !AOM_ARCH_AARCH64 + +static INLINE void obmc_variance_large_neon(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int width, + int height, unsigned *sse, + int *sum) { + // Non-aarch64 targets do not have a 128-bit tbl instruction, so use the + // widening version of the core kernel instead. + + assert(width % 16 == 0); + + int32x4_t ssev = vdupq_n_s32(0); + int32x4_t sumv = vdupq_n_s32(0); + + int h = height; + do { + int w = width; + do { + uint8x16_t pre_u8 = vld1q_u8(pre); + + int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(pre_u8))); + obmc_variance_8x1_s16_neon(pre_s16, &wsrc[0], &mask[0], &ssev, &sumv); + + pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(pre_u8))); + obmc_variance_8x1_s16_neon(pre_s16, &wsrc[8], &mask[8], &ssev, &sumv); + + wsrc += 16; + mask += 16; + pre += 16; + w -= 16; + } while (w != 0); + + pre += pre_stride - width; + } while (--h != 0); + + *sse = horizontal_add_s32x4(ssev); + *sum = horizontal_add_s32x4(sumv); +} + +#endif // AOM_ARCH_AARCH64 + +static INLINE void obmc_variance_neon_128xh(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + unsigned *sse, int *sum) { + obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, sum); +} + +static INLINE void obmc_variance_neon_64xh(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + unsigned *sse, int *sum) { + obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum); +} + +static INLINE void obmc_variance_neon_32xh(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + unsigned *sse, int *sum) { + obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum); +} + +static INLINE void obmc_variance_neon_16xh(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + unsigned *sse, int *sum) { + obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum); +} + +static INLINE void obmc_variance_neon_8xh(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + unsigned *sse, int *sum) { + int32x4_t ssev = vdupq_n_s32(0); + int32x4_t sumv = vdupq_n_s32(0); + + do { + uint8x8_t pre_u8 = vld1_u8(pre); + int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8)); + + obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv); + + pre += pre_stride; + wsrc += 8; + mask += 8; + } while (--h != 0); + + *sse = horizontal_add_s32x4(ssev); + *sum = horizontal_add_s32x4(sumv); +} + +static INLINE void obmc_variance_neon_4xh(const uint8_t *pre, int pre_stride, + const int32_t *wsrc, + const int32_t *mask, int h, + unsigned *sse, int *sum) { + assert(h % 2 == 0); + + int32x4_t ssev = vdupq_n_s32(0); + int32x4_t sumv = vdupq_n_s32(0); + + do { + uint8x8_t pre_u8 = load_unaligned_u8(pre, pre_stride); + int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8)); + + obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv); + + pre += 2 * pre_stride; + wsrc += 8; + mask += 8; + h -= 2; + } while (h != 0); + + *sse = horizontal_add_s32x4(ssev); + *sum = horizontal_add_s32x4(sumv); +} + +#define OBMC_VARIANCE_WXH_NEON(W, H) \ + unsigned aom_obmc_variance##W##x##H##_neon( \ + const uint8_t *pre, int pre_stride, const int32_t *wsrc, \ + const int32_t *mask, unsigned *sse) { \ + int sum; \ + obmc_variance_neon_##W##xh(pre, pre_stride, wsrc, mask, H, sse, &sum); \ + return *sse - (unsigned)(((int64_t)sum * sum) / (W * H)); \ + } + +OBMC_VARIANCE_WXH_NEON(4, 4) +OBMC_VARIANCE_WXH_NEON(4, 8) +OBMC_VARIANCE_WXH_NEON(8, 4) +OBMC_VARIANCE_WXH_NEON(8, 8) +OBMC_VARIANCE_WXH_NEON(8, 16) +OBMC_VARIANCE_WXH_NEON(16, 8) +OBMC_VARIANCE_WXH_NEON(16, 16) +OBMC_VARIANCE_WXH_NEON(16, 32) +OBMC_VARIANCE_WXH_NEON(32, 16) +OBMC_VARIANCE_WXH_NEON(32, 32) +OBMC_VARIANCE_WXH_NEON(32, 64) +OBMC_VARIANCE_WXH_NEON(64, 32) +OBMC_VARIANCE_WXH_NEON(64, 64) +OBMC_VARIANCE_WXH_NEON(64, 128) +OBMC_VARIANCE_WXH_NEON(128, 64) +OBMC_VARIANCE_WXH_NEON(128, 128) +OBMC_VARIANCE_WXH_NEON(4, 16) +OBMC_VARIANCE_WXH_NEON(16, 4) +OBMC_VARIANCE_WXH_NEON(8, 32) +OBMC_VARIANCE_WXH_NEON(32, 8) +OBMC_VARIANCE_WXH_NEON(16, 64) +OBMC_VARIANCE_WXH_NEON(64, 16) diff --git a/third_party/aom/aom_dsp/arm/reinterpret_neon.h b/third_party/aom/aom_dsp/arm/reinterpret_neon.h new file mode 100644 index 0000000000..f9702513ad --- /dev/null +++ b/third_party/aom/aom_dsp/arm/reinterpret_neon.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef AOM_AOM_DSP_ARM_REINTERPRET_NEON_H_ +#define AOM_AOM_DSP_ARM_REINTERPRET_NEON_H_ + +#include + +#include "aom/aom_integer.h" // For AOM_FORCE_INLINE. +#include "config/aom_config.h" + +#define REINTERPRET_NEON(u, to_sz, to_count, from_sz, from_count, n, q) \ + static AOM_FORCE_INLINE u##int##to_sz##x##to_count##x##n##_t \ + aom_reinterpret##q##_##u##to_sz##_##u##from_sz##_x##n( \ + const u##int##from_sz##x##from_count##x##n##_t src) { \ + u##int##to_sz##x##to_count##x##n##_t ret; \ + for (int i = 0; i < (n); ++i) { \ + ret.val[i] = vreinterpret##q##_##u##to_sz##_##u##from_sz(src.val[i]); \ + } \ + return ret; \ + } + +REINTERPRET_NEON(u, 8, 8, 16, 4, 2, ) // uint8x8x2_t from uint16x4x2_t +REINTERPRET_NEON(u, 8, 16, 16, 8, 2, q) // uint8x16x2_t from uint16x8x2_t + +#endif // AOM_AOM_DSP_ARM_REINTERPRET_NEON_H_ diff --git a/third_party/aom/aom_dsp/arm/sad_neon.c b/third_party/aom/aom_dsp/arm/sad_neon.c new file mode 100644 index 0000000000..46a1666331 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sad_neon.c @@ -0,0 +1,873 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/dist_wtd_avg_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE unsigned int sad128xh_neon(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride, + int h) { + // We use 8 accumulators to prevent overflow for large values of 'h', as well + // as enabling optimal UADALP instruction throughput on CPUs that have either + // 2 or 4 Neon pipes. + uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0) }; + + int i = h; + do { + uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7; + uint8x16_t r0, r1, r2, r3, r4, r5, r6, r7; + uint8x16_t diff0, diff1, diff2, diff3, diff4, diff5, diff6, diff7; + + s0 = vld1q_u8(src_ptr); + r0 = vld1q_u8(ref_ptr); + diff0 = vabdq_u8(s0, r0); + sum[0] = vpadalq_u8(sum[0], diff0); + + s1 = vld1q_u8(src_ptr + 16); + r1 = vld1q_u8(ref_ptr + 16); + diff1 = vabdq_u8(s1, r1); + sum[1] = vpadalq_u8(sum[1], diff1); + + s2 = vld1q_u8(src_ptr + 32); + r2 = vld1q_u8(ref_ptr + 32); + diff2 = vabdq_u8(s2, r2); + sum[2] = vpadalq_u8(sum[2], diff2); + + s3 = vld1q_u8(src_ptr + 48); + r3 = vld1q_u8(ref_ptr + 48); + diff3 = vabdq_u8(s3, r3); + sum[3] = vpadalq_u8(sum[3], diff3); + + s4 = vld1q_u8(src_ptr + 64); + r4 = vld1q_u8(ref_ptr + 64); + diff4 = vabdq_u8(s4, r4); + sum[4] = vpadalq_u8(sum[4], diff4); + + s5 = vld1q_u8(src_ptr + 80); + r5 = vld1q_u8(ref_ptr + 80); + diff5 = vabdq_u8(s5, r5); + sum[5] = vpadalq_u8(sum[5], diff5); + + s6 = vld1q_u8(src_ptr + 96); + r6 = vld1q_u8(ref_ptr + 96); + diff6 = vabdq_u8(s6, r6); + sum[6] = vpadalq_u8(sum[6], diff6); + + s7 = vld1q_u8(src_ptr + 112); + r7 = vld1q_u8(ref_ptr + 112); + diff7 = vabdq_u8(s7, r7); + sum[7] = vpadalq_u8(sum[7], diff7); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + uint32x4_t sum_u32 = vpaddlq_u16(sum[0]); + sum_u32 = vpadalq_u16(sum_u32, sum[1]); + sum_u32 = vpadalq_u16(sum_u32, sum[2]); + sum_u32 = vpadalq_u16(sum_u32, sum[3]); + sum_u32 = vpadalq_u16(sum_u32, sum[4]); + sum_u32 = vpadalq_u16(sum_u32, sum[5]); + sum_u32 = vpadalq_u16(sum_u32, sum[6]); + sum_u32 = vpadalq_u16(sum_u32, sum[7]); + + return horizontal_add_u32x4(sum_u32); +} + +static INLINE unsigned int sad64xh_neon(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride, + int h) { + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int i = h; + do { + uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3; + uint8x16_t diff0, diff1, diff2, diff3; + + s0 = vld1q_u8(src_ptr); + r0 = vld1q_u8(ref_ptr); + diff0 = vabdq_u8(s0, r0); + sum[0] = vpadalq_u8(sum[0], diff0); + + s1 = vld1q_u8(src_ptr + 16); + r1 = vld1q_u8(ref_ptr + 16); + diff1 = vabdq_u8(s1, r1); + sum[1] = vpadalq_u8(sum[1], diff1); + + s2 = vld1q_u8(src_ptr + 32); + r2 = vld1q_u8(ref_ptr + 32); + diff2 = vabdq_u8(s2, r2); + sum[2] = vpadalq_u8(sum[2], diff2); + + s3 = vld1q_u8(src_ptr + 48); + r3 = vld1q_u8(ref_ptr + 48); + diff3 = vabdq_u8(s3, r3); + sum[3] = vpadalq_u8(sum[3], diff3); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + uint32x4_t sum_u32 = vpaddlq_u16(sum[0]); + sum_u32 = vpadalq_u16(sum_u32, sum[1]); + sum_u32 = vpadalq_u16(sum_u32, sum[2]); + sum_u32 = vpadalq_u16(sum_u32, sum[3]); + + return horizontal_add_u32x4(sum_u32); +} + +static INLINE unsigned int sad32xh_neon(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride, + int h) { + uint16x8_t sum[2] = { vdupq_n_u16(0), vdupq_n_u16(0) }; + + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t diff0 = vabdq_u8(s0, r0); + sum[0] = vpadalq_u8(sum[0], diff0); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t diff1 = vabdq_u8(s1, r1); + sum[1] = vpadalq_u8(sum[1], diff1); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u16x8(vaddq_u16(sum[0], sum[1])); +} + +static INLINE unsigned int sad16xh_neon(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride, + int h) { + uint16x8_t sum = vdupq_n_u16(0); + + int i = h; + do { + uint8x16_t s = vld1q_u8(src_ptr); + uint8x16_t r = vld1q_u8(ref_ptr); + + uint8x16_t diff = vabdq_u8(s, r); + sum = vpadalq_u8(sum, diff); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +static INLINE unsigned int sad8xh_neon(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride, + int h) { + uint16x8_t sum = vdupq_n_u16(0); + + int i = h; + do { + uint8x8_t s = vld1_u8(src_ptr); + uint8x8_t r = vld1_u8(ref_ptr); + + sum = vabal_u8(sum, s, r); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +static INLINE unsigned int sad4xh_neon(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride, + int h) { + uint16x8_t sum = vdupq_n_u16(0); + + int i = h / 2; + do { + uint8x8_t s = load_unaligned_u8(src_ptr, src_stride); + uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride); + + sum = vabal_u8(sum, s, r); + + src_ptr += 2 * src_stride; + ref_ptr += 2 * ref_stride; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +#define SAD_WXH_NEON(w, h) \ + unsigned int aom_sad##w##x##h##_neon(const uint8_t *src, int src_stride, \ + const uint8_t *ref, int ref_stride) { \ + return sad##w##xh_neon(src, src_stride, ref, ref_stride, (h)); \ + } + +SAD_WXH_NEON(4, 4) +SAD_WXH_NEON(4, 8) + +SAD_WXH_NEON(8, 4) +SAD_WXH_NEON(8, 8) +SAD_WXH_NEON(8, 16) + +SAD_WXH_NEON(16, 8) +SAD_WXH_NEON(16, 16) +SAD_WXH_NEON(16, 32) + +SAD_WXH_NEON(32, 16) +SAD_WXH_NEON(32, 32) +SAD_WXH_NEON(32, 64) + +SAD_WXH_NEON(64, 32) +SAD_WXH_NEON(64, 64) +SAD_WXH_NEON(64, 128) + +SAD_WXH_NEON(128, 64) +SAD_WXH_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_NEON(4, 16) +SAD_WXH_NEON(8, 32) +SAD_WXH_NEON(16, 4) +SAD_WXH_NEON(16, 64) +SAD_WXH_NEON(32, 8) +SAD_WXH_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_NEON + +#define SAD_SKIP_WXH_NEON(w, h) \ + unsigned int aom_sad_skip_##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return 2 * \ + sad##w##xh_neon(src, 2 * src_stride, ref, 2 * ref_stride, (h) / 2); \ + } + +SAD_SKIP_WXH_NEON(4, 4) +SAD_SKIP_WXH_NEON(4, 8) + +SAD_SKIP_WXH_NEON(8, 4) +SAD_SKIP_WXH_NEON(8, 8) +SAD_SKIP_WXH_NEON(8, 16) + +SAD_SKIP_WXH_NEON(16, 8) +SAD_SKIP_WXH_NEON(16, 16) +SAD_SKIP_WXH_NEON(16, 32) + +SAD_SKIP_WXH_NEON(32, 16) +SAD_SKIP_WXH_NEON(32, 32) +SAD_SKIP_WXH_NEON(32, 64) + +SAD_SKIP_WXH_NEON(64, 32) +SAD_SKIP_WXH_NEON(64, 64) +SAD_SKIP_WXH_NEON(64, 128) + +SAD_SKIP_WXH_NEON(128, 64) +SAD_SKIP_WXH_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_SKIP_WXH_NEON(4, 16) +SAD_SKIP_WXH_NEON(8, 32) +SAD_SKIP_WXH_NEON(16, 4) +SAD_SKIP_WXH_NEON(16, 64) +SAD_SKIP_WXH_NEON(32, 8) +SAD_SKIP_WXH_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_SKIP_WXH_NEON + +static INLINE unsigned int sad128xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + // We use 8 accumulators to prevent overflow for large values of 'h', as well + // as enabling optimal UADALP instruction throughput on CPUs that have either + // 2 or 4 Neon pipes. + uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0) }; + + int i = h; + do { + uint8x16_t s0, s1, s2, s3, s4, s5, s6, s7; + uint8x16_t r0, r1, r2, r3, r4, r5, r6, r7; + uint8x16_t p0, p1, p2, p3, p4, p5, p6, p7; + uint8x16_t avg0, avg1, avg2, avg3, avg4, avg5, avg6, avg7; + uint8x16_t diff0, diff1, diff2, diff3, diff4, diff5, diff6, diff7; + + s0 = vld1q_u8(src_ptr); + r0 = vld1q_u8(ref_ptr); + p0 = vld1q_u8(second_pred); + avg0 = vrhaddq_u8(r0, p0); + diff0 = vabdq_u8(s0, avg0); + sum[0] = vpadalq_u8(sum[0], diff0); + + s1 = vld1q_u8(src_ptr + 16); + r1 = vld1q_u8(ref_ptr + 16); + p1 = vld1q_u8(second_pred + 16); + avg1 = vrhaddq_u8(r1, p1); + diff1 = vabdq_u8(s1, avg1); + sum[1] = vpadalq_u8(sum[1], diff1); + + s2 = vld1q_u8(src_ptr + 32); + r2 = vld1q_u8(ref_ptr + 32); + p2 = vld1q_u8(second_pred + 32); + avg2 = vrhaddq_u8(r2, p2); + diff2 = vabdq_u8(s2, avg2); + sum[2] = vpadalq_u8(sum[2], diff2); + + s3 = vld1q_u8(src_ptr + 48); + r3 = vld1q_u8(ref_ptr + 48); + p3 = vld1q_u8(second_pred + 48); + avg3 = vrhaddq_u8(r3, p3); + diff3 = vabdq_u8(s3, avg3); + sum[3] = vpadalq_u8(sum[3], diff3); + + s4 = vld1q_u8(src_ptr + 64); + r4 = vld1q_u8(ref_ptr + 64); + p4 = vld1q_u8(second_pred + 64); + avg4 = vrhaddq_u8(r4, p4); + diff4 = vabdq_u8(s4, avg4); + sum[4] = vpadalq_u8(sum[4], diff4); + + s5 = vld1q_u8(src_ptr + 80); + r5 = vld1q_u8(ref_ptr + 80); + p5 = vld1q_u8(second_pred + 80); + avg5 = vrhaddq_u8(r5, p5); + diff5 = vabdq_u8(s5, avg5); + sum[5] = vpadalq_u8(sum[5], diff5); + + s6 = vld1q_u8(src_ptr + 96); + r6 = vld1q_u8(ref_ptr + 96); + p6 = vld1q_u8(second_pred + 96); + avg6 = vrhaddq_u8(r6, p6); + diff6 = vabdq_u8(s6, avg6); + sum[6] = vpadalq_u8(sum[6], diff6); + + s7 = vld1q_u8(src_ptr + 112); + r7 = vld1q_u8(ref_ptr + 112); + p7 = vld1q_u8(second_pred + 112); + avg7 = vrhaddq_u8(r7, p7); + diff7 = vabdq_u8(s7, avg7); + sum[7] = vpadalq_u8(sum[7], diff7); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 128; + } while (--i != 0); + + uint32x4_t sum_u32 = vpaddlq_u16(sum[0]); + sum_u32 = vpadalq_u16(sum_u32, sum[1]); + sum_u32 = vpadalq_u16(sum_u32, sum[2]); + sum_u32 = vpadalq_u16(sum_u32, sum[3]); + sum_u32 = vpadalq_u16(sum_u32, sum[4]); + sum_u32 = vpadalq_u16(sum_u32, sum[5]); + sum_u32 = vpadalq_u16(sum_u32, sum[6]); + sum_u32 = vpadalq_u16(sum_u32, sum[7]); + + return horizontal_add_u32x4(sum_u32); +} + +static INLINE unsigned int sad64xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int i = h; + do { + uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3, p0, p1, p2, p3; + uint8x16_t avg0, avg1, avg2, avg3, diff0, diff1, diff2, diff3; + + s0 = vld1q_u8(src_ptr); + r0 = vld1q_u8(ref_ptr); + p0 = vld1q_u8(second_pred); + avg0 = vrhaddq_u8(r0, p0); + diff0 = vabdq_u8(s0, avg0); + sum[0] = vpadalq_u8(sum[0], diff0); + + s1 = vld1q_u8(src_ptr + 16); + r1 = vld1q_u8(ref_ptr + 16); + p1 = vld1q_u8(second_pred + 16); + avg1 = vrhaddq_u8(r1, p1); + diff1 = vabdq_u8(s1, avg1); + sum[1] = vpadalq_u8(sum[1], diff1); + + s2 = vld1q_u8(src_ptr + 32); + r2 = vld1q_u8(ref_ptr + 32); + p2 = vld1q_u8(second_pred + 32); + avg2 = vrhaddq_u8(r2, p2); + diff2 = vabdq_u8(s2, avg2); + sum[2] = vpadalq_u8(sum[2], diff2); + + s3 = vld1q_u8(src_ptr + 48); + r3 = vld1q_u8(ref_ptr + 48); + p3 = vld1q_u8(second_pred + 48); + avg3 = vrhaddq_u8(r3, p3); + diff3 = vabdq_u8(s3, avg3); + sum[3] = vpadalq_u8(sum[3], diff3); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 64; + } while (--i != 0); + + uint32x4_t sum_u32 = vpaddlq_u16(sum[0]); + sum_u32 = vpadalq_u16(sum_u32, sum[1]); + sum_u32 = vpadalq_u16(sum_u32, sum[2]); + sum_u32 = vpadalq_u16(sum_u32, sum[3]); + + return horizontal_add_u32x4(sum_u32); +} + +static INLINE unsigned int sad32xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + uint16x8_t sum[2] = { vdupq_n_u16(0), vdupq_n_u16(0) }; + + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t avg0 = vrhaddq_u8(r0, p0); + uint8x16_t diff0 = vabdq_u8(s0, avg0); + sum[0] = vpadalq_u8(sum[0], diff0); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t avg1 = vrhaddq_u8(r1, p1); + uint8x16_t diff1 = vabdq_u8(s1, avg1); + sum[1] = vpadalq_u8(sum[1], diff1); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 32; + } while (--i != 0); + + return horizontal_add_u16x8(vaddq_u16(sum[0], sum[1])); +} + +static INLINE unsigned int sad16xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + uint16x8_t sum = vdupq_n_u16(0); + + int i = h; + do { + uint8x16_t s = vld1q_u8(src_ptr); + uint8x16_t r = vld1q_u8(ref_ptr); + uint8x16_t p = vld1q_u8(second_pred); + + uint8x16_t avg = vrhaddq_u8(r, p); + uint8x16_t diff = vabdq_u8(s, avg); + sum = vpadalq_u8(sum, diff); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 16; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +static INLINE unsigned int sad8xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + uint16x8_t sum = vdupq_n_u16(0); + + int i = h; + do { + uint8x8_t s = vld1_u8(src_ptr); + uint8x8_t r = vld1_u8(ref_ptr); + uint8x8_t p = vld1_u8(second_pred); + + uint8x8_t avg = vrhadd_u8(r, p); + sum = vabal_u8(sum, s, avg); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 8; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +static INLINE unsigned int sad4xh_avg_neon(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h, + const uint8_t *second_pred) { + uint16x8_t sum = vdupq_n_u16(0); + + int i = h / 2; + do { + uint8x8_t s = load_unaligned_u8(src_ptr, src_stride); + uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride); + uint8x8_t p = vld1_u8(second_pred); + + uint8x8_t avg = vrhadd_u8(r, p); + sum = vabal_u8(sum, s, avg); + + src_ptr += 2 * src_stride; + ref_ptr += 2 * ref_stride; + second_pred += 8; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +#define SAD_WXH_AVG_NEON(w, h) \ + unsigned int aom_sad##w##x##h##_avg_neon(const uint8_t *src, int src_stride, \ + const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred) { \ + return sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h), \ + second_pred); \ + } + +SAD_WXH_AVG_NEON(4, 4) +SAD_WXH_AVG_NEON(4, 8) + +SAD_WXH_AVG_NEON(8, 4) +SAD_WXH_AVG_NEON(8, 8) +SAD_WXH_AVG_NEON(8, 16) + +SAD_WXH_AVG_NEON(16, 8) +SAD_WXH_AVG_NEON(16, 16) +SAD_WXH_AVG_NEON(16, 32) + +SAD_WXH_AVG_NEON(32, 16) +SAD_WXH_AVG_NEON(32, 32) +SAD_WXH_AVG_NEON(32, 64) + +SAD_WXH_AVG_NEON(64, 32) +SAD_WXH_AVG_NEON(64, 64) +SAD_WXH_AVG_NEON(64, 128) + +SAD_WXH_AVG_NEON(128, 64) +SAD_WXH_AVG_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_AVG_NEON(4, 16) +SAD_WXH_AVG_NEON(8, 32) +SAD_WXH_AVG_NEON(16, 4) +SAD_WXH_AVG_NEON(16, 64) +SAD_WXH_AVG_NEON(32, 8) +SAD_WXH_AVG_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_AVG_NEON + +static INLINE unsigned int dist_wtd_sad128xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + // We use 8 accumulators to prevent overflow for large values of 'h', as well + // as enabling optimal UADALP instruction throughput on CPUs that have either + // 2 or 4 Neon pipes. + uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0), vdupq_n_u16(0) }; + + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vpadalq_u8(sum[0], diff0); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vpadalq_u8(sum[1], diff1); + + uint8x16_t s2 = vld1q_u8(src_ptr + 32); + uint8x16_t r2 = vld1q_u8(ref_ptr + 32); + uint8x16_t p2 = vld1q_u8(second_pred + 32); + uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset); + uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2); + sum[2] = vpadalq_u8(sum[2], diff2); + + uint8x16_t s3 = vld1q_u8(src_ptr + 48); + uint8x16_t r3 = vld1q_u8(ref_ptr + 48); + uint8x16_t p3 = vld1q_u8(second_pred + 48); + uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset); + uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3); + sum[3] = vpadalq_u8(sum[3], diff3); + + uint8x16_t s4 = vld1q_u8(src_ptr + 64); + uint8x16_t r4 = vld1q_u8(ref_ptr + 64); + uint8x16_t p4 = vld1q_u8(second_pred + 64); + uint8x16_t wtd_avg4 = dist_wtd_avg_u8x16(p4, r4, bck_offset, fwd_offset); + uint8x16_t diff4 = vabdq_u8(s4, wtd_avg4); + sum[4] = vpadalq_u8(sum[4], diff4); + + uint8x16_t s5 = vld1q_u8(src_ptr + 80); + uint8x16_t r5 = vld1q_u8(ref_ptr + 80); + uint8x16_t p5 = vld1q_u8(second_pred + 80); + uint8x16_t wtd_avg5 = dist_wtd_avg_u8x16(p5, r5, bck_offset, fwd_offset); + uint8x16_t diff5 = vabdq_u8(s5, wtd_avg5); + sum[5] = vpadalq_u8(sum[5], diff5); + + uint8x16_t s6 = vld1q_u8(src_ptr + 96); + uint8x16_t r6 = vld1q_u8(ref_ptr + 96); + uint8x16_t p6 = vld1q_u8(second_pred + 96); + uint8x16_t wtd_avg6 = dist_wtd_avg_u8x16(p6, r6, bck_offset, fwd_offset); + uint8x16_t diff6 = vabdq_u8(s6, wtd_avg6); + sum[6] = vpadalq_u8(sum[6], diff6); + + uint8x16_t s7 = vld1q_u8(src_ptr + 112); + uint8x16_t r7 = vld1q_u8(ref_ptr + 112); + uint8x16_t p7 = vld1q_u8(second_pred + 112); + uint8x16_t wtd_avg7 = dist_wtd_avg_u8x16(p7, r7, bck_offset, fwd_offset); + uint8x16_t diff7 = vabdq_u8(s7, wtd_avg7); + sum[7] = vpadalq_u8(sum[7], diff7); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 128; + } while (--h != 0); + + uint32x4_t sum_u32 = vpaddlq_u16(sum[0]); + sum_u32 = vpadalq_u16(sum_u32, sum[1]); + sum_u32 = vpadalq_u16(sum_u32, sum[2]); + sum_u32 = vpadalq_u16(sum_u32, sum[3]); + sum_u32 = vpadalq_u16(sum_u32, sum[4]); + sum_u32 = vpadalq_u16(sum_u32, sum[5]); + sum_u32 = vpadalq_u16(sum_u32, sum[6]); + sum_u32 = vpadalq_u16(sum_u32, sum[7]); + + return horizontal_add_u32x4(sum_u32); +} + +static INLINE unsigned int dist_wtd_sad64xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vpadalq_u8(sum[0], diff0); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vpadalq_u8(sum[1], diff1); + + uint8x16_t s2 = vld1q_u8(src_ptr + 32); + uint8x16_t r2 = vld1q_u8(ref_ptr + 32); + uint8x16_t p2 = vld1q_u8(second_pred + 32); + uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset); + uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2); + sum[2] = vpadalq_u8(sum[2], diff2); + + uint8x16_t s3 = vld1q_u8(src_ptr + 48); + uint8x16_t r3 = vld1q_u8(ref_ptr + 48); + uint8x16_t p3 = vld1q_u8(second_pred + 48); + uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset); + uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3); + sum[3] = vpadalq_u8(sum[3], diff3); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 64; + } while (--h != 0); + + uint32x4_t sum_u32 = vpaddlq_u16(sum[0]); + sum_u32 = vpadalq_u16(sum_u32, sum[1]); + sum_u32 = vpadalq_u16(sum_u32, sum[2]); + sum_u32 = vpadalq_u16(sum_u32, sum[3]); + + return horizontal_add_u32x4(sum_u32); +} + +static INLINE unsigned int dist_wtd_sad32xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + uint16x8_t sum[2] = { vdupq_n_u16(0), vdupq_n_u16(0) }; + + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vpadalq_u8(sum[0], diff0); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vpadalq_u8(sum[1], diff1); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 32; + } while (--h != 0); + + return horizontal_add_u16x8(vaddq_u16(sum[0], sum[1])); +} + +static INLINE unsigned int dist_wtd_sad16xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + uint16x8_t sum = vdupq_n_u16(0); + + do { + uint8x16_t s = vld1q_u8(src_ptr); + uint8x16_t r = vld1q_u8(ref_ptr); + uint8x16_t p = vld1q_u8(second_pred); + + uint8x16_t wtd_avg = dist_wtd_avg_u8x16(p, r, bck_offset, fwd_offset); + uint8x16_t diff = vabdq_u8(s, wtd_avg); + sum = vpadalq_u8(sum, diff); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 16; + } while (--h != 0); + + return horizontal_add_u16x8(sum); +} + +static INLINE unsigned int dist_wtd_sad8xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x8_t fwd_offset = vdup_n_u8(jcp_param->fwd_offset); + const uint8x8_t bck_offset = vdup_n_u8(jcp_param->bck_offset); + uint16x8_t sum = vdupq_n_u16(0); + + do { + uint8x8_t s = vld1_u8(src_ptr); + uint8x8_t r = vld1_u8(ref_ptr); + uint8x8_t p = vld1_u8(second_pred); + + uint8x8_t wtd_avg = dist_wtd_avg_u8x8(p, r, bck_offset, fwd_offset); + sum = vabal_u8(sum, s, wtd_avg); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 8; + } while (--h != 0); + + return horizontal_add_u16x8(sum); +} + +static INLINE unsigned int dist_wtd_sad4xh_avg_neon( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x8_t fwd_offset = vdup_n_u8(jcp_param->fwd_offset); + const uint8x8_t bck_offset = vdup_n_u8(jcp_param->bck_offset); + uint16x8_t sum = vdupq_n_u16(0); + + int i = h / 2; + do { + uint8x8_t s = load_unaligned_u8(src_ptr, src_stride); + uint8x8_t r = load_unaligned_u8(ref_ptr, ref_stride); + uint8x8_t p = vld1_u8(second_pred); + + uint8x8_t wtd_avg = dist_wtd_avg_u8x8(p, r, bck_offset, fwd_offset); + sum = vabal_u8(sum, s, wtd_avg); + + src_ptr += 2 * src_stride; + ref_ptr += 2 * ref_stride; + second_pred += 8; + } while (--i != 0); + + return horizontal_add_u16x8(sum); +} + +#define DIST_WTD_SAD_WXH_AVG_NEON(w, h) \ + unsigned int aom_dist_wtd_sad##w##x##h##_avg_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \ + return dist_wtd_sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h), \ + second_pred, jcp_param); \ + } + +DIST_WTD_SAD_WXH_AVG_NEON(4, 4) +DIST_WTD_SAD_WXH_AVG_NEON(4, 8) + +DIST_WTD_SAD_WXH_AVG_NEON(8, 4) +DIST_WTD_SAD_WXH_AVG_NEON(8, 8) +DIST_WTD_SAD_WXH_AVG_NEON(8, 16) + +DIST_WTD_SAD_WXH_AVG_NEON(16, 8) +DIST_WTD_SAD_WXH_AVG_NEON(16, 16) +DIST_WTD_SAD_WXH_AVG_NEON(16, 32) + +DIST_WTD_SAD_WXH_AVG_NEON(32, 16) +DIST_WTD_SAD_WXH_AVG_NEON(32, 32) +DIST_WTD_SAD_WXH_AVG_NEON(32, 64) + +DIST_WTD_SAD_WXH_AVG_NEON(64, 32) +DIST_WTD_SAD_WXH_AVG_NEON(64, 64) +DIST_WTD_SAD_WXH_AVG_NEON(64, 128) + +DIST_WTD_SAD_WXH_AVG_NEON(128, 64) +DIST_WTD_SAD_WXH_AVG_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +DIST_WTD_SAD_WXH_AVG_NEON(4, 16) +DIST_WTD_SAD_WXH_AVG_NEON(8, 32) +DIST_WTD_SAD_WXH_AVG_NEON(16, 4) +DIST_WTD_SAD_WXH_AVG_NEON(16, 64) +DIST_WTD_SAD_WXH_AVG_NEON(32, 8) +DIST_WTD_SAD_WXH_AVG_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef DIST_WTD_SAD_WXH_AVG_NEON diff --git a/third_party/aom/aom_dsp/arm/sad_neon_dotprod.c b/third_party/aom/aom_dsp/arm/sad_neon_dotprod.c new file mode 100644 index 0000000000..5504c6838e --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sad_neon_dotprod.c @@ -0,0 +1,530 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/dist_wtd_avg_neon.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE unsigned int sadwxh_neon_dotprod(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int w, int h) { + // Only two accumulators are required for optimal instruction throughput of + // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes. + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h; + do { + int j = 0; + do { + uint8x16_t s0, s1, r0, r1, diff0, diff1; + + s0 = vld1q_u8(src_ptr + j); + r0 = vld1q_u8(ref_ptr + j); + diff0 = vabdq_u8(s0, r0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + s1 = vld1q_u8(src_ptr + j + 16); + r1 = vld1q_u8(ref_ptr + j + 16); + diff1 = vabdq_u8(s1, r1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + j += 32; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1])); +} + +static INLINE unsigned int sad128xh_neon_dotprod(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 128, h); +} + +static INLINE unsigned int sad64xh_neon_dotprod(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 64, h); +} + +static INLINE unsigned int sad32xh_neon_dotprod(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 32, h); +} + +static INLINE unsigned int sad16xh_neon_dotprod(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int h) { + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h / 2; + do { + uint8x16_t s0, s1, r0, r1, diff0, diff1; + + s0 = vld1q_u8(src_ptr); + r0 = vld1q_u8(ref_ptr); + diff0 = vabdq_u8(s0, r0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + + s1 = vld1q_u8(src_ptr); + r1 = vld1q_u8(ref_ptr); + diff1 = vabdq_u8(s1, r1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1])); +} + +#define SAD_WXH_NEON_DOTPROD(w, h) \ + unsigned int aom_sad##w##x##h##_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return sad##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, (h)); \ + } + +SAD_WXH_NEON_DOTPROD(16, 8) +SAD_WXH_NEON_DOTPROD(16, 16) +SAD_WXH_NEON_DOTPROD(16, 32) + +SAD_WXH_NEON_DOTPROD(32, 16) +SAD_WXH_NEON_DOTPROD(32, 32) +SAD_WXH_NEON_DOTPROD(32, 64) + +SAD_WXH_NEON_DOTPROD(64, 32) +SAD_WXH_NEON_DOTPROD(64, 64) +SAD_WXH_NEON_DOTPROD(64, 128) + +SAD_WXH_NEON_DOTPROD(128, 64) +SAD_WXH_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_NEON_DOTPROD(16, 4) +SAD_WXH_NEON_DOTPROD(16, 64) +SAD_WXH_NEON_DOTPROD(32, 8) +SAD_WXH_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_NEON_DOTPROD + +#define SAD_SKIP_WXH_NEON_DOTPROD(w, h) \ + unsigned int aom_sad_skip_##w##x##h##_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *ref, \ + int ref_stride) { \ + return 2 * sad##w##xh_neon_dotprod(src, 2 * src_stride, ref, \ + 2 * ref_stride, (h) / 2); \ + } + +SAD_SKIP_WXH_NEON_DOTPROD(16, 8) +SAD_SKIP_WXH_NEON_DOTPROD(16, 16) +SAD_SKIP_WXH_NEON_DOTPROD(16, 32) + +SAD_SKIP_WXH_NEON_DOTPROD(32, 16) +SAD_SKIP_WXH_NEON_DOTPROD(32, 32) +SAD_SKIP_WXH_NEON_DOTPROD(32, 64) + +SAD_SKIP_WXH_NEON_DOTPROD(64, 32) +SAD_SKIP_WXH_NEON_DOTPROD(64, 64) +SAD_SKIP_WXH_NEON_DOTPROD(64, 128) + +SAD_SKIP_WXH_NEON_DOTPROD(128, 64) +SAD_SKIP_WXH_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_SKIP_WXH_NEON_DOTPROD(16, 4) +SAD_SKIP_WXH_NEON_DOTPROD(16, 64) +SAD_SKIP_WXH_NEON_DOTPROD(32, 8) +SAD_SKIP_WXH_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_SKIP_WXH_NEON_DOTPROD + +static INLINE unsigned int sadwxh_avg_neon_dotprod(const uint8_t *src_ptr, + int src_stride, + const uint8_t *ref_ptr, + int ref_stride, int w, int h, + const uint8_t *second_pred) { + // Only two accumulators are required for optimal instruction throughput of + // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes. + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h; + do { + int j = 0; + do { + uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1; + + s0 = vld1q_u8(src_ptr + j); + r0 = vld1q_u8(ref_ptr + j); + p0 = vld1q_u8(second_pred); + avg0 = vrhaddq_u8(r0, p0); + diff0 = vabdq_u8(s0, avg0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + s1 = vld1q_u8(src_ptr + j + 16); + r1 = vld1q_u8(ref_ptr + j + 16); + p1 = vld1q_u8(second_pred + 16); + avg1 = vrhaddq_u8(r1, p1); + diff1 = vabdq_u8(s1, avg1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + j += 32; + second_pred += 32; + } while (j < w); + + src_ptr += src_stride; + ref_ptr += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1])); +} + +static INLINE unsigned int sad128xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred) { + return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 128, + h, second_pred); +} + +static INLINE unsigned int sad64xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred) { + return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 64, + h, second_pred); +} + +static INLINE unsigned int sad32xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred) { + return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 32, + h, second_pred); +} + +static INLINE unsigned int sad16xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred) { + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h / 2; + do { + uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1; + + s0 = vld1q_u8(src_ptr); + r0 = vld1q_u8(ref_ptr); + p0 = vld1q_u8(second_pred); + avg0 = vrhaddq_u8(r0, p0); + diff0 = vabdq_u8(s0, avg0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 16; + + s1 = vld1q_u8(src_ptr); + r1 = vld1q_u8(ref_ptr); + p1 = vld1q_u8(second_pred); + avg1 = vrhaddq_u8(r1, p1); + diff1 = vabdq_u8(s1, avg1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 16; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1])); +} + +#define SAD_WXH_AVG_NEON_DOTPROD(w, h) \ + unsigned int aom_sad##w##x##h##_avg_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred) { \ + return sad##w##xh_avg_neon_dotprod(src, src_stride, ref, ref_stride, (h), \ + second_pred); \ + } + +SAD_WXH_AVG_NEON_DOTPROD(16, 8) +SAD_WXH_AVG_NEON_DOTPROD(16, 16) +SAD_WXH_AVG_NEON_DOTPROD(16, 32) + +SAD_WXH_AVG_NEON_DOTPROD(32, 16) +SAD_WXH_AVG_NEON_DOTPROD(32, 32) +SAD_WXH_AVG_NEON_DOTPROD(32, 64) + +SAD_WXH_AVG_NEON_DOTPROD(64, 32) +SAD_WXH_AVG_NEON_DOTPROD(64, 64) +SAD_WXH_AVG_NEON_DOTPROD(64, 128) + +SAD_WXH_AVG_NEON_DOTPROD(128, 64) +SAD_WXH_AVG_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_AVG_NEON_DOTPROD(16, 4) +SAD_WXH_AVG_NEON_DOTPROD(16, 64) +SAD_WXH_AVG_NEON_DOTPROD(32, 8) +SAD_WXH_AVG_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_AVG_NEON_DOTPROD + +static INLINE unsigned int dist_wtd_sad128xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + // We use 8 accumulators to minimize the accumulation and loop carried + // dependencies for better instruction throughput. + uint32x4_t sum[8] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0), vdupq_n_u32(0) }; + + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + uint8x16_t s2 = vld1q_u8(src_ptr + 32); + uint8x16_t r2 = vld1q_u8(ref_ptr + 32); + uint8x16_t p2 = vld1q_u8(second_pred + 32); + uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset); + uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2); + sum[2] = vdotq_u32(sum[2], diff2, vdupq_n_u8(1)); + + uint8x16_t s3 = vld1q_u8(src_ptr + 48); + uint8x16_t r3 = vld1q_u8(ref_ptr + 48); + uint8x16_t p3 = vld1q_u8(second_pred + 48); + uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset); + uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3); + sum[3] = vdotq_u32(sum[3], diff3, vdupq_n_u8(1)); + + uint8x16_t s4 = vld1q_u8(src_ptr + 64); + uint8x16_t r4 = vld1q_u8(ref_ptr + 64); + uint8x16_t p4 = vld1q_u8(second_pred + 64); + uint8x16_t wtd_avg4 = dist_wtd_avg_u8x16(p4, r4, bck_offset, fwd_offset); + uint8x16_t diff4 = vabdq_u8(s4, wtd_avg4); + sum[4] = vdotq_u32(sum[4], diff4, vdupq_n_u8(1)); + + uint8x16_t s5 = vld1q_u8(src_ptr + 80); + uint8x16_t r5 = vld1q_u8(ref_ptr + 80); + uint8x16_t p5 = vld1q_u8(second_pred + 80); + uint8x16_t wtd_avg5 = dist_wtd_avg_u8x16(p5, r5, bck_offset, fwd_offset); + uint8x16_t diff5 = vabdq_u8(s5, wtd_avg5); + sum[5] = vdotq_u32(sum[5], diff5, vdupq_n_u8(1)); + + uint8x16_t s6 = vld1q_u8(src_ptr + 96); + uint8x16_t r6 = vld1q_u8(ref_ptr + 96); + uint8x16_t p6 = vld1q_u8(second_pred + 96); + uint8x16_t wtd_avg6 = dist_wtd_avg_u8x16(p6, r6, bck_offset, fwd_offset); + uint8x16_t diff6 = vabdq_u8(s6, wtd_avg6); + sum[6] = vdotq_u32(sum[6], diff6, vdupq_n_u8(1)); + + uint8x16_t s7 = vld1q_u8(src_ptr + 112); + uint8x16_t r7 = vld1q_u8(ref_ptr + 112); + uint8x16_t p7 = vld1q_u8(second_pred + 112); + uint8x16_t wtd_avg7 = dist_wtd_avg_u8x16(p7, r7, bck_offset, fwd_offset); + uint8x16_t diff7 = vabdq_u8(s7, wtd_avg7); + sum[7] = vdotq_u32(sum[7], diff7, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 128; + } while (--h != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + sum[2] = vaddq_u32(sum[2], sum[3]); + sum[4] = vaddq_u32(sum[4], sum[5]); + sum[6] = vaddq_u32(sum[6], sum[7]); + sum[0] = vaddq_u32(sum[0], sum[2]); + sum[4] = vaddq_u32(sum[4], sum[6]); + sum[0] = vaddq_u32(sum[0], sum[4]); + return horizontal_add_u32x4(sum[0]); +} + +static INLINE unsigned int dist_wtd_sad64xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + uint8x16_t s2 = vld1q_u8(src_ptr + 32); + uint8x16_t r2 = vld1q_u8(ref_ptr + 32); + uint8x16_t p2 = vld1q_u8(second_pred + 32); + uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset); + uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2); + sum[2] = vdotq_u32(sum[2], diff2, vdupq_n_u8(1)); + + uint8x16_t s3 = vld1q_u8(src_ptr + 48); + uint8x16_t r3 = vld1q_u8(ref_ptr + 48); + uint8x16_t p3 = vld1q_u8(second_pred + 48); + uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset); + uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3); + sum[3] = vdotq_u32(sum[3], diff3, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 64; + } while (--h != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + sum[2] = vaddq_u32(sum[2], sum[3]); + sum[0] = vaddq_u32(sum[0], sum[2]); + return horizontal_add_u32x4(sum[0]); +} + +static INLINE unsigned int dist_wtd_sad32xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + uint8x16_t s1 = vld1q_u8(src_ptr + 16); + uint8x16_t r1 = vld1q_u8(ref_ptr + 16); + uint8x16_t p1 = vld1q_u8(second_pred + 16); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 32; + } while (--h != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + return horizontal_add_u32x4(sum[0]); +} + +static INLINE unsigned int dist_wtd_sad16xh_avg_neon_dotprod( + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, + int ref_stride, int h, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h / 2; + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + uint8x16_t r0 = vld1q_u8(ref_ptr); + uint8x16_t p0 = vld1q_u8(second_pred); + uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset); + uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0); + sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 16; + + uint8x16_t s1 = vld1q_u8(src_ptr); + uint8x16_t r1 = vld1q_u8(ref_ptr); + uint8x16_t p1 = vld1q_u8(second_pred); + uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset); + uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1); + sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1)); + + src_ptr += src_stride; + ref_ptr += ref_stride; + second_pred += 16; + } while (--i != 0); + + sum[0] = vaddq_u32(sum[0], sum[1]); + return horizontal_add_u32x4(sum[0]); +} + +#define DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(w, h) \ + unsigned int aom_dist_wtd_sad##w##x##h##_avg_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \ + return dist_wtd_sad##w##xh_avg_neon_dotprod( \ + src, src_stride, ref, ref_stride, (h), second_pred, jcp_param); \ + } + +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 8) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 16) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 32) + +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 16) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 32) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 64) + +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 32) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 64) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 128) + +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(128, 64) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 4) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 64) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 8) +DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD diff --git a/third_party/aom/aom_dsp/arm/sadxd_neon.c b/third_party/aom/aom_dsp/arm/sadxd_neon.c new file mode 100644 index 0000000000..e89e1c5a73 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sadxd_neon.c @@ -0,0 +1,514 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref, + uint16x8_t *const sad_sum) { + uint8x16_t abs_diff = vabdq_u8(src, ref); + *sad_sum = vpadalq_u8(*sad_sum, abs_diff); +} + +static INLINE void sadwxhx3d_large_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], + int ref_stride, uint32_t res[3], int w, + int h, int h_overflow) { + uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + int h_limit = h > h_overflow ? h_overflow : h; + + int ref_offset = 0; + int i = 0; + do { + uint16x8_t sum_lo[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) }; + uint16x8_t sum_hi[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) }; + + do { + int j = 0; + do { + const uint8x16_t s0 = vld1q_u8(src + j); + sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]); + sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]); + sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]); + + const uint8x16_t s1 = vld1q_u8(src + j + 16); + sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]); + sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]); + sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]); + + j += 32; + } while (j < w); + + src += src_stride; + ref_offset += ref_stride; + } while (++i < h_limit); + + sum[0] = vpadalq_u16(sum[0], sum_lo[0]); + sum[0] = vpadalq_u16(sum[0], sum_hi[0]); + sum[1] = vpadalq_u16(sum[1], sum_lo[1]); + sum[1] = vpadalq_u16(sum[1], sum_hi[1]); + sum[2] = vpadalq_u16(sum[2], sum_lo[2]); + sum[2] = vpadalq_u16(sum[2], sum_hi[2]); + + h_limit += h_overflow; + } while (i < h); + + res[0] = horizontal_add_u32x4(sum[0]); + res[1] = horizontal_add_u32x4(sum[1]); + res[2] = horizontal_add_u32x4(sum[2]); +} + +static INLINE void sad128xhx3d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], int ref_stride, + uint32_t res[3], int h) { + sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 128, h, 32); +} + +static INLINE void sad64xhx3d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], int ref_stride, + uint32_t res[3], int h) { + sadwxhx3d_large_neon(src, src_stride, ref, ref_stride, res, 64, h, 64); +} + +static INLINE void sad32xhx3d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], int ref_stride, + uint32_t res[3], int h) { + uint16x8_t sum_lo[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) }; + uint16x8_t sum_hi[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + const uint8x16_t s0 = vld1q_u8(src); + sad16_neon(s0, vld1q_u8(ref[0] + ref_offset), &sum_lo[0]); + sad16_neon(s0, vld1q_u8(ref[1] + ref_offset), &sum_lo[1]); + sad16_neon(s0, vld1q_u8(ref[2] + ref_offset), &sum_lo[2]); + + const uint8x16_t s1 = vld1q_u8(src + 16); + sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + 16), &sum_hi[0]); + sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + 16), &sum_hi[1]); + sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + 16), &sum_hi[2]); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + res[0] = horizontal_long_add_u16x8(sum_lo[0], sum_hi[0]); + res[1] = horizontal_long_add_u16x8(sum_lo[1], sum_hi[1]); + res[2] = horizontal_long_add_u16x8(sum_lo[2], sum_hi[2]); +} + +static INLINE void sad16xhx3d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], int ref_stride, + uint32_t res[3], int h) { + uint16x8_t sum[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + const uint8x16_t s = vld1q_u8(src); + sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]); + sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]); + sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + res[0] = horizontal_add_u16x8(sum[0]); + res[1] = horizontal_add_u16x8(sum[1]); + res[2] = horizontal_add_u16x8(sum[2]); +} + +static INLINE void sad8xhx3d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], int ref_stride, + uint32_t res[3], int h) { + uint16x8_t sum[3]; + + uint8x8_t s = vld1_u8(src); + sum[0] = vabdl_u8(s, vld1_u8(ref[0])); + sum[1] = vabdl_u8(s, vld1_u8(ref[1])); + sum[2] = vabdl_u8(s, vld1_u8(ref[2])); + + src += src_stride; + int ref_offset = ref_stride; + int i = h - 1; + do { + s = vld1_u8(src); + sum[0] = vabal_u8(sum[0], s, vld1_u8(ref[0] + ref_offset)); + sum[1] = vabal_u8(sum[1], s, vld1_u8(ref[1] + ref_offset)); + sum[2] = vabal_u8(sum[2], s, vld1_u8(ref[2] + ref_offset)); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + res[0] = horizontal_add_u16x8(sum[0]); + res[1] = horizontal_add_u16x8(sum[1]); + res[2] = horizontal_add_u16x8(sum[2]); +} + +static INLINE void sad4xhx3d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[3], int ref_stride, + uint32_t res[3], int h) { + assert(h % 2 == 0); + uint16x8_t sum[3]; + + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r0 = load_unaligned_u8(ref[0], ref_stride); + uint8x8_t r1 = load_unaligned_u8(ref[1], ref_stride); + uint8x8_t r2 = load_unaligned_u8(ref[2], ref_stride); + + sum[0] = vabdl_u8(s, r0); + sum[1] = vabdl_u8(s, r1); + sum[2] = vabdl_u8(s, r2); + + src += 2 * src_stride; + int ref_offset = 2 * ref_stride; + int i = (h / 2) - 1; + do { + s = load_unaligned_u8(src, src_stride); + r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride); + r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride); + r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride); + + sum[0] = vabal_u8(sum[0], s, r0); + sum[1] = vabal_u8(sum[1], s, r1); + sum[2] = vabal_u8(sum[2], s, r2); + + src += 2 * src_stride; + ref_offset += 2 * ref_stride; + } while (--i != 0); + + res[0] = horizontal_add_u16x8(sum[0]); + res[1] = horizontal_add_u16x8(sum[1]); + res[2] = horizontal_add_u16x8(sum[2]); +} + +#define SAD_WXH_3D_NEON(w, h) \ + void aom_sad##w##x##h##x3d_neon(const uint8_t *src, int src_stride, \ + const uint8_t *const ref[4], int ref_stride, \ + uint32_t res[4]) { \ + sad##w##xhx3d_neon(src, src_stride, ref, ref_stride, res, (h)); \ + } + +SAD_WXH_3D_NEON(4, 4) +SAD_WXH_3D_NEON(4, 8) + +SAD_WXH_3D_NEON(8, 4) +SAD_WXH_3D_NEON(8, 8) +SAD_WXH_3D_NEON(8, 16) + +SAD_WXH_3D_NEON(16, 8) +SAD_WXH_3D_NEON(16, 16) +SAD_WXH_3D_NEON(16, 32) + +SAD_WXH_3D_NEON(32, 16) +SAD_WXH_3D_NEON(32, 32) +SAD_WXH_3D_NEON(32, 64) + +SAD_WXH_3D_NEON(64, 32) +SAD_WXH_3D_NEON(64, 64) +SAD_WXH_3D_NEON(64, 128) + +SAD_WXH_3D_NEON(128, 64) +SAD_WXH_3D_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_3D_NEON(4, 16) +SAD_WXH_3D_NEON(8, 32) +SAD_WXH_3D_NEON(16, 4) +SAD_WXH_3D_NEON(16, 64) +SAD_WXH_3D_NEON(32, 8) +SAD_WXH_3D_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_3D_NEON + +static INLINE void sadwxhx4d_large_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], int w, + int h, int h_overflow) { + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + int h_limit = h > h_overflow ? h_overflow : h; + + int ref_offset = 0; + int i = 0; + do { + uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + do { + int j = 0; + do { + const uint8x16_t s0 = vld1q_u8(src + j); + sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]); + sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]); + sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]); + sad16_neon(s0, vld1q_u8(ref[3] + ref_offset + j), &sum_lo[3]); + + const uint8x16_t s1 = vld1q_u8(src + j + 16); + sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]); + sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]); + sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]); + sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + j + 16), &sum_hi[3]); + + j += 32; + } while (j < w); + + src += src_stride; + ref_offset += ref_stride; + } while (++i < h_limit); + + sum[0] = vpadalq_u16(sum[0], sum_lo[0]); + sum[0] = vpadalq_u16(sum[0], sum_hi[0]); + sum[1] = vpadalq_u16(sum[1], sum_lo[1]); + sum[1] = vpadalq_u16(sum[1], sum_hi[1]); + sum[2] = vpadalq_u16(sum[2], sum_lo[2]); + sum[2] = vpadalq_u16(sum[2], sum_hi[2]); + sum[3] = vpadalq_u16(sum[3], sum_lo[3]); + sum[3] = vpadalq_u16(sum[3], sum_hi[3]); + + h_limit += h_overflow; + } while (i < h); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void sad128xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4], int h) { + sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 128, h, 32); +} + +static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4], int h) { + sadwxhx4d_large_neon(src, src_stride, ref, ref_stride, res, 64, h, 64); +} + +static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4], int h) { + uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + + int ref_offset = 0; + int i = h; + do { + const uint8x16_t s0 = vld1q_u8(src); + sad16_neon(s0, vld1q_u8(ref[0] + ref_offset), &sum_lo[0]); + sad16_neon(s0, vld1q_u8(ref[1] + ref_offset), &sum_lo[1]); + sad16_neon(s0, vld1q_u8(ref[2] + ref_offset), &sum_lo[2]); + sad16_neon(s0, vld1q_u8(ref[3] + ref_offset), &sum_lo[3]); + + const uint8x16_t s1 = vld1q_u8(src + 16); + sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + 16), &sum_hi[0]); + sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + 16), &sum_hi[1]); + sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + 16), &sum_hi[2]); + sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + 16), &sum_hi[3]); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_long_add_4d_u16x8(sum_lo, sum_hi)); +} + +static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4], int h) { + uint16x8_t sum_u16[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0), + vdupq_n_u16(0) }; + uint32x4_t sum_u32[4]; + + int ref_offset = 0; + int i = h; + do { + const uint8x16_t s = vld1q_u8(src); + sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum_u16[0]); + sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum_u16[1]); + sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum_u16[2]); + sad16_neon(s, vld1q_u8(ref[3] + ref_offset), &sum_u16[3]); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + sum_u32[0] = vpaddlq_u16(sum_u16[0]); + sum_u32[1] = vpaddlq_u16(sum_u16[1]); + sum_u32[2] = vpaddlq_u16(sum_u16[2]); + sum_u32[3] = vpaddlq_u16(sum_u16[3]); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32)); +} + +static INLINE void sad8xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4], int h) { + uint16x8_t sum[4]; + + uint8x8_t s = vld1_u8(src); + sum[0] = vabdl_u8(s, vld1_u8(ref[0])); + sum[1] = vabdl_u8(s, vld1_u8(ref[1])); + sum[2] = vabdl_u8(s, vld1_u8(ref[2])); + sum[3] = vabdl_u8(s, vld1_u8(ref[3])); + + src += src_stride; + int ref_offset = ref_stride; + int i = h - 1; + do { + s = vld1_u8(src); + sum[0] = vabal_u8(sum[0], s, vld1_u8(ref[0] + ref_offset)); + sum[1] = vabal_u8(sum[1], s, vld1_u8(ref[1] + ref_offset)); + sum[2] = vabal_u8(sum[2], s, vld1_u8(ref[2] + ref_offset)); + sum[3] = vabal_u8(sum[3], s, vld1_u8(ref[3] + ref_offset)); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_add_4d_u16x8(sum)); +} + +static INLINE void sad4xhx4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4], int h) { + uint16x8_t sum[4]; + + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r0 = load_unaligned_u8(ref[0], ref_stride); + uint8x8_t r1 = load_unaligned_u8(ref[1], ref_stride); + uint8x8_t r2 = load_unaligned_u8(ref[2], ref_stride); + uint8x8_t r3 = load_unaligned_u8(ref[3], ref_stride); + + sum[0] = vabdl_u8(s, r0); + sum[1] = vabdl_u8(s, r1); + sum[2] = vabdl_u8(s, r2); + sum[3] = vabdl_u8(s, r3); + + src += 2 * src_stride; + int ref_offset = 2 * ref_stride; + int i = h / 2; + while (--i != 0) { + s = load_unaligned_u8(src, src_stride); + r0 = load_unaligned_u8(ref[0] + ref_offset, ref_stride); + r1 = load_unaligned_u8(ref[1] + ref_offset, ref_stride); + r2 = load_unaligned_u8(ref[2] + ref_offset, ref_stride); + r3 = load_unaligned_u8(ref[3] + ref_offset, ref_stride); + + sum[0] = vabal_u8(sum[0], s, r0); + sum[1] = vabal_u8(sum[1], s, r1); + sum[2] = vabal_u8(sum[2], s, r2); + sum[3] = vabal_u8(sum[3], s, r3); + + src += 2 * src_stride; + ref_offset += 2 * ref_stride; + } + + vst1q_u32(res, horizontal_add_4d_u16x8(sum)); +} + +#define SAD_WXH_4D_NEON(w, h) \ + void aom_sad##w##x##h##x4d_neon(const uint8_t *src, int src_stride, \ + const uint8_t *const ref[4], int ref_stride, \ + uint32_t res[4]) { \ + sad##w##xhx4d_neon(src, src_stride, ref, ref_stride, res, (h)); \ + } + +SAD_WXH_4D_NEON(4, 4) +SAD_WXH_4D_NEON(4, 8) + +SAD_WXH_4D_NEON(8, 4) +SAD_WXH_4D_NEON(8, 8) +SAD_WXH_4D_NEON(8, 16) + +SAD_WXH_4D_NEON(16, 8) +SAD_WXH_4D_NEON(16, 16) +SAD_WXH_4D_NEON(16, 32) + +SAD_WXH_4D_NEON(32, 16) +SAD_WXH_4D_NEON(32, 32) +SAD_WXH_4D_NEON(32, 64) + +SAD_WXH_4D_NEON(64, 32) +SAD_WXH_4D_NEON(64, 64) +SAD_WXH_4D_NEON(64, 128) + +SAD_WXH_4D_NEON(128, 64) +SAD_WXH_4D_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_4D_NEON(4, 16) +SAD_WXH_4D_NEON(8, 32) +SAD_WXH_4D_NEON(16, 4) +SAD_WXH_4D_NEON(16, 64) +SAD_WXH_4D_NEON(32, 8) +SAD_WXH_4D_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_4D_NEON + +#define SAD_SKIP_WXH_4D_NEON(w, h) \ + void aom_sad_skip_##w##x##h##x4d_neon(const uint8_t *src, int src_stride, \ + const uint8_t *const ref[4], \ + int ref_stride, uint32_t res[4]) { \ + sad##w##xhx4d_neon(src, 2 * src_stride, ref, 2 * ref_stride, res, \ + ((h) >> 1)); \ + res[0] <<= 1; \ + res[1] <<= 1; \ + res[2] <<= 1; \ + res[3] <<= 1; \ + } + +SAD_SKIP_WXH_4D_NEON(4, 4) +SAD_SKIP_WXH_4D_NEON(4, 8) + +SAD_SKIP_WXH_4D_NEON(8, 4) +SAD_SKIP_WXH_4D_NEON(8, 8) +SAD_SKIP_WXH_4D_NEON(8, 16) + +SAD_SKIP_WXH_4D_NEON(16, 8) +SAD_SKIP_WXH_4D_NEON(16, 16) +SAD_SKIP_WXH_4D_NEON(16, 32) + +SAD_SKIP_WXH_4D_NEON(32, 16) +SAD_SKIP_WXH_4D_NEON(32, 32) +SAD_SKIP_WXH_4D_NEON(32, 64) + +SAD_SKIP_WXH_4D_NEON(64, 32) +SAD_SKIP_WXH_4D_NEON(64, 64) +SAD_SKIP_WXH_4D_NEON(64, 128) + +SAD_SKIP_WXH_4D_NEON(128, 64) +SAD_SKIP_WXH_4D_NEON(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_SKIP_WXH_4D_NEON(4, 16) +SAD_SKIP_WXH_4D_NEON(8, 32) +SAD_SKIP_WXH_4D_NEON(16, 4) +SAD_SKIP_WXH_4D_NEON(16, 64) +SAD_SKIP_WXH_4D_NEON(32, 8) +SAD_SKIP_WXH_4D_NEON(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_SKIP_WXH_4D_NEON diff --git a/third_party/aom/aom_dsp/arm/sadxd_neon_dotprod.c b/third_party/aom/aom_dsp/arm/sadxd_neon_dotprod.c new file mode 100644 index 0000000000..3d11d1cb96 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sadxd_neon_dotprod.c @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref, + uint32x4_t *const sad_sum) { + uint8x16_t abs_diff = vabdq_u8(src, ref); + *sad_sum = vdotq_u32(*sad_sum, abs_diff, vdupq_n_u8(1)); +} + +static INLINE void sadwxhx3d_large_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int w, int h) { + uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + + int ref_offset = 0; + int i = h; + do { + int j = 0; + do { + const uint8x16_t s0 = vld1q_u8(src + j); + sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]); + sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]); + sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]); + + const uint8x16_t s1 = vld1q_u8(src + j + 16); + sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]); + sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]); + sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]); + + j += 32; + } while (j < w); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + res[0] = horizontal_add_u32x4(vaddq_u32(sum_lo[0], sum_hi[0])); + res[1] = horizontal_add_u32x4(vaddq_u32(sum_lo[1], sum_hi[1])); + res[2] = horizontal_add_u32x4(vaddq_u32(sum_lo[2], sum_hi[2])); +} + +static INLINE void sad128xhx3d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + sadwxhx3d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 128, h); +} + +static INLINE void sad64xhx3d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + sadwxhx3d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 64, h); +} + +static INLINE void sad32xhx3d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + sadwxhx3d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 32, h); +} + +static INLINE void sad16xhx3d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) }; + + int ref_offset = 0; + int i = h; + do { + const uint8x16_t s = vld1q_u8(src); + sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]); + sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]); + sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + res[0] = horizontal_add_u32x4(sum[0]); + res[1] = horizontal_add_u32x4(sum[1]); + res[2] = horizontal_add_u32x4(sum[2]); +} + +#define SAD_WXH_3D_NEON_DOTPROD(w, h) \ + void aom_sad##w##x##h##x3d_neon_dotprod(const uint8_t *src, int src_stride, \ + const uint8_t *const ref[4], \ + int ref_stride, uint32_t res[4]) { \ + sad##w##xhx3d_neon_dotprod(src, src_stride, ref, ref_stride, res, (h)); \ + } + +SAD_WXH_3D_NEON_DOTPROD(16, 8) +SAD_WXH_3D_NEON_DOTPROD(16, 16) +SAD_WXH_3D_NEON_DOTPROD(16, 32) + +SAD_WXH_3D_NEON_DOTPROD(32, 16) +SAD_WXH_3D_NEON_DOTPROD(32, 32) +SAD_WXH_3D_NEON_DOTPROD(32, 64) + +SAD_WXH_3D_NEON_DOTPROD(64, 32) +SAD_WXH_3D_NEON_DOTPROD(64, 64) +SAD_WXH_3D_NEON_DOTPROD(64, 128) + +SAD_WXH_3D_NEON_DOTPROD(128, 64) +SAD_WXH_3D_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_3D_NEON_DOTPROD(16, 4) +SAD_WXH_3D_NEON_DOTPROD(16, 64) +SAD_WXH_3D_NEON_DOTPROD(32, 8) +SAD_WXH_3D_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_3D_NEON_DOTPROD + +static INLINE void sadwxhx4d_large_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int w, int h) { + uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint32x4_t sum[4]; + + int ref_offset = 0; + int i = h; + do { + int j = 0; + do { + const uint8x16_t s0 = vld1q_u8(src + j); + sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]); + sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]); + sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]); + sad16_neon(s0, vld1q_u8(ref[3] + ref_offset + j), &sum_lo[3]); + + const uint8x16_t s1 = vld1q_u8(src + j + 16); + sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]); + sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]); + sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]); + sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + j + 16), &sum_hi[3]); + + j += 32; + } while (j < w); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]); + sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]); + sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]); + sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +static INLINE void sad128xhx4d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + sadwxhx4d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 128, h); +} + +static INLINE void sad64xhx4d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + sadwxhx4d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 64, h); +} + +static INLINE void sad32xhx4d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + sadwxhx4d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 32, h); +} + +static INLINE void sad16xhx4d_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], + int ref_stride, uint32_t res[4], + int h) { + uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + + int ref_offset = 0; + int i = h; + do { + const uint8x16_t s = vld1q_u8(src); + sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]); + sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]); + sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]); + sad16_neon(s, vld1q_u8(ref[3] + ref_offset), &sum[3]); + + src += src_stride; + ref_offset += ref_stride; + } while (--i != 0); + + vst1q_u32(res, horizontal_add_4d_u32x4(sum)); +} + +#define SAD_WXH_4D_NEON_DOTPROD(w, h) \ + void aom_sad##w##x##h##x4d_neon_dotprod(const uint8_t *src, int src_stride, \ + const uint8_t *const ref[4], \ + int ref_stride, uint32_t res[4]) { \ + sad##w##xhx4d_neon_dotprod(src, src_stride, ref, ref_stride, res, (h)); \ + } + +SAD_WXH_4D_NEON_DOTPROD(16, 8) +SAD_WXH_4D_NEON_DOTPROD(16, 16) +SAD_WXH_4D_NEON_DOTPROD(16, 32) + +SAD_WXH_4D_NEON_DOTPROD(32, 16) +SAD_WXH_4D_NEON_DOTPROD(32, 32) +SAD_WXH_4D_NEON_DOTPROD(32, 64) + +SAD_WXH_4D_NEON_DOTPROD(64, 32) +SAD_WXH_4D_NEON_DOTPROD(64, 64) +SAD_WXH_4D_NEON_DOTPROD(64, 128) + +SAD_WXH_4D_NEON_DOTPROD(128, 64) +SAD_WXH_4D_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_WXH_4D_NEON_DOTPROD(16, 4) +SAD_WXH_4D_NEON_DOTPROD(16, 64) +SAD_WXH_4D_NEON_DOTPROD(32, 8) +SAD_WXH_4D_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_WXH_4D_NEON_DOTPROD + +#define SAD_SKIP_WXH_4D_NEON_DOTPROD(w, h) \ + void aom_sad_skip_##w##x##h##x4d_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *const ref[4], \ + int ref_stride, uint32_t res[4]) { \ + sad##w##xhx4d_neon_dotprod(src, 2 * src_stride, ref, 2 * ref_stride, res, \ + ((h) >> 1)); \ + res[0] <<= 1; \ + res[1] <<= 1; \ + res[2] <<= 1; \ + res[3] <<= 1; \ + } + +SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 8) +SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 16) +SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 32) + +SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 16) +SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 32) +SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 64) + +SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 32) +SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 64) +SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 128) + +SAD_SKIP_WXH_4D_NEON_DOTPROD(128, 64) +SAD_SKIP_WXH_4D_NEON_DOTPROD(128, 128) + +#if !CONFIG_REALTIME_ONLY +SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 4) +SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 64) +SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 8) +SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +#undef SAD_SKIP_WXH_4D_NEON_DOTPROD diff --git a/third_party/aom/aom_dsp/arm/sse_neon.c b/third_party/aom/aom_dsp/arm/sse_neon.c new file mode 100644 index 0000000000..ec8f0ee183 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sse_neon.c @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void sse_16x1_neon(const uint8_t *src, const uint8_t *ref, + uint32x4_t *sse) { + uint8x16_t s = vld1q_u8(src); + uint8x16_t r = vld1q_u8(ref); + + uint8x16_t abs_diff = vabdq_u8(s, r); + uint8x8_t abs_diff_lo = vget_low_u8(abs_diff); + uint8x8_t abs_diff_hi = vget_high_u8(abs_diff); + + *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_lo, abs_diff_lo)); + *sse = vpadalq_u16(*sse, vmull_u8(abs_diff_hi, abs_diff_hi)); +} + +static INLINE void sse_8x1_neon(const uint8_t *src, const uint8_t *ref, + uint32x4_t *sse) { + uint8x8_t s = vld1_u8(src); + uint8x8_t r = vld1_u8(ref); + + uint8x8_t abs_diff = vabd_u8(s, r); + + *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff)); +} + +static INLINE void sse_4x2_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + uint32x4_t *sse) { + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r = load_unaligned_u8(ref, ref_stride); + + uint8x8_t abs_diff = vabd_u8(s, r); + + *sse = vpadalq_u16(*sse, vmull_u8(abs_diff, abs_diff)); +} + +static INLINE uint32_t sse_wxh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int width, int height) { + uint32x4_t sse = vdupq_n_u32(0); + + if ((width & 0x07) && ((width & 0x07) < 5)) { + int i = height; + do { + int j = 0; + do { + sse_8x1_neon(src + j, ref + j, &sse); + sse_8x1_neon(src + j + src_stride, ref + j + ref_stride, &sse); + j += 8; + } while (j + 4 < width); + + sse_4x2_neon(src + j, src_stride, ref + j, ref_stride, &sse); + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + } else { + int i = height; + do { + int j = 0; + do { + sse_8x1_neon(src + j, ref + j, &sse); + j += 8; + } while (j < width); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + } + return horizontal_add_u32x4(sse); +} + +static INLINE uint32_t sse_128xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon(src, ref, &sse[0]); + sse_16x1_neon(src + 16, ref + 16, &sse[1]); + sse_16x1_neon(src + 32, ref + 32, &sse[0]); + sse_16x1_neon(src + 48, ref + 48, &sse[1]); + sse_16x1_neon(src + 64, ref + 64, &sse[0]); + sse_16x1_neon(src + 80, ref + 80, &sse[1]); + sse_16x1_neon(src + 96, ref + 96, &sse[0]); + sse_16x1_neon(src + 112, ref + 112, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_64xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon(src, ref, &sse[0]); + sse_16x1_neon(src + 16, ref + 16, &sse[1]); + sse_16x1_neon(src + 32, ref + 32, &sse[0]); + sse_16x1_neon(src + 48, ref + 48, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_32xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon(src, ref, &sse[0]); + sse_16x1_neon(src + 16, ref + 16, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_16xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon(src, ref, &sse[0]); + src += src_stride; + ref += ref_stride; + sse_16x1_neon(src, ref, &sse[1]); + src += src_stride; + ref += ref_stride; + i -= 2; + } while (i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_8xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse = vdupq_n_u32(0); + + int i = height; + do { + sse_8x1_neon(src, ref, &sse); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(sse); +} + +static INLINE uint32_t sse_4xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse = vdupq_n_u32(0); + + int i = height; + do { + sse_4x2_neon(src, src_stride, ref, ref_stride, &sse); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + + return horizontal_add_u32x4(sse); +} + +int64_t aom_sse_neon(const uint8_t *src, int src_stride, const uint8_t *ref, + int ref_stride, int width, int height) { + switch (width) { + case 4: return sse_4xh_neon(src, src_stride, ref, ref_stride, height); + case 8: return sse_8xh_neon(src, src_stride, ref, ref_stride, height); + case 16: return sse_16xh_neon(src, src_stride, ref, ref_stride, height); + case 32: return sse_32xh_neon(src, src_stride, ref, ref_stride, height); + case 64: return sse_64xh_neon(src, src_stride, ref, ref_stride, height); + case 128: return sse_128xh_neon(src, src_stride, ref, ref_stride, height); + default: + return sse_wxh_neon(src, src_stride, ref, ref_stride, width, height); + } +} diff --git a/third_party/aom/aom_dsp/arm/sse_neon_dotprod.c b/third_party/aom/aom_dsp/arm/sse_neon_dotprod.c new file mode 100644 index 0000000000..979049780b --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sse_neon_dotprod.c @@ -0,0 +1,223 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref, + uint32x4_t *sse) { + uint8x16_t s = vld1q_u8(src); + uint8x16_t r = vld1q_u8(ref); + + uint8x16_t abs_diff = vabdq_u8(s, r); + + *sse = vdotq_u32(*sse, abs_diff, abs_diff); +} + +static INLINE void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref, + uint32x2_t *sse) { + uint8x8_t s = vld1_u8(src); + uint8x8_t r = vld1_u8(ref); + + uint8x8_t abs_diff = vabd_u8(s, r); + + *sse = vdot_u32(*sse, abs_diff, abs_diff); +} + +static INLINE void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + uint32x2_t *sse) { + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r = load_unaligned_u8(ref, ref_stride); + + uint8x8_t abs_diff = vabd_u8(s, r); + + *sse = vdot_u32(*sse, abs_diff, abs_diff); +} + +static INLINE uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int width, int height) { + uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) }; + + if ((width & 0x07) && ((width & 0x07) < 5)) { + int i = height; + do { + int j = 0; + do { + sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]); + sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride, + &sse[1]); + j += 8; + } while (j + 4 < width); + + sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]); + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + } else { + int i = height; + do { + int j = 0; + do { + sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]); + sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride, + &sse[1]); + j += 8; + } while (j < width); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + } + return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_128xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon_dotprod(src, ref, &sse[0]); + sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]); + sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]); + sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]); + sse_16x1_neon_dotprod(src + 64, ref + 64, &sse[0]); + sse_16x1_neon_dotprod(src + 80, ref + 80, &sse[1]); + sse_16x1_neon_dotprod(src + 96, ref + 96, &sse[0]); + sse_16x1_neon_dotprod(src + 112, ref + 112, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon_dotprod(src, ref, &sse[0]); + sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]); + sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]); + sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon_dotprod(src, ref, &sse[0]); + sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = height; + do { + sse_16x1_neon_dotprod(src, ref, &sse[0]); + src += src_stride; + ref += ref_stride; + sse_16x1_neon_dotprod(src, ref, &sse[1]); + src += src_stride; + ref += ref_stride; + i -= 2; + } while (i != 0); + + return horizontal_add_u32x4(vaddq_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) }; + + int i = height; + do { + sse_8x1_neon_dotprod(src, ref, &sse[0]); + src += src_stride; + ref += ref_stride; + sse_8x1_neon_dotprod(src, ref, &sse[1]); + src += src_stride; + ref += ref_stride; + i -= 2; + } while (i != 0); + + return horizontal_add_u32x4(vcombine_u32(sse[0], sse[1])); +} + +static INLINE uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int height) { + uint32x2_t sse = vdup_n_u32(0); + + int i = height; + do { + sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + + return horizontal_add_u32x2(sse); +} + +int64_t aom_sse_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, int width, + int height) { + switch (width) { + case 4: + return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height); + case 8: + return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height); + case 16: + return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height); + case 32: + return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height); + case 64: + return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height); + case 128: + return sse_128xh_neon_dotprod(src, src_stride, ref, ref_stride, height); + default: + return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width, + height); + } +} diff --git a/third_party/aom/aom_dsp/arm/subpel_variance_neon.c b/third_party/aom/aom_dsp/arm/subpel_variance_neon.c new file mode 100644 index 0000000000..2e6e738853 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/subpel_variance_neon.c @@ -0,0 +1,1103 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +#include "aom_ports/mem.h" +#include "aom/aom_integer.h" + +#include "aom_dsp/variance.h" +#include "aom_dsp/arm/dist_wtd_avg_neon.h" +#include "aom_dsp/arm/mem_neon.h" + +static void var_filter_block2d_bil_w4(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int pixel_step, + int dst_height, int filter_offset) { + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride); + uint8x8_t s1 = load_unaligned_u8(src_ptr + pixel_step, src_stride); + uint16x8_t blend = vmull_u8(s0, f0); + blend = vmlal_u8(blend, s1, f1); + uint8x8_t blend_u8 = vrshrn_n_u16(blend, 3); + vst1_u8(dst_ptr, blend_u8); + + src_ptr += 2 * src_stride; + dst_ptr += 2 * 4; + i -= 2; + } while (i != 0); +} + +static void var_filter_block2d_bil_w8(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int pixel_step, + int dst_height, int filter_offset) { + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + uint8x8_t s0 = vld1_u8(src_ptr); + uint8x8_t s1 = vld1_u8(src_ptr + pixel_step); + uint16x8_t blend = vmull_u8(s0, f0); + blend = vmlal_u8(blend, s1, f1); + uint8x8_t blend_u8 = vrshrn_n_u16(blend, 3); + vst1_u8(dst_ptr, blend_u8); + + src_ptr += src_stride; + dst_ptr += 8; + } while (--i != 0); +} + +static void var_filter_block2d_bil_large(const uint8_t *src_ptr, + uint8_t *dst_ptr, int src_stride, + int pixel_step, int dst_width, + int dst_height, int filter_offset) { + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src_ptr + j); + uint8x16_t s1 = vld1q_u8(src_ptr + j + pixel_step); + uint16x8_t blend_l = vmull_u8(vget_low_u8(s0), f0); + blend_l = vmlal_u8(blend_l, vget_low_u8(s1), f1); + uint16x8_t blend_h = vmull_u8(vget_high_u8(s0), f0); + blend_h = vmlal_u8(blend_h, vget_high_u8(s1), f1); + uint8x16_t blend_u8 = + vcombine_u8(vrshrn_n_u16(blend_l, 3), vrshrn_n_u16(blend_h, 3)); + vst1q_u8(dst_ptr + j, blend_u8); + + j += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +static void var_filter_block2d_bil_w16(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int pixel_step, + int dst_height, int filter_offset) { + var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, 16, + dst_height, filter_offset); +} + +static void var_filter_block2d_bil_w32(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int pixel_step, + int dst_height, int filter_offset) { + var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, 32, + dst_height, filter_offset); +} + +static void var_filter_block2d_bil_w64(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int pixel_step, + int dst_height, int filter_offset) { + var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, 64, + dst_height, filter_offset); +} + +static void var_filter_block2d_bil_w128(const uint8_t *src_ptr, + uint8_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset) { + var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, pixel_step, 128, + dst_height, filter_offset); +} + +static void var_filter_block2d_avg(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int pixel_step, + int dst_width, int dst_height) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src_ptr + j); + uint8x16_t s1 = vld1q_u8(src_ptr + j + pixel_step); + uint8x16_t avg = vrhaddq_u8(s0, s1); + vst1q_u8(dst_ptr + j, avg); + + j += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +#define SUBPEL_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, uint32_t *sse) { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding), \ + xoffset); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } + +#define SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, unsigned int *sse) { \ + if (xoffset == 0) { \ + if (yoffset == 0) { \ + return aom_variance##w##x##h(src, src_stride, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp[w * h]; \ + var_filter_block2d_avg(src, tmp, src_stride, src_stride, w, h); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp, src_stride, src_stride, h, \ + yoffset); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + var_filter_block2d_avg(src, tmp0, src_stride, 1, w, h); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(src, tmp0, src_stride, 1, w, (h + padding)); \ + var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(src, tmp0, src_stride, 1, w, (h + padding)); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } else { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, h, xoffset); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding), \ + xoffset); \ + var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding), \ + xoffset); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } \ + } + +SUBPEL_VARIANCE_WXH_NEON(4, 4, 2) +SUBPEL_VARIANCE_WXH_NEON(4, 8, 2) + +SUBPEL_VARIANCE_WXH_NEON(8, 4, 1) +SUBPEL_VARIANCE_WXH_NEON(8, 8, 1) +SUBPEL_VARIANCE_WXH_NEON(8, 16, 1) + +SUBPEL_VARIANCE_WXH_NEON(16, 8, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(16, 16, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(16, 32, 1) + +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(32, 16, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(32, 32, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(32, 64, 1) + +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(64, 32, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(64, 64, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(64, 128, 1) + +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(128, 64, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(128, 128, 1) + +// Realtime mode doesn't use 4x rectangular blocks. +#if !CONFIG_REALTIME_ONLY + +SUBPEL_VARIANCE_WXH_NEON(4, 16, 2) + +SUBPEL_VARIANCE_WXH_NEON(8, 32, 1) + +SUBPEL_VARIANCE_WXH_NEON(16, 4, 1) +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(16, 64, 1) + +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(32, 8, 1) + +SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON(64, 16, 1) + +#endif // !CONFIG_REALTIME_ONLY + +#undef SUBPEL_VARIANCE_WXH_NEON +#undef SPECIALIZED_SUBPEL_VARIANCE_WXH_NEON + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 4. +static void avg_pred_var_filter_block2d_bil_w4(const uint8_t *src_ptr, + uint8_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset, + const uint8_t *second_pred) { + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride); + uint8x8_t s1 = load_unaligned_u8(src_ptr + pixel_step, src_stride); + uint16x8_t blend = vmull_u8(s0, f0); + blend = vmlal_u8(blend, s1, f1); + uint8x8_t blend_u8 = vrshrn_n_u16(blend, 3); + + uint8x8_t p = vld1_u8(second_pred); + uint8x8_t avg = vrhadd_u8(blend_u8, p); + + vst1_u8(dst_ptr, avg); + + src_ptr += 2 * src_stride; + dst_ptr += 2 * 4; + second_pred += 2 * 4; + i -= 2; + } while (i != 0); +} + +// Combine bilinear filter with aom_dist_wtd_comp_avg_pred for blocks having +// width 4. +static void dist_wtd_avg_pred_var_filter_block2d_bil_w4( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x8_t fwd_offset = vdup_n_u8(jcp_param->fwd_offset); + const uint8x8_t bck_offset = vdup_n_u8(jcp_param->bck_offset); + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride); + uint8x8_t s1 = load_unaligned_u8(src_ptr + pixel_step, src_stride); + uint8x8_t p = vld1_u8(second_pred); + uint16x8_t blend = vmull_u8(s0, f0); + blend = vmlal_u8(blend, s1, f1); + uint8x8_t blend_u8 = vrshrn_n_u16(blend, 3); + uint8x8_t avg = dist_wtd_avg_u8x8(blend_u8, p, fwd_offset, bck_offset); + + vst1_u8(dst_ptr, avg); + + src_ptr += 2 * src_stride; + dst_ptr += 2 * 4; + second_pred += 2 * 4; + i -= 2; + } while (i != 0); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 8. +static void avg_pred_var_filter_block2d_bil_w8(const uint8_t *src_ptr, + uint8_t *dst_ptr, int src_stride, + int pixel_step, int dst_height, + int filter_offset, + const uint8_t *second_pred) { + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + uint8x8_t s0 = vld1_u8(src_ptr); + uint8x8_t s1 = vld1_u8(src_ptr + pixel_step); + uint16x8_t blend = vmull_u8(s0, f0); + blend = vmlal_u8(blend, s1, f1); + uint8x8_t blend_u8 = vrshrn_n_u16(blend, 3); + + uint8x8_t p = vld1_u8(second_pred); + uint8x8_t avg = vrhadd_u8(blend_u8, p); + + vst1_u8(dst_ptr, avg); + + src_ptr += src_stride; + dst_ptr += 8; + second_pred += 8; + } while (--i > 0); +} + +// Combine bilinear filter with aom_dist_wtd_comp_avg_pred for blocks having +// width 8. +static void dist_wtd_avg_pred_var_filter_block2d_bil_w8( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x8_t fwd_offset = vdup_n_u8(jcp_param->fwd_offset); + const uint8x8_t bck_offset = vdup_n_u8(jcp_param->bck_offset); + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + uint8x8_t s0 = vld1_u8(src_ptr); + uint8x8_t s1 = vld1_u8(src_ptr + pixel_step); + uint8x8_t p = vld1_u8(second_pred); + uint16x8_t blend = vmull_u8(s0, f0); + blend = vmlal_u8(blend, s1, f1); + uint8x8_t blend_u8 = vrshrn_n_u16(blend, 3); + uint8x8_t avg = dist_wtd_avg_u8x8(blend_u8, p, fwd_offset, bck_offset); + + vst1_u8(dst_ptr, avg); + + src_ptr += src_stride; + dst_ptr += 8; + second_pred += 8; + } while (--i > 0); +} + +// Combine bilinear filter with aom_comp_avg_pred for large blocks. +static void avg_pred_var_filter_block2d_bil_large( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, int filter_offset, + const uint8_t *second_pred) { + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src_ptr + j); + uint8x16_t s1 = vld1q_u8(src_ptr + j + pixel_step); + uint16x8_t blend_l = vmull_u8(vget_low_u8(s0), f0); + blend_l = vmlal_u8(blend_l, vget_low_u8(s1), f1); + uint16x8_t blend_h = vmull_u8(vget_high_u8(s0), f0); + blend_h = vmlal_u8(blend_h, vget_high_u8(s1), f1); + uint8x16_t blend_u8 = + vcombine_u8(vrshrn_n_u16(blend_l, 3), vrshrn_n_u16(blend_h, 3)); + + uint8x16_t p = vld1q_u8(second_pred); + uint8x16_t avg = vrhaddq_u8(blend_u8, p); + + vst1q_u8(dst_ptr + j, avg); + + j += 16; + second_pred += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +// Combine bilinear filter with aom_dist_wtd_comp_avg_pred for large blocks. +static void dist_wtd_avg_pred_var_filter_block2d_bil_large( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, int filter_offset, + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + const uint8x8_t f0 = vdup_n_u8(8 - filter_offset); + const uint8x8_t f1 = vdup_n_u8(filter_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src_ptr + j); + uint8x16_t s1 = vld1q_u8(src_ptr + j + pixel_step); + uint16x8_t blend_l = vmull_u8(vget_low_u8(s0), f0); + blend_l = vmlal_u8(blend_l, vget_low_u8(s1), f1); + uint16x8_t blend_h = vmull_u8(vget_high_u8(s0), f0); + blend_h = vmlal_u8(blend_h, vget_high_u8(s1), f1); + uint8x16_t blend_u8 = + vcombine_u8(vrshrn_n_u16(blend_l, 3), vrshrn_n_u16(blend_h, 3)); + + uint8x16_t p = vld1q_u8(second_pred); + uint8x16_t avg = dist_wtd_avg_u8x16(blend_u8, p, fwd_offset, bck_offset); + + vst1q_u8(dst_ptr + j, avg); + + j += 16; + second_pred += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 16. +static void avg_pred_var_filter_block2d_bil_w16( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred) { + avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 16, dst_height, + filter_offset, second_pred); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 32. +static void avg_pred_var_filter_block2d_bil_w32( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred) { + avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 32, dst_height, + filter_offset, second_pred); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 64. +static void avg_pred_var_filter_block2d_bil_w64( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred) { + avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 64, dst_height, + filter_offset, second_pred); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 128. +static void avg_pred_var_filter_block2d_bil_w128( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred) { + avg_pred_var_filter_block2d_bil_large(src_ptr, dst_ptr, src_stride, + pixel_step, 128, dst_height, + filter_offset, second_pred); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 16. +static void dist_wtd_avg_pred_var_filter_block2d_bil_w16( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 16, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 32. +static void dist_wtd_avg_pred_var_filter_block2d_bil_w32( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 32, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 64. +static void dist_wtd_avg_pred_var_filter_block2d_bil_w64( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 64, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine bilinear filter with aom_comp_avg_pred for blocks having width 128. +static void dist_wtd_avg_pred_var_filter_block2d_bil_w128( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_height, int filter_offset, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + dist_wtd_avg_pred_var_filter_block2d_bil_large( + src_ptr, dst_ptr, src_stride, pixel_step, 128, dst_height, filter_offset, + second_pred, jcp_param); +} + +// Combine averaging subpel filter with aom_comp_avg_pred. +static void avg_pred_var_filter_block2d_avg(const uint8_t *src_ptr, + uint8_t *dst_ptr, int src_stride, + int pixel_step, int dst_width, + int dst_height, + const uint8_t *second_pred) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src_ptr + j); + uint8x16_t s1 = vld1q_u8(src_ptr + j + pixel_step); + uint8x16_t avg = vrhaddq_u8(s0, s1); + + uint8x16_t p = vld1q_u8(second_pred); + avg = vrhaddq_u8(avg, p); + + vst1q_u8(dst_ptr + j, avg); + + j += 16; + second_pred += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +// Combine averaging subpel filter with aom_dist_wtd_comp_avg_pred. +static void dist_wtd_avg_pred_var_filter_block2d_avg( + const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, int pixel_step, + int dst_width, int dst_height, const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s0 = vld1q_u8(src_ptr + j); + uint8x16_t s1 = vld1q_u8(src_ptr + j + pixel_step); + uint8x16_t p = vld1q_u8(second_pred); + uint8x16_t avg = vrhaddq_u8(s0, s1); + avg = dist_wtd_avg_u8x16(avg, p, fwd_offset, bck_offset); + + vst1q_u8(dst_ptr + j, avg); + + j += 16; + second_pred += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +// Implementation of aom_comp_avg_pred for blocks having width >= 16. +static void avg_pred(const uint8_t *src_ptr, uint8_t *dst_ptr, int src_stride, + int dst_width, int dst_height, + const uint8_t *second_pred) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s = vld1q_u8(src_ptr + j); + uint8x16_t p = vld1q_u8(second_pred); + + uint8x16_t avg = vrhaddq_u8(s, p); + + vst1q_u8(dst_ptr + j, avg); + + j += 16; + second_pred += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +// Implementation of aom_dist_wtd_comp_avg_pred for blocks having width >= 16. +static void dist_wtd_avg_pred(const uint8_t *src_ptr, uint8_t *dst_ptr, + int src_stride, int dst_width, int dst_height, + const uint8_t *second_pred, + const DIST_WTD_COMP_PARAMS *jcp_param) { + // We only specialise on the filter values for large block sizes (>= 16x16.) + assert(dst_width >= 16 && dst_width % 16 == 0); + const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset); + const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset); + + int i = dst_height; + do { + int j = 0; + do { + uint8x16_t s = vld1q_u8(src_ptr + j); + uint8x16_t p = vld1q_u8(second_pred); + + uint8x16_t avg = dist_wtd_avg_u8x16(s, p, fwd_offset, bck_offset); + + vst1q_u8(dst_ptr + j, avg); + + j += 16; + second_pred += 16; + } while (j < dst_width); + + src_ptr += src_stride; + dst_ptr += dst_width; + } while (--i != 0); +} + +#define SUBPEL_AVG_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, uint32_t *sse, \ + const uint8_t *second_pred) { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, (h + padding), \ + xoffset); \ + avg_pred_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset, \ + second_pred); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } + +#define SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, unsigned int *sse, \ + const uint8_t *second_pred) { \ + if (xoffset == 0) { \ + uint8_t tmp[w * h]; \ + if (yoffset == 0) { \ + avg_pred(src, tmp, source_stride, w, h, second_pred); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + avg_pred_var_filter_block2d_avg(src, tmp, source_stride, \ + source_stride, w, h, second_pred); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } else { \ + avg_pred_var_filter_block2d_bil_w##w( \ + src, tmp, source_stride, source_stride, h, yoffset, second_pred); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + avg_pred_var_filter_block2d_avg(src, tmp0, source_stride, 1, w, h, \ + second_pred); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(src, tmp0, source_stride, 1, w, (h + padding)); \ + avg_pred_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h, second_pred); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(src, tmp0, source_stride, 1, w, (h + padding)); \ + avg_pred_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset, \ + second_pred); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } else { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + avg_pred_var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, h, \ + xoffset, second_pred); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, \ + (h + padding), xoffset); \ + avg_pred_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h, second_pred); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, \ + (h + padding), xoffset); \ + avg_pred_var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset, \ + second_pred); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } \ + } + +SUBPEL_AVG_VARIANCE_WXH_NEON(4, 4, 2) +SUBPEL_AVG_VARIANCE_WXH_NEON(4, 8, 2) + +SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 1) +SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 1) +SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 1) + +SUBPEL_AVG_VARIANCE_WXH_NEON(16, 8, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 16, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 32, 1) + +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 16, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 32, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 64, 1) + +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 32, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 64, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 128, 1) + +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(128, 64, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(128, 128, 1) + +#if !CONFIG_REALTIME_ONLY + +SUBPEL_AVG_VARIANCE_WXH_NEON(4, 16, 2) + +SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 1) + +SUBPEL_AVG_VARIANCE_WXH_NEON(16, 4, 1) +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 64, 1) + +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 8, 1) + +SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 16, 1) + +#endif // !CONFIG_REALTIME_ONLY + +#undef SUBPEL_AVG_VARIANCE_WXH_NEON +#undef SPECIALIZED_SUBPEL_AVG_VARIANCE_WXH_NEON + +#define DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_dist_wtd_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, uint32_t *sse, \ + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, (h + padding), \ + xoffset); \ + dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } + +#define SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_dist_wtd_sub_pixel_avg_variance##w##x##h##_neon( \ + const uint8_t *src, int source_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, unsigned int *sse, \ + const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \ + if (xoffset == 0) { \ + uint8_t tmp[w * h]; \ + if (yoffset == 0) { \ + dist_wtd_avg_pred(src, tmp, source_stride, w, h, second_pred, \ + jcp_param); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + dist_wtd_avg_pred_var_filter_block2d_avg(src, tmp, source_stride, \ + source_stride, w, h, \ + second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } else { \ + dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + src, tmp, source_stride, source_stride, h, yoffset, second_pred, \ + jcp_param); \ + return aom_variance##w##x##h(tmp, w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + dist_wtd_avg_pred_var_filter_block2d_avg( \ + src, tmp0, source_stride, 1, w, h, second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(src, tmp0, source_stride, 1, w, (h + padding)); \ + dist_wtd_avg_pred_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h, \ + second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(src, tmp0, source_stride, 1, w, (h + padding)); \ + dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } else { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + src, tmp0, source_stride, 1, h, xoffset, second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, \ + (h + padding), xoffset); \ + dist_wtd_avg_pred_var_filter_block2d_avg(tmp0, tmp1, w, w, w, h, \ + second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, source_stride, 1, \ + (h + padding), xoffset); \ + dist_wtd_avg_pred_var_filter_block2d_bil_w##w( \ + tmp0, tmp1, w, w, h, yoffset, second_pred, jcp_param); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } \ + } + +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(4, 4, 2) +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(4, 8, 2) + +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 4, 1) +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 8, 1) +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 16, 1) + +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 8, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 16, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 32, 1) + +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 16, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 32, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 64, 1) + +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 32, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 64, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 128, 1) + +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(128, 64, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(128, 128, 1) + +#if !CONFIG_REALTIME_ONLY + +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(4, 16, 2) + +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(8, 32, 1) + +DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 4, 1) +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(16, 64, 1) + +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(32, 8, 1) + +SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON(64, 16, 1) + +#endif // !CONFIG_REALTIME_ONLY + +#undef DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON +#undef SPECIALIZED_DIST_WTD_SUBPEL_AVG_VARIANCE_WXH_NEON + +#if !CONFIG_REALTIME_ONLY + +#define OBMC_SUBPEL_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_obmc_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *pre, int pre_stride, int xoffset, int yoffset, \ + const int32_t *wsrc, const int32_t *mask, unsigned int *sse) { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h + padding, \ + xoffset); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_obmc_variance##w##x##h(tmp1, w, wsrc, mask, sse); \ + } + +#define SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_obmc_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *pre, int pre_stride, int xoffset, int yoffset, \ + const int32_t *wsrc, const int32_t *mask, unsigned int *sse) { \ + if (xoffset == 0) { \ + if (yoffset == 0) { \ + return aom_obmc_variance##w##x##h##_neon(pre, pre_stride, wsrc, mask, \ + sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp[w * h]; \ + var_filter_block2d_avg(pre, tmp, pre_stride, pre_stride, w, h); \ + return aom_obmc_variance##w##x##h##_neon(tmp, w, wsrc, mask, sse); \ + } else { \ + uint8_t tmp[w * h]; \ + var_filter_block2d_bil_w##w(pre, tmp, pre_stride, pre_stride, h, \ + yoffset); \ + return aom_obmc_variance##w##x##h##_neon(tmp, w, wsrc, mask, sse); \ + } \ + } else if (xoffset == 4) { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + var_filter_block2d_avg(pre, tmp0, pre_stride, 1, w, h); \ + return aom_obmc_variance##w##x##h##_neon(tmp0, w, wsrc, mask, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(pre, tmp0, pre_stride, 1, w, h + padding); \ + var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse); \ + } else { \ + uint8_t tmp1[w * (h + padding)]; \ + var_filter_block2d_avg(pre, tmp0, pre_stride, 1, w, h + padding); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse); \ + } \ + } else { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h, xoffset); \ + return aom_obmc_variance##w##x##h##_neon(tmp0, w, wsrc, mask, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h + padding, \ + xoffset); \ + var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse); \ + } else { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(pre, tmp0, pre_stride, 1, h + padding, \ + xoffset); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + return aom_obmc_variance##w##x##h##_neon(tmp1, w, wsrc, mask, sse); \ + } \ + } \ + } + +OBMC_SUBPEL_VARIANCE_WXH_NEON(4, 4, 2) +OBMC_SUBPEL_VARIANCE_WXH_NEON(4, 8, 2) +OBMC_SUBPEL_VARIANCE_WXH_NEON(4, 16, 2) + +OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 4, 1) +OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 8, 1) +OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 16, 1) +OBMC_SUBPEL_VARIANCE_WXH_NEON(8, 32, 1) + +OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 4, 1) +OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 8, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 16, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 32, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(16, 64, 1) + +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 8, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 16, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 32, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(32, 64, 1) + +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 16, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 32, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 64, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(64, 128, 1) + +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(128, 64, 1) +SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON(128, 128, 1) + +#undef OBMC_SUBPEL_VARIANCE_WXH_NEON +#undef SPECIALIZED_OBMC_SUBPEL_VARIANCE_WXH_NEON +#endif // !CONFIG_REALTIME_ONLY + +#define MASKED_SUBPEL_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_masked_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \ + const uint8_t *msk, int msk_stride, int invert_mask, \ + unsigned int *sse) { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * h]; \ + uint8_t tmp2[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding), \ + xoffset); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk, msk_stride, \ + invert_mask); \ + return aom_variance##w##x##h(tmp2, w, ref, ref_stride, sse); \ + } + +#define SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(w, h, padding) \ + unsigned int aom_masked_sub_pixel_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, int xoffset, int yoffset, \ + const uint8_t *ref, int ref_stride, const uint8_t *second_pred, \ + const uint8_t *msk, int msk_stride, int invert_mask, \ + unsigned int *sse) { \ + if (xoffset == 0) { \ + uint8_t tmp0[w * h]; \ + if (yoffset == 0) { \ + aom_comp_mask_pred_neon(tmp0, second_pred, w, h, src, src_stride, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp0, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_avg(src, tmp0, src_stride, src_stride, w, h); \ + aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, src_stride, h, \ + yoffset); \ + aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } \ + } else if (xoffset == 4) { \ + uint8_t tmp0[w * (h + padding)]; \ + if (yoffset == 0) { \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_avg(src, tmp0, src_stride, 1, w, h); \ + aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp1[w * h]; \ + uint8_t tmp2[w * h]; \ + var_filter_block2d_avg(src, tmp0, src_stride, 1, w, (h + padding)); \ + var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp2, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp1[w * h]; \ + uint8_t tmp2[w * h]; \ + var_filter_block2d_avg(src, tmp0, src_stride, 1, w, (h + padding)); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp2, w, ref, ref_stride, sse); \ + } \ + } else { \ + if (yoffset == 0) { \ + uint8_t tmp0[w * h]; \ + uint8_t tmp1[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, h, xoffset); \ + aom_comp_mask_pred_neon(tmp1, second_pred, w, h, tmp0, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp1, w, ref, ref_stride, sse); \ + } else if (yoffset == 4) { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * h]; \ + uint8_t tmp2[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding), \ + xoffset); \ + var_filter_block2d_avg(tmp0, tmp1, w, w, w, h); \ + aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp2, w, ref, ref_stride, sse); \ + } else { \ + uint8_t tmp0[w * (h + padding)]; \ + uint8_t tmp1[w * (h + padding)]; \ + uint8_t tmp2[w * h]; \ + var_filter_block2d_bil_w##w(src, tmp0, src_stride, 1, (h + padding), \ + xoffset); \ + var_filter_block2d_bil_w##w(tmp0, tmp1, w, w, h, yoffset); \ + aom_comp_mask_pred_neon(tmp2, second_pred, w, h, tmp1, w, msk, \ + msk_stride, invert_mask); \ + return aom_variance##w##x##h(tmp2, w, ref, ref_stride, sse); \ + } \ + } \ + } + +MASKED_SUBPEL_VARIANCE_WXH_NEON(4, 4, 2) +MASKED_SUBPEL_VARIANCE_WXH_NEON(4, 8, 2) + +MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 4, 1) +MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 8, 1) +MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 16, 1) + +MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 8, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 16, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 32, 1) + +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 16, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 32, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 64, 1) + +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 32, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 64, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 128, 1) + +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(128, 64, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(128, 128, 1) + +// Realtime mode doesn't use 4x rectangular blocks. +#if !CONFIG_REALTIME_ONLY +MASKED_SUBPEL_VARIANCE_WXH_NEON(4, 16, 2) +MASKED_SUBPEL_VARIANCE_WXH_NEON(8, 32, 1) +MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 4, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(16, 64, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(32, 8, 1) +SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON(64, 16, 1) +#endif // !CONFIG_REALTIME_ONLY + +#undef MASKED_SUBPEL_VARIANCE_WXH_NEON +#undef SPECIALIZED_MASKED_SUBPEL_VARIANCE_WXH_NEON diff --git a/third_party/aom/aom_dsp/arm/subtract_neon.c b/third_party/aom/aom_dsp/arm/subtract_neon.c new file mode 100644 index 0000000000..a195c40d19 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/subtract_neon.c @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "config/aom_config.h" + +#include "aom/aom_integer.h" +#include "aom_ports/mem.h" + +void aom_subtract_block_neon(int rows, int cols, int16_t *diff, + ptrdiff_t diff_stride, const uint8_t *src, + ptrdiff_t src_stride, const uint8_t *pred, + ptrdiff_t pred_stride) { + if (cols > 16) { + int r = rows; + do { + int c = 0; + do { + const uint8x16_t v_src_00 = vld1q_u8(&src[c + 0]); + const uint8x16_t v_src_16 = vld1q_u8(&src[c + 16]); + const uint8x16_t v_pred_00 = vld1q_u8(&pred[c + 0]); + const uint8x16_t v_pred_16 = vld1q_u8(&pred[c + 16]); + const uint16x8_t v_diff_lo_00 = + vsubl_u8(vget_low_u8(v_src_00), vget_low_u8(v_pred_00)); + const uint16x8_t v_diff_hi_00 = + vsubl_u8(vget_high_u8(v_src_00), vget_high_u8(v_pred_00)); + const uint16x8_t v_diff_lo_16 = + vsubl_u8(vget_low_u8(v_src_16), vget_low_u8(v_pred_16)); + const uint16x8_t v_diff_hi_16 = + vsubl_u8(vget_high_u8(v_src_16), vget_high_u8(v_pred_16)); + vst1q_s16(&diff[c + 0], vreinterpretq_s16_u16(v_diff_lo_00)); + vst1q_s16(&diff[c + 8], vreinterpretq_s16_u16(v_diff_hi_00)); + vst1q_s16(&diff[c + 16], vreinterpretq_s16_u16(v_diff_lo_16)); + vst1q_s16(&diff[c + 24], vreinterpretq_s16_u16(v_diff_hi_16)); + c += 32; + } while (c < cols); + diff += diff_stride; + pred += pred_stride; + src += src_stride; + } while (--r != 0); + } else if (cols > 8) { + int r = rows; + do { + const uint8x16_t v_src = vld1q_u8(&src[0]); + const uint8x16_t v_pred = vld1q_u8(&pred[0]); + const uint16x8_t v_diff_lo = + vsubl_u8(vget_low_u8(v_src), vget_low_u8(v_pred)); + const uint16x8_t v_diff_hi = + vsubl_u8(vget_high_u8(v_src), vget_high_u8(v_pred)); + vst1q_s16(&diff[0], vreinterpretq_s16_u16(v_diff_lo)); + vst1q_s16(&diff[8], vreinterpretq_s16_u16(v_diff_hi)); + diff += diff_stride; + pred += pred_stride; + src += src_stride; + } while (--r != 0); + } else if (cols > 4) { + int r = rows; + do { + const uint8x8_t v_src = vld1_u8(&src[0]); + const uint8x8_t v_pred = vld1_u8(&pred[0]); + const uint16x8_t v_diff = vsubl_u8(v_src, v_pred); + vst1q_s16(&diff[0], vreinterpretq_s16_u16(v_diff)); + diff += diff_stride; + pred += pred_stride; + src += src_stride; + } while (--r != 0); + } else { + int r = rows; + do { + int c = 0; + do { + diff[c] = src[c] - pred[c]; + } while (++c < cols); + diff += diff_stride; + pred += pred_stride; + src += src_stride; + } while (--r != 0); + } +} + +#if CONFIG_AV1_HIGHBITDEPTH +void aom_highbd_subtract_block_neon(int rows, int cols, int16_t *diff, + ptrdiff_t diff_stride, const uint8_t *src8, + ptrdiff_t src_stride, const uint8_t *pred8, + ptrdiff_t pred_stride) { + uint16_t *src = CONVERT_TO_SHORTPTR(src8); + uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); + + if (cols > 16) { + int r = rows; + do { + int c = 0; + do { + const uint16x8_t v_src_00 = vld1q_u16(&src[c + 0]); + const uint16x8_t v_pred_00 = vld1q_u16(&pred[c + 0]); + const uint16x8_t v_diff_00 = vsubq_u16(v_src_00, v_pred_00); + const uint16x8_t v_src_08 = vld1q_u16(&src[c + 8]); + const uint16x8_t v_pred_08 = vld1q_u16(&pred[c + 8]); + const uint16x8_t v_diff_08 = vsubq_u16(v_src_08, v_pred_08); + vst1q_s16(&diff[c + 0], vreinterpretq_s16_u16(v_diff_00)); + vst1q_s16(&diff[c + 8], vreinterpretq_s16_u16(v_diff_08)); + c += 16; + } while (c < cols); + diff += diff_stride; + pred += pred_stride; + src += src_stride; + } while (--r != 0); + } else if (cols > 8) { + int r = rows; + do { + const uint16x8_t v_src_00 = vld1q_u16(&src[0]); + const uint16x8_t v_pred_00 = vld1q_u16(&pred[0]); + const uint16x8_t v_diff_00 = vsubq_u16(v_src_00, v_pred_00); + const uint16x8_t v_src_08 = vld1q_u16(&src[8]); + const uint16x8_t v_pred_08 = vld1q_u16(&pred[8]); + const uint16x8_t v_diff_08 = vsubq_u16(v_src_08, v_pred_08); + vst1q_s16(&diff[0], vreinterpretq_s16_u16(v_diff_00)); + vst1q_s16(&diff[8], vreinterpretq_s16_u16(v_diff_08)); + diff += diff_stride; + pred += pred_stride; + src += src_stride; + } while (--r != 0); + } else if (cols > 4) { + int r = rows; + do { + const uint16x8_t v_src_r0 = vld1q_u16(&src[0]); + const uint16x8_t v_src_r1 = vld1q_u16(&src[src_stride]); + const uint16x8_t v_pred_r0 = vld1q_u16(&pred[0]); + const uint16x8_t v_pred_r1 = vld1q_u16(&pred[pred_stride]); + const uint16x8_t v_diff_r0 = vsubq_u16(v_src_r0, v_pred_r0); + const uint16x8_t v_diff_r1 = vsubq_u16(v_src_r1, v_pred_r1); + vst1q_s16(&diff[0], vreinterpretq_s16_u16(v_diff_r0)); + vst1q_s16(&diff[diff_stride], vreinterpretq_s16_u16(v_diff_r1)); + diff += diff_stride << 1; + pred += pred_stride << 1; + src += src_stride << 1; + r -= 2; + } while (r != 0); + } else { + int r = rows; + do { + const uint16x4_t v_src_r0 = vld1_u16(&src[0]); + const uint16x4_t v_src_r1 = vld1_u16(&src[src_stride]); + const uint16x4_t v_pred_r0 = vld1_u16(&pred[0]); + const uint16x4_t v_pred_r1 = vld1_u16(&pred[pred_stride]); + const uint16x4_t v_diff_r0 = vsub_u16(v_src_r0, v_pred_r0); + const uint16x4_t v_diff_r1 = vsub_u16(v_src_r1, v_pred_r1); + vst1_s16(&diff[0], vreinterpret_s16_u16(v_diff_r0)); + vst1_s16(&diff[diff_stride], vreinterpret_s16_u16(v_diff_r1)); + diff += diff_stride << 1; + pred += pred_stride << 1; + src += src_stride << 1; + r -= 2; + } while (r != 0); + } +} +#endif // CONFIG_AV1_HIGHBITDEPTH diff --git a/third_party/aom/aom_dsp/arm/sum_neon.h b/third_party/aom/aom_dsp/arm/sum_neon.h new file mode 100644 index 0000000000..30a108e70a --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sum_neon.h @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2019, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef AOM_AOM_DSP_ARM_SUM_NEON_H_ +#define AOM_AOM_DSP_ARM_SUM_NEON_H_ + +#include "config/aom_dsp_rtcd.h" +#include "config/aom_config.h" + +#include "aom/aom_integer.h" +#include "aom_ports/mem.h" + +static INLINE int horizontal_add_u8x8(const uint8x8_t a) { +#if AOM_ARCH_AARCH64 + return vaddlv_u8(a); +#else + uint16x4_t b = vpaddl_u8(a); + uint32x2_t c = vpaddl_u16(b); + return vget_lane_u32(c, 0) + vget_lane_u32(c, 1); +#endif +} + +static INLINE int horizontal_add_s16x8(const int16x8_t a) { +#if AOM_ARCH_AARCH64 + return vaddlvq_s16(a); +#else + const int32x4_t b = vpaddlq_s16(a); + const int64x2_t c = vpaddlq_s32(b); + const int32x2_t d = vadd_s32(vreinterpret_s32_s64(vget_low_s64(c)), + vreinterpret_s32_s64(vget_high_s64(c))); + return vget_lane_s32(d, 0); +#endif +} + +static INLINE int horizontal_add_s32x4(const int32x4_t a) { +#if AOM_ARCH_AARCH64 + return vaddvq_s32(a); +#else + const int64x2_t b = vpaddlq_s32(a); + const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)), + vreinterpret_s32_s64(vget_high_s64(b))); + return vget_lane_s32(c, 0); +#endif +} + +static INLINE int64_t horizontal_add_s64x2(const int64x2_t a) { +#if AOM_ARCH_AARCH64 + return vaddvq_s64(a); +#else + return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1); +#endif +} + +static INLINE uint64_t horizontal_add_u64x2(const uint64x2_t a) { +#if AOM_ARCH_AARCH64 + return vaddvq_u64(a); +#else + return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1); +#endif +} + +static INLINE uint64_t horizontal_long_add_u32x4(const uint32x4_t a) { +#if AOM_ARCH_AARCH64 + return vaddlvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + return vgetq_lane_u64(b, 0) + vgetq_lane_u64(b, 1); +#endif +} + +static INLINE int64_t horizontal_long_add_s32x4(const int32x4_t a) { +#if AOM_ARCH_AARCH64 + return vaddlvq_s32(a); +#else + const int64x2_t b = vpaddlq_s32(a); + return vgetq_lane_s64(b, 0) + vgetq_lane_s64(b, 1); +#endif +} + +static INLINE uint32_t horizontal_add_u32x4(const uint32x4_t a) { +#if AOM_ARCH_AARCH64 + return vaddvq_u32(a); +#else + const uint64x2_t b = vpaddlq_u32(a); + const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)), + vreinterpret_u32_u64(vget_high_u64(b))); + return vget_lane_u32(c, 0); +#endif +} + +static INLINE uint32x4_t horizontal_add_4d_u32x4(const uint32x4_t sum[4]) { +#if AOM_ARCH_AARCH64 + uint32x4_t res01 = vpaddq_u32(sum[0], sum[1]); + uint32x4_t res23 = vpaddq_u32(sum[2], sum[3]); + return vpaddq_u32(res01, res23); +#else + uint32x4_t res = vdupq_n_u32(0); + res = vsetq_lane_u32(horizontal_add_u32x4(sum[0]), res, 0); + res = vsetq_lane_u32(horizontal_add_u32x4(sum[1]), res, 1); + res = vsetq_lane_u32(horizontal_add_u32x4(sum[2]), res, 2); + res = vsetq_lane_u32(horizontal_add_u32x4(sum[3]), res, 3); + return res; +#endif +} + +static INLINE int32x4_t horizontal_add_4d_s32x4(const int32x4_t sum[4]) { +#if AOM_ARCH_AARCH64 + int32x4_t res01 = vpaddq_s32(sum[0], sum[1]); + int32x4_t res23 = vpaddq_s32(sum[2], sum[3]); + return vpaddq_s32(res01, res23); +#else + int32x4_t res = vdupq_n_s32(0); + res = vsetq_lane_s32(horizontal_add_s32x4(sum[0]), res, 0); + res = vsetq_lane_s32(horizontal_add_s32x4(sum[1]), res, 1); + res = vsetq_lane_s32(horizontal_add_s32x4(sum[2]), res, 2); + res = vsetq_lane_s32(horizontal_add_s32x4(sum[3]), res, 3); + return res; +#endif +} + +static INLINE uint32_t horizontal_long_add_u16x8(const uint16x8_t vec_lo, + const uint16x8_t vec_hi) { +#if AOM_ARCH_AARCH64 + return vaddlvq_u16(vec_lo) + vaddlvq_u16(vec_hi); +#else + const uint32x4_t vec_l_lo = + vaddl_u16(vget_low_u16(vec_lo), vget_high_u16(vec_lo)); + const uint32x4_t vec_l_hi = + vaddl_u16(vget_low_u16(vec_hi), vget_high_u16(vec_hi)); + const uint32x4_t a = vaddq_u32(vec_l_lo, vec_l_hi); + const uint64x2_t b = vpaddlq_u32(a); + const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)), + vreinterpret_u32_u64(vget_high_u64(b))); + return vget_lane_u32(c, 0); +#endif +} + +static INLINE uint32x4_t horizontal_long_add_4d_u16x8( + const uint16x8_t sum_lo[4], const uint16x8_t sum_hi[4]) { + const uint32x4_t a0 = vpaddlq_u16(sum_lo[0]); + const uint32x4_t a1 = vpaddlq_u16(sum_lo[1]); + const uint32x4_t a2 = vpaddlq_u16(sum_lo[2]); + const uint32x4_t a3 = vpaddlq_u16(sum_lo[3]); + const uint32x4_t b0 = vpadalq_u16(a0, sum_hi[0]); + const uint32x4_t b1 = vpadalq_u16(a1, sum_hi[1]); + const uint32x4_t b2 = vpadalq_u16(a2, sum_hi[2]); + const uint32x4_t b3 = vpadalq_u16(a3, sum_hi[3]); +#if AOM_ARCH_AARCH64 + const uint32x4_t c0 = vpaddq_u32(b0, b1); + const uint32x4_t c1 = vpaddq_u32(b2, b3); + return vpaddq_u32(c0, c1); +#else + const uint32x2_t c0 = vadd_u32(vget_low_u32(b0), vget_high_u32(b0)); + const uint32x2_t c1 = vadd_u32(vget_low_u32(b1), vget_high_u32(b1)); + const uint32x2_t c2 = vadd_u32(vget_low_u32(b2), vget_high_u32(b2)); + const uint32x2_t c3 = vadd_u32(vget_low_u32(b3), vget_high_u32(b3)); + const uint32x2_t d0 = vpadd_u32(c0, c1); + const uint32x2_t d1 = vpadd_u32(c2, c3); + return vcombine_u32(d0, d1); +#endif +} + +static INLINE uint32_t horizontal_add_u16x8(const uint16x8_t a) { +#if AOM_ARCH_AARCH64 + return vaddlvq_u16(a); +#else + const uint32x4_t b = vpaddlq_u16(a); + const uint64x2_t c = vpaddlq_u32(b); + const uint32x2_t d = vadd_u32(vreinterpret_u32_u64(vget_low_u64(c)), + vreinterpret_u32_u64(vget_high_u64(c))); + return vget_lane_u32(d, 0); +#endif +} + +static INLINE uint32x4_t horizontal_add_4d_u16x8(const uint16x8_t sum[4]) { +#if AOM_ARCH_AARCH64 + const uint16x8_t a0 = vpaddq_u16(sum[0], sum[1]); + const uint16x8_t a1 = vpaddq_u16(sum[2], sum[3]); + const uint16x8_t b0 = vpaddq_u16(a0, a1); + return vpaddlq_u16(b0); +#else + const uint16x4_t a0 = vadd_u16(vget_low_u16(sum[0]), vget_high_u16(sum[0])); + const uint16x4_t a1 = vadd_u16(vget_low_u16(sum[1]), vget_high_u16(sum[1])); + const uint16x4_t a2 = vadd_u16(vget_low_u16(sum[2]), vget_high_u16(sum[2])); + const uint16x4_t a3 = vadd_u16(vget_low_u16(sum[3]), vget_high_u16(sum[3])); + const uint16x4_t b0 = vpadd_u16(a0, a1); + const uint16x4_t b1 = vpadd_u16(a2, a3); + return vpaddlq_u16(vcombine_u16(b0, b1)); +#endif +} + +static INLINE int32x4_t horizontal_add_4d_s16x8(const int16x8_t sum[4]) { +#if AOM_ARCH_AARCH64 + const int16x8_t a0 = vpaddq_s16(sum[0], sum[1]); + const int16x8_t a1 = vpaddq_s16(sum[2], sum[3]); + const int16x8_t b0 = vpaddq_s16(a0, a1); + return vpaddlq_s16(b0); +#else + const int16x4_t a0 = vadd_s16(vget_low_s16(sum[0]), vget_high_s16(sum[0])); + const int16x4_t a1 = vadd_s16(vget_low_s16(sum[1]), vget_high_s16(sum[1])); + const int16x4_t a2 = vadd_s16(vget_low_s16(sum[2]), vget_high_s16(sum[2])); + const int16x4_t a3 = vadd_s16(vget_low_s16(sum[3]), vget_high_s16(sum[3])); + const int16x4_t b0 = vpadd_s16(a0, a1); + const int16x4_t b1 = vpadd_s16(a2, a3); + return vpaddlq_s16(vcombine_s16(b0, b1)); +#endif +} + +static INLINE uint32_t horizontal_add_u32x2(const uint32x2_t a) { +#if AOM_ARCH_AARCH64 + return vaddv_u32(a); +#else + const uint64x1_t b = vpaddl_u32(a); + return vget_lane_u32(vreinterpret_u32_u64(b), 0); +#endif +} + +static INLINE uint64_t horizontal_long_add_u32x2(const uint32x2_t a) { +#if AOM_ARCH_AARCH64 + return vaddlv_u32(a); +#else + const uint64x1_t b = vpaddl_u32(a); + return vget_lane_u64(b, 0); +#endif +} + +static INLINE uint32_t horizontal_add_u16x4(const uint16x4_t a) { +#if AOM_ARCH_AARCH64 + return vaddlv_u16(a); +#else + const uint32x2_t b = vpaddl_u16(a); + const uint64x1_t c = vpaddl_u32(b); + return vget_lane_u32(vreinterpret_u32_u64(c), 0); +#endif +} + +static INLINE int32x4_t horizontal_add_2d_s32(int32x4_t a, int32x4_t b) { +#if AOM_ARCH_AARCH64 + return vpaddq_s32(a, b); +#else + const int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + const int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +#endif +} + +static INLINE int32x2_t add_pairwise_s32x4(int32x4_t a) { +#if AOM_ARCH_AARCH64 + return vget_low_s32(vpaddq_s32(a, a)); +#else + return vpadd_s32(vget_low_s32(a), vget_high_s32(a)); +#endif +} + +static INLINE uint64_t horizontal_long_add_u32x4_x2(const uint32x4_t a[2]) { + return horizontal_long_add_u32x4(a[0]) + horizontal_long_add_u32x4(a[1]); +} + +static INLINE uint64_t horizontal_long_add_u32x4_x4(const uint32x4_t a[4]) { + uint64x2_t sum = vpaddlq_u32(a[0]); + sum = vpadalq_u32(sum, a[1]); + sum = vpadalq_u32(sum, a[2]); + sum = vpadalq_u32(sum, a[3]); + + return horizontal_add_u64x2(sum); +} + +static INLINE uint64_t horizontal_long_add_u32x4_x8(const uint32x4_t a[8]) { + uint64x2_t sum[2]; + sum[0] = vpaddlq_u32(a[0]); + sum[1] = vpaddlq_u32(a[1]); + sum[0] = vpadalq_u32(sum[0], a[2]); + sum[1] = vpadalq_u32(sum[1], a[3]); + sum[0] = vpadalq_u32(sum[0], a[4]); + sum[1] = vpadalq_u32(sum[1], a[5]); + sum[0] = vpadalq_u32(sum[0], a[6]); + sum[1] = vpadalq_u32(sum[1], a[7]); + + return horizontal_add_u64x2(vaddq_u64(sum[0], sum[1])); +} + +static INLINE uint64_t horizontal_long_add_u32x4_x16(const uint32x4_t a[16]) { + uint64x2_t sum[2]; + sum[0] = vpaddlq_u32(a[0]); + sum[1] = vpaddlq_u32(a[1]); + sum[0] = vpadalq_u32(sum[0], a[2]); + sum[1] = vpadalq_u32(sum[1], a[3]); + sum[0] = vpadalq_u32(sum[0], a[4]); + sum[1] = vpadalq_u32(sum[1], a[5]); + sum[0] = vpadalq_u32(sum[0], a[6]); + sum[1] = vpadalq_u32(sum[1], a[7]); + sum[0] = vpadalq_u32(sum[0], a[8]); + sum[1] = vpadalq_u32(sum[1], a[9]); + sum[0] = vpadalq_u32(sum[0], a[10]); + sum[1] = vpadalq_u32(sum[1], a[11]); + sum[0] = vpadalq_u32(sum[0], a[12]); + sum[1] = vpadalq_u32(sum[1], a[13]); + sum[0] = vpadalq_u32(sum[0], a[14]); + sum[1] = vpadalq_u32(sum[1], a[15]); + + return horizontal_add_u64x2(vaddq_u64(sum[0], sum[1])); +} + +#endif // AOM_AOM_DSP_ARM_SUM_NEON_H_ diff --git a/third_party/aom/aom_dsp/arm/sum_squares_neon.c b/third_party/aom/aom_dsp/arm/sum_squares_neon.c new file mode 100644 index 0000000000..424b2b4445 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sum_squares_neon.c @@ -0,0 +1,574 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src, + int stride) { + int16x4_t s0 = vld1_s16(src + 0 * stride); + int16x4_t s1 = vld1_s16(src + 1 * stride); + int16x4_t s2 = vld1_s16(src + 2 * stride); + int16x4_t s3 = vld1_s16(src + 3 * stride); + + int32x4_t sum_squares = vmull_s16(s0, s0); + sum_squares = vmlal_s16(sum_squares, s1, s1); + sum_squares = vmlal_s16(sum_squares, s2, s2); + sum_squares = vmlal_s16(sum_squares, s3, s3); + + return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sum_squares)); +} + +static INLINE uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src, + int stride, int height) { + int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + int h = height; + do { + int16x4_t s0 = vld1_s16(src + 0 * stride); + int16x4_t s1 = vld1_s16(src + 1 * stride); + int16x4_t s2 = vld1_s16(src + 2 * stride); + int16x4_t s3 = vld1_s16(src + 3 * stride); + + sum_squares[0] = vmlal_s16(sum_squares[0], s0, s0); + sum_squares[0] = vmlal_s16(sum_squares[0], s1, s1); + sum_squares[1] = vmlal_s16(sum_squares[1], s2, s2); + sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3); + + src += 4 * stride; + h -= 4; + } while (h != 0); + + return horizontal_long_add_u32x4( + vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1]))); +} + +static INLINE uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src, + int stride, int width, + int height) { + uint64x2_t sum_squares = vdupq_n_u64(0); + + int h = height; + do { + int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int w = 0; + do { + const int16_t *s = src + w; + int16x8_t s0 = vld1q_s16(s + 0 * stride); + int16x8_t s1 = vld1q_s16(s + 1 * stride); + int16x8_t s2 = vld1q_s16(s + 2 * stride); + int16x8_t s3 = vld1q_s16(s + 3 * stride); + + ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s0), vget_low_s16(s0)); + ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s1), vget_low_s16(s1)); + ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s2), vget_low_s16(s2)); + ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s3), vget_low_s16(s3)); + ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s0), vget_high_s16(s0)); + ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s1), vget_high_s16(s1)); + ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s2), vget_high_s16(s2)); + ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s3), vget_high_s16(s3)); + w += 8; + } while (w < width); + + sum_squares = vpadalq_u32( + sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1]))); + + src += 4 * stride; + h -= 4; + } while (h != 0); + + return horizontal_add_u64x2(sum_squares); +} + +uint64_t aom_sum_squares_2d_i16_neon(const int16_t *src, int stride, int width, + int height) { + // 4 elements per row only requires half an SIMD register, so this + // must be a special case, but also note that over 75% of all calls + // are with size == 4, so it is also the common case. + if (LIKELY(width == 4 && height == 4)) { + return aom_sum_squares_2d_i16_4x4_neon(src, stride); + } else if (LIKELY(width == 4 && (height & 3) == 0)) { + return aom_sum_squares_2d_i16_4xn_neon(src, stride, height); + } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) { + // Generic case + return aom_sum_squares_2d_i16_nxn_neon(src, stride, width, height); + } else { + return aom_sum_squares_2d_i16_c(src, stride, width, height); + } +} + +static INLINE uint64_t aom_sum_sse_2d_i16_4x4_neon(const int16_t *src, + int stride, int *sum) { + int16x4_t s0 = vld1_s16(src + 0 * stride); + int16x4_t s1 = vld1_s16(src + 1 * stride); + int16x4_t s2 = vld1_s16(src + 2 * stride); + int16x4_t s3 = vld1_s16(src + 3 * stride); + + int32x4_t sse = vmull_s16(s0, s0); + sse = vmlal_s16(sse, s1, s1); + sse = vmlal_s16(sse, s2, s2); + sse = vmlal_s16(sse, s3, s3); + + int32x4_t sum_01 = vaddl_s16(s0, s1); + int32x4_t sum_23 = vaddl_s16(s2, s3); + *sum += horizontal_add_s32x4(vaddq_s32(sum_01, sum_23)); + + return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sse)); +} + +static INLINE uint64_t aom_sum_sse_2d_i16_4xn_neon(const int16_t *src, + int stride, int height, + int *sum) { + int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x2_t sum_acc[2] = { vdup_n_s32(0), vdup_n_s32(0) }; + + int h = height; + do { + int16x4_t s0 = vld1_s16(src + 0 * stride); + int16x4_t s1 = vld1_s16(src + 1 * stride); + int16x4_t s2 = vld1_s16(src + 2 * stride); + int16x4_t s3 = vld1_s16(src + 3 * stride); + + sse[0] = vmlal_s16(sse[0], s0, s0); + sse[0] = vmlal_s16(sse[0], s1, s1); + sse[1] = vmlal_s16(sse[1], s2, s2); + sse[1] = vmlal_s16(sse[1], s3, s3); + + sum_acc[0] = vpadal_s16(sum_acc[0], s0); + sum_acc[0] = vpadal_s16(sum_acc[0], s1); + sum_acc[1] = vpadal_s16(sum_acc[1], s2); + sum_acc[1] = vpadal_s16(sum_acc[1], s3); + + src += 4 * stride; + h -= 4; + } while (h != 0); + + *sum += horizontal_add_s32x4(vcombine_s32(sum_acc[0], sum_acc[1])); + return horizontal_long_add_u32x4( + vreinterpretq_u32_s32(vaddq_s32(sse[0], sse[1]))); +} + +static INLINE uint64_t aom_sum_sse_2d_i16_nxn_neon(const int16_t *src, + int stride, int width, + int height, int *sum) { + uint64x2_t sse = vdupq_n_u64(0); + int32x4_t sum_acc = vdupq_n_s32(0); + + int h = height; + do { + int32x4_t sse_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int w = 0; + do { + const int16_t *s = src + w; + int16x8_t s0 = vld1q_s16(s + 0 * stride); + int16x8_t s1 = vld1q_s16(s + 1 * stride); + int16x8_t s2 = vld1q_s16(s + 2 * stride); + int16x8_t s3 = vld1q_s16(s + 3 * stride); + + sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s0), vget_low_s16(s0)); + sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s1), vget_low_s16(s1)); + sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s2), vget_low_s16(s2)); + sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s3), vget_low_s16(s3)); + sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s0), vget_high_s16(s0)); + sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s1), vget_high_s16(s1)); + sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s2), vget_high_s16(s2)); + sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s3), vget_high_s16(s3)); + + sum_acc = vpadalq_s16(sum_acc, s0); + sum_acc = vpadalq_s16(sum_acc, s1); + sum_acc = vpadalq_s16(sum_acc, s2); + sum_acc = vpadalq_s16(sum_acc, s3); + + w += 8; + } while (w < width); + + sse = vpadalq_u32(sse, + vreinterpretq_u32_s32(vaddq_s32(sse_row[0], sse_row[1]))); + + src += 4 * stride; + h -= 4; + } while (h != 0); + + *sum += horizontal_add_s32x4(sum_acc); + return horizontal_add_u64x2(sse); +} + +uint64_t aom_sum_sse_2d_i16_neon(const int16_t *src, int stride, int width, + int height, int *sum) { + uint64_t sse; + + if (LIKELY(width == 4 && height == 4)) { + sse = aom_sum_sse_2d_i16_4x4_neon(src, stride, sum); + } else if (LIKELY(width == 4 && (height & 3) == 0)) { + // width = 4, height is a multiple of 4. + sse = aom_sum_sse_2d_i16_4xn_neon(src, stride, height, sum); + } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) { + // Generic case - width is multiple of 8, height is multiple of 4. + sse = aom_sum_sse_2d_i16_nxn_neon(src, stride, width, height, sum); + } else { + sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum); + } + + return sse; +} + +static INLINE uint64_t aom_sum_squares_i16_4xn_neon(const int16_t *src, + uint32_t n) { + uint64x2_t sum_u64 = vdupq_n_u64(0); + + int i = n; + do { + uint32x4_t sum; + int16x4_t s0 = vld1_s16(src); + + sum = vreinterpretq_u32_s32(vmull_s16(s0, s0)); + + sum_u64 = vpadalq_u32(sum_u64, sum); + + src += 4; + i -= 4; + } while (i >= 4); + + if (i > 0) { + return horizontal_add_u64x2(sum_u64) + aom_sum_squares_i16_c(src, i); + } + return horizontal_add_u64x2(sum_u64); +} + +static INLINE uint64_t aom_sum_squares_i16_8xn_neon(const int16_t *src, + uint32_t n) { + uint64x2_t sum_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; + + int i = n; + do { + uint32x4_t sum[2]; + int16x8_t s0 = vld1q_s16(src); + + sum[0] = + vreinterpretq_u32_s32(vmull_s16(vget_low_s16(s0), vget_low_s16(s0))); + sum[1] = + vreinterpretq_u32_s32(vmull_s16(vget_high_s16(s0), vget_high_s16(s0))); + + sum_u64[0] = vpadalq_u32(sum_u64[0], sum[0]); + sum_u64[1] = vpadalq_u32(sum_u64[1], sum[1]); + + src += 8; + i -= 8; + } while (i >= 8); + + if (i > 0) { + return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1])) + + aom_sum_squares_i16_c(src, i); + } + return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1])); +} + +uint64_t aom_sum_squares_i16_neon(const int16_t *src, uint32_t n) { + // This function seems to be called only for values of N >= 64. See + // av1/encoder/compound_type.c. + if (LIKELY(n >= 8)) { + return aom_sum_squares_i16_8xn_neon(src, n); + } + if (n >= 4) { + return aom_sum_squares_i16_4xn_neon(src, n); + } + return aom_sum_squares_i16_c(src, n); +} + +static INLINE uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride, + int width, int height) { + uint64_t sum = 0; + uint64_t sse = 0; + uint32x2_t sum_u32 = vdup_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit + // element before we need to accumulate to 32-bit elements. Since we're + // accumulating in uint16x4_t vectors, this means we can accumulate up to 4 + // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4 + // * 256) / width. + int h_limit = (4 * 256) / width; + int h_tmp = height > h_limit ? h_limit : height; + + int h = 0; + do { + uint16x4_t sum_u16 = vdup_n_u16(0); + do { + uint8_t *src_ptr = src; + int w = width; + do { + uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride); + + sum_u16 = vpadal_u8(sum_u16, s0); + + uint16x8_t sse_u16 = vmull_u8(s0, s0); + + sse_u32 = vpadalq_u16(sse_u32, sse_u16); + + src_ptr += 8; + w -= 8; + } while (w >= 8); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint8_t v = src[idx]; + sum += v; + sse += v * v; + w--; + } + + src += 2 * src_stride; + h += 2; + } while (h < h_tmp && h < height); + + sum_u32 = vpadal_u16(sum_u32, sum_u16); + h_tmp += h_limit; + } while (h < height); + + sum += horizontal_long_add_u32x2(sum_u32); + sse += horizontal_long_add_u32x4(sse_u32); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride, + int width, int height) { + uint64_t sum = 0; + uint64_t sse = 0; + uint32x2_t sum_u32 = vdup_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit + // element before we need to accumulate to 32-bit elements. Since we're + // accumulating in uint16x4_t vectors, this means we can accumulate up to 4 + // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4 + // * 256) / width. + int h_limit = (4 * 256) / width; + int h_tmp = height > h_limit ? h_limit : height; + + int h = 0; + do { + uint16x4_t sum_u16 = vdup_n_u16(0); + do { + uint8_t *src_ptr = src; + int w = width; + do { + uint8x8_t s0 = vld1_u8(src_ptr); + + sum_u16 = vpadal_u8(sum_u16, s0); + + uint16x8_t sse_u16 = vmull_u8(s0, s0); + + sse_u32 = vpadalq_u16(sse_u32, sse_u16); + + src_ptr += 8; + w -= 8; + } while (w >= 8); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint8_t v = src[idx]; + sum += v; + sse += v * v; + w--; + } + + src += src_stride; + ++h; + } while (h < h_tmp && h < height); + + sum_u32 = vpadal_u16(sum_u32, sum_u16); + h_tmp += h_limit; + } while (h < height); + + sum += horizontal_long_add_u32x2(sum_u32); + sse += horizontal_long_add_u32x4(sse_u32); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride, + int width, int height) { + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32 = vdupq_n_u32(0); + uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit + // element before we need to accumulate to 32-bit elements. Since we're + // accumulating in uint16x8_t vectors, this means we can accumulate up to 8 + // rows of 256 elements. Therefore the limit can be computed as: h_limit = (8 + // * 256) / width. + int h_limit = (8 * 256) / width; + int h_tmp = height > h_limit ? h_limit : height; + + int h = 0; + do { + uint16x8_t sum_u16 = vdupq_n_u16(0); + do { + int w = width; + uint8_t *src_ptr = src; + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + + sum_u16 = vpadalq_u8(sum_u16, s0); + + uint16x8_t sse_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(s0)); + uint16x8_t sse_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(s0)); + + sse_u32[0] = vpadalq_u16(sse_u32[0], sse_u16_lo); + sse_u32[1] = vpadalq_u16(sse_u32[1], sse_u16_hi); + + src_ptr += 16; + w -= 16; + } while (w >= 16); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint8_t v = src[idx]; + sum += v; + sse += v * v; + w--; + } + + src += src_stride; + ++h; + } while (h < h_tmp && h < height); + + sum_u32 = vpadalq_u16(sum_u32, sum_u16); + h_tmp += h_limit; + } while (h < height); + + sum += horizontal_long_add_u32x4(sum_u32); + sse += horizontal_long_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); + + return sse - sum * sum / (width * height); +} + +uint64_t aom_var_2d_u8_neon(uint8_t *src, int src_stride, int width, + int height) { + if (width >= 16) { + return aom_var_2d_u8_16xh_neon(src, src_stride, width, height); + } + if (width >= 8) { + return aom_var_2d_u8_8xh_neon(src, src_stride, width, height); + } + if (width >= 4 && height % 2 == 0) { + return aom_var_2d_u8_4xh_neon(src, src_stride, width, height); + } + return aom_var_2d_u8_c(src, src_stride, width, height); +} + +static INLINE uint64_t aom_var_2d_u16_4xh_neon(uint8_t *src, int src_stride, + int width, int height) { + uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); + uint64_t sum = 0; + uint64_t sse = 0; + uint32x2_t sum_u32 = vdup_n_u32(0); + uint64x2_t sse_u64 = vdupq_n_u64(0); + + int h = height; + do { + int w = width; + uint16_t *src_ptr = src_u16; + do { + uint16x4_t s0 = vld1_u16(src_ptr); + + sum_u32 = vpadal_u16(sum_u32, s0); + + uint32x4_t sse_u32 = vmull_u16(s0, s0); + + sse_u64 = vpadalq_u32(sse_u64, sse_u32); + + src_ptr += 4; + w -= 4; + } while (w >= 4); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint16_t v = src_u16[idx]; + sum += v; + sse += v * v; + w--; + } + + src_u16 += src_stride; + } while (--h != 0); + + sum += horizontal_long_add_u32x2(sum_u32); + sse += horizontal_add_u64x2(sse_u64); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u16_8xh_neon(uint8_t *src, int src_stride, + int width, int height) { + uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32 = vdupq_n_u32(0); + uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; + + int h = height; + do { + int w = width; + uint16_t *src_ptr = src_u16; + do { + uint16x8_t s0 = vld1q_u16(src_ptr); + + sum_u32 = vpadalq_u16(sum_u32, s0); + + uint32x4_t sse_u32_lo = vmull_u16(vget_low_u16(s0), vget_low_u16(s0)); + uint32x4_t sse_u32_hi = vmull_u16(vget_high_u16(s0), vget_high_u16(s0)); + + sse_u64[0] = vpadalq_u32(sse_u64[0], sse_u32_lo); + sse_u64[1] = vpadalq_u32(sse_u64[1], sse_u32_hi); + + src_ptr += 8; + w -= 8; + } while (w >= 8); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint16_t v = src_u16[idx]; + sum += v; + sse += v * v; + w--; + } + + src_u16 += src_stride; + } while (--h != 0); + + sum += horizontal_long_add_u32x4(sum_u32); + sse += horizontal_add_u64x2(vaddq_u64(sse_u64[0], sse_u64[1])); + + return sse - sum * sum / (width * height); +} + +uint64_t aom_var_2d_u16_neon(uint8_t *src, int src_stride, int width, + int height) { + if (width >= 8) { + return aom_var_2d_u16_8xh_neon(src, src_stride, width, height); + } + if (width >= 4) { + return aom_var_2d_u16_4xh_neon(src, src_stride, width, height); + } + return aom_var_2d_u16_c(src, src_stride, width, height); +} diff --git a/third_party/aom/aom_dsp/arm/sum_squares_neon_dotprod.c b/third_party/aom/aom_dsp/arm/sum_squares_neon_dotprod.c new file mode 100644 index 0000000000..44462a693c --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sum_squares_neon_dotprod.c @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include + +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE uint64_t aom_var_2d_u8_4xh_neon_dotprod(uint8_t *src, + int src_stride, int width, + int height) { + uint64_t sum = 0; + uint64_t sse = 0; + uint32x2_t sum_u32 = vdup_n_u32(0); + uint32x2_t sse_u32 = vdup_n_u32(0); + + int h = height / 2; + do { + int w = width; + uint8_t *src_ptr = src; + do { + uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride); + + sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1)); + + sse_u32 = vdot_u32(sse_u32, s0, s0); + + src_ptr += 8; + w -= 8; + } while (w >= 8); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint8_t v = src[idx]; + sum += v; + sse += v * v; + w--; + } + + src += 2 * src_stride; + } while (--h != 0); + + sum += horizontal_long_add_u32x2(sum_u32); + sse += horizontal_long_add_u32x2(sse_u32); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u8_8xh_neon_dotprod(uint8_t *src, + int src_stride, int width, + int height) { + uint64_t sum = 0; + uint64_t sse = 0; + uint32x2_t sum_u32 = vdup_n_u32(0); + uint32x2_t sse_u32 = vdup_n_u32(0); + + int h = height; + do { + int w = width; + uint8_t *src_ptr = src; + do { + uint8x8_t s0 = vld1_u8(src_ptr); + + sum_u32 = vdot_u32(sum_u32, s0, vdup_n_u8(1)); + + sse_u32 = vdot_u32(sse_u32, s0, s0); + + src_ptr += 8; + w -= 8; + } while (w >= 8); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint8_t v = src[idx]; + sum += v; + sse += v * v; + w--; + } + + src += src_stride; + } while (--h != 0); + + sum += horizontal_long_add_u32x2(sum_u32); + sse += horizontal_long_add_u32x2(sse_u32); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u8_16xh_neon_dotprod(uint8_t *src, + int src_stride, + int width, int height) { + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32 = vdupq_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int h = height; + do { + int w = width; + uint8_t *src_ptr = src; + do { + uint8x16_t s0 = vld1q_u8(src_ptr); + + sum_u32 = vdotq_u32(sum_u32, s0, vdupq_n_u8(1)); + + sse_u32 = vdotq_u32(sse_u32, s0, s0); + + src_ptr += 16; + w -= 16; + } while (w >= 16); + + // Process remaining columns in the row using C. + while (w > 0) { + int idx = width - w; + const uint8_t v = src[idx]; + sum += v; + sse += v * v; + w--; + } + + src += src_stride; + } while (--h != 0); + + sum += horizontal_long_add_u32x4(sum_u32); + sse += horizontal_long_add_u32x4(sse_u32); + + return sse - sum * sum / (width * height); +} + +uint64_t aom_var_2d_u8_neon_dotprod(uint8_t *src, int src_stride, int width, + int height) { + if (width >= 16) { + return aom_var_2d_u8_16xh_neon_dotprod(src, src_stride, width, height); + } + if (width >= 8) { + return aom_var_2d_u8_8xh_neon_dotprod(src, src_stride, width, height); + } + if (width >= 4 && height % 2 == 0) { + return aom_var_2d_u8_4xh_neon_dotprod(src, src_stride, width, height); + } + return aom_var_2d_u8_c(src, src_stride, width, height); +} diff --git a/third_party/aom/aom_dsp/arm/sum_squares_sve.c b/third_party/aom/aom_dsp/arm/sum_squares_sve.c new file mode 100644 index 0000000000..724e43859e --- /dev/null +++ b/third_party/aom/aom_dsp/arm/sum_squares_sve.c @@ -0,0 +1,402 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "aom_dsp/arm/dot_sve.h" +#include "aom_dsp/arm/mem_neon.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE uint64_t aom_sum_squares_2d_i16_4xh_sve(const int16_t *src, + int stride, int height) { + int64x2_t sum_squares = vdupq_n_s64(0); + + do { + int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride)); + + sum_squares = aom_sdotq_s16(sum_squares, s, s); + + src += 2 * stride; + height -= 2; + } while (height != 0); + + return (uint64_t)vaddvq_s64(sum_squares); +} + +static INLINE uint64_t aom_sum_squares_2d_i16_8xh_sve(const int16_t *src, + int stride, int height) { + int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + do { + int16x8_t s0 = vld1q_s16(src + 0 * stride); + int16x8_t s1 = vld1q_s16(src + 1 * stride); + + sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0); + sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1); + + src += 2 * stride; + height -= 2; + } while (height != 0); + + sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]); + return (uint64_t)vaddvq_s64(sum_squares[0]); +} + +static INLINE uint64_t aom_sum_squares_2d_i16_large_sve(const int16_t *src, + int stride, int width, + int height) { + int64x2_t sum_squares[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + do { + const int16_t *src_ptr = src; + int w = width; + do { + int16x8_t s0 = vld1q_s16(src_ptr); + int16x8_t s1 = vld1q_s16(src_ptr + 8); + + sum_squares[0] = aom_sdotq_s16(sum_squares[0], s0, s0); + sum_squares[1] = aom_sdotq_s16(sum_squares[1], s1, s1); + + src_ptr += 16; + w -= 16; + } while (w != 0); + + src += stride; + } while (--height != 0); + + sum_squares[0] = vaddq_s64(sum_squares[0], sum_squares[1]); + return (uint64_t)vaddvq_s64(sum_squares[0]); +} + +static INLINE uint64_t aom_sum_squares_2d_i16_wxh_sve(const int16_t *src, + int stride, int width, + int height) { + svint64_t sum_squares = svdup_n_s64(0); + uint64_t step = svcnth(); + + do { + const int16_t *src_ptr = src; + int w = 0; + do { + svbool_t pred = svwhilelt_b16_u32(w, width); + svint16_t s0 = svld1_s16(pred, src_ptr); + + sum_squares = svdot_s64(sum_squares, s0, s0); + + src_ptr += step; + w += step; + } while (w < width); + + src += stride; + } while (--height != 0); + + return (uint64_t)svaddv_s64(svptrue_b64(), sum_squares); +} + +uint64_t aom_sum_squares_2d_i16_sve(const int16_t *src, int stride, int width, + int height) { + if (width == 4) { + return aom_sum_squares_2d_i16_4xh_sve(src, stride, height); + } + if (width == 8) { + return aom_sum_squares_2d_i16_8xh_sve(src, stride, height); + } + if (width % 16 == 0) { + return aom_sum_squares_2d_i16_large_sve(src, stride, width, height); + } + return aom_sum_squares_2d_i16_wxh_sve(src, stride, width, height); +} + +uint64_t aom_sum_squares_i16_sve(const int16_t *src, uint32_t n) { + // This function seems to be called only for values of N >= 64. See + // av1/encoder/compound_type.c. Additionally, because N = width x height for + // width and height between the standard block sizes, N will also be a + // multiple of 64. + if (LIKELY(n % 64 == 0)) { + int64x2_t sum[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0), + vdupq_n_s64(0) }; + + do { + int16x8_t s0 = vld1q_s16(src); + int16x8_t s1 = vld1q_s16(src + 8); + int16x8_t s2 = vld1q_s16(src + 16); + int16x8_t s3 = vld1q_s16(src + 24); + + sum[0] = aom_sdotq_s16(sum[0], s0, s0); + sum[1] = aom_sdotq_s16(sum[1], s1, s1); + sum[2] = aom_sdotq_s16(sum[2], s2, s2); + sum[3] = aom_sdotq_s16(sum[3], s3, s3); + + src += 32; + n -= 32; + } while (n != 0); + + sum[0] = vaddq_s64(sum[0], sum[1]); + sum[2] = vaddq_s64(sum[2], sum[3]); + sum[0] = vaddq_s64(sum[0], sum[2]); + return vaddvq_s64(sum[0]); + } + return aom_sum_squares_i16_c(src, n); +} + +static INLINE uint64_t aom_sum_sse_2d_i16_4xh_sve(const int16_t *src, + int stride, int height, + int *sum) { + int64x2_t sse = vdupq_n_s64(0); + int32x4_t sum_s32 = vdupq_n_s32(0); + + do { + int16x8_t s = vcombine_s16(vld1_s16(src), vld1_s16(src + stride)); + + sse = aom_sdotq_s16(sse, s, s); + + sum_s32 = vpadalq_s16(sum_s32, s); + + src += 2 * stride; + height -= 2; + } while (height != 0); + + *sum += vaddvq_s32(sum_s32); + return vaddvq_s64(sse); +} + +static INLINE uint64_t aom_sum_sse_2d_i16_8xh_sve(const int16_t *src, + int stride, int height, + int *sum) { + int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + do { + int16x8_t s0 = vld1q_s16(src); + int16x8_t s1 = vld1q_s16(src + stride); + + sse[0] = aom_sdotq_s16(sse[0], s0, s0); + sse[1] = aom_sdotq_s16(sse[1], s1, s1); + + sum_acc[0] = vpadalq_s16(sum_acc[0], s0); + sum_acc[1] = vpadalq_s16(sum_acc[1], s1); + + src += 2 * stride; + height -= 2; + } while (height != 0); + + *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1])); + return vaddvq_s64(vaddq_s64(sse[0], sse[1])); +} + +static INLINE uint64_t aom_sum_sse_2d_i16_16xh_sve(const int16_t *src, + int stride, int width, + int height, int *sum) { + int64x2_t sse[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + int32x4_t sum_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + do { + int w = 0; + do { + int16x8_t s0 = vld1q_s16(src + w); + int16x8_t s1 = vld1q_s16(src + w + 8); + + sse[0] = aom_sdotq_s16(sse[0], s0, s0); + sse[1] = aom_sdotq_s16(sse[1], s1, s1); + + sum_acc[0] = vpadalq_s16(sum_acc[0], s0); + sum_acc[1] = vpadalq_s16(sum_acc[1], s1); + + w += 16; + } while (w < width); + + src += stride; + } while (--height != 0); + + *sum += vaddvq_s32(vaddq_s32(sum_acc[0], sum_acc[1])); + return vaddvq_s64(vaddq_s64(sse[0], sse[1])); +} + +uint64_t aom_sum_sse_2d_i16_sve(const int16_t *src, int stride, int width, + int height, int *sum) { + uint64_t sse; + + if (width == 4) { + sse = aom_sum_sse_2d_i16_4xh_sve(src, stride, height, sum); + } else if (width == 8) { + sse = aom_sum_sse_2d_i16_8xh_sve(src, stride, height, sum); + } else if (width % 16 == 0) { + sse = aom_sum_sse_2d_i16_16xh_sve(src, stride, width, height, sum); + } else { + sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum); + } + + return sse; +} + +static INLINE uint64_t aom_var_2d_u16_4xh_sve(uint8_t *src, int src_stride, + int width, int height) { + uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32 = vdupq_n_u32(0); + uint64x2_t sse_u64 = vdupq_n_u64(0); + + int h = height; + do { + uint16x8_t s0 = + vcombine_u16(vld1_u16(src_u16), vld1_u16(src_u16 + src_stride)); + + sum_u32 = vpadalq_u16(sum_u32, s0); + + sse_u64 = aom_udotq_u16(sse_u64, s0, s0); + + src_u16 += 2 * src_stride; + h -= 2; + } while (h != 0); + + sum += vaddlvq_u32(sum_u32); + sse += vaddvq_u64(sse_u64); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u16_8xh_sve(uint8_t *src, int src_stride, + int width, int height) { + uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32 = vdupq_n_u32(0); + uint64x2_t sse_u64 = vdupq_n_u64(0); + + int h = height; + do { + int w = width; + uint16_t *src_ptr = src_u16; + do { + uint16x8_t s0 = vld1q_u16(src_ptr); + + sum_u32 = vpadalq_u16(sum_u32, s0); + + sse_u64 = aom_udotq_u16(sse_u64, s0, s0); + + src_ptr += 8; + w -= 8; + } while (w != 0); + + src_u16 += src_stride; + } while (--h != 0); + + sum += vaddlvq_u32(sum_u32); + sse += vaddvq_u64(sse_u64); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u16_16xh_sve(uint8_t *src, int src_stride, + int width, int height) { + uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; + + int h = height; + do { + int w = width; + uint16_t *src_ptr = src_u16; + do { + uint16x8_t s0 = vld1q_u16(src_ptr); + uint16x8_t s1 = vld1q_u16(src_ptr + 8); + + sum_u32[0] = vpadalq_u16(sum_u32[0], s0); + sum_u32[1] = vpadalq_u16(sum_u32[1], s1); + + sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0); + sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1); + + src_ptr += 16; + w -= 16; + } while (w != 0); + + src_u16 += src_stride; + } while (--h != 0); + + sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]); + sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]); + + sum += vaddlvq_u32(sum_u32[0]); + sse += vaddvq_u64(sse_u64[0]); + + return sse - sum * sum / (width * height); +} + +static INLINE uint64_t aom_var_2d_u16_large_sve(uint8_t *src, int src_stride, + int width, int height) { + uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src); + uint64_t sum = 0; + uint64_t sse = 0; + uint32x4_t sum_u32[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0), + vdupq_n_u32(0) }; + uint64x2_t sse_u64[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0), + vdupq_n_u64(0) }; + + int h = height; + do { + int w = width; + uint16_t *src_ptr = src_u16; + do { + uint16x8_t s0 = vld1q_u16(src_ptr); + uint16x8_t s1 = vld1q_u16(src_ptr + 8); + uint16x8_t s2 = vld1q_u16(src_ptr + 16); + uint16x8_t s3 = vld1q_u16(src_ptr + 24); + + sum_u32[0] = vpadalq_u16(sum_u32[0], s0); + sum_u32[1] = vpadalq_u16(sum_u32[1], s1); + sum_u32[2] = vpadalq_u16(sum_u32[2], s2); + sum_u32[3] = vpadalq_u16(sum_u32[3], s3); + + sse_u64[0] = aom_udotq_u16(sse_u64[0], s0, s0); + sse_u64[1] = aom_udotq_u16(sse_u64[1], s1, s1); + sse_u64[2] = aom_udotq_u16(sse_u64[2], s2, s2); + sse_u64[3] = aom_udotq_u16(sse_u64[3], s3, s3); + + src_ptr += 32; + w -= 32; + } while (w != 0); + + src_u16 += src_stride; + } while (--h != 0); + + sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[1]); + sum_u32[2] = vaddq_u32(sum_u32[2], sum_u32[3]); + sum_u32[0] = vaddq_u32(sum_u32[0], sum_u32[2]); + sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[1]); + sse_u64[2] = vaddq_u64(sse_u64[2], sse_u64[3]); + sse_u64[0] = vaddq_u64(sse_u64[0], sse_u64[2]); + + sum += vaddlvq_u32(sum_u32[0]); + sse += vaddvq_u64(sse_u64[0]); + + return sse - sum * sum / (width * height); +} + +uint64_t aom_var_2d_u16_sve(uint8_t *src, int src_stride, int width, + int height) { + if (width == 4) { + return aom_var_2d_u16_4xh_sve(src, src_stride, width, height); + } + if (width == 8) { + return aom_var_2d_u16_8xh_sve(src, src_stride, width, height); + } + if (width == 16) { + return aom_var_2d_u16_16xh_sve(src, src_stride, width, height); + } + if (width % 32 == 0) { + return aom_var_2d_u16_large_sve(src, src_stride, width, height); + } + return aom_var_2d_u16_neon(src, src_stride, width, height); +} diff --git a/third_party/aom/aom_dsp/arm/transpose_neon.h b/third_party/aom/aom_dsp/arm/transpose_neon.h new file mode 100644 index 0000000000..8027018235 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/transpose_neon.h @@ -0,0 +1,1263 @@ +/* + * Copyright (c) 2018, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef AOM_AOM_DSP_ARM_TRANSPOSE_NEON_H_ +#define AOM_AOM_DSP_ARM_TRANSPOSE_NEON_H_ + +#include + +#include "aom/aom_integer.h" // For AOM_FORCE_INLINE. +#include "config/aom_config.h" + +static INLINE void transpose_elems_u8_8x8( + uint8x8_t a0, uint8x8_t a1, uint8x8_t a2, uint8x8_t a3, uint8x8_t a4, + uint8x8_t a5, uint8x8_t a6, uint8x8_t a7, uint8x8_t *o0, uint8x8_t *o1, + uint8x8_t *o2, uint8x8_t *o3, uint8x8_t *o4, uint8x8_t *o5, uint8x8_t *o6, + uint8x8_t *o7) { + // Swap 8 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // a4: 40 41 42 43 44 45 46 47 + // a5: 50 51 52 53 54 55 56 57 + // a6: 60 61 62 63 64 65 66 67 + // a7: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 40 50 42 52 44 54 46 56 + // b0.val[1]: 01 11 03 13 05 15 07 17 41 51 43 53 45 55 47 57 + // b1.val[0]: 20 30 22 32 24 34 26 36 60 70 62 72 64 74 66 76 + // b1.val[1]: 21 31 23 33 25 35 27 37 61 71 63 73 65 75 67 77 + + const uint8x16x2_t b0 = vtrnq_u8(vcombine_u8(a0, a4), vcombine_u8(a1, a5)); + const uint8x16x2_t b1 = vtrnq_u8(vcombine_u8(a2, a6), vcombine_u8(a3, a7)); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 40 50 60 70 44 54 64 74 + // c0.val[1]: 02 12 22 32 06 16 26 36 42 52 62 72 46 56 66 76 + // c1.val[0]: 01 11 21 31 05 15 25 35 41 51 61 71 45 55 65 75 + // c1.val[1]: 03 13 23 33 07 17 27 37 43 53 63 73 47 57 67 77 + + const uint16x8x2_t c0 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[0]), + vreinterpretq_u16_u8(b1.val[0])); + const uint16x8x2_t c1 = vtrnq_u16(vreinterpretq_u16_u8(b0.val[1]), + vreinterpretq_u16_u8(b1.val[1])); + + // Unzip 32 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 + // d0.val[1]: 04 14 24 34 44 54 64 74 05 15 25 35 45 55 65 75 + // d1.val[0]: 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 + // d1.val[1]: 06 16 26 36 46 56 66 76 07 17 27 37 47 57 67 77 + const uint32x4x2_t d0 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c1.val[0])); + const uint32x4x2_t d1 = vuzpq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c1.val[1])); + + *o0 = vreinterpret_u8_u32(vget_low_u32(d0.val[0])); + *o1 = vreinterpret_u8_u32(vget_high_u32(d0.val[0])); + *o2 = vreinterpret_u8_u32(vget_low_u32(d1.val[0])); + *o3 = vreinterpret_u8_u32(vget_high_u32(d1.val[0])); + *o4 = vreinterpret_u8_u32(vget_low_u32(d0.val[1])); + *o5 = vreinterpret_u8_u32(vget_high_u32(d0.val[1])); + *o6 = vreinterpret_u8_u32(vget_low_u32(d1.val[1])); + *o7 = vreinterpret_u8_u32(vget_high_u32(d1.val[1])); +} + +static INLINE void transpose_elems_inplace_u8_8x8(uint8x8_t *a0, uint8x8_t *a1, + uint8x8_t *a2, uint8x8_t *a3, + uint8x8_t *a4, uint8x8_t *a5, + uint8x8_t *a6, + uint8x8_t *a7) { + transpose_elems_u8_8x8(*a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7, a0, a1, a2, a3, + a4, a5, a6, a7); +} + +static INLINE void transpose_arrays_u8_8x8(const uint8x8_t *in, + uint8x8_t *out) { + transpose_elems_u8_8x8(in[0], in[1], in[2], in[3], in[4], in[5], in[6], in[7], + &out[0], &out[1], &out[2], &out[3], &out[4], &out[5], + &out[6], &out[7]); +} + +static AOM_FORCE_INLINE void transpose_arrays_u8_8x16(const uint8x8_t *x, + uint8x16_t *d) { + uint8x8x2_t w0 = vzip_u8(x[0], x[1]); + uint8x8x2_t w1 = vzip_u8(x[2], x[3]); + uint8x8x2_t w2 = vzip_u8(x[4], x[5]); + uint8x8x2_t w3 = vzip_u8(x[6], x[7]); + + uint8x8x2_t w8 = vzip_u8(x[8], x[9]); + uint8x8x2_t w9 = vzip_u8(x[10], x[11]); + uint8x8x2_t w10 = vzip_u8(x[12], x[13]); + uint8x8x2_t w11 = vzip_u8(x[14], x[15]); + + uint16x4x2_t w4 = + vzip_u16(vreinterpret_u16_u8(w0.val[0]), vreinterpret_u16_u8(w1.val[0])); + uint16x4x2_t w5 = + vzip_u16(vreinterpret_u16_u8(w2.val[0]), vreinterpret_u16_u8(w3.val[0])); + uint16x4x2_t w12 = + vzip_u16(vreinterpret_u16_u8(w8.val[0]), vreinterpret_u16_u8(w9.val[0])); + uint16x4x2_t w13 = vzip_u16(vreinterpret_u16_u8(w10.val[0]), + vreinterpret_u16_u8(w11.val[0])); + + uint32x2x2_t w6 = vzip_u32(vreinterpret_u32_u16(w4.val[0]), + vreinterpret_u32_u16(w5.val[0])); + uint32x2x2_t w7 = vzip_u32(vreinterpret_u32_u16(w4.val[1]), + vreinterpret_u32_u16(w5.val[1])); + uint32x2x2_t w14 = vzip_u32(vreinterpret_u32_u16(w12.val[0]), + vreinterpret_u32_u16(w13.val[0])); + uint32x2x2_t w15 = vzip_u32(vreinterpret_u32_u16(w12.val[1]), + vreinterpret_u32_u16(w13.val[1])); + + // Store first 4-line result + d[0] = vreinterpretq_u8_u32(vcombine_u32(w6.val[0], w14.val[0])); + d[1] = vreinterpretq_u8_u32(vcombine_u32(w6.val[1], w14.val[1])); + d[2] = vreinterpretq_u8_u32(vcombine_u32(w7.val[0], w15.val[0])); + d[3] = vreinterpretq_u8_u32(vcombine_u32(w7.val[1], w15.val[1])); + + w4 = vzip_u16(vreinterpret_u16_u8(w0.val[1]), vreinterpret_u16_u8(w1.val[1])); + w5 = vzip_u16(vreinterpret_u16_u8(w2.val[1]), vreinterpret_u16_u8(w3.val[1])); + w12 = + vzip_u16(vreinterpret_u16_u8(w8.val[1]), vreinterpret_u16_u8(w9.val[1])); + w13 = vzip_u16(vreinterpret_u16_u8(w10.val[1]), + vreinterpret_u16_u8(w11.val[1])); + + w6 = vzip_u32(vreinterpret_u32_u16(w4.val[0]), + vreinterpret_u32_u16(w5.val[0])); + w7 = vzip_u32(vreinterpret_u32_u16(w4.val[1]), + vreinterpret_u32_u16(w5.val[1])); + w14 = vzip_u32(vreinterpret_u32_u16(w12.val[0]), + vreinterpret_u32_u16(w13.val[0])); + w15 = vzip_u32(vreinterpret_u32_u16(w12.val[1]), + vreinterpret_u32_u16(w13.val[1])); + + // Store second 4-line result + d[4] = vreinterpretq_u8_u32(vcombine_u32(w6.val[0], w14.val[0])); + d[5] = vreinterpretq_u8_u32(vcombine_u32(w6.val[1], w14.val[1])); + d[6] = vreinterpretq_u8_u32(vcombine_u32(w7.val[0], w15.val[0])); + d[7] = vreinterpretq_u8_u32(vcombine_u32(w7.val[1], w15.val[1])); +} + +static AOM_FORCE_INLINE void transpose_arrays_u8_16x8(const uint8x16_t *x, + uint8x8_t *d) { + uint8x16x2_t w0 = vzipq_u8(x[0], x[1]); + uint8x16x2_t w1 = vzipq_u8(x[2], x[3]); + uint8x16x2_t w2 = vzipq_u8(x[4], x[5]); + uint8x16x2_t w3 = vzipq_u8(x[6], x[7]); + + uint16x8x2_t w4 = vzipq_u16(vreinterpretq_u16_u8(w0.val[0]), + vreinterpretq_u16_u8(w1.val[0])); + uint16x8x2_t w5 = vzipq_u16(vreinterpretq_u16_u8(w2.val[0]), + vreinterpretq_u16_u8(w3.val[0])); + uint16x8x2_t w6 = vzipq_u16(vreinterpretq_u16_u8(w0.val[1]), + vreinterpretq_u16_u8(w1.val[1])); + uint16x8x2_t w7 = vzipq_u16(vreinterpretq_u16_u8(w2.val[1]), + vreinterpretq_u16_u8(w3.val[1])); + + uint32x4x2_t w8 = vzipq_u32(vreinterpretq_u32_u16(w4.val[0]), + vreinterpretq_u32_u16(w5.val[0])); + uint32x4x2_t w9 = vzipq_u32(vreinterpretq_u32_u16(w6.val[0]), + vreinterpretq_u32_u16(w7.val[0])); + uint32x4x2_t w10 = vzipq_u32(vreinterpretq_u32_u16(w4.val[1]), + vreinterpretq_u32_u16(w5.val[1])); + uint32x4x2_t w11 = vzipq_u32(vreinterpretq_u32_u16(w6.val[1]), + vreinterpretq_u32_u16(w7.val[1])); + + d[0] = vreinterpret_u8_u32(vget_low_u32(w8.val[0])); + d[1] = vreinterpret_u8_u32(vget_high_u32(w8.val[0])); + d[2] = vreinterpret_u8_u32(vget_low_u32(w8.val[1])); + d[3] = vreinterpret_u8_u32(vget_high_u32(w8.val[1])); + d[4] = vreinterpret_u8_u32(vget_low_u32(w10.val[0])); + d[5] = vreinterpret_u8_u32(vget_high_u32(w10.val[0])); + d[6] = vreinterpret_u8_u32(vget_low_u32(w10.val[1])); + d[7] = vreinterpret_u8_u32(vget_high_u32(w10.val[1])); + d[8] = vreinterpret_u8_u32(vget_low_u32(w9.val[0])); + d[9] = vreinterpret_u8_u32(vget_high_u32(w9.val[0])); + d[10] = vreinterpret_u8_u32(vget_low_u32(w9.val[1])); + d[11] = vreinterpret_u8_u32(vget_high_u32(w9.val[1])); + d[12] = vreinterpret_u8_u32(vget_low_u32(w11.val[0])); + d[13] = vreinterpret_u8_u32(vget_high_u32(w11.val[0])); + d[14] = vreinterpret_u8_u32(vget_low_u32(w11.val[1])); + d[15] = vreinterpret_u8_u32(vget_high_u32(w11.val[1])); +} + +static INLINE uint16x8x2_t aom_vtrnq_u64_to_u16(uint32x4_t a0, uint32x4_t a1) { + uint16x8x2_t b0; +#if AOM_ARCH_AARCH64 + b0.val[0] = vreinterpretq_u16_u64( + vtrn1q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1))); + b0.val[1] = vreinterpretq_u16_u64( + vtrn2q_u64(vreinterpretq_u64_u32(a0), vreinterpretq_u64_u32(a1))); +#else + b0.val[0] = vcombine_u16(vreinterpret_u16_u32(vget_low_u32(a0)), + vreinterpret_u16_u32(vget_low_u32(a1))); + b0.val[1] = vcombine_u16(vreinterpret_u16_u32(vget_high_u32(a0)), + vreinterpret_u16_u32(vget_high_u32(a1))); +#endif + return b0; +} + +static INLINE void transpose_arrays_u8_16x16(const uint8x16_t *x, + uint8x16_t *d) { + uint8x16x2_t w0 = vzipq_u8(x[0], x[1]); + uint8x16x2_t w1 = vzipq_u8(x[2], x[3]); + uint8x16x2_t w2 = vzipq_u8(x[4], x[5]); + uint8x16x2_t w3 = vzipq_u8(x[6], x[7]); + + uint8x16x2_t w4 = vzipq_u8(x[8], x[9]); + uint8x16x2_t w5 = vzipq_u8(x[10], x[11]); + uint8x16x2_t w6 = vzipq_u8(x[12], x[13]); + uint8x16x2_t w7 = vzipq_u8(x[14], x[15]); + + uint16x8x2_t w8 = vzipq_u16(vreinterpretq_u16_u8(w0.val[0]), + vreinterpretq_u16_u8(w1.val[0])); + uint16x8x2_t w9 = vzipq_u16(vreinterpretq_u16_u8(w2.val[0]), + vreinterpretq_u16_u8(w3.val[0])); + uint16x8x2_t w10 = vzipq_u16(vreinterpretq_u16_u8(w4.val[0]), + vreinterpretq_u16_u8(w5.val[0])); + uint16x8x2_t w11 = vzipq_u16(vreinterpretq_u16_u8(w6.val[0]), + vreinterpretq_u16_u8(w7.val[0])); + + uint32x4x2_t w12 = vzipq_u32(vreinterpretq_u32_u16(w8.val[0]), + vreinterpretq_u32_u16(w9.val[0])); + uint32x4x2_t w13 = vzipq_u32(vreinterpretq_u32_u16(w10.val[0]), + vreinterpretq_u32_u16(w11.val[0])); + uint32x4x2_t w14 = vzipq_u32(vreinterpretq_u32_u16(w8.val[1]), + vreinterpretq_u32_u16(w9.val[1])); + uint32x4x2_t w15 = vzipq_u32(vreinterpretq_u32_u16(w10.val[1]), + vreinterpretq_u32_u16(w11.val[1])); + + uint16x8x2_t d01 = aom_vtrnq_u64_to_u16(w12.val[0], w13.val[0]); + d[0] = vreinterpretq_u8_u16(d01.val[0]); + d[1] = vreinterpretq_u8_u16(d01.val[1]); + uint16x8x2_t d23 = aom_vtrnq_u64_to_u16(w12.val[1], w13.val[1]); + d[2] = vreinterpretq_u8_u16(d23.val[0]); + d[3] = vreinterpretq_u8_u16(d23.val[1]); + uint16x8x2_t d45 = aom_vtrnq_u64_to_u16(w14.val[0], w15.val[0]); + d[4] = vreinterpretq_u8_u16(d45.val[0]); + d[5] = vreinterpretq_u8_u16(d45.val[1]); + uint16x8x2_t d67 = aom_vtrnq_u64_to_u16(w14.val[1], w15.val[1]); + d[6] = vreinterpretq_u8_u16(d67.val[0]); + d[7] = vreinterpretq_u8_u16(d67.val[1]); + + // upper half + w8 = vzipq_u16(vreinterpretq_u16_u8(w0.val[1]), + vreinterpretq_u16_u8(w1.val[1])); + w9 = vzipq_u16(vreinterpretq_u16_u8(w2.val[1]), + vreinterpretq_u16_u8(w3.val[1])); + w10 = vzipq_u16(vreinterpretq_u16_u8(w4.val[1]), + vreinterpretq_u16_u8(w5.val[1])); + w11 = vzipq_u16(vreinterpretq_u16_u8(w6.val[1]), + vreinterpretq_u16_u8(w7.val[1])); + + w12 = vzipq_u32(vreinterpretq_u32_u16(w8.val[0]), + vreinterpretq_u32_u16(w9.val[0])); + w13 = vzipq_u32(vreinterpretq_u32_u16(w10.val[0]), + vreinterpretq_u32_u16(w11.val[0])); + w14 = vzipq_u32(vreinterpretq_u32_u16(w8.val[1]), + vreinterpretq_u32_u16(w9.val[1])); + w15 = vzipq_u32(vreinterpretq_u32_u16(w10.val[1]), + vreinterpretq_u32_u16(w11.val[1])); + + d01 = aom_vtrnq_u64_to_u16(w12.val[0], w13.val[0]); + d[8] = vreinterpretq_u8_u16(d01.val[0]); + d[9] = vreinterpretq_u8_u16(d01.val[1]); + d23 = aom_vtrnq_u64_to_u16(w12.val[1], w13.val[1]); + d[10] = vreinterpretq_u8_u16(d23.val[0]); + d[11] = vreinterpretq_u8_u16(d23.val[1]); + d45 = aom_vtrnq_u64_to_u16(w14.val[0], w15.val[0]); + d[12] = vreinterpretq_u8_u16(d45.val[0]); + d[13] = vreinterpretq_u8_u16(d45.val[1]); + d67 = aom_vtrnq_u64_to_u16(w14.val[1], w15.val[1]); + d[14] = vreinterpretq_u8_u16(d67.val[0]); + d[15] = vreinterpretq_u8_u16(d67.val[1]); +} + +static AOM_FORCE_INLINE void transpose_arrays_u8_32x16(const uint8x16x2_t *x, + uint8x16_t *d) { + uint8x16_t x2[32]; + for (int i = 0; i < 16; ++i) { + x2[i] = x[i].val[0]; + x2[i + 16] = x[i].val[1]; + } + transpose_arrays_u8_16x16(x2, d); + transpose_arrays_u8_16x16(x2 + 16, d + 16); +} + +static INLINE void transpose_elems_inplace_u8_8x4(uint8x8_t *a0, uint8x8_t *a1, + uint8x8_t *a2, + uint8x8_t *a3) { + // Swap 8 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + + const uint8x8x2_t b0 = vtrn_u8(*a0, *a1); + const uint8x8x2_t b1 = vtrn_u8(*a2, *a3); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + + const uint16x4x2_t c0 = + vtrn_u16(vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0])); + const uint16x4x2_t c1 = + vtrn_u16(vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1])); + + *a0 = vreinterpret_u8_u16(c0.val[0]); + *a1 = vreinterpret_u8_u16(c1.val[0]); + *a2 = vreinterpret_u8_u16(c0.val[1]); + *a3 = vreinterpret_u8_u16(c1.val[1]); +} + +static INLINE void transpose_elems_inplace_u8_4x4(uint8x8_t *a0, + uint8x8_t *a1) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 10 11 12 13 + // a1: 20 21 22 23 30 31 32 33 + // to: + // b0.val[0]: 00 01 20 21 10 11 30 31 + // b0.val[1]: 02 03 22 23 12 13 32 33 + + const uint16x4x2_t b0 = + vtrn_u16(vreinterpret_u16_u8(*a0), vreinterpret_u16_u8(*a1)); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 01 20 21 02 03 22 23 + // c0.val[1]: 10 11 30 31 12 13 32 33 + + const uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), + vreinterpret_u32_u16(b0.val[1])); + + // Swap 8 bit elements resulting in: + // d0.val[0]: 00 10 20 30 02 12 22 32 + // d0.val[1]: 01 11 21 31 03 13 23 33 + + const uint8x8x2_t d0 = + vtrn_u8(vreinterpret_u8_u32(c0.val[0]), vreinterpret_u8_u32(c0.val[1])); + + *a0 = d0.val[0]; + *a1 = d0.val[1]; +} + +static INLINE void transpose_elems_u8_4x8(uint8x8_t a0, uint8x8_t a1, + uint8x8_t a2, uint8x8_t a3, + uint8x8_t a4, uint8x8_t a5, + uint8x8_t a6, uint8x8_t a7, + uint8x8_t *o0, uint8x8_t *o1, + uint8x8_t *o2, uint8x8_t *o3) { + // Swap 32 bit elements. Goes from: + // a0: 00 01 02 03 XX XX XX XX + // a1: 10 11 12 13 XX XX XX XX + // a2: 20 21 22 23 XX XX XX XX + // a3; 30 31 32 33 XX XX XX XX + // a4: 40 41 42 43 XX XX XX XX + // a5: 50 51 52 53 XX XX XX XX + // a6: 60 61 62 63 XX XX XX XX + // a7: 70 71 72 73 XX XX XX XX + // to: + // b0.val[0]: 00 01 02 03 40 41 42 43 + // b1.val[0]: 10 11 12 13 50 51 52 53 + // b2.val[0]: 20 21 22 23 60 61 62 63 + // b3.val[0]: 30 31 32 33 70 71 72 73 + + const uint32x2x2_t b0 = + vtrn_u32(vreinterpret_u32_u8(a0), vreinterpret_u32_u8(a4)); + const uint32x2x2_t b1 = + vtrn_u32(vreinterpret_u32_u8(a1), vreinterpret_u32_u8(a5)); + const uint32x2x2_t b2 = + vtrn_u32(vreinterpret_u32_u8(a2), vreinterpret_u32_u8(a6)); + const uint32x2x2_t b3 = + vtrn_u32(vreinterpret_u32_u8(a3), vreinterpret_u32_u8(a7)); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 01 20 21 40 41 60 61 + // c0.val[1]: 02 03 22 23 42 43 62 63 + // c1.val[0]: 10 11 30 31 50 51 70 71 + // c1.val[1]: 12 13 32 33 52 53 72 73 + + const uint16x4x2_t c0 = vtrn_u16(vreinterpret_u16_u32(b0.val[0]), + vreinterpret_u16_u32(b2.val[0])); + const uint16x4x2_t c1 = vtrn_u16(vreinterpret_u16_u32(b1.val[0]), + vreinterpret_u16_u32(b3.val[0])); + + // Swap 8 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 01 11 21 31 41 51 61 71 + // d1.val[0]: 02 12 22 32 42 52 62 72 + // d1.val[1]: 03 13 23 33 43 53 63 73 + + const uint8x8x2_t d0 = + vtrn_u8(vreinterpret_u8_u16(c0.val[0]), vreinterpret_u8_u16(c1.val[0])); + const uint8x8x2_t d1 = + vtrn_u8(vreinterpret_u8_u16(c0.val[1]), vreinterpret_u8_u16(c1.val[1])); + + *o0 = d0.val[0]; + *o1 = d0.val[1]; + *o2 = d1.val[0]; + *o3 = d1.val[1]; +} + +static INLINE void transpose_array_inplace_u16_4x4(uint16x4_t a[4]) { + // Input: + // 00 01 02 03 + // 10 11 12 13 + // 20 21 22 23 + // 30 31 32 33 + + // b: + // 00 10 02 12 + // 01 11 03 13 + const uint16x4x2_t b = vtrn_u16(a[0], a[1]); + // c: + // 20 30 22 32 + // 21 31 23 33 + const uint16x4x2_t c = vtrn_u16(a[2], a[3]); + // d: + // 00 10 20 30 + // 02 12 22 32 + const uint32x2x2_t d = + vtrn_u32(vreinterpret_u32_u16(b.val[0]), vreinterpret_u32_u16(c.val[0])); + // e: + // 01 11 21 31 + // 03 13 23 33 + const uint32x2x2_t e = + vtrn_u32(vreinterpret_u32_u16(b.val[1]), vreinterpret_u32_u16(c.val[1])); + + // Output: + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + a[0] = vreinterpret_u16_u32(d.val[0]); + a[1] = vreinterpret_u16_u32(e.val[0]); + a[2] = vreinterpret_u16_u32(d.val[1]); + a[3] = vreinterpret_u16_u32(e.val[1]); +} + +static INLINE void transpose_array_inplace_u16_4x8(uint16x8_t a[4]) { + // 4x8 Input: + // a[0]: 00 01 02 03 04 05 06 07 + // a[1]: 10 11 12 13 14 15 16 17 + // a[2]: 20 21 22 23 24 25 26 27 + // a[3]: 30 31 32 33 34 35 36 37 + + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]); + const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]); + + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + + // 8x4 Output: + // a[0]: 00 10 20 30 04 14 24 34 + // a[1]: 01 11 21 31 05 15 25 35 + // a[2]: 02 12 22 32 06 16 26 36 + // a[3]: 03 13 23 33 07 17 27 37 + a[0] = vreinterpretq_u16_u32(c0.val[0]); + a[1] = vreinterpretq_u16_u32(c1.val[0]); + a[2] = vreinterpretq_u16_u32(c0.val[1]); + a[3] = vreinterpretq_u16_u32(c1.val[1]); +} + +// Special transpose for loop filter. +// 4x8 Input: +// p_q: p3 p2 p1 p0 q0 q1 q2 q3 +// a[0]: 00 01 02 03 04 05 06 07 +// a[1]: 10 11 12 13 14 15 16 17 +// a[2]: 20 21 22 23 24 25 26 27 +// a[3]: 30 31 32 33 34 35 36 37 +// 8x4 Output: +// a[0]: 03 13 23 33 04 14 24 34 p0q0 +// a[1]: 02 12 22 32 05 15 25 35 p1q1 +// a[2]: 01 11 21 31 06 16 26 36 p2q2 +// a[3]: 00 10 20 30 07 17 27 37 p3q3 +// Direct reapplication of the function will reset the high halves, but +// reverse the low halves: +// p_q: p0 p1 p2 p3 q0 q1 q2 q3 +// a[0]: 33 32 31 30 04 05 06 07 +// a[1]: 23 22 21 20 14 15 16 17 +// a[2]: 13 12 11 10 24 25 26 27 +// a[3]: 03 02 01 00 34 35 36 37 +// Simply reordering the inputs (3, 2, 1, 0) will reset the low halves, but +// reverse the high halves. +// The standard transpose_u16_4x8q will produce the same reversals, but with the +// order of the low halves also restored relative to the high halves. This is +// preferable because it puts all values from the same source row back together, +// but some post-processing is inevitable. +static INLINE void loop_filter_transpose_u16_4x8q(uint16x8_t a[4]) { + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]); + const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]); + + // Reverse odd vectors to bring the appropriate items to the front of zips. + // b0.val[0]: 00 10 02 12 04 14 06 16 + // r0 : 03 13 01 11 07 17 05 15 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // r1 : 23 33 21 31 27 37 25 35 + const uint32x4_t r0 = vrev64q_u32(vreinterpretq_u32_u16(b0.val[1])); + const uint32x4_t r1 = vrev64q_u32(vreinterpretq_u32_u16(b1.val[1])); + + // Zip to complete the halves. + // c0.val[0]: 00 10 20 30 02 12 22 32 p3p1 + // c0.val[1]: 04 14 24 34 06 16 26 36 q0q2 + // c1.val[0]: 03 13 23 33 01 11 21 31 p0p2 + // c1.val[1]: 07 17 27 37 05 15 25 35 q3q1 + const uint32x4x2_t c0 = vzipq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vzipq_u32(r0, r1); + + // d0.val[0]: 00 10 20 30 07 17 27 37 p3q3 + // d0.val[1]: 02 12 22 32 05 15 25 35 p1q1 + // d1.val[0]: 03 13 23 33 04 14 24 34 p0q0 + // d1.val[1]: 01 11 21 31 06 16 26 36 p2q2 + const uint16x8x2_t d0 = aom_vtrnq_u64_to_u16(c0.val[0], c1.val[1]); + // The third row of c comes first here to swap p2 with q0. + const uint16x8x2_t d1 = aom_vtrnq_u64_to_u16(c1.val[0], c0.val[1]); + + // 8x4 Output: + // a[0]: 03 13 23 33 04 14 24 34 p0q0 + // a[1]: 02 12 22 32 05 15 25 35 p1q1 + // a[2]: 01 11 21 31 06 16 26 36 p2q2 + // a[3]: 00 10 20 30 07 17 27 37 p3q3 + a[0] = d1.val[0]; // p0q0 + a[1] = d0.val[1]; // p1q1 + a[2] = d1.val[1]; // p2q2 + a[3] = d0.val[0]; // p3q3 +} + +static INLINE void transpose_elems_u16_4x8( + const uint16x4_t a0, const uint16x4_t a1, const uint16x4_t a2, + const uint16x4_t a3, const uint16x4_t a4, const uint16x4_t a5, + const uint16x4_t a6, const uint16x4_t a7, uint16x8_t *o0, uint16x8_t *o1, + uint16x8_t *o2, uint16x8_t *o3) { + // Combine rows. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // a4: 40 41 42 43 + // a5: 50 51 52 53 + // a6: 60 61 62 63 + // a7: 70 71 72 73 + // to: + // b0: 00 01 02 03 40 41 42 43 + // b1: 10 11 12 13 50 51 52 53 + // b2: 20 21 22 23 60 61 62 63 + // b3: 30 31 32 33 70 71 72 73 + + const uint16x8_t b0 = vcombine_u16(a0, a4); + const uint16x8_t b1 = vcombine_u16(a1, a5); + const uint16x8_t b2 = vcombine_u16(a2, a6); + const uint16x8_t b3 = vcombine_u16(a3, a7); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 10 02 12 40 50 42 52 + // c0.val[1]: 01 11 03 13 41 51 43 53 + // c1.val[0]: 20 30 22 32 60 70 62 72 + // c1.val[1]: 21 31 23 33 61 71 63 73 + + const uint16x8x2_t c0 = vtrnq_u16(b0, b1); + const uint16x8x2_t c1 = vtrnq_u16(b2, b3); + + // Swap 32 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 02 12 22 32 42 52 62 72 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 03 13 23 33 43 53 63 73 + + const uint32x4x2_t d0 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[0]), + vreinterpretq_u32_u16(c1.val[0])); + const uint32x4x2_t d1 = vtrnq_u32(vreinterpretq_u32_u16(c0.val[1]), + vreinterpretq_u32_u16(c1.val[1])); + + *o0 = vreinterpretq_u16_u32(d0.val[0]); + *o1 = vreinterpretq_u16_u32(d1.val[0]); + *o2 = vreinterpretq_u16_u32(d0.val[1]); + *o3 = vreinterpretq_u16_u32(d1.val[1]); +} + +static INLINE void transpose_elems_s16_4x8( + const int16x4_t a0, const int16x4_t a1, const int16x4_t a2, + const int16x4_t a3, const int16x4_t a4, const int16x4_t a5, + const int16x4_t a6, const int16x4_t a7, int16x8_t *o0, int16x8_t *o1, + int16x8_t *o2, int16x8_t *o3) { + // Combine rows. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // a4: 40 41 42 43 + // a5: 50 51 52 53 + // a6: 60 61 62 63 + // a7: 70 71 72 73 + // to: + // b0: 00 01 02 03 40 41 42 43 + // b1: 10 11 12 13 50 51 52 53 + // b2: 20 21 22 23 60 61 62 63 + // b3: 30 31 32 33 70 71 72 73 + + const int16x8_t b0 = vcombine_s16(a0, a4); + const int16x8_t b1 = vcombine_s16(a1, a5); + const int16x8_t b2 = vcombine_s16(a2, a6); + const int16x8_t b3 = vcombine_s16(a3, a7); + + // Swap 16 bit elements resulting in: + // c0.val[0]: 00 10 02 12 40 50 42 52 + // c0.val[1]: 01 11 03 13 41 51 43 53 + // c1.val[0]: 20 30 22 32 60 70 62 72 + // c1.val[1]: 21 31 23 33 61 71 63 73 + + const int16x8x2_t c0 = vtrnq_s16(b0, b1); + const int16x8x2_t c1 = vtrnq_s16(b2, b3); + + // Swap 32 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 02 12 22 32 42 52 62 72 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 03 13 23 33 43 53 63 73 + + const int32x4x2_t d0 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[0]), + vreinterpretq_s32_s16(c1.val[0])); + const int32x4x2_t d1 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[1]), + vreinterpretq_s32_s16(c1.val[1])); + + *o0 = vreinterpretq_s16_s32(d0.val[0]); + *o1 = vreinterpretq_s16_s32(d1.val[0]); + *o2 = vreinterpretq_s16_s32(d0.val[1]); + *o3 = vreinterpretq_s16_s32(d1.val[1]); +} + +static INLINE void transpose_elems_inplace_u16_8x8( + uint16x8_t *a0, uint16x8_t *a1, uint16x8_t *a2, uint16x8_t *a3, + uint16x8_t *a4, uint16x8_t *a5, uint16x8_t *a6, uint16x8_t *a7) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // a4: 40 41 42 43 44 45 46 47 + // a5: 50 51 52 53 54 55 56 57 + // a6: 60 61 62 63 64 65 66 67 + // a7: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + // b2.val[0]: 40 50 42 52 44 54 46 56 + // b2.val[1]: 41 51 43 53 45 55 47 57 + // b3.val[0]: 60 70 62 72 64 74 66 76 + // b3.val[1]: 61 71 63 73 65 75 67 77 + + const uint16x8x2_t b0 = vtrnq_u16(*a0, *a1); + const uint16x8x2_t b1 = vtrnq_u16(*a2, *a3); + const uint16x8x2_t b2 = vtrnq_u16(*a4, *a5); + const uint16x8x2_t b3 = vtrnq_u16(*a6, *a7); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + // c2.val[0]: 40 50 60 70 44 54 64 74 + // c2.val[1]: 42 52 62 72 46 56 66 76 + // c3.val[0]: 41 51 61 71 45 55 65 75 + // c3.val[1]: 43 53 63 73 47 57 67 77 + + const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]), + vreinterpretq_u32_u16(b1.val[0])); + const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]), + vreinterpretq_u32_u16(b1.val[1])); + const uint32x4x2_t c2 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[0]), + vreinterpretq_u32_u16(b3.val[0])); + const uint32x4x2_t c3 = vtrnq_u32(vreinterpretq_u32_u16(b2.val[1]), + vreinterpretq_u32_u16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 + + const uint16x8x2_t d0 = aom_vtrnq_u64_to_u16(c0.val[0], c2.val[0]); + const uint16x8x2_t d1 = aom_vtrnq_u64_to_u16(c1.val[0], c3.val[0]); + const uint16x8x2_t d2 = aom_vtrnq_u64_to_u16(c0.val[1], c2.val[1]); + const uint16x8x2_t d3 = aom_vtrnq_u64_to_u16(c1.val[1], c3.val[1]); + + *a0 = d0.val[0]; + *a1 = d1.val[0]; + *a2 = d2.val[0]; + *a3 = d3.val[0]; + *a4 = d0.val[1]; + *a5 = d1.val[1]; + *a6 = d2.val[1]; + *a7 = d3.val[1]; +} + +static INLINE int16x8x2_t aom_vtrnq_s64_to_s16(int32x4_t a0, int32x4_t a1) { + int16x8x2_t b0; +#if AOM_ARCH_AARCH64 + b0.val[0] = vreinterpretq_s16_s64( + vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1))); + b0.val[1] = vreinterpretq_s16_s64( + vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1))); +#else + b0.val[0] = vcombine_s16(vreinterpret_s16_s32(vget_low_s32(a0)), + vreinterpret_s16_s32(vget_low_s32(a1))); + b0.val[1] = vcombine_s16(vreinterpret_s16_s32(vget_high_s32(a0)), + vreinterpret_s16_s32(vget_high_s32(a1))); +#endif + return b0; +} + +static INLINE void transpose_elems_inplace_s16_8x8(int16x8_t *a0, int16x8_t *a1, + int16x8_t *a2, int16x8_t *a3, + int16x8_t *a4, int16x8_t *a5, + int16x8_t *a6, + int16x8_t *a7) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // a4: 40 41 42 43 44 45 46 47 + // a5: 50 51 52 53 54 55 56 57 + // a6: 60 61 62 63 64 65 66 67 + // a7: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + // b2.val[0]: 40 50 42 52 44 54 46 56 + // b2.val[1]: 41 51 43 53 45 55 47 57 + // b3.val[0]: 60 70 62 72 64 74 66 76 + // b3.val[1]: 61 71 63 73 65 75 67 77 + + const int16x8x2_t b0 = vtrnq_s16(*a0, *a1); + const int16x8x2_t b1 = vtrnq_s16(*a2, *a3); + const int16x8x2_t b2 = vtrnq_s16(*a4, *a5); + const int16x8x2_t b3 = vtrnq_s16(*a6, *a7); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + // c2.val[0]: 40 50 60 70 44 54 64 74 + // c2.val[1]: 42 52 62 72 46 56 66 76 + // c3.val[0]: 41 51 61 71 45 55 65 75 + // c3.val[1]: 43 53 63 73 47 57 67 77 + + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]), + vreinterpretq_s32_s16(b3.val[0])); + const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), + vreinterpretq_s32_s16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 + + const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]); + + *a0 = d0.val[0]; + *a1 = d1.val[0]; + *a2 = d2.val[0]; + *a3 = d3.val[0]; + *a4 = d0.val[1]; + *a5 = d1.val[1]; + *a6 = d2.val[1]; + *a7 = d3.val[1]; +} + +static INLINE void transpose_arrays_s16_8x8(const int16x8_t *a, + int16x8_t *out) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 04 05 06 07 + // a1: 10 11 12 13 14 15 16 17 + // a2: 20 21 22 23 24 25 26 27 + // a3: 30 31 32 33 34 35 36 37 + // a4: 40 41 42 43 44 45 46 47 + // a5: 50 51 52 53 54 55 56 57 + // a6: 60 61 62 63 64 65 66 67 + // a7: 70 71 72 73 74 75 76 77 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + // b2.val[0]: 40 50 42 52 44 54 46 56 + // b2.val[1]: 41 51 43 53 45 55 47 57 + // b3.val[0]: 60 70 62 72 64 74 66 76 + // b3.val[1]: 61 71 63 73 65 75 67 77 + + const int16x8x2_t b0 = vtrnq_s16(a[0], a[1]); + const int16x8x2_t b1 = vtrnq_s16(a[2], a[3]); + const int16x8x2_t b2 = vtrnq_s16(a[4], a[5]); + const int16x8x2_t b3 = vtrnq_s16(a[6], a[7]); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + // c2.val[0]: 40 50 60 70 44 54 64 74 + // c2.val[1]: 42 52 62 72 46 56 66 76 + // c3.val[0]: 41 51 61 71 45 55 65 75 + // c3.val[1]: 43 53 63 73 47 57 67 77 + + const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]), + vreinterpretq_s32_s16(b1.val[0])); + const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]), + vreinterpretq_s32_s16(b1.val[1])); + const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]), + vreinterpretq_s32_s16(b3.val[0])); + const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]), + vreinterpretq_s32_s16(b3.val[1])); + + // Swap 64 bit elements resulting in: + // d0.val[0]: 00 10 20 30 40 50 60 70 + // d0.val[1]: 04 14 24 34 44 54 64 74 + // d1.val[0]: 01 11 21 31 41 51 61 71 + // d1.val[1]: 05 15 25 35 45 55 65 75 + // d2.val[0]: 02 12 22 32 42 52 62 72 + // d2.val[1]: 06 16 26 36 46 56 66 76 + // d3.val[0]: 03 13 23 33 43 53 63 73 + // d3.val[1]: 07 17 27 37 47 57 67 77 + + const int16x8x2_t d0 = aom_vtrnq_s64_to_s16(c0.val[0], c2.val[0]); + const int16x8x2_t d1 = aom_vtrnq_s64_to_s16(c1.val[0], c3.val[0]); + const int16x8x2_t d2 = aom_vtrnq_s64_to_s16(c0.val[1], c2.val[1]); + const int16x8x2_t d3 = aom_vtrnq_s64_to_s16(c1.val[1], c3.val[1]); + + out[0] = d0.val[0]; + out[1] = d1.val[0]; + out[2] = d2.val[0]; + out[3] = d3.val[0]; + out[4] = d0.val[1]; + out[5] = d1.val[1]; + out[6] = d2.val[1]; + out[7] = d3.val[1]; +} + +static INLINE void transpose_elems_inplace_u16_4x4(uint16x4_t *a0, + uint16x4_t *a1, + uint16x4_t *a2, + uint16x4_t *a3) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + + const uint16x4x2_t b0 = vtrn_u16(*a0, *a1); + const uint16x4x2_t b1 = vtrn_u16(*a2, *a3); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 02 12 22 32 + // c1.val[0]: 01 11 21 31 + // c1.val[1]: 03 13 23 33 + + const uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), + vreinterpret_u32_u16(b1.val[0])); + const uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), + vreinterpret_u32_u16(b1.val[1])); + + *a0 = vreinterpret_u16_u32(c0.val[0]); + *a1 = vreinterpret_u16_u32(c1.val[0]); + *a2 = vreinterpret_u16_u32(c0.val[1]); + *a3 = vreinterpret_u16_u32(c1.val[1]); +} + +static INLINE void transpose_elems_inplace_s16_4x4(int16x4_t *a0, int16x4_t *a1, + int16x4_t *a2, + int16x4_t *a3) { + // Swap 16 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + + const int16x4x2_t b0 = vtrn_s16(*a0, *a1); + const int16x4x2_t b1 = vtrn_s16(*a2, *a3); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 02 12 22 32 + // c1.val[0]: 01 11 21 31 + // c1.val[1]: 03 13 23 33 + + const int32x2x2_t c0 = vtrn_s32(vreinterpret_s32_s16(b0.val[0]), + vreinterpret_s32_s16(b1.val[0])); + const int32x2x2_t c1 = vtrn_s32(vreinterpret_s32_s16(b0.val[1]), + vreinterpret_s32_s16(b1.val[1])); + + *a0 = vreinterpret_s16_s32(c0.val[0]); + *a1 = vreinterpret_s16_s32(c1.val[0]); + *a2 = vreinterpret_s16_s32(c0.val[1]); + *a3 = vreinterpret_s16_s32(c1.val[1]); +} + +static INLINE int32x4x2_t aom_vtrnq_s64_to_s32(int32x4_t a0, int32x4_t a1) { + int32x4x2_t b0; +#if AOM_ARCH_AARCH64 + b0.val[0] = vreinterpretq_s32_s64( + vtrn1q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1))); + b0.val[1] = vreinterpretq_s32_s64( + vtrn2q_s64(vreinterpretq_s64_s32(a0), vreinterpretq_s64_s32(a1))); +#else + b0.val[0] = vcombine_s32(vget_low_s32(a0), vget_low_s32(a1)); + b0.val[1] = vcombine_s32(vget_high_s32(a0), vget_high_s32(a1)); +#endif + return b0; +} + +static INLINE void transpose_elems_s32_4x4(const int32x4_t a0, + const int32x4_t a1, + const int32x4_t a2, + const int32x4_t a3, int32x4_t *o0, + int32x4_t *o1, int32x4_t *o2, + int32x4_t *o3) { + // Swap 32 bit elements. Goes from: + // a0: 00 01 02 03 + // a1: 10 11 12 13 + // a2: 20 21 22 23 + // a3: 30 31 32 33 + // to: + // b0.val[0]: 00 10 02 12 + // b0.val[1]: 01 11 03 13 + // b1.val[0]: 20 30 22 32 + // b1.val[1]: 21 31 23 33 + + const int32x4x2_t b0 = vtrnq_s32(a0, a1); + const int32x4x2_t b1 = vtrnq_s32(a2, a3); + + // Swap 64 bit elements resulting in: + // c0.val[0]: 00 10 20 30 + // c0.val[1]: 02 12 22 32 + // c1.val[0]: 01 11 21 31 + // c1.val[1]: 03 13 23 33 + + const int32x4x2_t c0 = aom_vtrnq_s64_to_s32(b0.val[0], b1.val[0]); + const int32x4x2_t c1 = aom_vtrnq_s64_to_s32(b0.val[1], b1.val[1]); + + *o0 = c0.val[0]; + *o1 = c1.val[0]; + *o2 = c0.val[1]; + *o3 = c1.val[1]; +} + +static INLINE void transpose_elems_inplace_s32_4x4(int32x4_t *a0, int32x4_t *a1, + int32x4_t *a2, + int32x4_t *a3) { + transpose_elems_s32_4x4(*a0, *a1, *a2, *a3, a0, a1, a2, a3); +} + +static INLINE void transpose_arrays_s32_4x4(const int32x4_t *in, + int32x4_t *out) { + transpose_elems_s32_4x4(in[0], in[1], in[2], in[3], &out[0], &out[1], &out[2], + &out[3]); +} + +static AOM_FORCE_INLINE void transpose_arrays_s32_4nx4n(const int32x4_t *in, + int32x4_t *out, + const int width, + const int height) { + const int h = height >> 2; + const int w = width >> 2; + for (int j = 0; j < w; j++) { + for (int i = 0; i < h; i++) { + transpose_arrays_s32_4x4(in + j * height + i * 4, + out + i * width + j * 4); + } + } +} + +#define TRANSPOSE_ARRAYS_S32_WXH_NEON(w, h) \ + static AOM_FORCE_INLINE void transpose_arrays_s32_##w##x##h( \ + const int32x4_t *in, int32x4_t *out) { \ + transpose_arrays_s32_4nx4n(in, out, w, h); \ + } + +TRANSPOSE_ARRAYS_S32_WXH_NEON(4, 8) +TRANSPOSE_ARRAYS_S32_WXH_NEON(4, 16) +TRANSPOSE_ARRAYS_S32_WXH_NEON(8, 4) +TRANSPOSE_ARRAYS_S32_WXH_NEON(8, 8) +TRANSPOSE_ARRAYS_S32_WXH_NEON(8, 16) +TRANSPOSE_ARRAYS_S32_WXH_NEON(8, 32) +TRANSPOSE_ARRAYS_S32_WXH_NEON(16, 8) +TRANSPOSE_ARRAYS_S32_WXH_NEON(16, 16) +TRANSPOSE_ARRAYS_S32_WXH_NEON(16, 32) +TRANSPOSE_ARRAYS_S32_WXH_NEON(16, 64) +TRANSPOSE_ARRAYS_S32_WXH_NEON(32, 8) +TRANSPOSE_ARRAYS_S32_WXH_NEON(32, 16) +TRANSPOSE_ARRAYS_S32_WXH_NEON(32, 32) +TRANSPOSE_ARRAYS_S32_WXH_NEON(32, 64) +TRANSPOSE_ARRAYS_S32_WXH_NEON(64, 16) +TRANSPOSE_ARRAYS_S32_WXH_NEON(64, 32) + +#undef TRANSPOSE_ARRAYS_S32_WXH_NEON + +static INLINE int64x2_t aom_vtrn1q_s64(int64x2_t a, int64x2_t b) { +#if AOM_ARCH_AARCH64 + return vtrn1q_s64(a, b); +#else + return vcombine_s64(vget_low_s64(a), vget_low_s64(b)); +#endif +} + +static INLINE int64x2_t aom_vtrn2q_s64(int64x2_t a, int64x2_t b) { +#if AOM_ARCH_AARCH64 + return vtrn2q_s64(a, b); +#else + return vcombine_s64(vget_high_s64(a), vget_high_s64(b)); +#endif +} + +static INLINE void transpose_elems_s32_4x8(int32x4_t a0, int32x4_t a1, + int32x4_t a2, int32x4_t a3, + int32x4_t a4, int32x4_t a5, + int32x4_t a6, int32x4_t a7, + int32x4x2_t *o0, int32x4x2_t *o1, + int32x4x2_t *o2, int32x4x2_t *o3) { + // Perform a 4 x 8 matrix transpose by building on top of the existing 4 x 4 + // matrix transpose implementation: + // [ A ]^T => [ A^T B^T ] + // [ B ] + + transpose_elems_inplace_s32_4x4(&a0, &a1, &a2, &a3); // A^T + transpose_elems_inplace_s32_4x4(&a4, &a5, &a6, &a7); // B^T + + o0->val[0] = a0; + o1->val[0] = a1; + o2->val[0] = a2; + o3->val[0] = a3; + + o0->val[1] = a4; + o1->val[1] = a5; + o2->val[1] = a6; + o3->val[1] = a7; +} + +static INLINE void transpose_elems_inplace_s32_8x8( + int32x4x2_t *a0, int32x4x2_t *a1, int32x4x2_t *a2, int32x4x2_t *a3, + int32x4x2_t *a4, int32x4x2_t *a5, int32x4x2_t *a6, int32x4x2_t *a7) { + // Perform an 8 x 8 matrix transpose by building on top of the existing 4 x 4 + // matrix transpose implementation: + // [ A B ]^T => [ A^T C^T ] + // [ C D ] [ B^T D^T ] + + int32x4_t q0_v1 = a0->val[0]; + int32x4_t q0_v2 = a1->val[0]; + int32x4_t q0_v3 = a2->val[0]; + int32x4_t q0_v4 = a3->val[0]; + + int32x4_t q1_v1 = a0->val[1]; + int32x4_t q1_v2 = a1->val[1]; + int32x4_t q1_v3 = a2->val[1]; + int32x4_t q1_v4 = a3->val[1]; + + int32x4_t q2_v1 = a4->val[0]; + int32x4_t q2_v2 = a5->val[0]; + int32x4_t q2_v3 = a6->val[0]; + int32x4_t q2_v4 = a7->val[0]; + + int32x4_t q3_v1 = a4->val[1]; + int32x4_t q3_v2 = a5->val[1]; + int32x4_t q3_v3 = a6->val[1]; + int32x4_t q3_v4 = a7->val[1]; + + transpose_elems_inplace_s32_4x4(&q0_v1, &q0_v2, &q0_v3, &q0_v4); // A^T + transpose_elems_inplace_s32_4x4(&q1_v1, &q1_v2, &q1_v3, &q1_v4); // B^T + transpose_elems_inplace_s32_4x4(&q2_v1, &q2_v2, &q2_v3, &q2_v4); // C^T + transpose_elems_inplace_s32_4x4(&q3_v1, &q3_v2, &q3_v3, &q3_v4); // D^T + + a0->val[0] = q0_v1; + a1->val[0] = q0_v2; + a2->val[0] = q0_v3; + a3->val[0] = q0_v4; + + a0->val[1] = q2_v1; + a1->val[1] = q2_v2; + a2->val[1] = q2_v3; + a3->val[1] = q2_v4; + + a4->val[0] = q1_v1; + a5->val[0] = q1_v2; + a6->val[0] = q1_v3; + a7->val[0] = q1_v4; + + a4->val[1] = q3_v1; + a5->val[1] = q3_v2; + a6->val[1] = q3_v3; + a7->val[1] = q3_v4; +} + +static INLINE void transpose_arrays_s16_4x4(const int16x4_t *const in, + int16x4_t *const out) { + int16x4_t a0 = in[0]; + int16x4_t a1 = in[1]; + int16x4_t a2 = in[2]; + int16x4_t a3 = in[3]; + + transpose_elems_inplace_s16_4x4(&a0, &a1, &a2, &a3); + + out[0] = a0; + out[1] = a1; + out[2] = a2; + out[3] = a3; +} + +static INLINE void transpose_arrays_s16_4x8(const int16x4_t *const in, + int16x8_t *const out) { +#if AOM_ARCH_AARCH64 + const int16x8_t a0 = vzip1q_s16(vcombine_s16(in[0], vdup_n_s16(0)), + vcombine_s16(in[1], vdup_n_s16(0))); + const int16x8_t a1 = vzip1q_s16(vcombine_s16(in[2], vdup_n_s16(0)), + vcombine_s16(in[3], vdup_n_s16(0))); + const int16x8_t a2 = vzip1q_s16(vcombine_s16(in[4], vdup_n_s16(0)), + vcombine_s16(in[5], vdup_n_s16(0))); + const int16x8_t a3 = vzip1q_s16(vcombine_s16(in[6], vdup_n_s16(0)), + vcombine_s16(in[7], vdup_n_s16(0))); +#else + int16x4x2_t temp; + temp = vzip_s16(in[0], in[1]); + const int16x8_t a0 = vcombine_s16(temp.val[0], temp.val[1]); + temp = vzip_s16(in[2], in[3]); + const int16x8_t a1 = vcombine_s16(temp.val[0], temp.val[1]); + temp = vzip_s16(in[4], in[5]); + const int16x8_t a2 = vcombine_s16(temp.val[0], temp.val[1]); + temp = vzip_s16(in[6], in[7]); + const int16x8_t a3 = vcombine_s16(temp.val[0], temp.val[1]); +#endif + + const int32x4x2_t b02 = + vzipq_s32(vreinterpretq_s32_s16(a0), vreinterpretq_s32_s16(a1)); + const int32x4x2_t b13 = + vzipq_s32(vreinterpretq_s32_s16(a2), vreinterpretq_s32_s16(a3)); + +#if AOM_ARCH_AARCH64 + out[0] = vreinterpretq_s16_s64(vzip1q_s64(vreinterpretq_s64_s32(b02.val[0]), + vreinterpretq_s64_s32(b13.val[0]))); + out[1] = vreinterpretq_s16_s64(vzip2q_s64(vreinterpretq_s64_s32(b02.val[0]), + vreinterpretq_s64_s32(b13.val[0]))); + out[2] = vreinterpretq_s16_s64(vzip1q_s64(vreinterpretq_s64_s32(b02.val[1]), + vreinterpretq_s64_s32(b13.val[1]))); + out[3] = vreinterpretq_s16_s64(vzip2q_s64(vreinterpretq_s64_s32(b02.val[1]), + vreinterpretq_s64_s32(b13.val[1]))); +#else + out[0] = vreinterpretq_s16_s32( + vextq_s32(vextq_s32(b02.val[0], b02.val[0], 2), b13.val[0], 2)); + out[2] = vreinterpretq_s16_s32( + vextq_s32(vextq_s32(b02.val[1], b02.val[1], 2), b13.val[1], 2)); + out[1] = vreinterpretq_s16_s32( + vextq_s32(b02.val[0], vextq_s32(b13.val[0], b13.val[0], 2), 2)); + out[3] = vreinterpretq_s16_s32( + vextq_s32(b02.val[1], vextq_s32(b13.val[1], b13.val[1], 2), 2)); +#endif +} + +static INLINE void transpose_arrays_s16_8x4(const int16x8_t *const in, + int16x4_t *const out) { + // Swap 16 bit elements. Goes from: + // in[0]: 00 01 02 03 04 05 06 07 + // in[1]: 10 11 12 13 14 15 16 17 + // in[2]: 20 21 22 23 24 25 26 27 + // in[3]: 30 31 32 33 34 35 36 37 + // to: + // b0.val[0]: 00 10 02 12 04 14 06 16 + // b0.val[1]: 01 11 03 13 05 15 07 17 + // b1.val[0]: 20 30 22 32 24 34 26 36 + // b1.val[1]: 21 31 23 33 25 35 27 37 + + const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]); + const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]); + + // Swap 32 bit elements resulting in: + // c0.val[0]: 00 10 20 30 04 14 24 34 + // c0.val[1]: 02 12 22 32 06 16 26 36 + // c1.val[0]: 01 11 21 31 05 15 25 35 + // c1.val[1]: 03 13 23 33 07 17 27 37 + + const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_s16(b0.val[0]), + vreinterpretq_u32_s16(b1.val[0])); + const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_s16(b0.val[1]), + vreinterpretq_u32_s16(b1.val[1])); + + // Unpack 64 bit elements resulting in: + // out[0]: 00 10 20 30 + // out[1]: 01 11 21 31 + // out[2]: 02 12 22 32 + // out[3]: 03 13 23 33 + // out[4]: 04 14 24 34 + // out[5]: 05 15 25 35 + // out[6]: 06 16 26 36 + // out[7]: 07 17 27 37 + + out[0] = vget_low_s16(vreinterpretq_s16_u32(c0.val[0])); + out[1] = vget_low_s16(vreinterpretq_s16_u32(c1.val[0])); + out[2] = vget_low_s16(vreinterpretq_s16_u32(c0.val[1])); + out[3] = vget_low_s16(vreinterpretq_s16_u32(c1.val[1])); + out[4] = vget_high_s16(vreinterpretq_s16_u32(c0.val[0])); + out[5] = vget_high_s16(vreinterpretq_s16_u32(c1.val[0])); + out[6] = vget_high_s16(vreinterpretq_s16_u32(c0.val[1])); + out[7] = vget_high_s16(vreinterpretq_s16_u32(c1.val[1])); +} + +#endif // AOM_AOM_DSP_ARM_TRANSPOSE_NEON_H_ diff --git a/third_party/aom/aom_dsp/arm/variance_neon.c b/third_party/aom/aom_dsp/arm/variance_neon.c new file mode 100644 index 0000000000..9e4e8c0cf0 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/variance_neon.c @@ -0,0 +1,470 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_ports/mem.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE void variance_4xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, int h, + uint32_t *sse, int *sum) { + int16x8_t sum_s16 = vdupq_n_s16(0); + int32x4_t sse_s32 = vdupq_n_s32(0); + + // Number of rows we can process before 'sum_s16' overflows: + // 32767 / 255 ~= 128, but we use an 8-wide accumulator; so 256 4-wide rows. + assert(h <= 256); + + int i = h; + do { + uint8x8_t s = load_unaligned_u8(src, src_stride); + uint8x8_t r = load_unaligned_u8(ref, ref_stride); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s, r)); + + sum_s16 = vaddq_s16(sum_s16, diff); + + sse_s32 = vmlal_s16(sse_s32, vget_low_s16(diff), vget_low_s16(diff)); + sse_s32 = vmlal_s16(sse_s32, vget_high_s16(diff), vget_high_s16(diff)); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + + *sum = horizontal_add_s16x8(sum_s16); + *sse = (uint32_t)horizontal_add_s32x4(sse_s32); +} + +static INLINE void variance_8xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, int h, + uint32_t *sse, int *sum) { + int16x8_t sum_s16 = vdupq_n_s16(0); + int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + // Number of rows we can process before 'sum_s16' overflows: + // 32767 / 255 ~= 128 + assert(h <= 128); + + int i = h; + do { + uint8x8_t s = vld1_u8(src); + uint8x8_t r = vld1_u8(ref); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s, r)); + + sum_s16 = vaddq_s16(sum_s16, diff); + + sse_s32[0] = vmlal_s16(sse_s32[0], vget_low_s16(diff), vget_low_s16(diff)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff), vget_high_s16(diff)); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + *sum = horizontal_add_s16x8(sum_s16); + *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); +} + +static INLINE void variance_16xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, int h, + uint32_t *sse, int *sum) { + int16x8_t sum_s16[2] = { vdupq_n_s16(0), vdupq_n_s16(0) }; + int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + // Number of rows we can process before 'sum_s16' accumulators overflow: + // 32767 / 255 ~= 128, so 128 16-wide rows. + assert(h <= 128); + + int i = h; + do { + uint8x16_t s = vld1q_u8(src); + uint8x16_t r = vld1q_u8(ref); + + int16x8_t diff_l = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(s), vget_low_u8(r))); + int16x8_t diff_h = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(s), vget_high_u8(r))); + + sum_s16[0] = vaddq_s16(sum_s16[0], diff_l); + sum_s16[1] = vaddq_s16(sum_s16[1], diff_h); + + sse_s32[0] = + vmlal_s16(sse_s32[0], vget_low_s16(diff_l), vget_low_s16(diff_l)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff_l), vget_high_s16(diff_l)); + sse_s32[0] = + vmlal_s16(sse_s32[0], vget_low_s16(diff_h), vget_low_s16(diff_h)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff_h), vget_high_s16(diff_h)); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + *sum = horizontal_add_s16x8(vaddq_s16(sum_s16[0], sum_s16[1])); + *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); +} + +static INLINE void variance_large_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int w, int h, int h_limit, uint32_t *sse, + int *sum) { + int32x4_t sum_s32 = vdupq_n_s32(0); + int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + // 'h_limit' is the number of 'w'-width rows we can process before our 16-bit + // accumulator overflows. After hitting this limit we accumulate into 32-bit + // elements. + int h_tmp = h > h_limit ? h_limit : h; + + int i = 0; + do { + int16x8_t sum_s16[2] = { vdupq_n_s16(0), vdupq_n_s16(0) }; + do { + int j = 0; + do { + uint8x16_t s = vld1q_u8(src + j); + uint8x16_t r = vld1q_u8(ref + j); + + int16x8_t diff_l = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(s), vget_low_u8(r))); + int16x8_t diff_h = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(s), vget_high_u8(r))); + + sum_s16[0] = vaddq_s16(sum_s16[0], diff_l); + sum_s16[1] = vaddq_s16(sum_s16[1], diff_h); + + sse_s32[0] = + vmlal_s16(sse_s32[0], vget_low_s16(diff_l), vget_low_s16(diff_l)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff_l), vget_high_s16(diff_l)); + sse_s32[0] = + vmlal_s16(sse_s32[0], vget_low_s16(diff_h), vget_low_s16(diff_h)); + sse_s32[1] = + vmlal_s16(sse_s32[1], vget_high_s16(diff_h), vget_high_s16(diff_h)); + + j += 16; + } while (j < w); + + src += src_stride; + ref += ref_stride; + i++; + } while (i < h_tmp); + + sum_s32 = vpadalq_s16(sum_s32, sum_s16[0]); + sum_s32 = vpadalq_s16(sum_s32, sum_s16[1]); + + h_tmp += h_limit; + } while (i < h); + + *sum = horizontal_add_s32x4(sum_s32); + *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_s32[0], sse_s32[1])); +} + +static INLINE void variance_32xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, int h, + uint32_t *sse, int *sum) { + variance_large_neon(src, src_stride, ref, ref_stride, 32, h, 64, sse, sum); +} + +static INLINE void variance_64xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, int h, + uint32_t *sse, int *sum) { + variance_large_neon(src, src_stride, ref, ref_stride, 64, h, 32, sse, sum); +} + +static INLINE void variance_128xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int h, uint32_t *sse, int *sum) { + variance_large_neon(src, src_stride, ref, ref_stride, 128, h, 16, sse, sum); +} + +#define VARIANCE_WXH_NEON(w, h, shift) \ + unsigned int aom_variance##w##x##h##_neon( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + unsigned int *sse) { \ + int sum; \ + variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse, &sum); \ + return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \ + } + +VARIANCE_WXH_NEON(4, 4, 4) +VARIANCE_WXH_NEON(4, 8, 5) +VARIANCE_WXH_NEON(4, 16, 6) + +VARIANCE_WXH_NEON(8, 4, 5) +VARIANCE_WXH_NEON(8, 8, 6) +VARIANCE_WXH_NEON(8, 16, 7) +VARIANCE_WXH_NEON(8, 32, 8) + +VARIANCE_WXH_NEON(16, 4, 6) +VARIANCE_WXH_NEON(16, 8, 7) +VARIANCE_WXH_NEON(16, 16, 8) +VARIANCE_WXH_NEON(16, 32, 9) +VARIANCE_WXH_NEON(16, 64, 10) + +VARIANCE_WXH_NEON(32, 8, 8) +VARIANCE_WXH_NEON(32, 16, 9) +VARIANCE_WXH_NEON(32, 32, 10) +VARIANCE_WXH_NEON(32, 64, 11) + +VARIANCE_WXH_NEON(64, 16, 10) +VARIANCE_WXH_NEON(64, 32, 11) +VARIANCE_WXH_NEON(64, 64, 12) +VARIANCE_WXH_NEON(64, 128, 13) + +VARIANCE_WXH_NEON(128, 64, 13) +VARIANCE_WXH_NEON(128, 128, 14) + +#undef VARIANCE_WXH_NEON + +// TODO(yunqingwang): Perform variance of two/four 8x8 blocks similar to that of +// AVX2. Also, implement the NEON for variance computation present in this +// function. +void aom_get_var_sse_sum_8x8_quad_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + uint32_t *sse8x8, int *sum8x8, + unsigned int *tot_sse, int *tot_sum, + uint32_t *var8x8) { + // Loop over four 8x8 blocks. Process one 8x32 block. + for (int k = 0; k < 4; k++) { + variance_8xh_neon(src + (k * 8), src_stride, ref + (k * 8), ref_stride, 8, + &sse8x8[k], &sum8x8[k]); + } + + *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3]; + *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3]; + for (int i = 0; i < 4; i++) { + var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6); + } +} + +void aom_get_var_sse_sum_16x16_dual_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + uint32_t *sse16x16, + unsigned int *tot_sse, int *tot_sum, + uint32_t *var16x16) { + int sum16x16[2] = { 0 }; + // Loop over two 16x16 blocks. Process one 16x32 block. + for (int k = 0; k < 2; k++) { + variance_16xh_neon(src + (k * 16), src_stride, ref + (k * 16), ref_stride, + 16, &sse16x16[k], &sum16x16[k]); + } + + *tot_sse += sse16x16[0] + sse16x16[1]; + *tot_sum += sum16x16[0] + sum16x16[1]; + for (int i = 0; i < 2; i++) { + var16x16[i] = + sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8); + } +} + +static INLINE unsigned int mse8xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse, int h) { + uint8x8_t s[2], r[2]; + int16x4_t diff_lo[2], diff_hi[2]; + uint16x8_t diff[2]; + int32x4_t sse_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + int i = h; + do { + s[0] = vld1_u8(src); + src += src_stride; + s[1] = vld1_u8(src); + src += src_stride; + r[0] = vld1_u8(ref); + ref += ref_stride; + r[1] = vld1_u8(ref); + ref += ref_stride; + + diff[0] = vsubl_u8(s[0], r[0]); + diff[1] = vsubl_u8(s[1], r[1]); + + diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0])); + diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1])); + sse_s32[0] = vmlal_s16(sse_s32[0], diff_lo[0], diff_lo[0]); + sse_s32[1] = vmlal_s16(sse_s32[1], diff_lo[1], diff_lo[1]); + + diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0])); + diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1])); + sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]); + sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]); + + i -= 2; + } while (i != 0); + + sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]); + + *sse = horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); + return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); +} + +static INLINE unsigned int mse16xh_neon(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse, int h) { + uint8x16_t s[2], r[2]; + int16x4_t diff_lo[4], diff_hi[4]; + uint16x8_t diff[4]; + int32x4_t sse_s32[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), + vdupq_n_s32(0) }; + + int i = h; + do { + s[0] = vld1q_u8(src); + src += src_stride; + s[1] = vld1q_u8(src); + src += src_stride; + r[0] = vld1q_u8(ref); + ref += ref_stride; + r[1] = vld1q_u8(ref); + ref += ref_stride; + + diff[0] = vsubl_u8(vget_low_u8(s[0]), vget_low_u8(r[0])); + diff[1] = vsubl_u8(vget_high_u8(s[0]), vget_high_u8(r[0])); + diff[2] = vsubl_u8(vget_low_u8(s[1]), vget_low_u8(r[1])); + diff[3] = vsubl_u8(vget_high_u8(s[1]), vget_high_u8(r[1])); + + diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0])); + diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1])); + sse_s32[0] = vmlal_s16(sse_s32[0], diff_lo[0], diff_lo[0]); + sse_s32[1] = vmlal_s16(sse_s32[1], diff_lo[1], diff_lo[1]); + + diff_lo[2] = vreinterpret_s16_u16(vget_low_u16(diff[2])); + diff_lo[3] = vreinterpret_s16_u16(vget_low_u16(diff[3])); + sse_s32[2] = vmlal_s16(sse_s32[2], diff_lo[2], diff_lo[2]); + sse_s32[3] = vmlal_s16(sse_s32[3], diff_lo[3], diff_lo[3]); + + diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0])); + diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1])); + sse_s32[0] = vmlal_s16(sse_s32[0], diff_hi[0], diff_hi[0]); + sse_s32[1] = vmlal_s16(sse_s32[1], diff_hi[1], diff_hi[1]); + + diff_hi[2] = vreinterpret_s16_u16(vget_high_u16(diff[2])); + diff_hi[3] = vreinterpret_s16_u16(vget_high_u16(diff[3])); + sse_s32[2] = vmlal_s16(sse_s32[2], diff_hi[2], diff_hi[2]); + sse_s32[3] = vmlal_s16(sse_s32[3], diff_hi[3], diff_hi[3]); + + i -= 2; + } while (i != 0); + + sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[1]); + sse_s32[2] = vaddq_s32(sse_s32[2], sse_s32[3]); + sse_s32[0] = vaddq_s32(sse_s32[0], sse_s32[2]); + + *sse = horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); + return horizontal_add_u32x4(vreinterpretq_u32_s32(sse_s32[0])); +} + +#define MSE_WXH_NEON(w, h) \ + unsigned int aom_mse##w##x##h##_neon(const uint8_t *src, int src_stride, \ + const uint8_t *ref, int ref_stride, \ + unsigned int *sse) { \ + return mse##w##xh_neon(src, src_stride, ref, ref_stride, sse, h); \ + } + +MSE_WXH_NEON(8, 8) +MSE_WXH_NEON(8, 16) + +MSE_WXH_NEON(16, 8) +MSE_WXH_NEON(16, 16) + +#undef MSE_WXH_NEON + +static INLINE uint64x2_t mse_accumulate_u16_u8_8x2(uint64x2_t sum, + uint16x8_t s0, uint16x8_t s1, + uint8x8_t d0, uint8x8_t d1) { + int16x8_t e0 = vreinterpretq_s16_u16(vsubw_u8(s0, d0)); + int16x8_t e1 = vreinterpretq_s16_u16(vsubw_u8(s1, d1)); + + int32x4_t mse = vmull_s16(vget_low_s16(e0), vget_low_s16(e0)); + mse = vmlal_s16(mse, vget_high_s16(e0), vget_high_s16(e0)); + mse = vmlal_s16(mse, vget_low_s16(e1), vget_low_s16(e1)); + mse = vmlal_s16(mse, vget_high_s16(e1), vget_high_s16(e1)); + + return vpadalq_u32(sum, vreinterpretq_u32_s32(mse)); +} + +static uint64x2_t mse_wxh_16bit(uint8_t *dst, int dstride, const uint16_t *src, + int sstride, int w, int h) { + assert((w == 8 || w == 4) && (h == 8 || h == 4)); + + uint64x2_t sum = vdupq_n_u64(0); + + if (w == 8) { + do { + uint8x8_t d0 = vld1_u8(dst + 0 * dstride); + uint8x8_t d1 = vld1_u8(dst + 1 * dstride); + uint16x8_t s0 = vld1q_u16(src + 0 * sstride); + uint16x8_t s1 = vld1q_u16(src + 1 * sstride); + + sum = mse_accumulate_u16_u8_8x2(sum, s0, s1, d0, d1); + + dst += 2 * dstride; + src += 2 * sstride; + h -= 2; + } while (h != 0); + } else { + do { + uint8x8_t d0 = load_unaligned_u8_4x2(dst + 0 * dstride, dstride); + uint8x8_t d1 = load_unaligned_u8_4x2(dst + 2 * dstride, dstride); + uint16x8_t s0 = load_unaligned_u16_4x2(src + 0 * sstride, sstride); + uint16x8_t s1 = load_unaligned_u16_4x2(src + 2 * sstride, sstride); + + sum = mse_accumulate_u16_u8_8x2(sum, s0, s1, d0, d1); + + dst += 4 * dstride; + src += 4 * sstride; + h -= 4; + } while (h != 0); + } + + return sum; +} + +// Computes mse for a given block size. This function gets called for specific +// block sizes, which are 8x8, 8x4, 4x8 and 4x4. +uint64_t aom_mse_wxh_16bit_neon(uint8_t *dst, int dstride, uint16_t *src, + int sstride, int w, int h) { + return horizontal_add_u64x2(mse_wxh_16bit(dst, dstride, src, sstride, w, h)); +} + +uint32_t aom_get_mb_ss_neon(const int16_t *a) { + int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + + for (int i = 0; i < 256; i = i + 8) { + int16x8_t a_s16 = vld1q_s16(a + i); + + sse[0] = vmlal_s16(sse[0], vget_low_s16(a_s16), vget_low_s16(a_s16)); + sse[1] = vmlal_s16(sse[1], vget_high_s16(a_s16), vget_high_s16(a_s16)); + } + + return horizontal_add_s32x4(vaddq_s32(sse[0], sse[1])); +} + +uint64_t aom_mse_16xh_16bit_neon(uint8_t *dst, int dstride, uint16_t *src, + int w, int h) { + uint64x2_t sum = vdupq_n_u64(0); + + int num_blks = 16 / w; + do { + sum = vaddq_u64(sum, mse_wxh_16bit(dst, dstride, src, w, w, h)); + dst += w; + src += w * h; + } while (--num_blks != 0); + + return horizontal_add_u64x2(sum); +} diff --git a/third_party/aom/aom_dsp/arm/variance_neon_dotprod.c b/third_party/aom/aom_dsp/arm/variance_neon_dotprod.c new file mode 100644 index 0000000000..9fb52e1df7 --- /dev/null +++ b/third_party/aom/aom_dsp/arm/variance_neon_dotprod.c @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_ports/mem.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE void variance_4xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int h, uint32_t *sse, int *sum) { + uint32x4_t src_sum = vdupq_n_u32(0); + uint32x4_t ref_sum = vdupq_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h; + do { + uint8x16_t s = load_unaligned_u8q(src, src_stride); + uint8x16_t r = load_unaligned_u8q(ref, ref_stride); + + src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); + ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); + + uint8x16_t abs_diff = vabdq_u8(s, r); + sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); + + src += 4 * src_stride; + ref += 4 * ref_stride; + i -= 4; + } while (i != 0); + + int32x4_t sum_diff = + vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); + *sum = horizontal_add_s32x4(sum_diff); + *sse = horizontal_add_u32x4(sse_u32); +} + +static INLINE void variance_8xh_neon_dotprod(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + int h, uint32_t *sse, int *sum) { + uint32x4_t src_sum = vdupq_n_u32(0); + uint32x4_t ref_sum = vdupq_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h; + do { + uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride)); + uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride)); + + src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); + ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); + + uint8x16_t abs_diff = vabdq_u8(s, r); + sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + + int32x4_t sum_diff = + vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); + *sum = horizontal_add_s32x4(sum_diff); + *sse = horizontal_add_u32x4(sse_u32); +} + +static INLINE void variance_16xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, int h, + uint32_t *sse, int *sum) { + uint32x4_t src_sum = vdupq_n_u32(0); + uint32x4_t ref_sum = vdupq_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h; + do { + uint8x16_t s = vld1q_u8(src); + uint8x16_t r = vld1q_u8(ref); + + src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); + ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); + + uint8x16_t abs_diff = vabdq_u8(s, r); + sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + int32x4_t sum_diff = + vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); + *sum = horizontal_add_s32x4(sum_diff); + *sse = horizontal_add_u32x4(sse_u32); +} + +static INLINE void variance_large_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, int w, int h, + uint32_t *sse, int *sum) { + uint32x4_t src_sum = vdupq_n_u32(0); + uint32x4_t ref_sum = vdupq_n_u32(0); + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h; + do { + int j = 0; + do { + uint8x16_t s = vld1q_u8(src + j); + uint8x16_t r = vld1q_u8(ref + j); + + src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1)); + ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1)); + + uint8x16_t abs_diff = vabdq_u8(s, r); + sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); + + j += 16; + } while (j < w); + + src += src_stride; + ref += ref_stride; + } while (--i != 0); + + int32x4_t sum_diff = + vsubq_s32(vreinterpretq_s32_u32(src_sum), vreinterpretq_s32_u32(ref_sum)); + *sum = horizontal_add_s32x4(sum_diff); + *sse = horizontal_add_u32x4(sse_u32); +} + +static INLINE void variance_32xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, int h, + uint32_t *sse, int *sum) { + variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 32, h, sse, + sum); +} + +static INLINE void variance_64xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, int h, + uint32_t *sse, int *sum) { + variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 64, h, sse, + sum); +} + +static INLINE void variance_128xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, int h, + uint32_t *sse, int *sum) { + variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 128, h, sse, + sum); +} + +#define VARIANCE_WXH_NEON_DOTPROD(w, h, shift) \ + unsigned int aom_variance##w##x##h##_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + unsigned int *sse) { \ + int sum; \ + variance_##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, h, sse, \ + &sum); \ + return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \ + } + +VARIANCE_WXH_NEON_DOTPROD(4, 4, 4) +VARIANCE_WXH_NEON_DOTPROD(4, 8, 5) +VARIANCE_WXH_NEON_DOTPROD(4, 16, 6) + +VARIANCE_WXH_NEON_DOTPROD(8, 4, 5) +VARIANCE_WXH_NEON_DOTPROD(8, 8, 6) +VARIANCE_WXH_NEON_DOTPROD(8, 16, 7) +VARIANCE_WXH_NEON_DOTPROD(8, 32, 8) + +VARIANCE_WXH_NEON_DOTPROD(16, 4, 6) +VARIANCE_WXH_NEON_DOTPROD(16, 8, 7) +VARIANCE_WXH_NEON_DOTPROD(16, 16, 8) +VARIANCE_WXH_NEON_DOTPROD(16, 32, 9) +VARIANCE_WXH_NEON_DOTPROD(16, 64, 10) + +VARIANCE_WXH_NEON_DOTPROD(32, 8, 8) +VARIANCE_WXH_NEON_DOTPROD(32, 16, 9) +VARIANCE_WXH_NEON_DOTPROD(32, 32, 10) +VARIANCE_WXH_NEON_DOTPROD(32, 64, 11) + +VARIANCE_WXH_NEON_DOTPROD(64, 16, 10) +VARIANCE_WXH_NEON_DOTPROD(64, 32, 11) +VARIANCE_WXH_NEON_DOTPROD(64, 64, 12) +VARIANCE_WXH_NEON_DOTPROD(64, 128, 13) + +VARIANCE_WXH_NEON_DOTPROD(128, 64, 13) +VARIANCE_WXH_NEON_DOTPROD(128, 128, 14) + +#undef VARIANCE_WXH_NEON_DOTPROD + +void aom_get_var_sse_sum_8x8_quad_neon_dotprod( + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, + uint32_t *sse8x8, int *sum8x8, unsigned int *tot_sse, int *tot_sum, + uint32_t *var8x8) { + // Loop over four 8x8 blocks. Process one 8x32 block. + for (int k = 0; k < 4; k++) { + variance_8xh_neon_dotprod(src + (k * 8), src_stride, ref + (k * 8), + ref_stride, 8, &sse8x8[k], &sum8x8[k]); + } + + *tot_sse += sse8x8[0] + sse8x8[1] + sse8x8[2] + sse8x8[3]; + *tot_sum += sum8x8[0] + sum8x8[1] + sum8x8[2] + sum8x8[3]; + for (int i = 0; i < 4; i++) { + var8x8[i] = sse8x8[i] - (uint32_t)(((int64_t)sum8x8[i] * sum8x8[i]) >> 6); + } +} + +void aom_get_var_sse_sum_16x16_dual_neon_dotprod( + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, + uint32_t *sse16x16, unsigned int *tot_sse, int *tot_sum, + uint32_t *var16x16) { + int sum16x16[2] = { 0 }; + // Loop over two 16x16 blocks. Process one 16x32 block. + for (int k = 0; k < 2; k++) { + variance_16xh_neon_dotprod(src + (k * 16), src_stride, ref + (k * 16), + ref_stride, 16, &sse16x16[k], &sum16x16[k]); + } + + *tot_sse += sse16x16[0] + sse16x16[1]; + *tot_sum += sum16x16[0] + sum16x16[1]; + for (int i = 0; i < 2; i++) { + var16x16[i] = + sse16x16[i] - (uint32_t)(((int64_t)sum16x16[i] * sum16x16[i]) >> 8); + } +} + +static INLINE unsigned int mse8xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, + unsigned int *sse, int h) { + uint32x4_t sse_u32 = vdupq_n_u32(0); + + int i = h; + do { + uint8x16_t s = vcombine_u8(vld1_u8(src), vld1_u8(src + src_stride)); + uint8x16_t r = vcombine_u8(vld1_u8(ref), vld1_u8(ref + ref_stride)); + + uint8x16_t abs_diff = vabdq_u8(s, r); + + sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + + *sse = horizontal_add_u32x4(sse_u32); + return horizontal_add_u32x4(sse_u32); +} + +static INLINE unsigned int mse16xh_neon_dotprod(const uint8_t *src, + int src_stride, + const uint8_t *ref, + int ref_stride, + unsigned int *sse, int h) { + uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int i = h; + do { + uint8x16_t s0 = vld1q_u8(src); + uint8x16_t s1 = vld1q_u8(src + src_stride); + uint8x16_t r0 = vld1q_u8(ref); + uint8x16_t r1 = vld1q_u8(ref + ref_stride); + + uint8x16_t abs_diff0 = vabdq_u8(s0, r0); + uint8x16_t abs_diff1 = vabdq_u8(s1, r1); + + sse_u32[0] = vdotq_u32(sse_u32[0], abs_diff0, abs_diff0); + sse_u32[1] = vdotq_u32(sse_u32[1], abs_diff1, abs_diff1); + + src += 2 * src_stride; + ref += 2 * ref_stride; + i -= 2; + } while (i != 0); + + *sse = horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); + return horizontal_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1])); +} + +#define MSE_WXH_NEON_DOTPROD(w, h) \ + unsigned int aom_mse##w##x##h##_neon_dotprod( \ + const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \ + unsigned int *sse) { \ + return mse##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, sse, h); \ + } + +MSE_WXH_NEON_DOTPROD(8, 8) +MSE_WXH_NEON_DOTPROD(8, 16) + +MSE_WXH_NEON_DOTPROD(16, 8) +MSE_WXH_NEON_DOTPROD(16, 16) + +#undef MSE_WXH_NEON_DOTPROD -- cgit v1.2.3