From a90a5cba08fdf6c0ceb95101c275108a152a3aed Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 12 Jun 2024 07:35:37 +0200 Subject: Merging upstream version 127.0. Signed-off-by: Daniel Baumann --- third_party/aom/aom/aom_image.h | 36 +- third_party/aom/aom/src/aom_image.c | 43 +- third_party/aom/aom_dsp/aom_dsp.cmake | 3 + third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl | 2 +- third_party/aom/aom_dsp/arm/aom_convolve8_neon.c | 401 +++++++++----- .../aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c | 428 +++++++-------- .../aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c | 334 ++++++------ .../aom/aom_dsp/flow_estimation/arm/disflow_neon.c | 104 +--- .../aom/aom_dsp/flow_estimation/arm/disflow_neon.h | 127 +++++ .../aom/aom_dsp/flow_estimation/arm/disflow_sve.c | 268 ++++++++++ third_party/aom/aom_dsp/pyramid.c | 31 +- third_party/aom/aom_dsp/x86/synonyms.h | 1 - third_party/aom/aom_util/aom_pthread.h | 1 + third_party/aom/aom_util/aom_thread.h | 2 - third_party/aom/av1/av1.cmake | 2 + third_party/aom/av1/av1_cx_iface.c | 1 + .../common/arm/compound_convolve_neon_dotprod.c | 55 +- .../aom/av1/common/arm/convolve_neon_dotprod.c | 49 +- third_party/aom/av1/common/av1_rtcd_defs.pl | 7 +- third_party/aom/av1/common/resize.c | 58 +- third_party/aom/av1/common/resize.h | 10 + third_party/aom/av1/common/x86/resize_avx2.c | 411 ++++++++++++++ .../aom/av1/encoder/arm/neon/highbd_pickrst_neon.c | 5 +- third_party/aom/av1/encoder/arm/neon/pickrst_sve.c | 590 +++++++++++++++++++++ third_party/aom/av1/encoder/enc_enums.h | 4 + third_party/aom/av1/encoder/encodeframe.c | 4 +- third_party/aom/av1/encoder/encoder.h | 2 +- third_party/aom/av1/encoder/ethread.c | 7 +- third_party/aom/av1/encoder/global_motion.h | 7 +- third_party/aom/av1/encoder/nonrd_pickmode.c | 34 +- third_party/aom/av1/encoder/partition_search.c | 20 +- third_party/aom/av1/encoder/picklpf.c | 2 + third_party/aom/av1/encoder/pickrst.c | 21 +- third_party/aom/av1/encoder/speed_features.c | 2 +- third_party/aom/av1/encoder/tune_vmaf.c | 4 +- third_party/aom/av1/encoder/x86/pickrst_avx2.c | 12 +- third_party/aom/av1/encoder/x86/pickrst_sse4.c | 18 +- third_party/aom/test/aom_image_test.cc | 65 +++ third_party/aom/test/disflow_test.cc | 5 + third_party/aom/test/ethread_test.cc | 5 +- third_party/aom/test/frame_resize_test.cc | 157 ++++++ third_party/aom/test/test.cmake | 1 + third_party/aom/test/wiener_test.cc | 61 ++- 43 files changed, 2581 insertions(+), 819 deletions(-) create mode 100644 third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h create mode 100644 third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c create mode 100644 third_party/aom/av1/common/x86/resize_avx2.c create mode 100644 third_party/aom/av1/encoder/arm/neon/pickrst_sve.c create mode 100644 third_party/aom/test/frame_resize_test.cc (limited to 'third_party/aom') diff --git a/third_party/aom/aom/aom_image.h b/third_party/aom/aom/aom_image.h index d5f0c087e6..68fb312222 100644 --- a/third_party/aom/aom/aom_image.h +++ b/third_party/aom/aom/aom_image.h @@ -103,7 +103,8 @@ typedef enum aom_transfer_characteristics { AOM_CICP_TC_SMPTE_428 = 17, /**< SMPTE ST 428 */ AOM_CICP_TC_HLG = 18, /**< BT.2100 HLG, ARIB STD-B67 */ AOM_CICP_TC_RESERVED_19 = 19 /**< For future use (values 19-255) */ -} aom_transfer_characteristics_t; /**< alias for enum aom_transfer_function */ +} aom_transfer_characteristics_t; /**< alias for enum + aom_transfer_characteristics */ /*!\brief List of supported matrix coefficients */ typedef enum aom_matrix_coefficients { @@ -125,7 +126,7 @@ typedef enum aom_matrix_coefficients { AOM_CICP_MC_CHROMAT_CL = 13, /**< Chromaticity-derived constant luminance */ AOM_CICP_MC_ICTCP = 14, /**< BT.2100 ICtCp */ AOM_CICP_MC_RESERVED_15 = 15 /**< For future use (values 15-255) */ -} aom_matrix_coefficients_t; +} aom_matrix_coefficients_t; /**< alias for enum aom_matrix_coefficients */ /*!\brief List of supported color range */ typedef enum aom_color_range { @@ -144,7 +145,8 @@ typedef enum aom_chroma_sample_position { /**< sample, between two vertical samples */ AOM_CSP_COLOCATED = 2, /**< Co-located with luma(0, 0) sample */ AOM_CSP_RESERVED = 3 /**< Reserved value */ -} aom_chroma_sample_position_t; /**< alias for enum aom_transfer_function */ +} aom_chroma_sample_position_t; /**< alias for enum aom_chroma_sample_position + */ /*!\brief List of insert flags for Metadata * @@ -244,10 +246,13 @@ typedef struct aom_image { * is NULL, the storage for the descriptor will be * allocated on the heap. * \param[in] fmt Format for the image - * \param[in] d_w Width of the image - * \param[in] d_h Height of the image + * \param[in] d_w Width of the image. Must not exceed 0x08000000 + * (2^27). + * \param[in] d_h Height of the image. Must not exceed 0x08000000 + * (2^27). * \param[in] align Alignment, in bytes, of the image buffer and - * each row in the image (stride). + * each row in the image (stride). Must not exceed + * 65536. * * \return Returns a pointer to the initialized image descriptor. If the img * parameter is non-null, the value of the img parameter will be @@ -267,10 +272,12 @@ aom_image_t *aom_img_alloc(aom_image_t *img, aom_img_fmt_t fmt, * is NULL, the storage for the descriptor will be * allocated on the heap. * \param[in] fmt Format for the image - * \param[in] d_w Width of the image - * \param[in] d_h Height of the image + * \param[in] d_w Width of the image. Must not exceed 0x08000000 + * (2^27). + * \param[in] d_h Height of the image. Must not exceed 0x08000000 + * (2^27). * \param[in] align Alignment, in bytes, of each row in the image - * (stride). + * (stride). Must not exceed 65536. * \param[in] img_data Storage to use for the image * * \return Returns a pointer to the initialized image descriptor. If the img @@ -291,12 +298,17 @@ aom_image_t *aom_img_wrap(aom_image_t *img, aom_img_fmt_t fmt, unsigned int d_w, * is NULL, the storage for the descriptor will be * allocated on the heap. * \param[in] fmt Format for the image - * \param[in] d_w Width of the image - * \param[in] d_h Height of the image + * \param[in] d_w Width of the image. Must not exceed 0x08000000 + * (2^27). + * \param[in] d_h Height of the image. Must not exceed 0x08000000 + * (2^27). * \param[in] align Alignment, in bytes, of the image buffer and - * each row in the image (stride). + * each row in the image (stride). Must not exceed + * 65536. * \param[in] size_align Alignment, in pixels, of the image width and height. + * Must not exceed 65536. * \param[in] border A border that is padded on four sides of the image. + * Must not exceed 65536. * * \return Returns a pointer to the initialized image descriptor. If the img * parameter is non-null, the value of the img parameter will be diff --git a/third_party/aom/aom/src/aom_image.c b/third_party/aom/aom/src/aom_image.c index 3b1c33d056..1d3b7df245 100644 --- a/third_party/aom/aom/src/aom_image.c +++ b/third_party/aom/aom/src/aom_image.c @@ -9,6 +9,7 @@ * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ +#include #include #include #include @@ -36,13 +37,20 @@ static aom_image_t *img_alloc_helper( /* NOTE: In this function, bit_depth is either 8 or 16 (if * AOM_IMG_FMT_HIGHBITDEPTH is set), never 10 or 12. */ - unsigned int h, w, s, xcs, ycs, bps, bit_depth; - unsigned int stride_in_bytes; + unsigned int xcs, ycs, bps, bit_depth; if (img != NULL) memset(img, 0, sizeof(aom_image_t)); if (fmt == AOM_IMG_FMT_NONE) goto fail; + /* Impose maximum values on input parameters so that this function can + * perform arithmetic operations without worrying about overflows. + */ + if (d_w > 0x08000000 || d_h > 0x08000000 || buf_align > 65536 || + stride_align > 65536 || size_align > 65536 || border > 65536) { + goto fail; + } + /* Treat align==0 like align==1 */ if (!buf_align) buf_align = 1; @@ -105,12 +113,17 @@ static aom_image_t *img_alloc_helper( } /* Calculate storage sizes given the chroma subsampling */ - w = align_image_dimension(d_w, xcs, size_align); - h = align_image_dimension(d_h, ycs, size_align); - - s = (fmt & AOM_IMG_FMT_PLANAR) ? w : bps * w / bit_depth; - s = (s + 2 * border + stride_align - 1) & ~(stride_align - 1); - stride_in_bytes = s * bit_depth / 8; + const unsigned int w = align_image_dimension(d_w, xcs, size_align); + assert(d_w <= w); + const unsigned int h = align_image_dimension(d_h, ycs, size_align); + assert(d_h <= h); + + uint64_t s = (uint64_t)w + 2 * border; + s = (fmt & AOM_IMG_FMT_PLANAR) ? s : s * bps / bit_depth; + s = s * bit_depth / 8; + s = (s + stride_align - 1) & ~((uint64_t)stride_align - 1); + if (s > INT_MAX) goto fail; + const int stride_in_bytes = (int)s; /* Allocate the new image */ if (!img) { @@ -232,7 +245,7 @@ int aom_img_set_rect(aom_image_t *img, unsigned int x, unsigned int y, img->planes[AOM_PLANE_Y] = data + x * bytes_per_sample + y * img->stride[AOM_PLANE_Y]; - data += (img->h + 2 * border) * img->stride[AOM_PLANE_Y]; + data += ((size_t)img->h + 2 * border) * img->stride[AOM_PLANE_Y]; unsigned int uv_border_h = border >> img->y_chroma_shift; unsigned int uv_x = x >> img->x_chroma_shift; @@ -244,14 +257,14 @@ int aom_img_set_rect(aom_image_t *img, unsigned int x, unsigned int y, } else if (!(img->fmt & AOM_IMG_FMT_UV_FLIP)) { img->planes[AOM_PLANE_U] = data + uv_x * bytes_per_sample + uv_y * img->stride[AOM_PLANE_U]; - data += ((img->h >> img->y_chroma_shift) + 2 * uv_border_h) * + data += ((size_t)(img->h >> img->y_chroma_shift) + 2 * uv_border_h) * img->stride[AOM_PLANE_U]; img->planes[AOM_PLANE_V] = data + uv_x * bytes_per_sample + uv_y * img->stride[AOM_PLANE_V]; } else { img->planes[AOM_PLANE_V] = data + uv_x * bytes_per_sample + uv_y * img->stride[AOM_PLANE_V]; - data += ((img->h >> img->y_chroma_shift) + 2 * uv_border_h) * + data += ((size_t)(img->h >> img->y_chroma_shift) + 2 * uv_border_h) * img->stride[AOM_PLANE_V]; img->planes[AOM_PLANE_U] = data + uv_x * bytes_per_sample + uv_y * img->stride[AOM_PLANE_U]; @@ -291,15 +304,15 @@ void aom_img_free(aom_image_t *img) { } int aom_img_plane_width(const aom_image_t *img, int plane) { - if (plane > 0 && img->x_chroma_shift > 0) - return (img->d_w + 1) >> img->x_chroma_shift; + if (plane > 0) + return (img->d_w + img->x_chroma_shift) >> img->x_chroma_shift; else return img->d_w; } int aom_img_plane_height(const aom_image_t *img, int plane) { - if (plane > 0 && img->y_chroma_shift > 0) - return (img->d_h + 1) >> img->y_chroma_shift; + if (plane > 0) + return (img->d_h + img->y_chroma_shift) >> img->y_chroma_shift; else return img->d_h; } diff --git a/third_party/aom/aom_dsp/aom_dsp.cmake b/third_party/aom/aom_dsp/aom_dsp.cmake index de987cbd23..27099d36b2 100644 --- a/third_party/aom/aom_dsp/aom_dsp.cmake +++ b/third_party/aom/aom_dsp/aom_dsp.cmake @@ -205,6 +205,9 @@ if(CONFIG_AV1_ENCODER) list(APPEND AOM_DSP_ENCODER_INTRIN_NEON "${AOM_ROOT}/aom_dsp/flow_estimation/arm/disflow_neon.c") + + list(APPEND AOM_DSP_ENCODER_INTRIN_SVE + "${AOM_ROOT}/aom_dsp/flow_estimation/arm/disflow_sve.c") endif() list(APPEND AOM_DSP_ENCODER_ASM_SSE2 "${AOM_ROOT}/aom_dsp/x86/sad4d_sse2.asm" diff --git a/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl b/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl index 7e746e9cb9..b75bdc5a19 100755 --- a/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl +++ b/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl @@ -1799,7 +1799,7 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") { specialize qw/aom_compute_correlation sse4_1 avx2/; add_proto qw/void aom_compute_flow_at_point/, "const uint8_t *src, const uint8_t *ref, int x, int y, int width, int height, int stride, double *u, double *v"; - specialize qw/aom_compute_flow_at_point sse4_1 avx2 neon/; + specialize qw/aom_compute_flow_at_point sse4_1 avx2 neon sve/; } } # CONFIG_AV1_ENCODER diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c index 7441108b01..6a177b2e6b 100644 --- a/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c +++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c @@ -20,6 +20,7 @@ #include "aom/aom_integer.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/aom_filter.h" +#include "aom_dsp/arm/aom_filter.h" #include "aom_dsp/arm/mem_neon.h" #include "aom_dsp/arm/transpose_neon.h" #include "aom_ports/mem.h" @@ -31,14 +32,14 @@ static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1, const int16x8_t filter) { const int16x4_t filter_lo = vget_low_s16(filter); const int16x4_t filter_hi = vget_high_s16(filter); - int16x4_t sum; - sum = vmul_lane_s16(s0, filter_lo, 0); + int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0); sum = vmla_lane_s16(sum, s1, filter_lo, 1); sum = vmla_lane_s16(sum, s2, filter_lo, 2); sum = vmla_lane_s16(sum, s5, filter_hi, 1); sum = vmla_lane_s16(sum, s6, filter_hi, 2); sum = vmla_lane_s16(sum, s7, filter_hi, 3); + sum = vqadd_s16(sum, vmul_lane_s16(s3, filter_lo, 3)); sum = vqadd_s16(sum, vmul_lane_s16(s4, filter_hi, 0)); return sum; @@ -51,65 +52,56 @@ static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1, const int16x8_t filter) { const int16x4_t filter_lo = vget_low_s16(filter); const int16x4_t filter_hi = vget_high_s16(filter); - int16x8_t sum; - sum = vmulq_lane_s16(s0, filter_lo, 0); + int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0); sum = vmlaq_lane_s16(sum, s1, filter_lo, 1); sum = vmlaq_lane_s16(sum, s2, filter_lo, 2); sum = vmlaq_lane_s16(sum, s5, filter_hi, 1); sum = vmlaq_lane_s16(sum, s6, filter_hi, 2); sum = vmlaq_lane_s16(sum, s7, filter_hi, 3); + sum = vqaddq_s16(sum, vmulq_lane_s16(s3, filter_lo, 3)); sum = vqaddq_s16(sum, vmulq_lane_s16(s4, filter_hi, 0)); return vqrshrun_n_s16(sum, FILTER_BITS); } -void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, - uint8_t *dst, ptrdiff_t dst_stride, - const int16_t *filter_x, int x_step_q4, - const int16_t *filter_y, int y_step_q4, int w, - int h) { +static INLINE void convolve8_horiz_8tap_neon(const uint8_t *src, + ptrdiff_t src_stride, uint8_t *dst, + ptrdiff_t dst_stride, + const int16_t *filter_x, int w, + int h) { const int16x8_t filter = vld1q_s16(filter_x); - assert((intptr_t)dst % 4 == 0); - assert(dst_stride % 4 == 0); - - (void)x_step_q4; - (void)filter_y; - (void)y_step_q4; - - src -= ((SUBPEL_TAPS / 2) - 1); - if (h == 4) { - uint8x8_t t0, t1, t2, t3, d01, d23; - int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3; - + uint8x8_t t0, t1, t2, t3; load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3); transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3); - s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); - s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); - s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); - s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); - s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); - s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); - s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + + int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + int16x4_t s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + int16x4_t s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + int16x4_t s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); src += 7; do { load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3); transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3); - s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); - s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); - s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); - s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); - - d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter); - d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter); - d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter); - d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter); - d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); - d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + int16x4_t s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + int16x4_t s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + int16x4_t s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + int16x4_t s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + + int16x4_t d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter); + int16x4_t d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter); + int16x4_t d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter); + int16x4_t d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter); + uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); transpose_elems_inplace_u8_4x4(&d01, &d23); @@ -123,39 +115,40 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, s4 = s8; s5 = s9; s6 = s10; + src += 4; dst += 4; w -= 4; } while (w != 0); } else { - uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7, d0, d1, d2, d3; - int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; - if (w == 4) { do { + uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7; load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); - s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); - s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); - s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); - s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); - s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); - s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); - s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); + + int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); + int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); + int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); + int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); + int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); + int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); + int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); load_u8_8x8(src + 7, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); transpose_elems_u8_4x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2, &t3); - s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); - s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); - s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); - s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); - d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); - d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); - d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); - d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); + int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); + int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); + int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); + int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); + + uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); + uint8x8_t d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); + uint8x8_t d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); + uint8x8_t d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); transpose_elems_inplace_u8_8x4(&d0, &d1, &d2, &d3); @@ -169,48 +162,49 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, h -= 8; } while (h > 0); } else { - uint8x8_t d4, d5, d6, d7; - int16x8_t s11, s12, s13, s14; - int width; - const uint8_t *s; - uint8_t *d; - do { - load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); + int width = w; + const uint8_t *s = src; + uint8_t *d = dst; + + uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7; + load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); - s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); - s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); - s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); - s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); - s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); - s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); - s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); - - width = w; - s = src + 7; - d = dst; + + int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); + int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); + int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); + int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); + int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); + int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); + int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); + + s += 7; do { load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7); - s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); - s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); - s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); - s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); - s11 = vreinterpretq_s16_u16(vmovl_u8(t4)); - s12 = vreinterpretq_s16_u16(vmovl_u8(t5)); - s13 = vreinterpretq_s16_u16(vmovl_u8(t6)); - s14 = vreinterpretq_s16_u16(vmovl_u8(t7)); - - d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); - d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); - d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); - d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); - d4 = convolve8_8(s4, s5, s6, s7, s8, s9, s10, s11, filter); - d5 = convolve8_8(s5, s6, s7, s8, s9, s10, s11, s12, filter); - d6 = convolve8_8(s6, s7, s8, s9, s10, s11, s12, s13, filter); - d7 = convolve8_8(s7, s8, s9, s10, s11, s12, s13, s14, filter); + + int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); + int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); + int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); + int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); + int16x8_t s11 = vreinterpretq_s16_u16(vmovl_u8(t4)); + int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t5)); + int16x8_t s13 = vreinterpretq_s16_u16(vmovl_u8(t6)); + int16x8_t s14 = vreinterpretq_s16_u16(vmovl_u8(t7)); + + uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); + uint8x8_t d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); + uint8x8_t d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); + uint8x8_t d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); + uint8x8_t d4 = convolve8_8(s4, s5, s6, s7, s8, s9, s10, s11, filter); + uint8x8_t d5 = convolve8_8(s5, s6, s7, s8, s9, s10, s11, s12, filter); + uint8x8_t d6 = + convolve8_8(s6, s7, s8, s9, s10, s11, s12, s13, filter); + uint8x8_t d7 = + convolve8_8(s7, s8, s9, s10, s11, s12, s13, s14, filter); transpose_elems_inplace_u8_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7); @@ -224,6 +218,7 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, s4 = s12; s5 = s13; s6 = s14; + s += 8; d += 8; width -= 8; @@ -236,6 +231,137 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, } } +static INLINE int16x4_t convolve4_4(const int16x4_t s0, const int16x4_t s1, + const int16x4_t s2, const int16x4_t s3, + const int16x4_t filter) { + int16x4_t sum = vmul_lane_s16(s0, filter, 0); + sum = vmla_lane_s16(sum, s1, filter, 1); + sum = vmla_lane_s16(sum, s2, filter, 2); + sum = vmla_lane_s16(sum, s3, filter, 3); + + return sum; +} + +static INLINE uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1, + const int16x8_t s2, const int16x8_t s3, + const int16x4_t filter) { + int16x8_t sum = vmulq_lane_s16(s0, filter, 0); + sum = vmlaq_lane_s16(sum, s1, filter, 1); + sum = vmlaq_lane_s16(sum, s2, filter, 2); + sum = vmlaq_lane_s16(sum, s3, filter, 3); + + // We halved the filter values so -1 from right shift. + return vqrshrun_n_s16(sum, FILTER_BITS - 1); +} + +static INLINE void convolve8_horiz_4tap_neon(const uint8_t *src, + ptrdiff_t src_stride, uint8_t *dst, + ptrdiff_t dst_stride, + const int16_t *filter_x, int w, + int h) { + // All filter values are even, halve to reduce intermediate precision + // requirements. + const int16x4_t filter = vshr_n_s16(vld1_s16(filter_x + 2), 1); + + if (w == 4) { + do { + int16x8_t t0 = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(src + 0 * src_stride))); + int16x8_t t1 = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(src + 1 * src_stride))); + + int16x4_t s0[4], s1[4]; + s0[0] = vget_low_s16(t0); + s0[1] = vget_low_s16(vextq_s16(t0, t0, 1)); + s0[2] = vget_low_s16(vextq_s16(t0, t0, 2)); + s0[3] = vget_low_s16(vextq_s16(t0, t0, 3)); + + s1[0] = vget_low_s16(t1); + s1[1] = vget_low_s16(vextq_s16(t1, t1, 1)); + s1[2] = vget_low_s16(vextq_s16(t1, t1, 2)); + s1[3] = vget_low_s16(vextq_s16(t1, t1, 3)); + + int16x4_t d0 = convolve4_4(s0[0], s0[1], s0[2], s0[3], filter); + int16x4_t d1 = convolve4_4(s1[0], s1[1], s1[2], s1[3], filter); + // We halved the filter values so -1 from right shift. + uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS - 1); + + store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); + + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h > 0); + } else { + do { + int width = w; + const uint8_t *s = src; + uint8_t *d = dst; + + int16x8_t t0 = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 0 * src_stride))); + int16x8_t t1 = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 1 * src_stride))); + + s += 8; + do { + int16x8_t t2 = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 0 * src_stride))); + int16x8_t t3 = + vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 1 * src_stride))); + + int16x8_t s0[4], s1[4]; + s0[0] = t0; + s0[1] = vextq_s16(t0, t2, 1); + s0[2] = vextq_s16(t0, t2, 2); + s0[3] = vextq_s16(t0, t2, 3); + + s1[0] = t1; + s1[1] = vextq_s16(t1, t3, 1); + s1[2] = vextq_s16(t1, t3, 2); + s1[3] = vextq_s16(t1, t3, 3); + + uint8x8_t d0 = convolve4_8(s0[0], s0[1], s0[2], s0[3], filter); + uint8x8_t d1 = convolve4_8(s1[0], s1[1], s1[2], s1[3], filter); + + store_u8_8x2(d, dst_stride, d0, d1); + + t0 = t2; + t1 = t3; + + s += 8; + d += 8; + width -= 8; + } while (width != 0); + src += 2 * src_stride; + dst += 2 * dst_stride; + h -= 2; + } while (h > 0); + } +} + +void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const int16_t *filter_x, int x_step_q4, + const int16_t *filter_y, int y_step_q4, int w, + int h) { + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + + (void)x_step_q4; + (void)filter_y; + (void)y_step_q4; + + src -= ((SUBPEL_TAPS / 2) - 1); + + if (get_filter_taps_convolve8(filter_x) <= 4) { + convolve8_horiz_4tap_neon(src + 2, src_stride, dst, dst_stride, filter_x, w, + h); + } else { + convolve8_horiz_8tap_neon(src, src_stride, dst, dst_stride, filter_x, w, h); + } +} + void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const int16_t *filter_x, int x_step_q4, @@ -253,33 +379,33 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride, src -= ((SUBPEL_TAPS / 2) - 1) * src_stride; if (w == 4) { - uint8x8_t t0, t1, t2, t3, t4, t5, t6, d01, d23; - int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3; - + uint8x8_t t0, t1, t2, t3, t4, t5, t6; load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); - s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); - s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); - s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); - s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); - s4 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t4))); - s5 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t5))); - s6 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t6))); + + int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + int16x4_t s4 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t4))); + int16x4_t s5 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t5))); + int16x4_t s6 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t6))); src += 7 * src_stride; do { load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3); - s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); - s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); - s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); - s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); - - d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter); - d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter); - d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter); - d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter); - d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); - d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + int16x4_t s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0))); + int16x4_t s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1))); + int16x4_t s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2))); + int16x4_t s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3))); + + int16x4_t d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter); + int16x4_t d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter); + int16x4_t d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter); + int16x4_t d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter); + uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); @@ -291,42 +417,40 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride, s4 = s8; s5 = s9; s6 = s10; + src += 4 * src_stride; dst += 4 * dst_stride; h -= 4; } while (h != 0); } else { - uint8x8_t t0, t1, t2, t3, t4, t5, t6, d0, d1, d2, d3; - int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10; - int height; - const uint8_t *s; - uint8_t *d; - do { + uint8x8_t t0, t1, t2, t3, t4, t5, t6; load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); - s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); - s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); - s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); - s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); - s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); - s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); - s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); - - height = h; - s = src + 7 * src_stride; - d = dst; + + int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0)); + int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1)); + int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2)); + int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3)); + int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4)); + int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5)); + int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6)); + + int height = h; + const uint8_t *s = src + 7 * src_stride; + uint8_t *d = dst; do { load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3); - s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); - s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); - s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); - s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); - d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); - d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); - d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); - d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); + int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0)); + int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1)); + int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2)); + int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3)); + + uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter); + uint8x8_t d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter); + uint8x8_t d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter); + uint8x8_t d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter); store_u8_8x4(d, dst_stride, d0, d1, d2, d3); @@ -337,6 +461,7 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride, s4 = s8; s5 = s9; s6 = s10; + s += 4 * src_stride; d += 4 * dst_stride; height -= 4; diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c index c82125ba17..120c479798 100644 --- a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c +++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c @@ -24,81 +24,72 @@ #include "aom_dsp/arm/transpose_neon.h" #include "aom_ports/mem.h" -DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = { +// Filter values always sum to 128. +#define FILTER_WEIGHT 128 + +DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }; -DECLARE_ALIGNED(16, static const uint8_t, dot_prod_tran_concat_tbl[32]) = { - 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, - 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 -}; - -DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = { - /* Shift left and insert new last column in transposed 4x4 block. */ +DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = { + // Shift left and insert new last column in transposed 4x4 block. 1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28, - /* Shift left and insert two new columns in transposed 4x4 block. */ + // Shift left and insert two new columns in transposed 4x4 block. 2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29, - /* Shift left and insert three new columns in transposed 4x4 block. */ + // Shift left and insert three new columns in transposed 4x4 block. 3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30 }; -static INLINE int16x4_t convolve8_4_sdot(uint8x16_t samples, - const int8x8_t filter, - const int32x4_t correction, - const uint8x16_t range_limit, - const uint8x16x2_t permute_tbl) { - int8x16_t clamped_samples, permuted_samples[2]; - int32x4_t sum; - - /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ - clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit)); - - /* Permute samples ready for dot product. */ - /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ - permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]); - /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ - permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]); - - /* Accumulate dot product into 'correction' to account for range clamp. */ - sum = vdotq_lane_s32(correction, permuted_samples[0], filter, 0); - sum = vdotq_lane_s32(sum, permuted_samples[1], filter, 1); - - /* Further narrowing and packing is performed by the caller. */ +static INLINE int16x4_t convolve8_4_h(const uint8x16_t samples, + const int8x8_t filters, + const uint8x16x2_t permute_tbl) { + // Transform sample range to [-128, 127] for 8-bit signed dot product. + int8x16_t samples_128 = + vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128))); + + // Permute samples ready for dot product. + // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } + // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } + int8x16_t perm_samples[2] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]), + vqtbl1q_s8(samples_128, permute_tbl.val[1]) }; + + // Accumulate into 128 * FILTER_WEIGHT to account for range transform. + int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT); + int32x4_t sum = vdotq_lane_s32(acc, perm_samples[0], filters, 0); + sum = vdotq_lane_s32(sum, perm_samples[1], filters, 1); + + // Further narrowing and packing is performed by the caller. return vqmovn_s32(sum); } -static INLINE uint8x8_t convolve8_8_sdot(uint8x16_t samples, - const int8x8_t filter, - const int32x4_t correction, - const uint8x16_t range_limit, - const uint8x16x3_t permute_tbl) { - int8x16_t clamped_samples, permuted_samples[3]; - int32x4_t sum0, sum1; - int16x8_t sum; - - /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ - clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit)); - - /* Permute samples ready for dot product. */ - /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ - permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]); - /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ - permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]); - /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */ - permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]); - - /* Accumulate dot product into 'correction' to account for range clamp. */ - /* First 4 output values. */ - sum0 = vdotq_lane_s32(correction, permuted_samples[0], filter, 0); - sum0 = vdotq_lane_s32(sum0, permuted_samples[1], filter, 1); - /* Second 4 output values. */ - sum1 = vdotq_lane_s32(correction, permuted_samples[1], filter, 0); - sum1 = vdotq_lane_s32(sum1, permuted_samples[2], filter, 1); - - /* Narrow and re-pack. */ - sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); +static INLINE uint8x8_t convolve8_8_h(const uint8x16_t samples, + const int8x8_t filters, + const uint8x16x3_t permute_tbl) { + // Transform sample range to [-128, 127] for 8-bit signed dot product. + int8x16_t samples_128 = + vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128))); + + // Permute samples ready for dot product. + // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } + // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } + // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } + int8x16_t perm_samples[3] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]), + vqtbl1q_s8(samples_128, permute_tbl.val[1]), + vqtbl1q_s8(samples_128, permute_tbl.val[2]) }; + + // Accumulate into 128 * FILTER_WEIGHT to account for range transform. + int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT); + // First 4 output values. + int32x4_t sum0 = vdotq_lane_s32(acc, perm_samples[0], filters, 0); + sum0 = vdotq_lane_s32(sum0, perm_samples[1], filters, 1); + // Second 4 output values. + int32x4_t sum1 = vdotq_lane_s32(acc, perm_samples[1], filters, 0); + sum1 = vdotq_lane_s32(sum1, perm_samples[2], filters, 1); + + // Narrow and re-pack. + int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); return vqrshrun_n_s16(sum, FILTER_BITS); } @@ -108,10 +99,6 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, const int16_t *filter_y, int y_step_q4, int w, int h) { const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x)); - const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_x), 128); - const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp)); - const uint8x16_t range_limit = vdupq_n_u8(128); - uint8x16_t s0, s1, s2, s3; assert((intptr_t)dst % 4 == 0); assert(dst_stride % 4 == 0); @@ -123,19 +110,17 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, src -= ((SUBPEL_TAPS / 2) - 1); if (w == 4) { - const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl); + const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl); do { - int16x4_t t0, t1, t2, t3; - uint8x8_t d01, d23; - + uint8x16_t s0, s1, s2, s3; load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); - t0 = convolve8_4_sdot(s0, filter, correction, range_limit, perm_tbl); - t1 = convolve8_4_sdot(s1, filter, correction, range_limit, perm_tbl); - t2 = convolve8_4_sdot(s2, filter, correction, range_limit, perm_tbl); - t3 = convolve8_4_sdot(s3, filter, correction, range_limit, perm_tbl); - d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS); - d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS); + int16x4_t d0 = convolve8_4_h(s0, filter, perm_tbl); + int16x4_t d1 = convolve8_4_h(s1, filter, perm_tbl); + int16x4_t d2 = convolve8_4_h(s2, filter, perm_tbl); + int16x4_t d3 = convolve8_4_h(s3, filter, perm_tbl); + uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); @@ -145,23 +130,20 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, h -= 4; } while (h > 0); } else { - const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); - const uint8_t *s; - uint8_t *d; - int width; - uint8x8_t d0, d1, d2, d3; + const uint8x16x3_t perm_tbl = vld1q_u8_x3(kDotProdPermuteTbl); do { - width = w; - s = src; - d = dst; + int width = w; + const uint8_t *s = src; + uint8_t *d = dst; do { + uint8x16_t s0, s1, s2, s3; load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); - d0 = convolve8_8_sdot(s0, filter, correction, range_limit, perm_tbl); - d1 = convolve8_8_sdot(s1, filter, correction, range_limit, perm_tbl); - d2 = convolve8_8_sdot(s2, filter, correction, range_limit, perm_tbl); - d3 = convolve8_8_sdot(s3, filter, correction, range_limit, perm_tbl); + uint8x8_t d0 = convolve8_8_h(s0, filter, perm_tbl); + uint8x8_t d1 = convolve8_8_h(s1, filter, perm_tbl); + uint8x8_t d2 = convolve8_8_h(s2, filter, perm_tbl); + uint8x8_t d3 = convolve8_8_h(s3, filter, perm_tbl); store_u8_8x4(d, dst_stride, d0, d1, d2, d3); @@ -177,83 +159,88 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, } static INLINE void transpose_concat_4x4(int8x8_t a0, int8x8_t a1, int8x8_t a2, - int8x8_t a3, int8x16_t *b, - const uint8x16_t permute_tbl) { - /* Transpose 8-bit elements and concatenate result rows as follows: - * a0: 00, 01, 02, 03, XX, XX, XX, XX - * a1: 10, 11, 12, 13, XX, XX, XX, XX - * a2: 20, 21, 22, 23, XX, XX, XX, XX - * a3: 30, 31, 32, 33, XX, XX, XX, XX - * - * b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 - * - * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it - * as an argument is preferable to loading it directly from memory as this - * inline helper is called many times from the same parent function. - */ - - int8x16x2_t samples = { { vcombine_s8(a0, a1), vcombine_s8(a2, a3) } }; - *b = vqtbl2q_s8(samples, permute_tbl); + int8x8_t a3, int8x16_t *b) { + // Transpose 8-bit elements and concatenate result rows as follows: + // a0: 00, 01, 02, 03, XX, XX, XX, XX + // a1: 10, 11, 12, 13, XX, XX, XX, XX + // a2: 20, 21, 22, 23, XX, XX, XX, XX + // a3: 30, 31, 32, 33, XX, XX, XX, XX + // + // b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + + int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0)); + int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0)); + int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0)); + int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0)); + + int8x16_t a01 = vzipq_s8(a0q, a1q).val[0]; + int8x16_t a23 = vzipq_s8(a2q, a3q).val[0]; + + int16x8_t a0123 = + vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23)).val[0]; + + *b = vreinterpretq_s8_s16(a0123); } static INLINE void transpose_concat_8x4(int8x8_t a0, int8x8_t a1, int8x8_t a2, int8x8_t a3, int8x16_t *b0, - int8x16_t *b1, - const uint8x16x2_t permute_tbl) { - /* Transpose 8-bit elements and concatenate result rows as follows: - * a0: 00, 01, 02, 03, 04, 05, 06, 07 - * a1: 10, 11, 12, 13, 14, 15, 16, 17 - * a2: 20, 21, 22, 23, 24, 25, 26, 27 - * a3: 30, 31, 32, 33, 34, 35, 36, 37 - * - * b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 - * b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37 - * - * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it - * as an argument is preferable to loading it directly from memory as this - * inline helper is called many times from the same parent function. - */ - - int8x16x2_t samples = { { vcombine_s8(a0, a1), vcombine_s8(a2, a3) } }; - *b0 = vqtbl2q_s8(samples, permute_tbl.val[0]); - *b1 = vqtbl2q_s8(samples, permute_tbl.val[1]); + int8x16_t *b1) { + // Transpose 8-bit elements and concatenate result rows as follows: + // a0: 00, 01, 02, 03, 04, 05, 06, 07 + // a1: 10, 11, 12, 13, 14, 15, 16, 17 + // a2: 20, 21, 22, 23, 24, 25, 26, 27 + // a3: 30, 31, 32, 33, 34, 35, 36, 37 + // + // b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + // b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37 + + int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0)); + int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0)); + int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0)); + int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0)); + + int8x16_t a01 = vzipq_s8(a0q, a1q).val[0]; + int8x16_t a23 = vzipq_s8(a2q, a3q).val[0]; + + int16x8x2_t a0123 = + vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23)); + + *b0 = vreinterpretq_s8_s16(a0123.val[0]); + *b1 = vreinterpretq_s8_s16(a0123.val[1]); } -static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo, - const int8x16_t samples_hi, - const int32x4_t correction, - const int8x8_t filter) { - /* Sample range-clamping and permutation are performed by the caller. */ - int32x4_t sum; +static INLINE int16x4_t convolve8_4_v(const int8x16_t samples_lo, + const int8x16_t samples_hi, + const int8x8_t filters) { + // The sample range transform and permutation are performed by the caller. - /* Accumulate dot product into 'correction' to account for range clamp. */ - sum = vdotq_lane_s32(correction, samples_lo, filter, 0); - sum = vdotq_lane_s32(sum, samples_hi, filter, 1); + // Accumulate into 128 * FILTER_WEIGHT to account for range transform. + int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT); + int32x4_t sum = vdotq_lane_s32(acc, samples_lo, filters, 0); + sum = vdotq_lane_s32(sum, samples_hi, filters, 1); - /* Further narrowing and packing is performed by the caller. */ + // Further narrowing and packing is performed by the caller. return vqmovn_s32(sum); } -static INLINE uint8x8_t convolve8_8_sdot_partial(const int8x16_t samples0_lo, - const int8x16_t samples0_hi, - const int8x16_t samples1_lo, - const int8x16_t samples1_hi, - const int32x4_t correction, - const int8x8_t filter) { - /* Sample range-clamping and permutation are performed by the caller. */ - int32x4_t sum0, sum1; - int16x8_t sum; - - /* Accumulate dot product into 'correction' to account for range clamp. */ - /* First 4 output values. */ - sum0 = vdotq_lane_s32(correction, samples0_lo, filter, 0); - sum0 = vdotq_lane_s32(sum0, samples0_hi, filter, 1); - /* Second 4 output values. */ - sum1 = vdotq_lane_s32(correction, samples1_lo, filter, 0); - sum1 = vdotq_lane_s32(sum1, samples1_hi, filter, 1); - - /* Narrow and re-pack. */ - sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); +static INLINE uint8x8_t convolve8_8_v(const int8x16_t samples0_lo, + const int8x16_t samples0_hi, + const int8x16_t samples1_lo, + const int8x16_t samples1_hi, + const int8x8_t filters) { + // The sample range transform and permutation are performed by the caller. + + // Accumulate into 128 * FILTER_WEIGHT to account for range transform. + int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT); + // First 4 output values. + int32x4_t sum0 = vdotq_lane_s32(acc, samples0_lo, filters, 0); + sum0 = vdotq_lane_s32(sum0, samples0_hi, filters, 1); + // Second 4 output values. + int32x4_t sum1 = vdotq_lane_s32(acc, samples1_lo, filters, 0); + sum1 = vdotq_lane_s32(sum1, samples1_hi, filters, 1); + + // Narrow and re-pack. + int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); return vqrshrun_n_s16(sum, FILTER_BITS); } @@ -263,10 +250,7 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, const int16_t *filter_y, int y_step_q4, int w, int h) { const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y)); - const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_y), 128); - const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp)); - const uint8x8_t range_limit = vdup_n_u8(128); - const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl); + const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl); int8x16x2_t samples_LUT; assert((intptr_t)dst % 4 == 0); @@ -279,62 +263,58 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, src -= ((SUBPEL_TAPS / 2) - 1) * src_stride; if (w == 4) { - const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl); - uint8x8_t t0, t1, t2, t3, t4, t5, t6; load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); src += 7 * src_stride; - /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ - int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit)); - int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit)); - int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit)); - int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit)); - int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit)); - int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit)); - int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit)); - - /* This operation combines a conventional transpose and the sample permute - * (see horizontal case) required before computing the dot product. - */ + // Clamp sample range to [-128, 127] for 8-bit signed dot product. + int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128))); + int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128))); + int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128))); + int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128))); + int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128))); + int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128))); + int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128))); + + // This operation combines a conventional transpose and the sample permute + // (see horizontal case) required before computing the dot product. int8x16_t s0123, s1234, s2345, s3456; - transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl); - transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl); - transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl); - transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl); + transpose_concat_4x4(s0, s1, s2, s3, &s0123); + transpose_concat_4x4(s1, s2, s3, s4, &s1234); + transpose_concat_4x4(s2, s3, s4, s5, &s2345); + transpose_concat_4x4(s3, s4, s5, s6, &s3456); do { uint8x8_t t7, t8, t9, t10; load_u8_8x4(src, src_stride, &t7, &t8, &t9, &t10); - int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit)); - int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit)); - int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit)); - int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit)); + int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128))); + int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128))); + int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128))); + int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128))); int8x16_t s4567, s5678, s6789, s78910; - transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl); + transpose_concat_4x4(s7, s8, s9, s10, &s78910); - /* Merge new data into block from previous iteration. */ + // Merge new data into block from previous iteration. samples_LUT.val[0] = s3456; samples_LUT.val[1] = s78910; s4567 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]); s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]); s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]); - int16x4_t d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filter); - int16x4_t d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filter); - int16x4_t d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filter); - int16x4_t d3 = - convolve8_4_sdot_partial(s3456, s78910, correction, filter); + int16x4_t d0 = convolve8_4_v(s0123, s4567, filter); + int16x4_t d1 = convolve8_4_v(s1234, s5678, filter); + int16x4_t d2 = convolve8_4_v(s2345, s6789, filter); + int16x4_t d3 = convolve8_4_v(s3456, s78910, filter); uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); - /* Prepare block for next iteration - re-using as much as possible. */ - /* Shuffle everything up four rows. */ + // Prepare block for next iteration - re-using as much as possible. + // Shuffle everything up four rows. s0123 = s4567; s1234 = s5678; s2345 = s6789; @@ -345,8 +325,6 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, h -= 4; } while (h != 0); } else { - const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl); - do { int height = h; const uint8_t *s = src; @@ -356,44 +334,38 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, load_u8_8x7(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6); s += 7 * src_stride; - /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */ - int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit)); - int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit)); - int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit)); - int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit)); - int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit)); - int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit)); - int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit)); - - /* This operation combines a conventional transpose and the sample permute - * (see horizontal case) required before computing the dot product. - */ + // Clamp sample range to [-128, 127] for 8-bit signed dot product. + int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128))); + int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128))); + int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128))); + int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128))); + int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128))); + int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128))); + int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128))); + + // This operation combines a conventional transpose and the sample permute + // (see horizontal case) required before computing the dot product. int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi, s3456_lo, s3456_hi; - transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi, - tran_concat_tbl); - transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi, - tran_concat_tbl); - transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi, - tran_concat_tbl); - transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi, - tran_concat_tbl); + transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi); + transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi); + transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi); + transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi); do { uint8x8_t t7, t8, t9, t10; load_u8_8x4(s, src_stride, &t7, &t8, &t9, &t10); - int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit)); - int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit)); - int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit)); - int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit)); + int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128))); + int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128))); + int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128))); + int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128))); int8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi, s78910_lo, s78910_hi; - transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi, - tran_concat_tbl); + transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi); - /* Merge new data into block from previous iteration. */ + // Merge new data into block from previous iteration. samples_LUT.val[0] = s3456_lo; samples_LUT.val[1] = s78910_lo; s4567_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]); @@ -406,19 +378,19 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride, s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]); s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]); - uint8x8_t d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi, - s4567_hi, correction, filter); - uint8x8_t d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi, - s5678_hi, correction, filter); - uint8x8_t d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi, - s6789_hi, correction, filter); - uint8x8_t d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi, - s78910_hi, correction, filter); + uint8x8_t d0 = + convolve8_8_v(s0123_lo, s4567_lo, s0123_hi, s4567_hi, filter); + uint8x8_t d1 = + convolve8_8_v(s1234_lo, s5678_lo, s1234_hi, s5678_hi, filter); + uint8x8_t d2 = + convolve8_8_v(s2345_lo, s6789_lo, s2345_hi, s6789_hi, filter); + uint8x8_t d3 = + convolve8_8_v(s3456_lo, s78910_lo, s3456_hi, s78910_hi, filter); store_u8_8x4(d, dst_stride, d0, d1, d2, d3); - /* Prepare block for next iteration - re-using as much as possible. */ - /* Shuffle everything up four rows. */ + // Prepare block for next iteration - re-using as much as possible. + // Shuffle everything up four rows. s0123_lo = s4567_lo; s0123_hi = s4567_hi; s1234_lo = s5678_lo; diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c index df6e4d2ab5..68e031461d 100644 --- a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c +++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c @@ -23,69 +23,60 @@ #include "aom_dsp/arm/transpose_neon.h" #include "aom_ports/mem.h" -DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = { +DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }; -DECLARE_ALIGNED(16, static const uint8_t, dot_prod_tran_concat_tbl[32]) = { - 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, - 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 -}; - -DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = { - /* Shift left and insert new last column in transposed 4x4 block. */ +DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = { + // Shift left and insert new last column in transposed 4x4 block. 1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28, - /* Shift left and insert two new columns in transposed 4x4 block. */ + // Shift left and insert two new columns in transposed 4x4 block. 2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29, - /* Shift left and insert three new columns in transposed 4x4 block. */ + // Shift left and insert three new columns in transposed 4x4 block. 3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30 }; -static INLINE int16x4_t convolve8_4_usdot(const uint8x16_t samples, - const int8x8_t filter, - const uint8x16x2_t permute_tbl) { - uint8x16_t permuted_samples[2]; - int32x4_t sum; - - /* Permute samples ready for dot product. */ - /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ - permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]); - /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ - permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]); +static INLINE int16x4_t convolve8_4_h(const uint8x16_t samples, + const int8x8_t filters, + const uint8x16x2_t permute_tbl) { + // Permute samples ready for dot product. + // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } + // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } + uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]), + vqtbl1q_u8(samples, permute_tbl.val[1]) }; - sum = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filter, 0); - sum = vusdotq_lane_s32(sum, permuted_samples[1], filter, 1); + int32x4_t sum = + vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0); + sum = vusdotq_lane_s32(sum, permuted_samples[1], filters, 1); - /* Further narrowing and packing is performed by the caller. */ + // Further narrowing and packing is performed by the caller. return vqmovn_s32(sum); } -static INLINE uint8x8_t convolve8_8_usdot(const uint8x16_t samples, - const int8x8_t filter, - const uint8x16x3_t permute_tbl) { - uint8x16_t permuted_samples[3]; - int32x4_t sum0, sum1; - int16x8_t sum; - - /* Permute samples ready for dot product. */ - /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */ - permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]); - /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */ - permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]); - /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */ - permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]); - - /* First 4 output values. */ - sum0 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filter, 0); - sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filter, 1); - /* Second 4 output values. */ - sum1 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filter, 0); - sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filter, 1); - - /* Narrow and re-pack. */ - sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); +static INLINE uint8x8_t convolve8_8_h(const uint8x16_t samples, + const int8x8_t filters, + const uint8x16x3_t permute_tbl) { + // Permute samples ready for dot product. + // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } + // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } + // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } + uint8x16_t permuted_samples[3] = { vqtbl1q_u8(samples, permute_tbl.val[0]), + vqtbl1q_u8(samples, permute_tbl.val[1]), + vqtbl1q_u8(samples, permute_tbl.val[2]) }; + + // First 4 output values. + int32x4_t sum0 = + vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0); + sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filters, 1); + // Second 4 output values. + int32x4_t sum1 = + vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0); + sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filters, 1); + + // Narrow and re-pack. + int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); return vqrshrun_n_s16(sum, FILTER_BITS); } @@ -95,7 +86,6 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, const int16_t *filter_y, int y_step_q4, int w, int h) { const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x)); - uint8x16_t s0, s1, s2, s3; assert((intptr_t)dst % 4 == 0); assert(dst_stride % 4 == 0); @@ -107,19 +97,17 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, src -= ((SUBPEL_TAPS / 2) - 1); if (w == 4) { - const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl); + const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl); do { - int16x4_t t0, t1, t2, t3; - uint8x8_t d01, d23; - + uint8x16_t s0, s1, s2, s3; load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); - t0 = convolve8_4_usdot(s0, filter, perm_tbl); - t1 = convolve8_4_usdot(s1, filter, perm_tbl); - t2 = convolve8_4_usdot(s2, filter, perm_tbl); - t3 = convolve8_4_usdot(s3, filter, perm_tbl); - d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS); - d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS); + int16x4_t d0 = convolve8_4_h(s0, filter, perm_tbl); + int16x4_t d1 = convolve8_4_h(s1, filter, perm_tbl); + int16x4_t d2 = convolve8_4_h(s2, filter, perm_tbl); + int16x4_t d3 = convolve8_4_h(s3, filter, perm_tbl); + uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); @@ -129,23 +117,20 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, h -= 4; } while (h > 0); } else { - const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); - const uint8_t *s; - uint8_t *d; - int width; - uint8x8_t d0, d1, d2, d3; + const uint8x16x3_t perm_tbl = vld1q_u8_x3(kDotProdPermuteTbl); do { - width = w; - s = src; - d = dst; + int width = w; + const uint8_t *s = src; + uint8_t *d = dst; do { + uint8x16_t s0, s1, s2, s3; load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); - d0 = convolve8_8_usdot(s0, filter, perm_tbl); - d1 = convolve8_8_usdot(s1, filter, perm_tbl); - d2 = convolve8_8_usdot(s2, filter, perm_tbl); - d3 = convolve8_8_usdot(s3, filter, perm_tbl); + uint8x8_t d0 = convolve8_8_h(s0, filter, perm_tbl); + uint8x8_t d1 = convolve8_8_h(s1, filter, perm_tbl); + uint8x8_t d2 = convolve8_8_h(s2, filter, perm_tbl); + uint8x8_t d3 = convolve8_8_h(s3, filter, perm_tbl); store_u8_8x4(d, dst_stride, d0, d1, d2, d3); @@ -162,79 +147,83 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, static INLINE void transpose_concat_4x4(uint8x8_t a0, uint8x8_t a1, uint8x8_t a2, uint8x8_t a3, - uint8x16_t *b, - const uint8x16_t permute_tbl) { - /* Transpose 8-bit elements and concatenate result rows as follows: - * a0: 00, 01, 02, 03, XX, XX, XX, XX - * a1: 10, 11, 12, 13, XX, XX, XX, XX - * a2: 20, 21, 22, 23, XX, XX, XX, XX - * a3: 30, 31, 32, 33, XX, XX, XX, XX - * - * b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 - * - * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it - * as an argument is preferable to loading it directly from memory as this - * inline helper is called many times from the same parent function. - */ - - uint8x16x2_t samples = { { vcombine_u8(a0, a1), vcombine_u8(a2, a3) } }; - *b = vqtbl2q_u8(samples, permute_tbl); + uint8x16_t *b) { + // Transpose 8-bit elements and concatenate result rows as follows: + // a0: 00, 01, 02, 03, XX, XX, XX, XX + // a1: 10, 11, 12, 13, XX, XX, XX, XX + // a2: 20, 21, 22, 23, XX, XX, XX, XX + // a3: 30, 31, 32, 33, XX, XX, XX, XX + // + // b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + + uint8x16_t a0q = vcombine_u8(a0, vdup_n_u8(0)); + uint8x16_t a1q = vcombine_u8(a1, vdup_n_u8(0)); + uint8x16_t a2q = vcombine_u8(a2, vdup_n_u8(0)); + uint8x16_t a3q = vcombine_u8(a3, vdup_n_u8(0)); + + uint8x16_t a01 = vzipq_u8(a0q, a1q).val[0]; + uint8x16_t a23 = vzipq_u8(a2q, a3q).val[0]; + + uint16x8_t a0123 = + vzipq_u16(vreinterpretq_u16_u8(a01), vreinterpretq_u16_u8(a23)).val[0]; + + *b = vreinterpretq_u8_u16(a0123); } static INLINE void transpose_concat_8x4(uint8x8_t a0, uint8x8_t a1, uint8x8_t a2, uint8x8_t a3, - uint8x16_t *b0, uint8x16_t *b1, - const uint8x16x2_t permute_tbl) { - /* Transpose 8-bit elements and concatenate result rows as follows: - * a0: 00, 01, 02, 03, 04, 05, 06, 07 - * a1: 10, 11, 12, 13, 14, 15, 16, 17 - * a2: 20, 21, 22, 23, 24, 25, 26, 27 - * a3: 30, 31, 32, 33, 34, 35, 36, 37 - * - * b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 - * b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37 - * - * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it - * as an argument is preferable to loading it directly from memory as this - * inline helper is called many times from the same parent function. - */ - - uint8x16x2_t samples = { { vcombine_u8(a0, a1), vcombine_u8(a2, a3) } }; - *b0 = vqtbl2q_u8(samples, permute_tbl.val[0]); - *b1 = vqtbl2q_u8(samples, permute_tbl.val[1]); + uint8x16_t *b0, uint8x16_t *b1) { + // Transpose 8-bit elements and concatenate result rows as follows: + // a0: 00, 01, 02, 03, 04, 05, 06, 07 + // a1: 10, 11, 12, 13, 14, 15, 16, 17 + // a2: 20, 21, 22, 23, 24, 25, 26, 27 + // a3: 30, 31, 32, 33, 34, 35, 36, 37 + // + // b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33 + // b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37 + + uint8x16_t a0q = vcombine_u8(a0, vdup_n_u8(0)); + uint8x16_t a1q = vcombine_u8(a1, vdup_n_u8(0)); + uint8x16_t a2q = vcombine_u8(a2, vdup_n_u8(0)); + uint8x16_t a3q = vcombine_u8(a3, vdup_n_u8(0)); + + uint8x16_t a01 = vzipq_u8(a0q, a1q).val[0]; + uint8x16_t a23 = vzipq_u8(a2q, a3q).val[0]; + + uint16x8x2_t a0123 = + vzipq_u16(vreinterpretq_u16_u8(a01), vreinterpretq_u16_u8(a23)); + + *b0 = vreinterpretq_u8_u16(a0123.val[0]); + *b1 = vreinterpretq_u8_u16(a0123.val[1]); } -static INLINE int16x4_t convolve8_4_usdot_partial(const uint8x16_t samples_lo, - const uint8x16_t samples_hi, - const int8x8_t filter) { - /* Sample permutation is performed by the caller. */ - int32x4_t sum; - - sum = vusdotq_lane_s32(vdupq_n_s32(0), samples_lo, filter, 0); - sum = vusdotq_lane_s32(sum, samples_hi, filter, 1); +static INLINE int16x4_t convolve8_4_v(const uint8x16_t samples_lo, + const uint8x16_t samples_hi, + const int8x8_t filters) { + // Sample permutation is performed by the caller. + int32x4_t sum = vusdotq_lane_s32(vdupq_n_s32(0), samples_lo, filters, 0); + sum = vusdotq_lane_s32(sum, samples_hi, filters, 1); - /* Further narrowing and packing is performed by the caller. */ + // Further narrowing and packing is performed by the caller. return vqmovn_s32(sum); } -static INLINE uint8x8_t convolve8_8_usdot_partial(const uint8x16_t samples0_lo, - const uint8x16_t samples0_hi, - const uint8x16_t samples1_lo, - const uint8x16_t samples1_hi, - const int8x8_t filter) { - /* Sample permutation is performed by the caller. */ - int32x4_t sum0, sum1; - int16x8_t sum; - - /* First 4 output values. */ - sum0 = vusdotq_lane_s32(vdupq_n_s32(0), samples0_lo, filter, 0); - sum0 = vusdotq_lane_s32(sum0, samples0_hi, filter, 1); - /* Second 4 output values. */ - sum1 = vusdotq_lane_s32(vdupq_n_s32(0), samples1_lo, filter, 0); - sum1 = vusdotq_lane_s32(sum1, samples1_hi, filter, 1); - - /* Narrow and re-pack. */ - sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); +static INLINE uint8x8_t convolve8_8_v(const uint8x16_t samples0_lo, + const uint8x16_t samples0_hi, + const uint8x16_t samples1_lo, + const uint8x16_t samples1_hi, + const int8x8_t filters) { + // Sample permutation is performed by the caller. + + // First 4 output values. + int32x4_t sum0 = vusdotq_lane_s32(vdupq_n_s32(0), samples0_lo, filters, 0); + sum0 = vusdotq_lane_s32(sum0, samples0_hi, filters, 1); + // Second 4 output values. + int32x4_t sum1 = vusdotq_lane_s32(vdupq_n_s32(0), samples1_lo, filters, 0); + sum1 = vusdotq_lane_s32(sum1, samples1_hi, filters, 1); + + // Narrow and re-pack. + int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1)); return vqrshrun_n_s16(sum, FILTER_BITS); } @@ -244,7 +233,7 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, const int16_t *filter_y, int y_step_q4, int w, int h) { const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y)); - const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl); + const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl); uint8x16x2_t samples_LUT; assert((intptr_t)dst % 4 == 0); @@ -257,47 +246,44 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, src -= ((SUBPEL_TAPS / 2) - 1) * src_stride; if (w == 4) { - const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl); - uint8x8_t s0, s1, s2, s3, s4, s5, s6; load_u8_8x7(src, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); src += 7 * src_stride; - /* This operation combines a conventional transpose and the sample permute - * (see horizontal case) required before computing the dot product. - */ + // This operation combines a conventional transpose and the sample permute + // (see horizontal case) required before computing the dot product. uint8x16_t s0123, s1234, s2345, s3456; - transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl); - transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl); - transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl); - transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl); + transpose_concat_4x4(s0, s1, s2, s3, &s0123); + transpose_concat_4x4(s1, s2, s3, s4, &s1234); + transpose_concat_4x4(s2, s3, s4, s5, &s2345); + transpose_concat_4x4(s3, s4, s5, s6, &s3456); do { uint8x8_t s7, s8, s9, s10; load_u8_8x4(src, src_stride, &s7, &s8, &s9, &s10); uint8x16_t s4567, s5678, s6789, s78910; - transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl); + transpose_concat_4x4(s7, s8, s9, s10, &s78910); - /* Merge new data into block from previous iteration. */ + // Merge new data into block from previous iteration. samples_LUT.val[0] = s3456; samples_LUT.val[1] = s78910; s4567 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]); s5678 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]); s6789 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]); - int16x4_t d0 = convolve8_4_usdot_partial(s0123, s4567, filter); - int16x4_t d1 = convolve8_4_usdot_partial(s1234, s5678, filter); - int16x4_t d2 = convolve8_4_usdot_partial(s2345, s6789, filter); - int16x4_t d3 = convolve8_4_usdot_partial(s3456, s78910, filter); + int16x4_t d0 = convolve8_4_v(s0123, s4567, filter); + int16x4_t d1 = convolve8_4_v(s1234, s5678, filter); + int16x4_t d2 = convolve8_4_v(s2345, s6789, filter); + int16x4_t d3 = convolve8_4_v(s3456, s78910, filter); uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01); store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23); - /* Prepare block for next iteration - re-using as much as possible. */ - /* Shuffle everything up four rows. */ + // Prepare block for next iteration - re-using as much as possible. + // Shuffle everything up four rows. s0123 = s4567; s1234 = s5678; s2345 = s6789; @@ -308,8 +294,6 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, h -= 4; } while (h != 0); } else { - const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl); - do { int height = h; const uint8_t *s = src; @@ -319,19 +303,14 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, load_u8_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6); s += 7 * src_stride; - /* This operation combines a conventional transpose and the sample permute - * (see horizontal case) required before computing the dot product. - */ + // This operation combines a conventional transpose and the sample permute + // (see horizontal case) required before computing the dot product. uint8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi, s3456_lo, s3456_hi; - transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi, - tran_concat_tbl); - transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi, - tran_concat_tbl); - transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi, - tran_concat_tbl); - transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi, - tran_concat_tbl); + transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi); + transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi); + transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi); + transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi); do { uint8x8_t s7, s8, s9, s10; @@ -339,10 +318,9 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, uint8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi, s78910_lo, s78910_hi; - transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi, - tran_concat_tbl); + transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi); - /* Merge new data into block from previous iteration. */ + // Merge new data into block from previous iteration. samples_LUT.val[0] = s3456_lo; samples_LUT.val[1] = s78910_lo; s4567_lo = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]); @@ -355,19 +333,19 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride, s5678_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]); s6789_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]); - uint8x8_t d0 = convolve8_8_usdot_partial(s0123_lo, s4567_lo, s0123_hi, - s4567_hi, filter); - uint8x8_t d1 = convolve8_8_usdot_partial(s1234_lo, s5678_lo, s1234_hi, - s5678_hi, filter); - uint8x8_t d2 = convolve8_8_usdot_partial(s2345_lo, s6789_lo, s2345_hi, - s6789_hi, filter); - uint8x8_t d3 = convolve8_8_usdot_partial(s3456_lo, s78910_lo, s3456_hi, - s78910_hi, filter); + uint8x8_t d0 = + convolve8_8_v(s0123_lo, s4567_lo, s0123_hi, s4567_hi, filter); + uint8x8_t d1 = + convolve8_8_v(s1234_lo, s5678_lo, s1234_hi, s5678_hi, filter); + uint8x8_t d2 = + convolve8_8_v(s2345_lo, s6789_lo, s2345_hi, s6789_hi, filter); + uint8x8_t d3 = + convolve8_8_v(s3456_lo, s78910_lo, s3456_hi, s78910_hi, filter); store_u8_8x4(d, dst_stride, d0, d1, d2, d3); - /* Prepare block for next iteration - re-using as much as possible. */ - /* Shuffle everything up four rows. */ + // Prepare block for next iteration - re-using as much as possible. + // Shuffle everything up four rows. s0123_lo = s4567_lo; s0123_hi = s4567_hi; s1234_lo = s5678_lo; diff --git a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c index 62729133e3..5758d2887f 100644 --- a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c +++ b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c @@ -16,36 +16,10 @@ #include "aom_dsp/arm/mem_neon.h" #include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/flow_estimation/arm/disflow_neon.h" #include "config/aom_config.h" #include "config/aom_dsp_rtcd.h" -static INLINE void get_cubic_kernel_dbl(double x, double kernel[4]) { - // Check that the fractional position is in range. - // - // Note: x is calculated from, e.g., `u_frac = u - floor(u)`. - // Mathematically, this implies that 0 <= x < 1. However, in practice it is - // possible to have x == 1 due to floating point rounding. This is fine, - // and we still interpolate correctly if we allow x = 1. - assert(0 <= x && x <= 1); - - double x2 = x * x; - double x3 = x2 * x; - kernel[0] = -0.5 * x + x2 - 0.5 * x3; - kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3; - kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3; - kernel[3] = -0.5 * x2 + 0.5 * x3; -} - -static INLINE void get_cubic_kernel_int(double x, int kernel[4]) { - double kernel_dbl[4]; - get_cubic_kernel_dbl(x, kernel_dbl); - - kernel[0] = (int)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS)); - kernel[1] = (int)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS)); - kernel[2] = (int)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS)); - kernel[3] = (int)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS)); -} - // Compare two regions of width x height pixels, one rooted at position // (x, y) in src and the other at (x + u, y + v) in ref. // This function returns the sum of squared pixel differences between @@ -157,82 +131,6 @@ static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref, } } -static INLINE void sobel_filter_x(const uint8_t *src, int src_stride, - int16_t *dst, int dst_stride) { - int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)]; - - // Horizontal filter, using kernel {1, 0, -1}. - const uint8_t *src_start = src - 1 * src_stride - 1; - - for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) { - uint8x16_t s = vld1q_u8(src_start + i * src_stride); - uint8x8_t s0 = vget_low_u8(s); - uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2)); - - // Given that the kernel is {1, 0, -1} the convolution is a simple - // subtraction. - int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s0, s2)); - - vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, diff); - } - - // Vertical filter, using kernel {1, 2, 1}. - // This kernel can be split into two 2-taps kernels of value {1, 1}. - // That way we need only 3 add operations to perform the convolution, one of - // which can be reused for the next line. - int16x8_t s0 = vld1q_s16(tmp); - int16x8_t s1 = vld1q_s16(tmp + DISFLOW_PATCH_SIZE); - int16x8_t sum01 = vaddq_s16(s0, s1); - for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) { - int16x8_t s2 = vld1q_s16(tmp + (i + 2) * DISFLOW_PATCH_SIZE); - - int16x8_t sum12 = vaddq_s16(s1, s2); - int16x8_t sum = vaddq_s16(sum01, sum12); - - vst1q_s16(dst + i * dst_stride, sum); - - sum01 = sum12; - s1 = s2; - } -} - -static INLINE void sobel_filter_y(const uint8_t *src, int src_stride, - int16_t *dst, int dst_stride) { - int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)]; - - // Horizontal filter, using kernel {1, 2, 1}. - // This kernel can be split into two 2-taps kernels of value {1, 1}. - // That way we need only 3 add operations to perform the convolution. - const uint8_t *src_start = src - 1 * src_stride - 1; - - for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) { - uint8x16_t s = vld1q_u8(src_start + i * src_stride); - uint8x8_t s0 = vget_low_u8(s); - uint8x8_t s1 = vget_low_u8(vextq_u8(s, s, 1)); - uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2)); - - uint16x8_t sum01 = vaddl_u8(s0, s1); - uint16x8_t sum12 = vaddl_u8(s1, s2); - uint16x8_t sum = vaddq_u16(sum01, sum12); - - vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, vreinterpretq_s16_u16(sum)); - } - - // Vertical filter, using kernel {1, 0, -1}. - // Load the whole block at once to avoid redundant loads during convolution. - int16x8_t t[10]; - load_s16_8x10(tmp, DISFLOW_PATCH_SIZE, &t[0], &t[1], &t[2], &t[3], &t[4], - &t[5], &t[6], &t[7], &t[8], &t[9]); - - for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) { - // Given that the kernel is {1, 0, -1} the convolution is a simple - // subtraction. - int16x8_t diff = vsubq_s16(t[i], t[i + 2]); - - vst1q_s16(dst + i * dst_stride, diff); - } -} - // Computes the components of the system of equations used to solve for // a flow vector. // diff --git a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h new file mode 100644 index 0000000000..d991a13460 --- /dev/null +++ b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_AOM_DSP_FLOW_ESTIMATION_ARM_DISFLOW_NEON_H_ +#define AOM_AOM_DSP_FLOW_ESTIMATION_ARM_DISFLOW_NEON_H_ + +#include "aom_dsp/flow_estimation/disflow.h" + +#include +#include + +#include "aom_dsp/arm/mem_neon.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +static INLINE void get_cubic_kernel_dbl(double x, double kernel[4]) { + // Check that the fractional position is in range. + // + // Note: x is calculated from, e.g., `u_frac = u - floor(u)`. + // Mathematically, this implies that 0 <= x < 1. However, in practice it is + // possible to have x == 1 due to floating point rounding. This is fine, + // and we still interpolate correctly if we allow x = 1. + assert(0 <= x && x <= 1); + + double x2 = x * x; + double x3 = x2 * x; + kernel[0] = -0.5 * x + x2 - 0.5 * x3; + kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3; + kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3; + kernel[3] = -0.5 * x2 + 0.5 * x3; +} + +static INLINE void get_cubic_kernel_int(double x, int kernel[4]) { + double kernel_dbl[4]; + get_cubic_kernel_dbl(x, kernel_dbl); + + kernel[0] = (int)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS)); + kernel[1] = (int)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS)); + kernel[2] = (int)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS)); + kernel[3] = (int)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS)); +} + +static INLINE void sobel_filter_x(const uint8_t *src, int src_stride, + int16_t *dst, int dst_stride) { + int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)]; + + // Horizontal filter, using kernel {1, 0, -1}. + const uint8_t *src_start = src - 1 * src_stride - 1; + + for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) { + uint8x16_t s = vld1q_u8(src_start + i * src_stride); + uint8x8_t s0 = vget_low_u8(s); + uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2)); + + // Given that the kernel is {1, 0, -1} the convolution is a simple + // subtraction. + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s0, s2)); + + vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, diff); + } + + // Vertical filter, using kernel {1, 2, 1}. + // This kernel can be split into two 2-taps kernels of value {1, 1}. + // That way we need only 3 add operations to perform the convolution, one of + // which can be reused for the next line. + int16x8_t s0 = vld1q_s16(tmp); + int16x8_t s1 = vld1q_s16(tmp + DISFLOW_PATCH_SIZE); + int16x8_t sum01 = vaddq_s16(s0, s1); + for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) { + int16x8_t s2 = vld1q_s16(tmp + (i + 2) * DISFLOW_PATCH_SIZE); + + int16x8_t sum12 = vaddq_s16(s1, s2); + int16x8_t sum = vaddq_s16(sum01, sum12); + + vst1q_s16(dst + i * dst_stride, sum); + + sum01 = sum12; + s1 = s2; + } +} + +static INLINE void sobel_filter_y(const uint8_t *src, int src_stride, + int16_t *dst, int dst_stride) { + int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)]; + + // Horizontal filter, using kernel {1, 2, 1}. + // This kernel can be split into two 2-taps kernels of value {1, 1}. + // That way we need only 3 add operations to perform the convolution. + const uint8_t *src_start = src - 1 * src_stride - 1; + + for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) { + uint8x16_t s = vld1q_u8(src_start + i * src_stride); + uint8x8_t s0 = vget_low_u8(s); + uint8x8_t s1 = vget_low_u8(vextq_u8(s, s, 1)); + uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2)); + + uint16x8_t sum01 = vaddl_u8(s0, s1); + uint16x8_t sum12 = vaddl_u8(s1, s2); + uint16x8_t sum = vaddq_u16(sum01, sum12); + + vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, vreinterpretq_s16_u16(sum)); + } + + // Vertical filter, using kernel {1, 0, -1}. + // Load the whole block at once to avoid redundant loads during convolution. + int16x8_t t[10]; + load_s16_8x10(tmp, DISFLOW_PATCH_SIZE, &t[0], &t[1], &t[2], &t[3], &t[4], + &t[5], &t[6], &t[7], &t[8], &t[9]); + + for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) { + // Given that the kernel is {1, 0, -1} the convolution is a simple + // subtraction. + int16x8_t diff = vsubq_s16(t[i], t[i + 2]); + + vst1q_s16(dst + i * dst_stride, diff); + } +} + +#endif // AOM_AOM_DSP_FLOW_ESTIMATION_ARM_DISFLOW_NEON_H_ diff --git a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c new file mode 100644 index 0000000000..7b01e90d12 --- /dev/null +++ b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include "aom_dsp/flow_estimation/disflow.h" + +#include +#include +#include + +#include "aom_dsp/arm/aom_neon_sve_bridge.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/flow_estimation/arm/disflow_neon.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = { + 0, 2, 4, 6, 1, 3, 5, 7, +}; + +// Compare two regions of width x height pixels, one rooted at position +// (x, y) in src and the other at (x + u, y + v) in ref. +// This function returns the sum of squared pixel differences between +// the two regions. +static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref, + int width, int height, int stride, int x, + int y, double u, double v, int16_t *dt) { + // Split offset into integer and fractional parts, and compute cubic + // interpolation kernels + const int u_int = (int)floor(u); + const int v_int = (int)floor(v); + const double u_frac = u - floor(u); + const double v_frac = v - floor(v); + + int h_kernel[4]; + int v_kernel[4]; + get_cubic_kernel_int(u_frac, h_kernel); + get_cubic_kernel_int(v_frac, v_kernel); + + int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)]; + + // Clamp coordinates so that all pixels we fetch will remain within the + // allocated border region, but allow them to go far enough out that + // the border pixels' values do not change. + // Since we are calculating an 8x8 block, the bottom-right pixel + // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic + // interpolation has 4 taps, meaning that the output of pixel + // (x_w, y_w) depends on the pixels in the range + // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]). + // + // Thus the most extreme coordinates which will be fetched are + // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9). + const int x0 = clamp(x + u_int, -9, width); + const int y0 = clamp(y + v_int, -9, height); + + // Horizontal convolution. + const uint8_t *ref_start = ref + (y0 - 1) * stride + (x0 - 1); + const int16x4_t h_kernel_s16 = vmovn_s32(vld1q_s32(h_kernel)); + const int16x8_t h_filter = vcombine_s16(h_kernel_s16, vdup_n_s16(0)); + const uint16x8_t idx = vld1q_u16(kDeinterleaveTbl); + + for (int i = 0; i < DISFLOW_PATCH_SIZE + 3; ++i) { + svuint16_t r0 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 0); + svuint16_t r1 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 1); + svuint16_t r2 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 2); + svuint16_t r3 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 3); + + int16x8_t s0 = vreinterpretq_s16_u16(svget_neonq_u16(r0)); + int16x8_t s1 = vreinterpretq_s16_u16(svget_neonq_u16(r1)); + int16x8_t s2 = vreinterpretq_s16_u16(svget_neonq_u16(r2)); + int16x8_t s3 = vreinterpretq_s16_u16(svget_neonq_u16(r3)); + + int64x2_t sum04 = aom_svdot_lane_s16(vdupq_n_s64(0), s0, h_filter, 0); + int64x2_t sum15 = aom_svdot_lane_s16(vdupq_n_s64(0), s1, h_filter, 0); + int64x2_t sum26 = aom_svdot_lane_s16(vdupq_n_s64(0), s2, h_filter, 0); + int64x2_t sum37 = aom_svdot_lane_s16(vdupq_n_s64(0), s3, h_filter, 0); + + int32x4_t res0 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15)); + int32x4_t res1 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37)); + + // 6 is the maximum allowable number of extra bits which will avoid + // the intermediate values overflowing an int16_t. The most extreme + // intermediate value occurs when: + // * The input pixels are [0, 255, 255, 0] + // * u_frac = 0.5 + // In this case, the un-scaled output is 255 * 1.125 = 286.875. + // As an integer with 6 fractional bits, that is 18360, which fits + // in an int16_t. But with 7 fractional bits it would be 36720, + // which is too large. + int16x8_t res = vcombine_s16(vrshrn_n_s32(res0, DISFLOW_INTERP_BITS - 6), + vrshrn_n_s32(res1, DISFLOW_INTERP_BITS - 6)); + + res = aom_tbl_s16(res, idx); + + vst1q_s16(tmp_ + i * DISFLOW_PATCH_SIZE, res); + } + + // Vertical convolution. + int16x4_t v_filter = vmovn_s32(vld1q_s32(v_kernel)); + int16_t *tmp_start = tmp_ + DISFLOW_PATCH_SIZE; + + for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) { + int16x8_t t0 = vld1q_s16(tmp_start + (i - 1) * DISFLOW_PATCH_SIZE); + int16x8_t t1 = vld1q_s16(tmp_start + i * DISFLOW_PATCH_SIZE); + int16x8_t t2 = vld1q_s16(tmp_start + (i + 1) * DISFLOW_PATCH_SIZE); + int16x8_t t3 = vld1q_s16(tmp_start + (i + 2) * DISFLOW_PATCH_SIZE); + + int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(t0), v_filter, 0); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t1), v_filter, 1); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t2), v_filter, 2); + sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t3), v_filter, 3); + + int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(t0), v_filter, 0); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t1), v_filter, 1); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t2), v_filter, 2); + sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t3), v_filter, 3); + + uint8x8_t s = vld1_u8(src + (i + y) * stride + x); + int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, 3)); + + // This time, we have to round off the 6 extra bits which were kept + // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits + // of precision to match the scale of the dx and dy arrays. + sum_lo = vrshrq_n_s32(sum_lo, + DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2); + sum_hi = vrshrq_n_s32(sum_hi, + DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2); + int32x4_t err_lo = vsubw_s16(sum_lo, vget_low_s16(s_s16)); + int32x4_t err_hi = vsubw_s16(sum_hi, vget_high_s16(s_s16)); + vst1q_s16(dt + i * DISFLOW_PATCH_SIZE, + vcombine_s16(vmovn_s32(err_lo), vmovn_s32(err_hi))); + } +} + +// Computes the components of the system of equations used to solve for +// a flow vector. +// +// The flow equations are a least-squares system, derived as follows: +// +// For each pixel in the patch, we calculate the current error `dt`, +// and the x and y gradients `dx` and `dy` of the source patch. +// This means that, to first order, the squared error for this pixel is +// +// (dt + u * dx + v * dy)^2 +// +// where (u, v) are the incremental changes to the flow vector. +// +// We then want to find the values of u and v which minimize the sum +// of the squared error across all pixels. Conveniently, this fits exactly +// into the form of a least squares problem, with one equation +// +// u * dx + v * dy = -dt +// +// for each pixel. +// +// Summing across all pixels in a square window of size DISFLOW_PATCH_SIZE, +// and absorbing the - sign elsewhere, this results in the least squares system +// +// M = |sum(dx * dx) sum(dx * dy)| +// |sum(dx * dy) sum(dy * dy)| +// +// b = |sum(dx * dt)| +// |sum(dy * dt)| +static INLINE void compute_flow_matrix(const int16_t *dx, int dx_stride, + const int16_t *dy, int dy_stride, + double *M_inv) { + int64x2_t sum[3] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0) }; + + for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) { + int16x8_t x = vld1q_s16(dx + i * dx_stride); + int16x8_t y = vld1q_s16(dy + i * dy_stride); + + sum[0] = aom_sdotq_s16(sum[0], x, x); + sum[1] = aom_sdotq_s16(sum[1], x, y); + sum[2] = aom_sdotq_s16(sum[2], y, y); + } + + sum[0] = vpaddq_s64(sum[0], sum[1]); + sum[2] = vpaddq_s64(sum[1], sum[2]); + int32x4_t res = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2])); + + // Apply regularization + // We follow the standard regularization method of adding `k * I` before + // inverting. This ensures that the matrix will be invertible. + // + // Setting the regularization strength k to 1 seems to work well here, as + // typical values coming from the other equations are very large (1e5 to + // 1e6, with an upper limit of around 6e7, at the time of writing). + // It also preserves the property that all matrix values are whole numbers, + // which is convenient for integerized SIMD implementation. + + double M0 = (double)vgetq_lane_s32(res, 0) + 1; + double M1 = (double)vgetq_lane_s32(res, 1); + double M2 = (double)vgetq_lane_s32(res, 2); + double M3 = (double)vgetq_lane_s32(res, 3) + 1; + + // Invert matrix M. + double det = (M0 * M3) - (M1 * M2); + assert(det >= 1); + const double det_inv = 1 / det; + + M_inv[0] = M3 * det_inv; + M_inv[1] = -M1 * det_inv; + M_inv[2] = -M2 * det_inv; + M_inv[3] = M0 * det_inv; +} + +static INLINE void compute_flow_vector(const int16_t *dx, int dx_stride, + const int16_t *dy, int dy_stride, + const int16_t *dt, int dt_stride, + int *b) { + int64x2_t b_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) { + int16x8_t dx16 = vld1q_s16(dx + i * dx_stride); + int16x8_t dy16 = vld1q_s16(dy + i * dy_stride); + int16x8_t dt16 = vld1q_s16(dt + i * dt_stride); + + b_s64[0] = aom_sdotq_s16(b_s64[0], dx16, dt16); + b_s64[1] = aom_sdotq_s16(b_s64[1], dy16, dt16); + } + + b_s64[0] = vpaddq_s64(b_s64[0], b_s64[1]); + vst1_s32(b, vmovn_s64(b_s64[0])); +} + +void aom_compute_flow_at_point_sve(const uint8_t *src, const uint8_t *ref, + int x, int y, int width, int height, + int stride, double *u, double *v) { + double M_inv[4]; + int b[2]; + int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE]; + int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE]; + int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE]; + + // Compute gradients within this patch + const uint8_t *src_patch = &src[y * stride + x]; + sobel_filter_x(src_patch, stride, dx, DISFLOW_PATCH_SIZE); + sobel_filter_y(src_patch, stride, dy, DISFLOW_PATCH_SIZE); + + compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M_inv); + + for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) { + compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt); + compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt, + DISFLOW_PATCH_SIZE, b); + + // Solve flow equations to find a better estimate for the flow vector + // at this point + const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1]; + const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1]; + *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2); + *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2); + + if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) { + // Stop iteration when we're close to convergence + break; + } + } +} diff --git a/third_party/aom/aom_dsp/pyramid.c b/third_party/aom/aom_dsp/pyramid.c index 5de001dbd5..05ddbb2f5f 100644 --- a/third_party/aom/aom_dsp/pyramid.c +++ b/third_party/aom/aom_dsp/pyramid.c @@ -305,6 +305,7 @@ static INLINE int fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth, // Fill in the remaining levels through progressive downsampling for (int level = already_filled_levels; level < n_levels; ++level) { + bool mem_status = false; PyramidLayer *prev_layer = &frame_pyr->layers[level - 1]; uint8_t *prev_buffer = prev_layer->buffer; int prev_stride = prev_layer->stride; @@ -315,6 +316,11 @@ static INLINE int fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth, int this_height = this_layer->height; int this_stride = this_layer->stride; + // The width and height of the previous layer that needs to be considered to + // derive the current layer frame. + const int input_layer_width = this_width << 1; + const int input_layer_height = this_height << 1; + // Compute the this pyramid level by downsampling the current level. // // We downsample by a factor of exactly 2, clipping the rightmost and @@ -329,13 +335,30 @@ static INLINE int fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth, // 2) Up/downsampling by a factor of 2 can be implemented much more // efficiently than up/downsampling by a generic ratio. // TODO(rachelbarker): Use optimized downsample-by-2 function - if (!av1_resize_plane(prev_buffer, this_height << 1, this_width << 1, - prev_stride, this_buffer, this_height, this_width, - this_stride)) { - // If we can't allocate memory, we'll have to terminate early + + // SIMD support has been added specifically for cases where the downsample + // factor is exactly 2. In such instances, horizontal and vertical resizing + // is performed utilizing the down2_symeven() function, which considers the + // even dimensions of the input layer. + if (should_resize_by_half(input_layer_height, input_layer_width, + this_height, this_width)) { + assert(input_layer_height % 2 == 0 && input_layer_width % 2 == 0 && + "Input width or height cannot be odd."); + mem_status = av1_resize_plane_to_half( + prev_buffer, input_layer_height, input_layer_width, prev_stride, + this_buffer, this_height, this_width, this_stride); + } else { + mem_status = av1_resize_plane(prev_buffer, input_layer_height, + input_layer_width, prev_stride, this_buffer, + this_height, this_width, this_stride); + } + + // Terminate early in cases of memory allocation failure. + if (!mem_status) { frame_pyr->filled_levels = n_levels; return -1; } + fill_border(this_buffer, this_width, this_height, this_stride); } diff --git a/third_party/aom/aom_dsp/x86/synonyms.h b/third_party/aom/aom_dsp/x86/synonyms.h index 74318de2e5..f9bc9ac733 100644 --- a/third_party/aom/aom_dsp/x86/synonyms.h +++ b/third_party/aom/aom_dsp/x86/synonyms.h @@ -46,7 +46,6 @@ static INLINE __m128i xx_loadu_128(const void *a) { return _mm_loadu_si128((const __m128i *)a); } - // _mm_loadu_si64 has been introduced in GCC 9, reimplement the function // manually on older compilers. #if !defined(__clang__) && __GNUC_MAJOR__ < 9 diff --git a/third_party/aom/aom_util/aom_pthread.h b/third_party/aom/aom_util/aom_pthread.h index 99deeb292a..e755487ae3 100644 --- a/third_party/aom/aom_util/aom_pthread.h +++ b/third_party/aom/aom_util/aom_pthread.h @@ -28,6 +28,7 @@ extern "C" { #define NOMINMAX #undef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN +#include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT diff --git a/third_party/aom/aom_util/aom_thread.h b/third_party/aom/aom_util/aom_thread.h index 92e162f121..80ed314752 100644 --- a/third_party/aom/aom_util/aom_thread.h +++ b/third_party/aom/aom_util/aom_thread.h @@ -21,8 +21,6 @@ extern "C" { #endif -#define MAX_NUM_THREADS 64 - // State of the worker thread object typedef enum { AVX_WORKER_STATUS_NOT_OK = 0, // object is unusable diff --git a/third_party/aom/av1/av1.cmake b/third_party/aom/av1/av1.cmake index 32645f6065..b6cf974aa7 100644 --- a/third_party/aom/av1/av1.cmake +++ b/third_party/aom/av1/av1.cmake @@ -302,6 +302,7 @@ list(APPEND AOM_AV1_COMMON_INTRIN_AVX2 "${AOM_ROOT}/av1/common/x86/highbd_inv_txfm_avx2.c" "${AOM_ROOT}/av1/common/x86/jnt_convolve_avx2.c" "${AOM_ROOT}/av1/common/x86/reconinter_avx2.c" + "${AOM_ROOT}/av1/common/x86/resize_avx2.c" "${AOM_ROOT}/av1/common/x86/selfguided_avx2.c" "${AOM_ROOT}/av1/common/x86/warp_plane_avx2.c" "${AOM_ROOT}/av1/common/x86/wiener_convolve_avx2.c") @@ -375,6 +376,7 @@ list(APPEND AOM_AV1_ENCODER_INTRIN_NEON_DOTPROD list(APPEND AOM_AV1_ENCODER_INTRIN_SVE "${AOM_ROOT}/av1/encoder/arm/neon/av1_error_sve.c" + "${AOM_ROOT}/av1/encoder/arm/neon/pickrst_sve.c" "${AOM_ROOT}/av1/encoder/arm/neon/wedge_utils_sve.c") list(APPEND AOM_AV1_ENCODER_INTRIN_ARM_CRC32 diff --git a/third_party/aom/av1/av1_cx_iface.c b/third_party/aom/av1/av1_cx_iface.c index 2b6b1504e6..39c03c9ecb 100644 --- a/third_party/aom/av1/av1_cx_iface.c +++ b/third_party/aom/av1/av1_cx_iface.c @@ -32,6 +32,7 @@ #include "av1/common/enums.h" #include "av1/common/scale.h" #include "av1/encoder/bitstream.h" +#include "av1/encoder/enc_enums.h" #include "av1/encoder/encoder.h" #include "av1/encoder/encoder_alloc.h" #include "av1/encoder/encoder_utils.h" diff --git a/third_party/aom/av1/common/arm/compound_convolve_neon_dotprod.c b/third_party/aom/av1/common/arm/compound_convolve_neon_dotprod.c index 3aeffbb0e6..40befdf44e 100644 --- a/third_party/aom/av1/common/arm/compound_convolve_neon_dotprod.c +++ b/third_party/aom/av1/common/arm/compound_convolve_neon_dotprod.c @@ -80,17 +80,15 @@ static INLINE void dist_wtd_convolve_2d_horiz_neon_dotprod( const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride, const int16_t *x_filter_ptr, const int im_h, int w) { const int bd = 8; - const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2)); // Dot product constants and other shims. const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr); - const int32_t correction_s32 = - vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1)); - // Fold horiz_const into the dot-product filter correction constant. The - // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non- - // rounding shifts - which are generally faster than rounding shifts on - // modern CPUs. (The extra -1 is needed because we halved the filter values.) - const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const + - (1 << ((ROUND0_BITS - 1) - 1))); + // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts + // - which are generally faster than rounding shifts on modern CPUs. + const int32_t horiz_const = + ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1))); + // Halve the total because we will halve the filter values. + const int32x4_t correction = + vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2); const uint8x16_t range_limit = vdupq_n_u8(128); const uint8_t *src_ptr = src; @@ -334,15 +332,14 @@ static INLINE void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod( // Dot-product constants and other shims. const uint8x16_t range_limit = vdupq_n_u8(128); - const int32_t correction_s32 = - vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1)); // Fold round_offset into the dot-product filter correction constant. The - // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non- - // rounding shifts - which are generally faster than rounding shifts on - // modern CPUs. (The extra -1 is needed because we halved the filter values.) + // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding + // shifts - which are generally faster than rounding shifts on modern CPUs. + // Halve the total because we will halve the filter values. int32x4_t correction = - vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) + - (1 << ((ROUND0_BITS - 1) - 1))); + vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) + + (1 << (ROUND0_BITS - 1))) / + 2); const int horiz_offset = filter_params_x->taps / 2 - 1; const uint8_t *src_ptr = src - horiz_offset; @@ -455,15 +452,14 @@ static INLINE void dist_wtd_convolve_x_avg_neon_dotprod( // Dot-product constants and other shims. const uint8x16_t range_limit = vdupq_n_u8(128); - const int32_t correction_s32 = - vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1)); // Fold round_offset into the dot-product filter correction constant. The - // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non- - // rounding shifts - which are generally faster than rounding shifts on - // modern CPUs. (The extra -1 is needed because we halved the filter values.) + // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding + // shifts - which are generally faster than rounding shifts on modern CPUs. + // Halve the total because we will halve the filter values. int32x4_t correction = - vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) + - (1 << ((ROUND0_BITS - 1) - 1))); + vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) + + (1 << (ROUND0_BITS - 1))) / + 2); const int horiz_offset = filter_params_x->taps / 2 - 1; const uint8_t *src_ptr = src - horiz_offset; @@ -574,15 +570,14 @@ static INLINE void dist_wtd_convolve_x_neon_dotprod( // Dot-product constants and other shims. const uint8x16_t range_limit = vdupq_n_u8(128); - const int32_t correction_s32 = - vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1)); // Fold round_offset into the dot-product filter correction constant. The - // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non- - // rounding shifts - which are generally faster than rounding shifts on - // modern CPUs. (The extra -1 is needed because we halved the filter values.) + // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding + // shifts - which are generally faster than rounding shifts on modern CPUs. + // Halve the total because we will halve the vilter values. int32x4_t correction = - vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) + - (1 << ((ROUND0_BITS - 1) - 1))); + vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) + + (1 << (ROUND0_BITS - 1))) / + 2); const int horiz_offset = filter_params_x->taps / 2 - 1; const uint8_t *src_ptr = src - horiz_offset; diff --git a/third_party/aom/av1/common/arm/convolve_neon_dotprod.c b/third_party/aom/av1/common/arm/convolve_neon_dotprod.c index c29229eb09..132da2442b 100644 --- a/third_party/aom/av1/common/arm/convolve_neon_dotprod.c +++ b/third_party/aom/av1/common/arm/convolve_neon_dotprod.c @@ -102,14 +102,12 @@ static INLINE void convolve_x_sr_12tap_neon_dotprod( const int8x16_t filter = vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15)); - const int32_t correction_s32 = - vaddvq_s32(vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, FILTER_BITS)), - vpaddlq_s16(vshlq_n_s16(filter_8_15, FILTER_BITS)))); - // A shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding right - // shift by FILTER_BITS - instead of a first rounding right shift by + // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding + // right shift by FILTER_BITS - instead of a first rounding right shift by // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS - // ROUND0_BITS. - int32x4_t correction = vdupq_n_s32(correction_s32 + (1 << (ROUND0_BITS - 1))); + int32x4_t correction = + vdupq_n_s32((128 << FILTER_BITS) + (1 << (ROUND0_BITS - 1))); const uint8x16_t range_limit = vdupq_n_u8(128); const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl); @@ -274,16 +272,13 @@ void av1_convolve_x_sr_neon_dotprod(const uint8_t *src, int src_stride, } const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr); - // Dot product constants. - const int32_t correction_s32 = - vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1)); - // This shim of (1 << ((ROUND0_BITS - 1) - 1) enables us to use a single - // rounding right shift by FILTER_BITS - instead of a first rounding right - // shift by ROUND0_BITS, followed by second rounding right shift by - // FILTER_BITS - ROUND0_BITS. - // The outermost -1 is needed because we will halve the filter values. + // Dot product constants: + // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding + // right shift by FILTER_BITS - instead of a first rounding right shift by + // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS - + // ROUND0_BITS. Halve the total because we will halve the filter values. const int32x4_t correction = - vdupq_n_s32(correction_s32 + (1 << ((ROUND0_BITS - 1) - 1))); + vdupq_n_s32(((128 << FILTER_BITS) + (1 << ((ROUND0_BITS - 1)))) / 2); const uint8x16_t range_limit = vdupq_n_u8(128); if (w <= 4) { @@ -465,16 +460,13 @@ static INLINE void convolve_2d_sr_horiz_12tap_neon_dotprod( const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]), vmovn_s16(x_filter_s16.val[1])); - // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts - // - which are generally faster than rounding shifts on modern CPUs. + // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding + // shifts - which are generally faster than rounding shifts on modern CPUs. const int32_t horiz_const = ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1))); // Dot product constants. - const int32x4_t correct_tmp = - vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)), - vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7))); const int32x4_t correction = - vdupq_n_s32(vaddvq_s32(correct_tmp) + horiz_const); + vdupq_n_s32((128 << FILTER_BITS) + horiz_const); const uint8x16_t range_limit = vdupq_n_u8(128); const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl); @@ -621,16 +613,15 @@ static INLINE void convolve_2d_sr_horiz_neon_dotprod( const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w, int im_h, const int16_t *x_filter_ptr) { const int bd = 8; - // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding - // shifts - which are generally faster than rounding shifts on modern CPUs. - // The outermost -1 is needed because we halved the filter values. - const int32_t horiz_const = - ((1 << (bd + FILTER_BITS - 2)) + (1 << ((ROUND0_BITS - 1) - 1))); // Dot product constants. const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr); - const int32_t correction_s32 = - vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1)); - const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const); + // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding + // shifts - which are generally faster than rounding shifts on modern CPUs. + const int32_t horiz_const = + ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1))); + // Halve the total because we will halve the filter values. + const int32x4_t correction = + vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2); const uint8x16_t range_limit = vdupq_n_u8(128); const uint8_t *src_ptr = src; diff --git a/third_party/aom/av1/common/av1_rtcd_defs.pl b/third_party/aom/av1/common/av1_rtcd_defs.pl index c0831330d1..6a0043c761 100644 --- a/third_party/aom/av1/common/av1_rtcd_defs.pl +++ b/third_party/aom/av1/common/av1_rtcd_defs.pl @@ -458,7 +458,7 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") { if (aom_config("CONFIG_REALTIME_ONLY") ne "yes") { add_proto qw/void av1_compute_stats/, "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int16_t *dgd_avg, int16_t *src_avg, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, int use_downsampled_wiener_stats"; - specialize qw/av1_compute_stats sse4_1 avx2 neon/; + specialize qw/av1_compute_stats sse4_1 avx2 neon sve/; add_proto qw/void av1_calc_proj_params/, "const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2], const sgr_params_type *params"; specialize qw/av1_calc_proj_params sse4_1 avx2 neon/; add_proto qw/int64_t av1_lowbd_pixel_proj_error/, "const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params"; @@ -469,7 +469,7 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") { specialize qw/av1_calc_proj_params_high_bd sse4_1 avx2 neon/; add_proto qw/int64_t av1_highbd_pixel_proj_error/, "const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params"; specialize qw/av1_highbd_pixel_proj_error sse4_1 avx2 neon/; - add_proto qw/void av1_compute_stats_highbd/, "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth"; + add_proto qw/void av1_compute_stats_highbd/, "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int16_t *dgd_avg, int16_t *src_avg, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth"; specialize qw/av1_compute_stats_highbd sse4_1 avx2 neon/; } } @@ -554,6 +554,9 @@ if (aom_config("CONFIG_AV1_HIGHBITDEPTH") eq "yes") { specialize qw/av1_highbd_warp_affine sse4_1 avx2 neon sve/; } +add_proto qw/bool resize_vert_dir/, "uint8_t *intbuf, uint8_t *output, int out_stride, int height, int height2, int width2, int start_col"; +specialize qw/resize_vert_dir avx2/; + add_proto qw/void av1_warp_affine/, "const int32_t *mat, const uint8_t *ref, int width, int height, int stride, uint8_t *pred, int p_col, int p_row, int p_width, int p_height, int p_stride, int subsampling_x, int subsampling_y, ConvolveParams *conv_params, int16_t alpha, int16_t beta, int16_t gamma, int16_t delta"; specialize qw/av1_warp_affine sse4_1 avx2 neon neon_i8mm sve/; diff --git a/third_party/aom/av1/common/resize.c b/third_party/aom/av1/common/resize.c index 441323ab1f..2b48b9fff4 100644 --- a/third_party/aom/av1/common/resize.c +++ b/third_party/aom/av1/common/resize.c @@ -18,6 +18,7 @@ #include #include "config/aom_config.h" +#include "config/av1_rtcd.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/flow_estimation/corner_detect.h" @@ -216,10 +217,6 @@ const int16_t av1_resize_filter_normative[( // Filters for interpolation (full-band) - no filtering for integer pixels #define filteredinterp_filters1000 av1_resize_filter_normative -// Filters for factor of 2 downsampling. -static const int16_t av1_down2_symeven_half_filter[] = { 56, 12, -3, -1 }; -static const int16_t av1_down2_symodd_half_filter[] = { 64, 35, 0, -3 }; - static const InterpKernel *choose_interp_filter(int in_length, int out_length) { int out_length16 = out_length * 16; if (out_length16 >= in_length * 16) @@ -524,6 +521,59 @@ static void fill_arr_to_col(uint8_t *img, int stride, int len, uint8_t *arr) { } } +bool resize_vert_dir_c(uint8_t *intbuf, uint8_t *output, int out_stride, + int height, int height2, int width2, int start_col) { + bool mem_status = true; + uint8_t *arrbuf = (uint8_t *)aom_malloc(sizeof(*arrbuf) * height); + uint8_t *arrbuf2 = (uint8_t *)aom_malloc(sizeof(*arrbuf2) * height2); + if (arrbuf == NULL || arrbuf2 == NULL) { + mem_status = false; + goto Error; + } + + for (int i = start_col; i < width2; ++i) { + fill_col_to_arr(intbuf + i, width2, height, arrbuf); + down2_symeven(arrbuf, height, arrbuf2); + fill_arr_to_col(output + i, out_stride, height2, arrbuf2); + } + +Error: + aom_free(arrbuf); + aom_free(arrbuf2); + return mem_status; +} + +void resize_horz_dir(const uint8_t *const input, int in_stride, uint8_t *intbuf, + int height, int filtered_length, int width2) { + for (int i = 0; i < height; ++i) + down2_symeven(input + in_stride * i, filtered_length, intbuf + width2 * i); +} + +bool av1_resize_plane_to_half(const uint8_t *const input, int height, int width, + int in_stride, uint8_t *output, int height2, + int width2, int out_stride) { + uint8_t *intbuf = (uint8_t *)aom_malloc(sizeof(*intbuf) * width2 * height); + if (intbuf == NULL) { + return false; + } + + // Resize in the horizontal direction + resize_horz_dir(input, in_stride, intbuf, height, width, width2); + // Resize in the vertical direction + bool mem_status = resize_vert_dir(intbuf, output, out_stride, height, height2, + width2, 0 /*start_col*/); + aom_free(intbuf); + return mem_status; +} + +// Check if both the output width and height are half of input width and +// height respectively. +bool should_resize_by_half(int height, int width, int height2, int width2) { + const bool is_width_by_2 = get_down2_length(width, 1) == width2; + const bool is_height_by_2 = get_down2_length(height, 1) == height2; + return (is_width_by_2 && is_height_by_2); +} + bool av1_resize_plane(const uint8_t *input, int height, int width, int in_stride, uint8_t *output, int height2, int width2, int out_stride) { diff --git a/third_party/aom/av1/common/resize.h b/third_party/aom/av1/common/resize.h index d573a538bf..de71f5d539 100644 --- a/third_party/aom/av1/common/resize.h +++ b/third_party/aom/av1/common/resize.h @@ -20,6 +20,10 @@ extern "C" { #endif +// Filters for factor of 2 downsampling. +static const int16_t av1_down2_symeven_half_filter[] = { 56, 12, -3, -1 }; +static const int16_t av1_down2_symodd_half_filter[] = { 64, 35, 0, -3 }; + bool av1_resize_plane(const uint8_t *input, int height, int width, int in_stride, uint8_t *output, int height2, int width2, int out_stride); @@ -93,6 +97,12 @@ void av1_calculate_unscaled_superres_size(int *width, int *height, int denom); void av1_superres_upscale(AV1_COMMON *cm, BufferPool *const pool, bool alloc_pyramid); +bool av1_resize_plane_to_half(const uint8_t *const input, int height, int width, + int in_stride, uint8_t *output, int height2, + int width2, int out_stride); + +bool should_resize_by_half(int height, int width, int height2, int width2); + // Returns 1 if a superres upscaled frame is scaled and 0 otherwise. static INLINE int av1_superres_scaled(const AV1_COMMON *cm) { // Note: for some corner cases (e.g. cm->width of 1), there may be no scaling diff --git a/third_party/aom/av1/common/x86/resize_avx2.c b/third_party/aom/av1/common/x86/resize_avx2.c new file mode 100644 index 0000000000..c44edb88d9 --- /dev/null +++ b/third_party/aom/av1/common/x86/resize_avx2.c @@ -0,0 +1,411 @@ +/* + * Copyright (c) 2024, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ +#include +#include + +#include "config/av1_rtcd.h" + +#include "av1/common/resize.h" + +#include "aom_dsp/x86/synonyms.h" + +#define CAST_HI(x) _mm256_castsi128_si256(x) +#define CAST_LOW(x) _mm256_castsi256_si128(x) + +#define PROCESS_RESIZE_Y_WD16 \ + const int idx1 = AOMMIN(height - 1, i + 5); \ + const int idx2 = AOMMIN(height - 1, i + 6); \ + l6 = l10; \ + l7 = l11; \ + l8 = _mm_loadu_si128((__m128i *)(data + idx1 * stride)); \ + l9 = _mm_loadu_si128((__m128i *)(data + idx2 * stride)); \ + \ + /* g0... g15 | i0... i15 */ \ + const __m256i s68 = \ + _mm256_permute2x128_si256(CAST_HI(l6), CAST_HI(l8), 0x20); \ + /* h0... h15 | j0... j15 */ \ + const __m256i s79 = \ + _mm256_permute2x128_si256(CAST_HI(l7), CAST_HI(l9), 0x20); \ + \ + /* g0h0... g7g7 | i0j0... i7j */ \ + s[3] = _mm256_unpacklo_epi8(s68, s79); \ + /* g8h8... g15g15 | i8j8... i15j15 */ \ + s[8] = _mm256_unpackhi_epi8(s68, s79); \ + \ + __m256i res_out[2] = { 0 }; \ + resize_y_convolve(s, coeffs_y, res_out); \ + \ + /* r00... r07 */ \ + __m256i res_a_round_1 = _mm256_add_epi32(res_out[0], round_const_bits); \ + /* r20... r27 */ \ + __m256i res_a_round_2 = _mm256_add_epi32(res_out[1], round_const_bits); \ + \ + res_a_round_1 = _mm256_sra_epi32(res_a_round_1, round_shift_bits); \ + res_a_round_2 = _mm256_sra_epi32(res_a_round_2, round_shift_bits); \ + \ + __m256i res_out_b[2] = { 0 }; \ + resize_y_convolve(s + 5, coeffs_y, res_out_b); \ + \ + /* r08... r015 */ \ + __m256i res_b_round_1 = _mm256_add_epi32(res_out_b[0], round_const_bits); \ + /* r28... r215 */ \ + __m256i res_b_round_2 = _mm256_add_epi32(res_out_b[1], round_const_bits); \ + res_b_round_1 = _mm256_sra_epi32(res_b_round_1, round_shift_bits); \ + res_b_round_2 = _mm256_sra_epi32(res_b_round_2, round_shift_bits); \ + \ + /* r00... r03 r20... r23 | r04... r07 r24... r27 */ \ + __m256i res_8bit0 = _mm256_packus_epi32(res_a_round_1, res_a_round_2); \ + /* r08... r012 r28... r212 | r013... r015 r213... r215 */ \ + __m256i res_8bit1 = _mm256_packus_epi32(res_b_round_1, res_b_round_2); \ + /* r00... r07 | r20... r27 */ \ + res_8bit0 = _mm256_permute4x64_epi64(res_8bit0, 0xd8); \ + /* r08... r015 | r28... r215 */ \ + res_8bit1 = _mm256_permute4x64_epi64(res_8bit1, 0xd8); \ + /* r00... r015 | r20... r215 */ \ + res_8bit1 = _mm256_packus_epi16(res_8bit0, res_8bit1); \ + res_8bit0 = _mm256_min_epu8(res_8bit1, clip_pixel); \ + res_8bit0 = _mm256_max_epu8(res_8bit0, zero); + +#define PROCESS_RESIZE_Y_WD8 \ + const int idx1 = AOMMIN(height - 1, i + 5); \ + const int idx2 = AOMMIN(height - 1, i + 6); \ + l6 = l10; \ + l7 = l11; \ + l8 = _mm_loadl_epi64((__m128i *)(data + idx1 * stride)); \ + l9 = _mm_loadl_epi64((__m128i *)(data + idx2 * stride)); \ + \ + /* g0h0... g7h7 */ \ + s67 = _mm_unpacklo_epi8(l6, l7); \ + /* i0j0...i7j7 */ \ + __m128i s89 = _mm_unpacklo_epi8(l8, l9); \ + \ + /* g0h0...g7g7 | i0j0...i7j7 */ \ + s[3] = _mm256_permute2x128_si256(CAST_HI(s67), CAST_HI(s89), 0x20); \ + \ + __m256i res_out[2] = { 0 }; \ + resize_y_convolve(s, coeffs_y, res_out); \ + \ + /* r00... r07 */ \ + __m256i res_a_round_1 = _mm256_add_epi32(res_out[0], round_const_bits); \ + /* r20...r27 */ \ + __m256i res_a_round_2 = _mm256_add_epi32(res_out[1], round_const_bits); \ + res_a_round_1 = _mm256_sra_epi32(res_a_round_1, round_shift_bits); \ + res_a_round_2 = _mm256_sra_epi32(res_a_round_2, round_shift_bits); \ + \ + /* r00...r03 r20...r23 | r04...r07 r24...r27 */ \ + res_a_round_1 = _mm256_packus_epi32(res_a_round_1, res_a_round_2); \ + /* r00...r07 | r20...r27 */ \ + res_a_round_1 = _mm256_permute4x64_epi64(res_a_round_1, 0xd8); \ + res_a_round_1 = _mm256_packus_epi16(res_a_round_1, res_a_round_1); \ + res_a_round_1 = _mm256_min_epu8(res_a_round_1, clip_pixel); \ + res_a_round_1 = _mm256_max_epu8(res_a_round_1, zero); + +static INLINE void resize_y_convolve(const __m256i *const s, + const __m256i *const coeffs, + __m256i *res_out) { + const __m256i res_0 = _mm256_maddubs_epi16(s[0], coeffs[0]); + const __m256i res_1 = _mm256_maddubs_epi16(s[1], coeffs[1]); + const __m256i res_2 = _mm256_maddubs_epi16(s[2], coeffs[2]); + const __m256i res_3 = _mm256_maddubs_epi16(s[3], coeffs[3]); + + const __m256i dst_0 = _mm256_add_epi16(res_0, res_1); + const __m256i dst_1 = _mm256_add_epi16(res_2, res_3); + // The sum of convolve operation crosses signed 16bit. Hence, the addition + // should happen in 32bit. + const __m256i dst_00 = _mm256_cvtepi16_epi32(CAST_LOW(dst_0)); + const __m256i dst_01 = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(dst_0, 1)); + const __m256i dst_10 = _mm256_cvtepi16_epi32(CAST_LOW(dst_1)); + const __m256i dst_11 = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(dst_1, 1)); + + res_out[0] = _mm256_add_epi32(dst_00, dst_10); + res_out[1] = _mm256_add_epi32(dst_01, dst_11); +} + +static INLINE void prepare_filter_coeffs(const int16_t *filter, + __m256i *const coeffs /* [4] */) { + // f0 f1 f2 f3 x x x x + const __m128i sym_even_filter = _mm_loadl_epi64((__m128i *)filter); + // f0 f1 f2 f3 f0 f1 f2 f3 + const __m128i tmp0 = _mm_shuffle_epi32(sym_even_filter, 0x44); + // f0 f1 f2 f3 f1 f0 f3 f2 + const __m128i tmp1 = _mm_shufflehi_epi16(tmp0, 0xb1); + + const __m128i filter_8bit = _mm_packs_epi16(tmp1, tmp1); + + // f0 f1 f0 f1 .. + coeffs[2] = _mm256_broadcastw_epi16(filter_8bit); + // f2 f3 f2 f3 .. + coeffs[3] = _mm256_broadcastw_epi16(_mm_bsrli_si128(filter_8bit, 2)); + // f3 f2 f3 f2 .. + coeffs[0] = _mm256_broadcastw_epi16(_mm_bsrli_si128(filter_8bit, 6)); + // f1 f0 f1 f0 .. + coeffs[1] = _mm256_broadcastw_epi16(_mm_bsrli_si128(filter_8bit, 4)); +} + +bool resize_vert_dir_avx2(uint8_t *intbuf, uint8_t *output, int out_stride, + int height, int height2, int stride, int start_col) { + assert(start_col <= stride); + // For the GM tool, the input layer height or width is assured to be an even + // number. Hence the function 'down2_symodd()' is not invoked and SIMD + // optimization of the same is not implemented. + // When the input height is less than 8 and even, the potential input + // heights are limited to 2, 4, or 6. These scenarios require seperate + // handling due to padding requirements. Invoking the C function here will + // eliminate the need for conditional statements within the subsequent SIMD + // code to manage these cases. + if (height & 1 || height < 8) { + return resize_vert_dir_c(intbuf, output, out_stride, height, height2, + stride, start_col); + } + + __m256i s[10], coeffs_y[4]; + const int bits = FILTER_BITS; + + const __m128i round_shift_bits = _mm_cvtsi32_si128(bits); + const __m256i round_const_bits = _mm256_set1_epi32((1 << bits) >> 1); + const uint8_t max_pixel = 255; + const __m256i clip_pixel = _mm256_set1_epi8(max_pixel); + const __m256i zero = _mm256_setzero_si256(); + + prepare_filter_coeffs(av1_down2_symeven_half_filter, coeffs_y); + + const int num_col16 = stride / 16; + int remain_col = stride % 16; + // The core vertical SIMD processes 4 input rows simultaneously to generate + // output corresponding to 2 rows. To streamline the core loop and eliminate + // the need for conditional checks, the remaining rows (4 or 6) are processed + // separately. + const int remain_row = (height % 4 == 0) ? 4 : 6; + + for (int j = start_col; j < stride - remain_col; j += 16) { + const uint8_t *data = &intbuf[j]; + const __m128i l3 = _mm_loadu_si128((__m128i *)(data + 0 * stride)); + // Padding top 3 rows with the last available row at the top. + const __m128i l0 = l3; + const __m128i l1 = l3; + const __m128i l2 = l3; + const __m128i l4 = _mm_loadu_si128((__m128i *)(data + 1 * stride)); + + __m128i l6, l7, l8, l9; + __m128i l5 = _mm_loadu_si128((__m128i *)(data + 2 * stride)); + __m128i l10 = _mm_loadu_si128((__m128i *)(data + 3 * stride)); + __m128i l11 = _mm_loadu_si128((__m128i *)(data + 4 * stride)); + + // a0...a15 | c0...c15 + const __m256i s02 = + _mm256_permute2x128_si256(CAST_HI(l0), CAST_HI(l2), 0x20); + // b0...b15 | d0...d15 + const __m256i s13 = + _mm256_permute2x128_si256(CAST_HI(l1), CAST_HI(l3), 0x20); + // c0...c15 | e0...e15 + const __m256i s24 = + _mm256_permute2x128_si256(CAST_HI(l2), CAST_HI(l4), 0x20); + // d0...d15 | f0...f15 + const __m256i s35 = + _mm256_permute2x128_si256(CAST_HI(l3), CAST_HI(l5), 0x20); + // e0...e15 | g0...g15 + const __m256i s46 = + _mm256_permute2x128_si256(CAST_HI(l4), CAST_HI(l10), 0x20); + // f0...f15 | h0...h15 + const __m256i s57 = + _mm256_permute2x128_si256(CAST_HI(l5), CAST_HI(l11), 0x20); + + // a0b0...a7b7 | c0d0...c7d7 + s[0] = _mm256_unpacklo_epi8(s02, s13); + // c0d0...c7d7 | e0f0...e7f7 + s[1] = _mm256_unpacklo_epi8(s24, s35); + // e0f0...e7f7 | g0h0...g7h7 + s[2] = _mm256_unpacklo_epi8(s46, s57); + + // a8b8...a15b15 | c8d8...c15d15 + s[5] = _mm256_unpackhi_epi8(s02, s13); + // c8d8...c15d15 | e8f8...e15f15 + s[6] = _mm256_unpackhi_epi8(s24, s35); + // e8f8...e15f15 | g8h8...g15h15 + s[7] = _mm256_unpackhi_epi8(s46, s57); + + // height to be processed here + const int process_ht = height - remain_row; + for (int i = 0; i < process_ht; i += 4) { + PROCESS_RESIZE_Y_WD16 + + _mm_storeu_si128((__m128i *)&output[(i / 2) * out_stride + j], + CAST_LOW(res_8bit0)); + + _mm_storeu_si128( + (__m128i *)&output[(i / 2) * out_stride + j + out_stride], + _mm256_extracti128_si256(res_8bit0, 1)); + + // Load the required data for processing of next 4 input rows. + const int idx7 = AOMMIN(height - 1, i + 7); + const int idx8 = AOMMIN(height - 1, i + 8); + l10 = _mm_loadu_si128((__m128i *)(data + idx7 * stride)); + l11 = _mm_loadu_si128((__m128i *)(data + idx8 * stride)); + + const __m256i s810 = + _mm256_permute2x128_si256(CAST_HI(l8), CAST_HI(l10), 0x20); + const __m256i s911 = + _mm256_permute2x128_si256(CAST_HI(l9), CAST_HI(l11), 0x20); + // i0j0... i7j7 | k0l0... k7l7 + s[4] = _mm256_unpacklo_epi8(s810, s911); + // i8j8... i15j15 | k8l8... k15l15 + s[9] = _mm256_unpackhi_epi8(s810, s911); + + s[0] = s[2]; + s[1] = s[3]; + s[2] = s[4]; + + s[5] = s[7]; + s[6] = s[8]; + s[7] = s[9]; + } + + // Process the remaining last 4 or 6 rows here. + int i = process_ht; + while (i < height - 1) { + PROCESS_RESIZE_Y_WD16 + + _mm_storeu_si128((__m128i *)&output[(i / 2) * out_stride + j], + CAST_LOW(res_8bit0)); + i += 2; + + const int is_store_valid = (i < height - 1); + if (is_store_valid) + _mm_storeu_si128((__m128i *)&output[(i / 2) * out_stride + j], + _mm256_extracti128_si256(res_8bit0, 1)); + i += 2; + + // Check if there is any remaining height to process. If so, perform the + // necessary data loading for processing the next row. + if (i < height - 1) { + l10 = l11 = l9; + const __m256i s810 = + _mm256_permute2x128_si256(CAST_HI(l8), CAST_HI(l10), 0x20); + const __m256i s911 = + _mm256_permute2x128_si256(CAST_HI(l9), CAST_HI(l11), 0x20); + // i0j0... i7j7 | k0l0... k7l7 + s[4] = _mm256_unpacklo_epi8(s810, s911); + // i8j8... i15j15 | k8l8... k15l15 + s[9] = _mm256_unpackhi_epi8(s810, s911); + + s[0] = s[2]; + s[1] = s[3]; + s[2] = s[4]; + + s[5] = s[7]; + s[6] = s[8]; + s[7] = s[9]; + } + } + } + + if (remain_col > 7) { + const int processed_wd = num_col16 * 16; + remain_col = stride % 8; + + const uint8_t *data = &intbuf[processed_wd]; + + const __m128i l3 = _mm_loadl_epi64((__m128i *)(data + 0 * stride)); + // Padding top 3 rows with available top-most row. + const __m128i l0 = l3; + const __m128i l1 = l3; + const __m128i l2 = l3; + const __m128i l4 = _mm_loadl_epi64((__m128i *)(data + 1 * stride)); + + __m128i l6, l7, l8, l9; + __m128i l5 = _mm_loadl_epi64((__m128i *)(data + 2 * stride)); + __m128i l10 = _mm_loadl_epi64((__m128i *)(data + 3 * stride)); + __m128i l11 = _mm_loadl_epi64((__m128i *)(data + 4 * stride)); + + // a0b0...a7b7 + const __m128i s01 = _mm_unpacklo_epi8(l0, l1); + // c0d0...c7d7 + const __m128i s23 = _mm_unpacklo_epi8(l2, l3); + // e0f0...e7f7 + const __m128i s45 = _mm_unpacklo_epi8(l4, l5); + // g0h0...g7h7 + __m128i s67 = _mm_unpacklo_epi8(l10, l11); + + // a0b0...a7b7 | c0d0...c7d7 + s[0] = _mm256_permute2x128_si256(CAST_HI(s01), CAST_HI(s23), 0x20); + // c0d0...c7d7 | e0f0...e7f7 + s[1] = _mm256_permute2x128_si256(CAST_HI(s23), CAST_HI(s45), 0x20); + // e0f0...e7f7 | g0h0...g7h7 + s[2] = _mm256_permute2x128_si256(CAST_HI(s45), CAST_HI(s67), 0x20); + + // height to be processed here + const int process_ht = height - remain_row; + for (int i = 0; i < process_ht; i += 4) { + PROCESS_RESIZE_Y_WD8 + + _mm_storel_epi64((__m128i *)&output[(i / 2) * out_stride + processed_wd], + CAST_LOW(res_a_round_1)); + + _mm_storel_epi64( + (__m128i *)&output[(i / 2) * out_stride + processed_wd + out_stride], + _mm256_extracti128_si256(res_a_round_1, 1)); + + const int idx7 = AOMMIN(height - 1, i + 7); + const int idx8 = AOMMIN(height - 1, i + 8); + l10 = _mm_loadl_epi64((__m128i *)(data + idx7 * stride)); + l11 = _mm_loadl_epi64((__m128i *)(data + idx8 * stride)); + + // k0l0... k7l7 + const __m128i s10s11 = _mm_unpacklo_epi8(l10, l11); + // i0j0... i7j7 | k0l0... k7l7 + s[4] = _mm256_permute2x128_si256(CAST_HI(s89), CAST_HI(s10s11), 0x20); + + s[0] = s[2]; + s[1] = s[3]; + s[2] = s[4]; + } + + // Process the remaining last 4 or 6 rows here. + int i = process_ht; + while (i < height - 1) { + PROCESS_RESIZE_Y_WD8 + + _mm_storel_epi64((__m128i *)&output[(i / 2) * out_stride + processed_wd], + CAST_LOW(res_a_round_1)); + + i += 2; + + const int is_store_valid = (i < height - 1); + if (is_store_valid) + _mm_storel_epi64( + (__m128i *)&output[(i / 2) * out_stride + processed_wd], + _mm256_extracti128_si256(res_a_round_1, 1)); + i += 2; + + // Check rows are still remaining for processing. If yes do the required + // load of data for the next iteration. + if (i < height - 1) { + l10 = l11 = l9; + // k0l0... k7l7 + const __m128i s10s11 = _mm_unpacklo_epi8(l10, l11); + // i0j0... i7j7 | k0l0... k7l7 + s[4] = _mm256_permute2x128_si256(CAST_HI(s89), CAST_HI(s10s11), 0x20); + + s[0] = s[2]; + s[1] = s[3]; + s[2] = s[4]; + } + } + } + + if (remain_col) + return resize_vert_dir_c(intbuf, output, out_stride, height, height2, + stride, stride - remain_col); + + return true; +} diff --git a/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c b/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c index 47b5f5cfb7..8b0d3bcc7e 100644 --- a/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c +++ b/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c @@ -1008,10 +1008,13 @@ static uint16_t highbd_find_average_neon(const uint16_t *src, int src_stride, } void av1_compute_stats_highbd_neon(int wiener_win, const uint8_t *dgd8, - const uint8_t *src8, int h_start, int h_end, + const uint8_t *src8, int16_t *dgd_avg, + int16_t *src_avg, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth) { + (void)dgd_avg; + (void)src_avg; assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_REDUCED); const int wiener_halfwin = wiener_win >> 1; diff --git a/third_party/aom/av1/encoder/arm/neon/pickrst_sve.c b/third_party/aom/av1/encoder/arm/neon/pickrst_sve.c new file mode 100644 index 0000000000..a519ecc5f5 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/pickrst_sve.c @@ -0,0 +1,590 @@ +/* + * Copyright (c) 2024, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include +#include + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" + +#include "aom_dsp/arm/aom_neon_sve_bridge.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "av1/common/restoration.h" +#include "av1/encoder/pickrst.h" + +static INLINE uint8_t find_average_sve(const uint8_t *src, int src_stride, + int width, int height) { + uint32x4_t avg_u32 = vdupq_n_u32(0); + uint8x16_t ones = vdupq_n_u8(1); + + // Use a predicate to compute the last columns. + svbool_t pattern = svwhilelt_b8_u32(0, width % 16); + + int h = height; + do { + int j = width; + const uint8_t *src_ptr = src; + while (j >= 16) { + uint8x16_t s = vld1q_u8(src_ptr); + avg_u32 = vdotq_u32(avg_u32, s, ones); + + j -= 16; + src_ptr += 16; + } + uint8x16_t s_end = svget_neonq_u8(svld1_u8(pattern, src_ptr)); + avg_u32 = vdotq_u32(avg_u32, s_end, ones); + + src += src_stride; + } while (--h != 0); + return (uint8_t)(vaddlvq_u32(avg_u32) / (width * height)); +} + +static INLINE void compute_sub_avg(const uint8_t *buf, int buf_stride, int avg, + int16_t *buf_avg, int buf_avg_stride, + int width, int height, + int downsample_factor) { + uint8x8_t avg_u8 = vdup_n_u8(avg); + + // Use a predicate to compute the last columns. + svbool_t pattern = svwhilelt_b8_u32(0, width % 8); + + uint8x8_t avg_end = vget_low_u8(svget_neonq_u8(svdup_n_u8_z(pattern, avg))); + + do { + int j = width; + const uint8_t *buf_ptr = buf; + int16_t *buf_avg_ptr = buf_avg; + while (j >= 8) { + uint8x8_t d = vld1_u8(buf_ptr); + vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d, avg_u8))); + + j -= 8; + buf_ptr += 8; + buf_avg_ptr += 8; + } + uint8x8_t d_end = vget_low_u8(svget_neonq_u8(svld1_u8(pattern, buf_ptr))); + vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d_end, avg_end))); + + buf += buf_stride; + buf_avg += buf_avg_stride; + height -= downsample_factor; + } while (height > 0); +} + +static INLINE void copy_upper_triangle(int64_t *H, int64_t *H_tmp, + const int wiener_win2, const int scale) { + for (int i = 0; i < wiener_win2 - 2; i = i + 2) { + // Transpose the first 2x2 square. It needs a special case as the element + // of the bottom left is on the diagonal. + int64x2_t row0 = vld1q_s64(H_tmp + i * wiener_win2 + i + 1); + int64x2_t row1 = vld1q_s64(H_tmp + (i + 1) * wiener_win2 + i + 1); + + int64x2_t tr_row = aom_vtrn2q_s64(row0, row1); + + vst1_s64(H_tmp + (i + 1) * wiener_win2 + i, vget_low_s64(row0)); + vst1q_s64(H_tmp + (i + 2) * wiener_win2 + i, tr_row); + + // Transpose and store all the remaining 2x2 squares of the line. + for (int j = i + 3; j < wiener_win2; j = j + 2) { + row0 = vld1q_s64(H_tmp + i * wiener_win2 + j); + row1 = vld1q_s64(H_tmp + (i + 1) * wiener_win2 + j); + + int64x2_t tr_row0 = aom_vtrn1q_s64(row0, row1); + int64x2_t tr_row1 = aom_vtrn2q_s64(row0, row1); + + vst1q_s64(H_tmp + j * wiener_win2 + i, tr_row0); + vst1q_s64(H_tmp + (j + 1) * wiener_win2 + i, tr_row1); + } + } + for (int i = 0; i < wiener_win2 * wiener_win2; i++) { + H[i] += H_tmp[i] * scale; + } +} + +// Transpose the matrix that has just been computed and accumulate it in M. +static INLINE void acc_transpose_M(int64_t *M, const int64_t *M_trn, + const int wiener_win, int scale) { + for (int i = 0; i < wiener_win; ++i) { + for (int j = 0; j < wiener_win; ++j) { + int tr_idx = j * wiener_win + i; + *M++ += (int64_t)(M_trn[tr_idx] * scale); + } + } +} + +// Swap each half of the dgd vectors so that we can accumulate the result of +// the dot-products directly in the destination matrix. +static INLINE int16x8x2_t transpose_dgd(int16x8_t dgd0, int16x8_t dgd1) { + int16x8_t dgd_trn0 = vreinterpretq_s16_s64( + vzip1q_s64(vreinterpretq_s64_s16(dgd0), vreinterpretq_s64_s16(dgd1))); + int16x8_t dgd_trn1 = vreinterpretq_s16_s64( + vzip2q_s64(vreinterpretq_s64_s16(dgd0), vreinterpretq_s64_s16(dgd1))); + + return (struct int16x8x2_t){ dgd_trn0, dgd_trn1 }; +} + +static INLINE void compute_M_one_row_win5(int16x8_t src, int16x8_t dgd[5], + int64_t *M, int row) { + const int wiener_win = 5; + + int64x2_t m01 = vld1q_s64(M + row * wiener_win + 0); + int16x8x2_t dgd01 = transpose_dgd(dgd[0], dgd[1]); + + int64x2_t cross_corr01 = aom_svdot_lane_s16(m01, dgd01.val[0], src, 0); + cross_corr01 = aom_svdot_lane_s16(cross_corr01, dgd01.val[1], src, 1); + vst1q_s64(M + row * wiener_win + 0, cross_corr01); + + int64x2_t m23 = vld1q_s64(M + row * wiener_win + 2); + int16x8x2_t dgd23 = transpose_dgd(dgd[2], dgd[3]); + + int64x2_t cross_corr23 = aom_svdot_lane_s16(m23, dgd23.val[0], src, 0); + cross_corr23 = aom_svdot_lane_s16(cross_corr23, dgd23.val[1], src, 1); + vst1q_s64(M + row * wiener_win + 2, cross_corr23); + + int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), src, dgd[4]); + M[row * wiener_win + 4] += vaddvq_s64(m4); +} + +static INLINE void compute_M_one_row_win7(int16x8_t src, int16x8_t dgd[7], + int64_t *M, int row) { + const int wiener_win = 7; + + int64x2_t m01 = vld1q_s64(M + row * wiener_win + 0); + int16x8x2_t dgd01 = transpose_dgd(dgd[0], dgd[1]); + + int64x2_t cross_corr01 = aom_svdot_lane_s16(m01, dgd01.val[0], src, 0); + cross_corr01 = aom_svdot_lane_s16(cross_corr01, dgd01.val[1], src, 1); + vst1q_s64(M + row * wiener_win + 0, cross_corr01); + + int64x2_t m23 = vld1q_s64(M + row * wiener_win + 2); + int16x8x2_t dgd23 = transpose_dgd(dgd[2], dgd[3]); + + int64x2_t cross_corr23 = aom_svdot_lane_s16(m23, dgd23.val[0], src, 0); + cross_corr23 = aom_svdot_lane_s16(cross_corr23, dgd23.val[1], src, 1); + vst1q_s64(M + row * wiener_win + 2, cross_corr23); + + int64x2_t m45 = vld1q_s64(M + row * wiener_win + 4); + int16x8x2_t dgd45 = transpose_dgd(dgd[4], dgd[5]); + + int64x2_t cross_corr45 = aom_svdot_lane_s16(m45, dgd45.val[0], src, 0); + cross_corr45 = aom_svdot_lane_s16(cross_corr45, dgd45.val[1], src, 1); + vst1q_s64(M + row * wiener_win + 4, cross_corr45); + + int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), src, dgd[6]); + M[row * wiener_win + 6] += vaddvq_s64(m6); +} + +static INLINE void compute_H_one_col(int16x8_t *dgd, int col, int64_t *H, + const int wiener_win, + const int wiener_win2) { + for (int row0 = 0; row0 < wiener_win; row0++) { + for (int row1 = row0; row1 < wiener_win; row1++) { + int auto_cov_idx = + (col * wiener_win + row0) * wiener_win2 + (col * wiener_win) + row1; + + int64x2_t auto_cov = aom_sdotq_s16(vdupq_n_s64(0), dgd[row0], dgd[row1]); + H[auto_cov_idx] += vaddvq_s64(auto_cov); + } + } +} + +static INLINE void compute_H_two_rows_win5(int16x8_t *dgd0, int16x8_t *dgd1, + int row0, int row1, int64_t *H) { + for (int col0 = 0; col0 < 5; col0++) { + int auto_cov_idx = (row0 * 5 + col0) * 25 + (row1 * 5); + + int64x2_t h01 = vld1q_s64(H + auto_cov_idx); + int16x8x2_t dgd01 = transpose_dgd(dgd1[0], dgd1[1]); + + int64x2_t auto_cov01 = aom_svdot_lane_s16(h01, dgd01.val[0], dgd0[col0], 0); + auto_cov01 = aom_svdot_lane_s16(auto_cov01, dgd01.val[1], dgd0[col0], 1); + vst1q_s64(H + auto_cov_idx, auto_cov01); + + int64x2_t h23 = vld1q_s64(H + auto_cov_idx + 2); + int16x8x2_t dgd23 = transpose_dgd(dgd1[2], dgd1[3]); + + int64x2_t auto_cov23 = aom_svdot_lane_s16(h23, dgd23.val[0], dgd0[col0], 0); + auto_cov23 = aom_svdot_lane_s16(auto_cov23, dgd23.val[1], dgd0[col0], 1); + vst1q_s64(H + auto_cov_idx + 2, auto_cov23); + + int64x2_t auto_cov4 = aom_sdotq_s16(vdupq_n_s64(0), dgd0[col0], dgd1[4]); + H[auto_cov_idx + 4] += vaddvq_s64(auto_cov4); + } +} + +static INLINE void compute_H_two_rows_win7(int16x8_t *dgd0, int16x8_t *dgd1, + int row0, int row1, int64_t *H) { + for (int col0 = 0; col0 < 7; col0++) { + int auto_cov_idx = (row0 * 7 + col0) * 49 + (row1 * 7); + + int64x2_t h01 = vld1q_s64(H + auto_cov_idx); + int16x8x2_t dgd01 = transpose_dgd(dgd1[0], dgd1[1]); + + int64x2_t auto_cov01 = aom_svdot_lane_s16(h01, dgd01.val[0], dgd0[col0], 0); + auto_cov01 = aom_svdot_lane_s16(auto_cov01, dgd01.val[1], dgd0[col0], 1); + vst1q_s64(H + auto_cov_idx, auto_cov01); + + int64x2_t h23 = vld1q_s64(H + auto_cov_idx + 2); + int16x8x2_t dgd23 = transpose_dgd(dgd1[2], dgd1[3]); + + int64x2_t auto_cov23 = aom_svdot_lane_s16(h23, dgd23.val[0], dgd0[col0], 0); + auto_cov23 = aom_svdot_lane_s16(auto_cov23, dgd23.val[1], dgd0[col0], 1); + vst1q_s64(H + auto_cov_idx + 2, auto_cov23); + + int64x2_t h45 = vld1q_s64(H + auto_cov_idx + 4); + int16x8x2_t dgd45 = transpose_dgd(dgd1[4], dgd1[5]); + + int64x2_t auto_cov45 = aom_svdot_lane_s16(h45, dgd45.val[0], dgd0[col0], 0); + auto_cov45 = aom_svdot_lane_s16(auto_cov45, dgd45.val[1], dgd0[col0], 1); + vst1q_s64(H + auto_cov_idx + 4, auto_cov45); + + int64x2_t auto_cov6 = aom_sdotq_s16(vdupq_n_s64(0), dgd0[col0], dgd1[6]); + H[auto_cov_idx + 6] += vaddvq_s64(auto_cov6); + } +} + +// This function computes two matrices: the cross-correlation between the src +// buffer and dgd buffer (M), and the auto-covariance of the dgd buffer (H). +// +// M is of size 7 * 7. It needs to be filled such that multiplying one element +// from src with each element of a row of the wiener window will fill one +// column of M. However this is not very convenient in terms of memory +// accesses, as it means we do contiguous loads of dgd but strided stores to M. +// As a result, we use an intermediate matrix M_trn which is instead filled +// such that one row of the wiener window gives one row of M_trn. Once fully +// computed, M_trn is then transposed to return M. +// +// H is of size 49 * 49. It is filled by multiplying every pair of elements of +// the wiener window together. Since it is a symmetric matrix, we only compute +// the upper triangle, and then copy it down to the lower one. Here we fill it +// by taking each different pair of columns, and multiplying all the elements of +// the first one with all the elements of the second one, with a special case +// when multiplying a column by itself. +static INLINE void compute_stats_win7_sve(int16_t *dgd_avg, int dgd_avg_stride, + int16_t *src_avg, int src_avg_stride, + int width, int height, int64_t *M, + int64_t *H, int downsample_factor) { + const int wiener_win = 7; + const int wiener_win2 = wiener_win * wiener_win; + + // Use a predicate to compute the last columns of the block for H. + svbool_t pattern = svwhilelt_b16_u32(0, width % 8); + + // Use intermediate matrices for H and M to perform the computation, they + // will be accumulated into the original H and M at the end. + int64_t M_trn[49]; + memset(M_trn, 0, sizeof(M_trn)); + + int64_t H_tmp[49 * 49]; + memset(H_tmp, 0, sizeof(H_tmp)); + + do { + // Cross-correlation (M). + for (int row = 0; row < wiener_win; row++) { + int j = 0; + while (j < width) { + int16x8_t dgd[7]; + load_s16_8x7(dgd_avg + row * dgd_avg_stride + j, 1, &dgd[0], &dgd[1], + &dgd[2], &dgd[3], &dgd[4], &dgd[5], &dgd[6]); + int16x8_t s = vld1q_s16(src_avg + j); + + // Compute all the elements of one row of M. + compute_M_one_row_win7(s, dgd, M_trn, row); + + j += 8; + } + } + + // Auto-covariance (H). + int j = 0; + while (j <= width - 8) { + for (int col0 = 0; col0 < wiener_win; col0++) { + int16x8_t dgd0[7]; + load_s16_8x7(dgd_avg + j + col0, dgd_avg_stride, &dgd0[0], &dgd0[1], + &dgd0[2], &dgd0[3], &dgd0[4], &dgd0[5], &dgd0[6]); + + // Perform computation of the first column with itself (28 elements). + // For the first column this will fill the upper triangle of the 7x7 + // matrix at the top left of the H matrix. For the next columns this + // will fill the upper triangle of the other 7x7 matrices around H's + // diagonal. + compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2); + + // All computation next to the matrix diagonal has already been done. + for (int col1 = col0 + 1; col1 < wiener_win; col1++) { + // Load second column and scale based on downsampling factor. + int16x8_t dgd1[7]; + load_s16_8x7(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1], + &dgd1[2], &dgd1[3], &dgd1[4], &dgd1[5], &dgd1[6]); + + // Compute all elements from the combination of both columns (49 + // elements). + compute_H_two_rows_win7(dgd0, dgd1, col0, col1, H_tmp); + } + } + j += 8; + } + + if (j < width) { + // Process remaining columns using a predicate to discard excess elements. + for (int col0 = 0; col0 < wiener_win; col0++) { + // Load first column. + int16x8_t dgd0[7]; + dgd0[0] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 0 * dgd_avg_stride + j + col0)); + dgd0[1] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 1 * dgd_avg_stride + j + col0)); + dgd0[2] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 2 * dgd_avg_stride + j + col0)); + dgd0[3] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 3 * dgd_avg_stride + j + col0)); + dgd0[4] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 4 * dgd_avg_stride + j + col0)); + dgd0[5] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 5 * dgd_avg_stride + j + col0)); + dgd0[6] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 6 * dgd_avg_stride + j + col0)); + + // Perform computation of the first column with itself (28 elements). + // For the first column this will fill the upper triangle of the 7x7 + // matrix at the top left of the H matrix. For the next columns this + // will fill the upper triangle of the other 7x7 matrices around H's + // diagonal. + compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2); + + // All computation next to the matrix diagonal has already been done. + for (int col1 = col0 + 1; col1 < wiener_win; col1++) { + // Load second column and scale based on downsampling factor. + int16x8_t dgd1[7]; + load_s16_8x7(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1], + &dgd1[2], &dgd1[3], &dgd1[4], &dgd1[5], &dgd1[6]); + + // Compute all elements from the combination of both columns (49 + // elements). + compute_H_two_rows_win7(dgd0, dgd1, col0, col1, H_tmp); + } + } + } + dgd_avg += downsample_factor * dgd_avg_stride; + src_avg += src_avg_stride; + } while (--height != 0); + + // Transpose M_trn. + acc_transpose_M(M, M_trn, 7, downsample_factor); + + // Copy upper triangle of H in the lower one. + copy_upper_triangle(H, H_tmp, wiener_win2, downsample_factor); +} + +// This function computes two matrices: the cross-correlation between the src +// buffer and dgd buffer (M), and the auto-covariance of the dgd buffer (H). +// +// M is of size 5 * 5. It needs to be filled such that multiplying one element +// from src with each element of a row of the wiener window will fill one +// column of M. However this is not very convenient in terms of memory +// accesses, as it means we do contiguous loads of dgd but strided stores to M. +// As a result, we use an intermediate matrix M_trn which is instead filled +// such that one row of the wiener window gives one row of M_trn. Once fully +// computed, M_trn is then transposed to return M. +// +// H is of size 25 * 25. It is filled by multiplying every pair of elements of +// the wiener window together. Since it is a symmetric matrix, we only compute +// the upper triangle, and then copy it down to the lower one. Here we fill it +// by taking each different pair of columns, and multiplying all the elements of +// the first one with all the elements of the second one, with a special case +// when multiplying a column by itself. +static INLINE void compute_stats_win5_sve(int16_t *dgd_avg, int dgd_avg_stride, + int16_t *src_avg, int src_avg_stride, + int width, int height, int64_t *M, + int64_t *H, int downsample_factor) { + const int wiener_win = 5; + const int wiener_win2 = wiener_win * wiener_win; + + // Use a predicate to compute the last columns of the block for H. + svbool_t pattern = svwhilelt_b16_u32(0, width % 8); + + // Use intermediate matrices for H and M to perform the computation, they + // will be accumulated into the original H and M at the end. + int64_t M_trn[25]; + memset(M_trn, 0, sizeof(M_trn)); + + int64_t H_tmp[25 * 25]; + memset(H_tmp, 0, sizeof(H_tmp)); + + do { + // Cross-correlation (M). + for (int row = 0; row < wiener_win; row++) { + int j = 0; + while (j < width) { + int16x8_t dgd[5]; + load_s16_8x5(dgd_avg + row * dgd_avg_stride + j, 1, &dgd[0], &dgd[1], + &dgd[2], &dgd[3], &dgd[4]); + int16x8_t s = vld1q_s16(src_avg + j); + + // Compute all the elements of one row of M. + compute_M_one_row_win5(s, dgd, M_trn, row); + + j += 8; + } + } + + // Auto-covariance (H). + int j = 0; + while (j <= width - 8) { + for (int col0 = 0; col0 < wiener_win; col0++) { + // Load first column. + int16x8_t dgd0[5]; + load_s16_8x5(dgd_avg + j + col0, dgd_avg_stride, &dgd0[0], &dgd0[1], + &dgd0[2], &dgd0[3], &dgd0[4]); + + // Perform computation of the first column with itself (15 elements). + // For the first column this will fill the upper triangle of the 5x5 + // matrix at the top left of the H matrix. For the next columns this + // will fill the upper triangle of the other 5x5 matrices around H's + // diagonal. + compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2); + + // All computation next to the matrix diagonal has already been done. + for (int col1 = col0 + 1; col1 < wiener_win; col1++) { + // Load second column and scale based on downsampling factor. + int16x8_t dgd1[5]; + load_s16_8x5(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1], + &dgd1[2], &dgd1[3], &dgd1[4]); + + // Compute all elements from the combination of both columns (25 + // elements). + compute_H_two_rows_win5(dgd0, dgd1, col0, col1, H_tmp); + } + } + j += 8; + } + + // Process remaining columns using a predicate to discard excess elements. + if (j < width) { + for (int col0 = 0; col0 < wiener_win; col0++) { + int16x8_t dgd0[5]; + dgd0[0] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 0 * dgd_avg_stride + j + col0)); + dgd0[1] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 1 * dgd_avg_stride + j + col0)); + dgd0[2] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 2 * dgd_avg_stride + j + col0)); + dgd0[3] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 3 * dgd_avg_stride + j + col0)); + dgd0[4] = svget_neonq_s16( + svld1_s16(pattern, dgd_avg + 4 * dgd_avg_stride + j + col0)); + + // Perform computation of the first column with itself (15 elements). + // For the first column this will fill the upper triangle of the 5x5 + // matrix at the top left of the H matrix. For the next columns this + // will fill the upper triangle of the other 5x5 matrices around H's + // diagonal. + compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2); + + // All computation next to the matrix diagonal has already been done. + for (int col1 = col0 + 1; col1 < wiener_win; col1++) { + // Load second column and scale based on downsampling factor. + int16x8_t dgd1[5]; + load_s16_8x5(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1], + &dgd1[2], &dgd1[3], &dgd1[4]); + + // Compute all elements from the combination of both columns (25 + // elements). + compute_H_two_rows_win5(dgd0, dgd1, col0, col1, H_tmp); + } + } + } + dgd_avg += downsample_factor * dgd_avg_stride; + src_avg += src_avg_stride; + } while (--height != 0); + + // Transpose M_trn. + acc_transpose_M(M, M_trn, 5, downsample_factor); + + // Copy upper triangle of H in the lower one. + copy_upper_triangle(H, H_tmp, wiener_win2, downsample_factor); +} + +void av1_compute_stats_sve(int wiener_win, const uint8_t *dgd, + const uint8_t *src, int16_t *dgd_avg, + int16_t *src_avg, int h_start, int h_end, + int v_start, int v_end, int dgd_stride, + int src_stride, int64_t *M, int64_t *H, + int use_downsampled_wiener_stats) { + assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA); + + const int wiener_win2 = wiener_win * wiener_win; + const int wiener_halfwin = wiener_win >> 1; + const int32_t width = h_end - h_start; + const int32_t height = v_end - v_start; + const uint8_t *dgd_start = &dgd[v_start * dgd_stride + h_start]; + memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2); + memset(M, 0, sizeof(*M) * wiener_win * wiener_win); + + const uint8_t avg = find_average_sve(dgd_start, dgd_stride, width, height); + const int downsample_factor = + use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1; + + // dgd_avg and src_avg have been memset to zero before calling this + // function, so round up the stride to the next multiple of 8 so that we + // don't have to worry about a tail loop when computing M. + const int dgd_avg_stride = ((width + 2 * wiener_halfwin) & ~7) + 8; + const int src_avg_stride = (width & ~7) + 8; + + // Compute (dgd - avg) and store it in dgd_avg. + // The wiener window will slide along the dgd frame, centered on each pixel. + // For the top left pixel and all the pixels on the side of the frame this + // means half of the window will be outside of the frame. As such the actual + // buffer that we need to subtract the avg from will be 2 * wiener_halfwin + // wider and 2 * wiener_halfwin higher than the original dgd buffer. + const int vert_offset = v_start - wiener_halfwin; + const int horiz_offset = h_start - wiener_halfwin; + const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride; + compute_sub_avg(dgd_win, dgd_stride, avg, dgd_avg, dgd_avg_stride, + width + 2 * wiener_halfwin, height + 2 * wiener_halfwin, 1); + + // Compute (src - avg), downsample if necessary and store in src-avg. + const uint8_t *src_start = src + h_start + v_start * src_stride; + compute_sub_avg(src_start, src_stride * downsample_factor, avg, src_avg, + src_avg_stride, width, height, downsample_factor); + + const int downsample_height = height / downsample_factor; + + // Since the height is not necessarily a multiple of the downsample factor, + // the last line of src will be scaled according to how many rows remain. + const int downsample_remainder = height % downsample_factor; + + if (wiener_win == WIENER_WIN) { + compute_stats_win7_sve(dgd_avg, dgd_avg_stride, src_avg, src_avg_stride, + width, downsample_height, M, H, downsample_factor); + } else { + compute_stats_win5_sve(dgd_avg, dgd_avg_stride, src_avg, src_avg_stride, + width, downsample_height, M, H, downsample_factor); + } + + if (downsample_remainder > 0) { + const int remainder_offset = height - downsample_remainder; + if (wiener_win == WIENER_WIN) { + compute_stats_win7_sve( + dgd_avg + remainder_offset * dgd_avg_stride, dgd_avg_stride, + src_avg + downsample_height * src_avg_stride, src_avg_stride, width, + 1, M, H, downsample_remainder); + } else { + compute_stats_win5_sve( + dgd_avg + remainder_offset * dgd_avg_stride, dgd_avg_stride, + src_avg + downsample_height * src_avg_stride, src_avg_stride, width, + 1, M, H, downsample_remainder); + } + } +} diff --git a/third_party/aom/av1/encoder/enc_enums.h b/third_party/aom/av1/encoder/enc_enums.h index 20cefa16a5..0a8b0f258a 100644 --- a/third_party/aom/av1/encoder/enc_enums.h +++ b/third_party/aom/av1/encoder/enc_enums.h @@ -12,10 +12,14 @@ #ifndef AOM_AV1_ENCODER_ENC_ENUMS_H_ #define AOM_AV1_ENCODER_ENC_ENUMS_H_ +#include "aom_ports/mem.h" + #ifdef __cplusplus extern "C" { #endif +#define MAX_NUM_THREADS 64 + // This enumerator type needs to be kept aligned with the mode order in // const MODE_DEFINITION av1_mode_defs[MAX_MODES] used in the rd code. enum { diff --git a/third_party/aom/av1/encoder/encodeframe.c b/third_party/aom/av1/encoder/encodeframe.c index a9214f77c2..07382eb6cc 100644 --- a/third_party/aom/av1/encoder/encodeframe.c +++ b/third_party/aom/av1/encoder/encodeframe.c @@ -537,7 +537,9 @@ static AOM_INLINE void encode_nonrd_sb(AV1_COMP *cpi, ThreadData *td, // Set the partition if (sf->part_sf.partition_search_type == FIXED_PARTITION || seg_skip || (sf->rt_sf.use_fast_fixed_part && x->sb_force_fixed_part == 1 && - !frame_is_intra_only(cm))) { + (!frame_is_intra_only(cm) && + (!cpi->ppi->use_svc || + !cpi->svc.layer_context[cpi->svc.temporal_layer_id].is_key_frame)))) { // set a fixed-size partition av1_set_offsets(cpi, tile_info, x, mi_row, mi_col, sb_size); BLOCK_SIZE bsize_select = sf->part_sf.fixed_partition_size; diff --git a/third_party/aom/av1/encoder/encoder.h b/third_party/aom/av1/encoder/encoder.h index 4de5d426ce..a919bd906a 100644 --- a/third_party/aom/av1/encoder/encoder.h +++ b/third_party/aom/av1/encoder/encoder.h @@ -37,6 +37,7 @@ #include "av1/encoder/av1_quantize.h" #include "av1/encoder/block.h" #include "av1/encoder/context_tree.h" +#include "av1/encoder/enc_enums.h" #include "av1/encoder/encodemb.h" #include "av1/encoder/external_partition.h" #include "av1/encoder/firstpass.h" @@ -74,7 +75,6 @@ #endif #include "aom/internal/aom_codec_internal.h" -#include "aom_util/aom_thread.h" #ifdef __cplusplus extern "C" { diff --git a/third_party/aom/av1/encoder/ethread.c b/third_party/aom/av1/encoder/ethread.c index 755535ba51..1d0092a5ed 100644 --- a/third_party/aom/av1/encoder/ethread.c +++ b/third_party/aom/av1/encoder/ethread.c @@ -19,6 +19,7 @@ #include "av1/encoder/allintra_vis.h" #include "av1/encoder/bitstream.h" +#include "av1/encoder/enc_enums.h" #include "av1/encoder/encodeframe.h" #include "av1/encoder/encodeframe_utils.h" #include "av1/encoder/encoder.h" @@ -2520,7 +2521,7 @@ void av1_tf_do_filtering_mt(AV1_COMP *cpi) { static AOM_INLINE int get_next_gm_job(AV1_COMP *cpi, int *frame_idx, int cur_dir) { GlobalMotionInfo *gm_info = &cpi->gm_info; - JobInfo *job_info = &cpi->mt_info.gm_sync.job_info; + GlobalMotionJobInfo *job_info = &cpi->mt_info.gm_sync.job_info; int total_refs = gm_info->num_ref_frames[cur_dir]; int8_t cur_frame_to_process = job_info->next_frame_to_process[cur_dir]; @@ -2551,7 +2552,7 @@ static int gm_mt_worker_hook(void *arg1, void *unused) { AV1_COMP *cpi = thread_data->cpi; GlobalMotionInfo *gm_info = &cpi->gm_info; AV1GlobalMotionSync *gm_sync = &cpi->mt_info.gm_sync; - JobInfo *job_info = &gm_sync->job_info; + GlobalMotionJobInfo *job_info = &gm_sync->job_info; int thread_id = thread_data->thread_id; GlobalMotionData *gm_thread_data = &thread_data->td->gm_data; #if CONFIG_MULTITHREAD @@ -2689,7 +2690,7 @@ static AOM_INLINE void gm_dealloc_thread_data(AV1_COMP *cpi, int num_workers) { // Implements multi-threading for global motion. void av1_global_motion_estimation_mt(AV1_COMP *cpi) { - JobInfo *job_info = &cpi->mt_info.gm_sync.job_info; + GlobalMotionJobInfo *job_info = &cpi->mt_info.gm_sync.job_info; av1_zero(*job_info); diff --git a/third_party/aom/av1/encoder/global_motion.h b/third_party/aom/av1/encoder/global_motion.h index de46a0e1f2..2645f93e3c 100644 --- a/third_party/aom/av1/encoder/global_motion.h +++ b/third_party/aom/av1/encoder/global_motion.h @@ -14,9 +14,8 @@ #include "aom/aom_integer.h" #include "aom_dsp/flow_estimation/flow_estimation.h" -#include "aom_scale/yv12config.h" #include "aom_util/aom_pthread.h" -#include "aom_util/aom_thread.h" +#include "av1/encoder/enc_enums.h" #ifdef __cplusplus extern "C" { @@ -58,11 +57,11 @@ typedef struct { // next_frame_to_process[i] will hold the count of next reference frame to be // processed in the direction 'i'. int8_t next_frame_to_process[MAX_DIRECTIONS]; -} JobInfo; +} GlobalMotionJobInfo; typedef struct { // Data related to assigning jobs for global motion multi-threading. - JobInfo job_info; + GlobalMotionJobInfo job_info; #if CONFIG_MULTITHREAD // Mutex lock used while dispatching jobs. diff --git a/third_party/aom/av1/encoder/nonrd_pickmode.c b/third_party/aom/av1/encoder/nonrd_pickmode.c index 57c74f66d5..08ecb8495a 100644 --- a/third_party/aom/av1/encoder/nonrd_pickmode.c +++ b/third_party/aom/av1/encoder/nonrd_pickmode.c @@ -1886,14 +1886,17 @@ static AOM_INLINE int skip_mode_by_low_temp( static AOM_INLINE int skip_mode_by_bsize_and_ref_frame( PREDICTION_MODE mode, MV_REFERENCE_FRAME ref_frame, BLOCK_SIZE bsize, - int extra_prune, unsigned int sse_zeromv_norm, int more_prune) { + int extra_prune, unsigned int sse_zeromv_norm, int more_prune, + int skip_nearmv) { const unsigned int thresh_skip_golden = 500; if (ref_frame != LAST_FRAME && sse_zeromv_norm < thresh_skip_golden && mode == NEWMV) return 1; - if (bsize == BLOCK_128X128 && mode == NEWMV) return 1; + if ((bsize == BLOCK_128X128 && mode == NEWMV) || + (skip_nearmv && mode == NEARMV)) + return 1; // Skip testing non-LAST if this flag is set. if (extra_prune) { @@ -2361,6 +2364,18 @@ static AOM_FORCE_INLINE bool skip_inter_mode_nonrd( (*this_mode != GLOBALMV || *ref_frame != LAST_FRAME)) return true; + // Skip the mode if use reference frame mask flag is not set. + if (!search_state->use_ref_frame_mask[*ref_frame]) return true; + + // Skip mode for some modes and reference frames when + // force_zeromv_skip_for_blk flag is true. + if (x->force_zeromv_skip_for_blk && + ((!(*this_mode == NEARESTMV && + search_state->frame_mv[*this_mode][*ref_frame].as_int == 0) && + *this_mode != GLOBALMV) || + *ref_frame != LAST_FRAME)) + return true; + if (x->sb_me_block && *ref_frame == LAST_FRAME) { // We want to make sure to test the superblock MV: // so don't skip (return false) for NEAREST_LAST or NEAR_LAST if they @@ -2400,18 +2415,6 @@ static AOM_FORCE_INLINE bool skip_inter_mode_nonrd( mi->ref_frame[0] = *ref_frame; mi->ref_frame[1] = *ref_frame2; - // Skip the mode if use reference frame mask flag is not set. - if (!search_state->use_ref_frame_mask[*ref_frame]) return true; - - // Skip mode for some modes and reference frames when - // force_zeromv_skip_for_blk flag is true. - if (x->force_zeromv_skip_for_blk && - ((!(*this_mode == NEARESTMV && - search_state->frame_mv[*this_mode][*ref_frame].as_int == 0) && - *this_mode != GLOBALMV) || - *ref_frame != LAST_FRAME)) - return true; - // Skip compound mode based on variance of previously evaluated single // reference modes. if (rt_sf->prune_compoundmode_with_singlemode_var && !*is_single_pred && @@ -2478,7 +2481,8 @@ static AOM_FORCE_INLINE bool skip_inter_mode_nonrd( // properties. if (skip_mode_by_bsize_and_ref_frame( *this_mode, *ref_frame, bsize, x->nonrd_prune_ref_frame_search, - sse_zeromv_norm, rt_sf->nonrd_aggressive_skip)) + sse_zeromv_norm, rt_sf->nonrd_aggressive_skip, + rt_sf->increase_source_sad_thresh)) return true; // Skip mode based on low temporal variance and souce sad. diff --git a/third_party/aom/av1/encoder/partition_search.c b/third_party/aom/av1/encoder/partition_search.c index 61d49a23f2..30ea7d9140 100644 --- a/third_party/aom/av1/encoder/partition_search.c +++ b/third_party/aom/av1/encoder/partition_search.c @@ -2323,8 +2323,9 @@ static void pick_sb_modes_nonrd(AV1_COMP *const cpi, TileDataEnc *tile_data, } if (cpi->sf.rt_sf.skip_cdef_sb) { // cdef_strength is initialized to 1 which means skip_cdef, and is updated - // here. Check to see is skipping cdef is allowed. - // Always allow cdef_skip for seg_skip = 1. + // here. Check to see is skipping cdef is allowed. Never skip on slide/scene + // change, near a key frame, or when color sensitivity is set. Always allow + // cdef_skip for seg_skip = 1. const int allow_cdef_skipping = seg_skip || (cpi->rc.frames_since_key > 10 && !cpi->rc.high_source_sad && @@ -2338,8 +2339,16 @@ static void pick_sb_modes_nonrd(AV1_COMP *const cpi, TileDataEnc *tile_data, MB_MODE_INFO **mi_sb = cm->mi_params.mi_grid_base + get_mi_grid_idx(&cm->mi_params, mi_row_sb, mi_col_sb); - // Do not skip if intra or new mv is picked, or color sensitivity is set. - // Never skip on slide/scene change. + const int is_720p_or_larger = AOMMIN(cm->width, cm->height) >= 720; + unsigned int thresh_spatial_var = + (cpi->oxcf.speed >= 11 && !is_720p_or_larger && + cpi->oxcf.tune_cfg.content != AOM_CONTENT_SCREEN) + ? 400 + : UINT_MAX; + // For skip_cdef_sb = 1: do not skip if allow_cdef_skipping is false or + // intra or new mv is picked, with possible conidition on spatial variance. + // For skip_cdef_sb >= 2: more aggressive mode to always skip unless + // allow_cdef_skipping is false and source_variance is non-zero. if (cpi->sf.rt_sf.skip_cdef_sb >= 2) { mi_sb[0]->cdef_strength = mi_sb[0]->cdef_strength && @@ -2347,7 +2356,8 @@ static void pick_sb_modes_nonrd(AV1_COMP *const cpi, TileDataEnc *tile_data, } else { mi_sb[0]->cdef_strength = mi_sb[0]->cdef_strength && allow_cdef_skipping && - !(mbmi->mode < INTRA_MODES || mbmi->mode == NEWMV); + !(x->source_variance < thresh_spatial_var && + (mbmi->mode < INTRA_MODES || mbmi->mode == NEWMV)); } // Store in the pickmode context. ctx->mic.cdef_strength = mi_sb[0]->cdef_strength; diff --git a/third_party/aom/av1/encoder/picklpf.c b/third_party/aom/av1/encoder/picklpf.c index a504535028..ce0357163d 100644 --- a/third_party/aom/av1/encoder/picklpf.c +++ b/third_party/aom/av1/encoder/picklpf.c @@ -257,6 +257,8 @@ void av1_pick_filter_level(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi, inter_frame_multiplier = inter_frame_multiplier << 1; else if (cpi->rc.frame_source_sad > 50000) inter_frame_multiplier = 3 * (inter_frame_multiplier >> 1); + } else if (cpi->sf.rt_sf.use_fast_fixed_part) { + inter_frame_multiplier = inter_frame_multiplier << 1; } // These values were determined by linear fitting the result of the // searched level for 8 bit depth: diff --git a/third_party/aom/av1/encoder/pickrst.c b/third_party/aom/av1/encoder/pickrst.c index b0d0d0bb78..a431c4dada 100644 --- a/third_party/aom/av1/encoder/pickrst.c +++ b/third_party/aom/av1/encoder/pickrst.c @@ -1044,10 +1044,13 @@ void av1_compute_stats_c(int wiener_win, const uint8_t *dgd, const uint8_t *src, #if CONFIG_AV1_HIGHBITDEPTH void av1_compute_stats_highbd_c(int wiener_win, const uint8_t *dgd8, - const uint8_t *src8, int h_start, int h_end, + const uint8_t *src8, int16_t *dgd_avg, + int16_t *src_avg, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth) { + (void)dgd_avg; + (void)src_avg; int i, j, k, l; int32_t Y[WIENER_WIN2]; const int wiener_win2 = wiener_win * wiener_win; @@ -1659,9 +1662,10 @@ static AOM_INLINE void search_wiener( // functions. Optimize intrinsics of HBD design similar to LBD (i.e., // pre-calculate d and s buffers and avoid most of the C operations). av1_compute_stats_highbd(reduced_wiener_win, rsc->dgd_buffer, - rsc->src_buffer, limits->h_start, limits->h_end, - limits->v_start, limits->v_end, rsc->dgd_stride, - rsc->src_stride, M, H, cm->seq_params->bit_depth); + rsc->src_buffer, rsc->dgd_avg, rsc->src_avg, + limits->h_start, limits->h_end, limits->v_start, + limits->v_end, rsc->dgd_stride, rsc->src_stride, M, + H, cm->seq_params->bit_depth); } else { av1_compute_stats(reduced_wiener_win, rsc->dgd_buffer, rsc->src_buffer, rsc->dgd_avg, rsc->src_avg, limits->h_start, @@ -2081,10 +2085,9 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) { // and height aligned to multiple of 16 is considered for intrinsic purpose. rsc.dgd_avg = NULL; rsc.src_avg = NULL; -#if HAVE_AVX2 || HAVE_NEON - // The buffers allocated below are used during Wiener filter processing of low - // bitdepth path. Hence, allocate the same when Wiener filter is enabled in - // low bitdepth path. +#if HAVE_AVX2 + // The buffers allocated below are used during Wiener filter processing. + // Hence, allocate the same when Wiener filter is enabled. if (!cpi->sf.lpf_sf.disable_wiener_filter && !highbd) { const int buf_size = sizeof(*cpi->pick_lr_ctxt.dgd_avg) * 6 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX; @@ -2221,7 +2224,7 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) { best_luma_unit_size); } -#if HAVE_AVX || HAVE_NEON +#if HAVE_AVX2 if (!cpi->sf.lpf_sf.disable_wiener_filter && !highbd) { aom_free(cpi->pick_lr_ctxt.dgd_avg); cpi->pick_lr_ctxt.dgd_avg = NULL; diff --git a/third_party/aom/av1/encoder/speed_features.c b/third_party/aom/av1/encoder/speed_features.c index 256b6fc9eb..9a00042520 100644 --- a/third_party/aom/av1/encoder/speed_features.c +++ b/third_party/aom/av1/encoder/speed_features.c @@ -1461,7 +1461,7 @@ static void set_rt_speed_feature_framesize_dependent(const AV1_COMP *const cpi, // for resolutions below 720p. if (speed >= 11 && !is_720p_or_larger && cpi->oxcf.tune_cfg.content != AOM_CONTENT_SCREEN) { - sf->rt_sf.skip_cdef_sb = 2; + sf->rt_sf.skip_cdef_sb = 1; sf->rt_sf.force_only_last_ref = 1; sf->rt_sf.selective_cdf_update = 1; sf->rt_sf.use_nonrd_filter_search = 0; diff --git a/third_party/aom/av1/encoder/tune_vmaf.c b/third_party/aom/av1/encoder/tune_vmaf.c index 91db3db726..fdb7c77ebc 100644 --- a/third_party/aom/av1/encoder/tune_vmaf.c +++ b/third_party/aom/av1/encoder/tune_vmaf.c @@ -247,7 +247,9 @@ static AOM_INLINE void unsharp(const AV1_COMP *const cpi, // 8-tap Gaussian convolution filter with sigma = 1.0, sums to 128, // all co-efficients must be even. -DECLARE_ALIGNED(16, static const int16_t, gauss_filter[8]) = { 0, 8, 30, 52, +// The array is of size 9 to allow passing gauss_filter + 1 to +// _mm_loadu_si128() in prepare_coeffs_6t(). +DECLARE_ALIGNED(16, static const int16_t, gauss_filter[9]) = { 0, 8, 30, 52, 30, 8, 0, 0 }; static AOM_INLINE void gaussian_blur(const int bit_depth, const YV12_BUFFER_CONFIG *source, diff --git a/third_party/aom/av1/encoder/x86/pickrst_avx2.c b/third_party/aom/av1/encoder/x86/pickrst_avx2.c index 6658ed39a8..1f76576c9e 100644 --- a/third_party/aom/av1/encoder/x86/pickrst_avx2.c +++ b/third_party/aom/av1/encoder/x86/pickrst_avx2.c @@ -345,21 +345,27 @@ static INLINE void compute_stats_highbd_win5_opt_avx2( } void av1_compute_stats_highbd_avx2(int wiener_win, const uint8_t *dgd8, - const uint8_t *src8, int h_start, int h_end, + const uint8_t *src8, int16_t *dgd_avg, + int16_t *src_avg, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth) { if (wiener_win == WIENER_WIN) { + (void)dgd_avg; + (void)src_avg; compute_stats_highbd_win7_opt_avx2(dgd8, src8, h_start, h_end, v_start, v_end, dgd_stride, src_stride, M, H, bit_depth); } else if (wiener_win == WIENER_WIN_CHROMA) { + (void)dgd_avg; + (void)src_avg; compute_stats_highbd_win5_opt_avx2(dgd8, src8, h_start, h_end, v_start, v_end, dgd_stride, src_stride, M, H, bit_depth); } else { - av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start, - v_end, dgd_stride, src_stride, M, H, bit_depth); + av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg, + h_start, h_end, v_start, v_end, dgd_stride, + src_stride, M, H, bit_depth); } } #endif // CONFIG_AV1_HIGHBITDEPTH diff --git a/third_party/aom/av1/encoder/x86/pickrst_sse4.c b/third_party/aom/av1/encoder/x86/pickrst_sse4.c index 50db305802..3617d33fef 100644 --- a/third_party/aom/av1/encoder/x86/pickrst_sse4.c +++ b/third_party/aom/av1/encoder/x86/pickrst_sse4.c @@ -524,21 +524,27 @@ static INLINE void compute_stats_highbd_win5_opt_sse4_1( } void av1_compute_stats_highbd_sse4_1(int wiener_win, const uint8_t *dgd8, - const uint8_t *src8, int h_start, - int h_end, int v_start, int v_end, - int dgd_stride, int src_stride, int64_t *M, - int64_t *H, aom_bit_depth_t bit_depth) { + const uint8_t *src8, int16_t *dgd_avg, + int16_t *src_avg, int h_start, int h_end, + int v_start, int v_end, int dgd_stride, + int src_stride, int64_t *M, int64_t *H, + aom_bit_depth_t bit_depth) { if (wiener_win == WIENER_WIN) { + (void)dgd_avg; + (void)src_avg; compute_stats_highbd_win7_opt_sse4_1(dgd8, src8, h_start, h_end, v_start, v_end, dgd_stride, src_stride, M, H, bit_depth); } else if (wiener_win == WIENER_WIN_CHROMA) { + (void)dgd_avg; + (void)src_avg; compute_stats_highbd_win5_opt_sse4_1(dgd8, src8, h_start, h_end, v_start, v_end, dgd_stride, src_stride, M, H, bit_depth); } else { - av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start, - v_end, dgd_stride, src_stride, M, H, bit_depth); + av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg, + h_start, h_end, v_start, v_end, dgd_stride, + src_stride, M, H, bit_depth); } } #endif // CONFIG_AV1_HIGHBITDEPTH diff --git a/third_party/aom/test/aom_image_test.cc b/third_party/aom/test/aom_image_test.cc index 03f4373f35..0dfb912215 100644 --- a/third_party/aom/test/aom_image_test.cc +++ b/third_party/aom/test/aom_image_test.cc @@ -9,6 +9,8 @@ * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ +#include + #include "aom/aom_image.h" #include "third_party/googletest/src/googletest/include/gtest/gtest.h" @@ -70,3 +72,66 @@ TEST(AomImageTest, AomImgAllocNv12) { EXPECT_EQ(img.planes[AOM_PLANE_V], nullptr); aom_img_free(&img); } + +TEST(AomImageTest, AomImgAllocHugeWidth) { + // The stride (0x80000000 * 2) would overflow unsigned int. + aom_image_t *image = + aom_img_alloc(nullptr, AOM_IMG_FMT_I42016, 0x80000000, 1, 1); + ASSERT_EQ(image, nullptr); + + // The stride (0x80000000) would overflow int. + image = aom_img_alloc(nullptr, AOM_IMG_FMT_I420, 0x80000000, 1, 1); + ASSERT_EQ(image, nullptr); + + // The aligned width (UINT_MAX + 1) would overflow unsigned int. + image = aom_img_alloc(nullptr, AOM_IMG_FMT_I420, UINT_MAX, 1, 1); + ASSERT_EQ(image, nullptr); + + image = aom_img_alloc_with_border(nullptr, AOM_IMG_FMT_I422, 1, INT_MAX, 1, + 0x40000000, 0); + if (image) { + uint16_t *y_plane = + reinterpret_cast(image->planes[AOM_PLANE_Y]); + y_plane[0] = 0; + y_plane[image->d_w - 1] = 0; + aom_img_free(image); + } + + image = aom_img_alloc(nullptr, AOM_IMG_FMT_I420, 0x7ffffffe, 1, 1); + if (image) { + aom_img_free(image); + } + + image = aom_img_alloc(nullptr, AOM_IMG_FMT_I420, 285245883, 64, 1); + if (image) { + aom_img_free(image); + } + + image = aom_img_alloc(nullptr, AOM_IMG_FMT_NV12, 285245883, 64, 1); + if (image) { + aom_img_free(image); + } + + image = aom_img_alloc(nullptr, AOM_IMG_FMT_YV12, 285245883, 64, 1); + if (image) { + aom_img_free(image); + } + + image = aom_img_alloc(nullptr, AOM_IMG_FMT_I42016, 65536, 2, 1); + if (image) { + uint16_t *y_plane = + reinterpret_cast(image->planes[AOM_PLANE_Y]); + y_plane[0] = 0; + y_plane[image->d_w - 1] = 0; + aom_img_free(image); + } + + image = aom_img_alloc(nullptr, AOM_IMG_FMT_I42016, 285245883, 2, 1); + if (image) { + uint16_t *y_plane = + reinterpret_cast(image->planes[AOM_PLANE_Y]); + y_plane[0] = 0; + y_plane[image->d_w - 1] = 0; + aom_img_free(image); + } +} diff --git a/third_party/aom/test/disflow_test.cc b/third_party/aom/test/disflow_test.cc index 4f004480e2..bee9e1261c 100644 --- a/third_party/aom/test/disflow_test.cc +++ b/third_party/aom/test/disflow_test.cc @@ -124,4 +124,9 @@ INSTANTIATE_TEST_SUITE_P(NEON, ComputeFlowTest, ::testing::Values(aom_compute_flow_at_point_neon)); #endif +#if HAVE_SVE +INSTANTIATE_TEST_SUITE_P(SVE, ComputeFlowTest, + ::testing::Values(aom_compute_flow_at_point_sve)); +#endif + } // namespace diff --git a/third_party/aom/test/ethread_test.cc b/third_party/aom/test/ethread_test.cc index ce45394eb8..415f5de269 100644 --- a/third_party/aom/test/ethread_test.cc +++ b/third_party/aom/test/ethread_test.cc @@ -18,6 +18,7 @@ #include "test/util.h" #include "test/y4m_video_source.h" #include "test/yuv_video_source.h" +#include "av1/encoder/enc_enums.h" #include "av1/encoder/firstpass.h" namespace { @@ -411,9 +412,7 @@ class AVxEncoderThreadTest const std::vector ref_size_enc, const std::vector ref_md5_enc, const std::vector ref_md5_dec) { - // This value should be kept the same as MAX_NUM_THREADS - // in aom_thread.h - cfg_.g_threads = 64; + cfg_.g_threads = MAX_NUM_THREADS; ASSERT_NO_FATAL_FAILURE(RunLoop(video)); std::vector multi_thr_max_row_mt_size_enc; std::vector multi_thr_max_row_mt_md5_enc; diff --git a/third_party/aom/test/frame_resize_test.cc b/third_party/aom/test/frame_resize_test.cc new file mode 100644 index 0000000000..8891304192 --- /dev/null +++ b/third_party/aom/test/frame_resize_test.cc @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2024, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include "config/av1_rtcd.h" +#include "test/acm_random.h" +#include "test/util.h" +#include "aom_ports/aom_timer.h" +#include "aom_ports/bitops.h" +#include "third_party/googletest/src/googletest/include/gtest/gtest.h" + +namespace { + +using ::testing::Combine; +using ::testing::Values; +using ::testing::ValuesIn; + +using std::make_tuple; +using std::tuple; + +const int kIters = 1000; + +typedef tuple FrameDimension; + +// Resolutions (width x height) to be tested for resizing. +const FrameDimension kFrameDim[] = { + make_tuple(3840, 2160), make_tuple(2560, 1440), make_tuple(1920, 1080), + make_tuple(1280, 720), make_tuple(640, 480), make_tuple(640, 360), + make_tuple(256, 256), +}; + +// Check that two 8-bit output buffers are identical. +void AssertOutputBufferEq(const uint8_t *p1, const uint8_t *p2, int width, + int height) { + ASSERT_TRUE(p1 != p2) << "Buffers must be at different memory locations"; + for (int j = 0; j < height; ++j) { + if (memcmp(p1, p2, sizeof(*p1) * width) == 0) { + p1 += width; + p2 += width; + continue; + } + for (int i = 0; i < width; ++i) { + ASSERT_EQ(p1[i], p2[i]) + << width << "x" << height << " Pixel mismatch at (" << i << ", " << j + << ")"; + } + } +} + +typedef bool (*LowBDResizeFunc)(uint8_t *intbuf, uint8_t *output, + int out_stride, int height, int height2, + int stride, int start_wd); +// Test parameter list: +// +typedef tuple ResizeTestParams; + +class AV1ResizeYTest : public ::testing::TestWithParam { + public: + void SetUp() { + test_fun_ = GET_PARAM(0); + frame_dim_ = GET_PARAM(1); + width_ = std::get<0>(frame_dim_); + height_ = std::get<1>(frame_dim_); + const int msb = get_msb(AOMMIN(width_, height_)); + n_levels_ = AOMMAX(msb - MIN_PYRAMID_SIZE_LOG2, 1); + + src_ = (uint8_t *)aom_malloc((width_ / 2) * height_ * sizeof(*src_)); + ref_dest_ = + (uint8_t *)aom_calloc((width_ * height_) / 4, sizeof(*ref_dest_)); + test_dest_ = + (uint8_t *)aom_calloc((width_ * height_) / 4, sizeof(*test_dest_)); + } + + void RunTest() { + int width2 = width_, height2 = height_; + + for (int i = 0; i < (width_ / 2) * height_; i++) src_[i] = rng_.Rand8(); + for (int level = 1; level < n_levels_; level++) { + width2 = (width_ >> level); + height2 = (height_ >> level); + resize_vert_dir_c(src_, ref_dest_, width2, height2 << 1, height2, width2, + 0); + test_fun_(src_, test_dest_, width2, height2 << 1, height2, width2, 0); + + AssertOutputBufferEq(ref_dest_, test_dest_, width2, height2); + } + } + + void SpeedTest() { + int width2 = width_, height2 = height_; + + for (int i = 0; i < (width_ / 2) * height_; i++) src_[i] = rng_.Rand8(); + for (int level = 1; level < n_levels_; level++) { + width2 = (width_ >> level); + height2 = (height_ >> level); + aom_usec_timer ref_timer; + aom_usec_timer_start(&ref_timer); + for (int j = 0; j < kIters; j++) { + resize_vert_dir_c(src_, ref_dest_, width2, height2 << 1, height2, + width2, 0); + } + aom_usec_timer_mark(&ref_timer); + const int64_t ref_time = aom_usec_timer_elapsed(&ref_timer); + + aom_usec_timer tst_timer; + aom_usec_timer_start(&tst_timer); + for (int j = 0; j < kIters; j++) { + test_fun_(src_, test_dest_, width2, height2 << 1, height2, width2, 0); + } + aom_usec_timer_mark(&tst_timer); + const int64_t tst_time = aom_usec_timer_elapsed(&tst_timer); + + std::cout << "level: " << level << " [" << width2 << " x " << height2 + << "] C time = " << ref_time << " , SIMD time = " << tst_time + << " scaling=" << float(1.00) * ref_time / tst_time << "x \n"; + } + } + + void TearDown() { + aom_free(src_); + aom_free(ref_dest_); + aom_free(test_dest_); + } + + private: + LowBDResizeFunc test_fun_; + FrameDimension frame_dim_; + int width_; + int height_; + int n_levels_; + uint8_t *src_; + uint8_t *ref_dest_; + uint8_t *test_dest_; + libaom_test::ACMRandom rng_; +}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AV1ResizeYTest); + +TEST_P(AV1ResizeYTest, RunTest) { RunTest(); } + +TEST_P(AV1ResizeYTest, DISABLED_SpeedTest) { SpeedTest(); } + +#if HAVE_AVX2 +INSTANTIATE_TEST_SUITE_P( + AVX2, AV1ResizeYTest, + ::testing::Combine(::testing::Values(resize_vert_dir_avx2), + ::testing::ValuesIn(kFrameDim))); +#endif + +} // namespace diff --git a/third_party/aom/test/test.cmake b/third_party/aom/test/test.cmake index e2f5da570d..2631c9fb39 100644 --- a/third_party/aom/test/test.cmake +++ b/third_party/aom/test/test.cmake @@ -209,6 +209,7 @@ if(NOT BUILD_SHARED_LIBS) "${AOM_ROOT}/test/fdct4x4_test.cc" "${AOM_ROOT}/test/fft_test.cc" "${AOM_ROOT}/test/firstpass_test.cc" + "${AOM_ROOT}/test/frame_resize_test.cc" "${AOM_ROOT}/test/fwht4x4_test.cc" "${AOM_ROOT}/test/hadamard_test.cc" "${AOM_ROOT}/test/horver_correlation_test.cc" diff --git a/third_party/aom/test/wiener_test.cc b/third_party/aom/test/wiener_test.cc index b995c84d8f..c38e10e3c2 100644 --- a/third_party/aom/test/wiener_test.cc +++ b/third_party/aom/test/wiener_test.cc @@ -397,6 +397,12 @@ INSTANTIATE_TEST_SUITE_P(NEON, WienerTest, ::testing::Values(av1_compute_stats_neon)); #endif // HAVE_NEON +#if HAVE_SVE + +INSTANTIATE_TEST_SUITE_P(SVE, WienerTest, + ::testing::Values(av1_compute_stats_sve)); +#endif // HAVE_SVE + } // namespace wiener_lowbd #if CONFIG_AV1_HIGHBITDEPTH @@ -514,25 +520,27 @@ static void compute_stats_highbd_win_opt_c(int wiener_win, const uint8_t *dgd8, } void compute_stats_highbd_opt_c(int wiener_win, const uint8_t *dgd, - const uint8_t *src, int h_start, int h_end, - int v_start, int v_end, int dgd_stride, - int src_stride, int64_t *M, int64_t *H, - aom_bit_depth_t bit_depth) { + const uint8_t *src, int16_t *d, int16_t *s, + int h_start, int h_end, int v_start, int v_end, + int dgd_stride, int src_stride, int64_t *M, + int64_t *H, aom_bit_depth_t bit_depth) { if (wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA) { compute_stats_highbd_win_opt_c(wiener_win, dgd, src, h_start, h_end, v_start, v_end, dgd_stride, src_stride, M, H, bit_depth); } else { - av1_compute_stats_highbd_c(wiener_win, dgd, src, h_start, h_end, v_start, - v_end, dgd_stride, src_stride, M, H, bit_depth); + av1_compute_stats_highbd_c(wiener_win, dgd, src, d, s, h_start, h_end, + v_start, v_end, dgd_stride, src_stride, M, H, + bit_depth); } } static const int kIterations = 100; typedef void (*compute_stats_Func)(int wiener_win, const uint8_t *dgd, - const uint8_t *src, int h_start, int h_end, - int v_start, int v_end, int dgd_stride, - int src_stride, int64_t *M, int64_t *H, + const uint8_t *src, int16_t *d, int16_t *s, + int h_start, int h_end, int v_start, + int v_end, int dgd_stride, int src_stride, + int64_t *M, int64_t *H, aom_bit_depth_t bit_depth); typedef std::tuple WienerTestParam; @@ -546,11 +554,17 @@ class WienerTestHighbd : public ::testing::TestWithParam { dgd_buf = (uint16_t *)aom_memalign( 32, MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*dgd_buf)); ASSERT_NE(dgd_buf, nullptr); + const size_t buf_size = + sizeof(*buf) * 6 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX; + buf = (int16_t *)aom_memalign(32, buf_size); + ASSERT_NE(buf, nullptr); + memset(buf, 0, buf_size); target_func_ = GET_PARAM(0); } void TearDown() override { aom_free(src_buf); aom_free(dgd_buf); + aom_free(buf); } void RunWienerTest(const int32_t wiener_win, int32_t run_times, aom_bit_depth_t bit_depth); @@ -562,6 +576,7 @@ class WienerTestHighbd : public ::testing::TestWithParam { libaom_test::ACMRandom rng_; uint16_t *src_buf; uint16_t *dgd_buf; + int16_t *buf; }; void WienerTestHighbd::RunWienerTest(const int32_t wiener_win, @@ -589,6 +604,9 @@ void WienerTestHighbd::RunWienerTest(const int32_t wiener_win, const int dgd_stride = h_end; const int src_stride = MAX_DATA_BLOCK; const int iters = run_times == 1 ? kIterations : 2; + int16_t *dgd_avg = buf; + int16_t *src_avg = + buf + (3 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX); for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) { for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) { dgd_buf[i] = rng_.Rand16() % (1 << bit_depth); @@ -601,16 +619,17 @@ void WienerTestHighbd::RunWienerTest(const int32_t wiener_win, aom_usec_timer timer; aom_usec_timer_start(&timer); for (int i = 0; i < run_times; ++i) { - av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, - v_start, v_end, dgd_stride, src_stride, M_ref, - H_ref, bit_depth); + av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg, + h_start, h_end, v_start, v_end, dgd_stride, + src_stride, M_ref, H_ref, bit_depth); } aom_usec_timer_mark(&timer); const double time1 = static_cast(aom_usec_timer_elapsed(&timer)); aom_usec_timer_start(&timer); for (int i = 0; i < run_times; ++i) { - target_func_(wiener_win, dgd8, src8, h_start, h_end, v_start, v_end, - dgd_stride, src_stride, M_test, H_test, bit_depth); + target_func_(wiener_win, dgd8, src8, dgd_avg, src_avg, h_start, h_end, + v_start, v_end, dgd_stride, src_stride, M_test, H_test, + bit_depth); } aom_usec_timer_mark(&timer); const double time2 = static_cast(aom_usec_timer_elapsed(&timer)); @@ -657,6 +676,9 @@ void WienerTestHighbd::RunWienerTest_ExtremeValues(const int32_t wiener_win, const int dgd_stride = h_end; const int src_stride = MAX_DATA_BLOCK; const int iters = 1; + int16_t *dgd_avg = buf; + int16_t *src_avg = + buf + (3 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX); for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) { // Fill with alternating extreme values to maximize difference with // the average. @@ -668,12 +690,13 @@ void WienerTestHighbd::RunWienerTest_ExtremeValues(const int32_t wiener_win, dgd_buf + wiener_halfwin * MAX_DATA_BLOCK + wiener_halfwin); const uint8_t *src8 = CONVERT_TO_BYTEPTR(src_buf); - av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start, - v_end, dgd_stride, src_stride, M_ref, H_ref, - bit_depth); + av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg, + h_start, h_end, v_start, v_end, dgd_stride, + src_stride, M_ref, H_ref, bit_depth); - target_func_(wiener_win, dgd8, src8, h_start, h_end, v_start, v_end, - dgd_stride, src_stride, M_test, H_test, bit_depth); + target_func_(wiener_win, dgd8, src8, dgd_avg, src_avg, h_start, h_end, + v_start, v_end, dgd_stride, src_stride, M_test, H_test, + bit_depth); int failed = 0; for (int i = 0; i < wiener_win2; ++i) { -- cgit v1.2.3