summaryrefslogtreecommitdiffstats
path: root/third_party/aom/aom_dsp
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/aom/aom_dsp')
-rw-r--r--third_party/aom/aom_dsp/aom_dsp.cmake3
-rwxr-xr-xthird_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl2
-rw-r--r--third_party/aom/aom_dsp/arm/aom_convolve8_neon.c401
-rw-r--r--third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c428
-rw-r--r--third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c334
-rw-r--r--third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c104
-rw-r--r--third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h127
-rw-r--r--third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c268
-rw-r--r--third_party/aom/aom_dsp/pyramid.c31
-rw-r--r--third_party/aom/aom_dsp/x86/synonyms.h1
10 files changed, 1046 insertions, 653 deletions
diff --git a/third_party/aom/aom_dsp/aom_dsp.cmake b/third_party/aom/aom_dsp/aom_dsp.cmake
index de987cbd23..27099d36b2 100644
--- a/third_party/aom/aom_dsp/aom_dsp.cmake
+++ b/third_party/aom/aom_dsp/aom_dsp.cmake
@@ -205,6 +205,9 @@ if(CONFIG_AV1_ENCODER)
list(APPEND AOM_DSP_ENCODER_INTRIN_NEON
"${AOM_ROOT}/aom_dsp/flow_estimation/arm/disflow_neon.c")
+
+ list(APPEND AOM_DSP_ENCODER_INTRIN_SVE
+ "${AOM_ROOT}/aom_dsp/flow_estimation/arm/disflow_sve.c")
endif()
list(APPEND AOM_DSP_ENCODER_ASM_SSE2 "${AOM_ROOT}/aom_dsp/x86/sad4d_sse2.asm"
diff --git a/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl b/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl
index 7e746e9cb9..b75bdc5a19 100755
--- a/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/third_party/aom/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -1799,7 +1799,7 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
specialize qw/aom_compute_correlation sse4_1 avx2/;
add_proto qw/void aom_compute_flow_at_point/, "const uint8_t *src, const uint8_t *ref, int x, int y, int width, int height, int stride, double *u, double *v";
- specialize qw/aom_compute_flow_at_point sse4_1 avx2 neon/;
+ specialize qw/aom_compute_flow_at_point sse4_1 avx2 neon sve/;
}
} # CONFIG_AV1_ENCODER
diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c
index 7441108b01..6a177b2e6b 100644
--- a/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c
+++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon.c
@@ -20,6 +20,7 @@
#include "aom/aom_integer.h"
#include "aom_dsp/aom_dsp_common.h"
#include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/aom_filter.h"
#include "aom_dsp/arm/mem_neon.h"
#include "aom_dsp/arm/transpose_neon.h"
#include "aom_ports/mem.h"
@@ -31,14 +32,14 @@ static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
const int16x8_t filter) {
const int16x4_t filter_lo = vget_low_s16(filter);
const int16x4_t filter_hi = vget_high_s16(filter);
- int16x4_t sum;
- sum = vmul_lane_s16(s0, filter_lo, 0);
+ int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0);
sum = vmla_lane_s16(sum, s1, filter_lo, 1);
sum = vmla_lane_s16(sum, s2, filter_lo, 2);
sum = vmla_lane_s16(sum, s5, filter_hi, 1);
sum = vmla_lane_s16(sum, s6, filter_hi, 2);
sum = vmla_lane_s16(sum, s7, filter_hi, 3);
+
sum = vqadd_s16(sum, vmul_lane_s16(s3, filter_lo, 3));
sum = vqadd_s16(sum, vmul_lane_s16(s4, filter_hi, 0));
return sum;
@@ -51,65 +52,56 @@ static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
const int16x8_t filter) {
const int16x4_t filter_lo = vget_low_s16(filter);
const int16x4_t filter_hi = vget_high_s16(filter);
- int16x8_t sum;
- sum = vmulq_lane_s16(s0, filter_lo, 0);
+ int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0);
sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
+
sum = vqaddq_s16(sum, vmulq_lane_s16(s3, filter_lo, 3));
sum = vqaddq_s16(sum, vmulq_lane_s16(s4, filter_hi, 0));
return vqrshrun_n_s16(sum, FILTER_BITS);
}
-void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
- uint8_t *dst, ptrdiff_t dst_stride,
- const int16_t *filter_x, int x_step_q4,
- const int16_t *filter_y, int y_step_q4, int w,
- int h) {
+static INLINE void convolve8_horiz_8tap_neon(const uint8_t *src,
+ ptrdiff_t src_stride, uint8_t *dst,
+ ptrdiff_t dst_stride,
+ const int16_t *filter_x, int w,
+ int h) {
const int16x8_t filter = vld1q_s16(filter_x);
- assert((intptr_t)dst % 4 == 0);
- assert(dst_stride % 4 == 0);
-
- (void)x_step_q4;
- (void)filter_y;
- (void)y_step_q4;
-
- src -= ((SUBPEL_TAPS / 2) - 1);
-
if (h == 4) {
- uint8x8_t t0, t1, t2, t3, d01, d23;
- int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3;
-
+ uint8x8_t t0, t1, t2, t3;
load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3);
transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3);
- s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
- s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
- s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
- s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
- s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
- s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
- s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+
+ int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+ int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+ int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+ int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+ int16x4_t s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+ int16x4_t s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+ int16x4_t s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
src += 7;
do {
load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3);
transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3);
- s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
- s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
- s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
- s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
-
- d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter);
- d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter);
- d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter);
- d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter);
- d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
- d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+ int16x4_t s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+ int16x4_t s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+ int16x4_t s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+ int16x4_t s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+
+ int16x4_t d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+ int16x4_t d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter);
+ int16x4_t d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter);
+ int16x4_t d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter);
+ uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
transpose_elems_inplace_u8_4x4(&d01, &d23);
@@ -123,39 +115,40 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
s4 = s8;
s5 = s9;
s6 = s10;
+
src += 4;
dst += 4;
w -= 4;
} while (w != 0);
} else {
- uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7, d0, d1, d2, d3;
- int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10;
-
if (w == 4) {
do {
+ uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
- s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
- s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
- s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
- s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
- s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
- s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
- s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
+
+ int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+ int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+ int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
+ int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
load_u8_8x8(src + 7, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6,
&t7);
transpose_elems_u8_4x8(t0, t1, t2, t3, t4, t5, t6, t7, &t0, &t1, &t2,
&t3);
- s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
- s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
- s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
- s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
- d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
- d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter);
- d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter);
- d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter);
+ int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
+
+ uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+ uint8x8_t d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter);
+ uint8x8_t d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter);
+ uint8x8_t d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter);
transpose_elems_inplace_u8_8x4(&d0, &d1, &d2, &d3);
@@ -169,48 +162,49 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
h -= 8;
} while (h > 0);
} else {
- uint8x8_t d4, d5, d6, d7;
- int16x8_t s11, s12, s13, s14;
- int width;
- const uint8_t *s;
- uint8_t *d;
-
do {
- load_u8_8x8(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+ int width = w;
+ const uint8_t *s = src;
+ uint8_t *d = dst;
+
+ uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+ load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
- s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
- s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
- s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
- s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
- s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
- s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
- s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
-
- width = w;
- s = src + 7;
- d = dst;
+
+ int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+ int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+ int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
+ int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
+
+ s += 7;
do {
load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6,
&t7);
- s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
- s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
- s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
- s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
- s11 = vreinterpretq_s16_u16(vmovl_u8(t4));
- s12 = vreinterpretq_s16_u16(vmovl_u8(t5));
- s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
- s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
-
- d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
- d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter);
- d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter);
- d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter);
- d4 = convolve8_8(s4, s5, s6, s7, s8, s9, s10, s11, filter);
- d5 = convolve8_8(s5, s6, s7, s8, s9, s10, s11, s12, filter);
- d6 = convolve8_8(s6, s7, s8, s9, s10, s11, s12, s13, filter);
- d7 = convolve8_8(s7, s8, s9, s10, s11, s12, s13, s14, filter);
+
+ int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
+ int16x8_t s11 = vreinterpretq_s16_u16(vmovl_u8(t4));
+ int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t5));
+ int16x8_t s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
+ int16x8_t s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
+
+ uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+ uint8x8_t d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter);
+ uint8x8_t d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter);
+ uint8x8_t d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter);
+ uint8x8_t d4 = convolve8_8(s4, s5, s6, s7, s8, s9, s10, s11, filter);
+ uint8x8_t d5 = convolve8_8(s5, s6, s7, s8, s9, s10, s11, s12, filter);
+ uint8x8_t d6 =
+ convolve8_8(s6, s7, s8, s9, s10, s11, s12, s13, filter);
+ uint8x8_t d7 =
+ convolve8_8(s7, s8, s9, s10, s11, s12, s13, s14, filter);
transpose_elems_inplace_u8_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6,
&d7);
@@ -224,6 +218,7 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
s4 = s12;
s5 = s13;
s6 = s14;
+
s += 8;
d += 8;
width -= 8;
@@ -236,6 +231,137 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
}
}
+static INLINE int16x4_t convolve4_4(const int16x4_t s0, const int16x4_t s1,
+ const int16x4_t s2, const int16x4_t s3,
+ const int16x4_t filter) {
+ int16x4_t sum = vmul_lane_s16(s0, filter, 0);
+ sum = vmla_lane_s16(sum, s1, filter, 1);
+ sum = vmla_lane_s16(sum, s2, filter, 2);
+ sum = vmla_lane_s16(sum, s3, filter, 3);
+
+ return sum;
+}
+
+static INLINE uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1,
+ const int16x8_t s2, const int16x8_t s3,
+ const int16x4_t filter) {
+ int16x8_t sum = vmulq_lane_s16(s0, filter, 0);
+ sum = vmlaq_lane_s16(sum, s1, filter, 1);
+ sum = vmlaq_lane_s16(sum, s2, filter, 2);
+ sum = vmlaq_lane_s16(sum, s3, filter, 3);
+
+ // We halved the filter values so -1 from right shift.
+ return vqrshrun_n_s16(sum, FILTER_BITS - 1);
+}
+
+static INLINE void convolve8_horiz_4tap_neon(const uint8_t *src,
+ ptrdiff_t src_stride, uint8_t *dst,
+ ptrdiff_t dst_stride,
+ const int16_t *filter_x, int w,
+ int h) {
+ // All filter values are even, halve to reduce intermediate precision
+ // requirements.
+ const int16x4_t filter = vshr_n_s16(vld1_s16(filter_x + 2), 1);
+
+ if (w == 4) {
+ do {
+ int16x8_t t0 =
+ vreinterpretq_s16_u16(vmovl_u8(vld1_u8(src + 0 * src_stride)));
+ int16x8_t t1 =
+ vreinterpretq_s16_u16(vmovl_u8(vld1_u8(src + 1 * src_stride)));
+
+ int16x4_t s0[4], s1[4];
+ s0[0] = vget_low_s16(t0);
+ s0[1] = vget_low_s16(vextq_s16(t0, t0, 1));
+ s0[2] = vget_low_s16(vextq_s16(t0, t0, 2));
+ s0[3] = vget_low_s16(vextq_s16(t0, t0, 3));
+
+ s1[0] = vget_low_s16(t1);
+ s1[1] = vget_low_s16(vextq_s16(t1, t1, 1));
+ s1[2] = vget_low_s16(vextq_s16(t1, t1, 2));
+ s1[3] = vget_low_s16(vextq_s16(t1, t1, 3));
+
+ int16x4_t d0 = convolve4_4(s0[0], s0[1], s0[2], s0[3], filter);
+ int16x4_t d1 = convolve4_4(s1[0], s1[1], s1[2], s1[3], filter);
+ // We halved the filter values so -1 from right shift.
+ uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS - 1);
+
+ store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
+
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ } while (h > 0);
+ } else {
+ do {
+ int width = w;
+ const uint8_t *s = src;
+ uint8_t *d = dst;
+
+ int16x8_t t0 =
+ vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 0 * src_stride)));
+ int16x8_t t1 =
+ vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 1 * src_stride)));
+
+ s += 8;
+ do {
+ int16x8_t t2 =
+ vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 0 * src_stride)));
+ int16x8_t t3 =
+ vreinterpretq_s16_u16(vmovl_u8(vld1_u8(s + 1 * src_stride)));
+
+ int16x8_t s0[4], s1[4];
+ s0[0] = t0;
+ s0[1] = vextq_s16(t0, t2, 1);
+ s0[2] = vextq_s16(t0, t2, 2);
+ s0[3] = vextq_s16(t0, t2, 3);
+
+ s1[0] = t1;
+ s1[1] = vextq_s16(t1, t3, 1);
+ s1[2] = vextq_s16(t1, t3, 2);
+ s1[3] = vextq_s16(t1, t3, 3);
+
+ uint8x8_t d0 = convolve4_8(s0[0], s0[1], s0[2], s0[3], filter);
+ uint8x8_t d1 = convolve4_8(s1[0], s1[1], s1[2], s1[3], filter);
+
+ store_u8_8x2(d, dst_stride, d0, d1);
+
+ t0 = t2;
+ t1 = t3;
+
+ s += 8;
+ d += 8;
+ width -= 8;
+ } while (width != 0);
+ src += 2 * src_stride;
+ dst += 2 * dst_stride;
+ h -= 2;
+ } while (h > 0);
+ }
+}
+
+void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
+ uint8_t *dst, ptrdiff_t dst_stride,
+ const int16_t *filter_x, int x_step_q4,
+ const int16_t *filter_y, int y_step_q4, int w,
+ int h) {
+ assert((intptr_t)dst % 4 == 0);
+ assert(dst_stride % 4 == 0);
+
+ (void)x_step_q4;
+ (void)filter_y;
+ (void)y_step_q4;
+
+ src -= ((SUBPEL_TAPS / 2) - 1);
+
+ if (get_filter_taps_convolve8(filter_x) <= 4) {
+ convolve8_horiz_4tap_neon(src + 2, src_stride, dst, dst_stride, filter_x, w,
+ h);
+ } else {
+ convolve8_horiz_8tap_neon(src, src_stride, dst, dst_stride, filter_x, w, h);
+ }
+}
+
void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
uint8_t *dst, ptrdiff_t dst_stride,
const int16_t *filter_x, int x_step_q4,
@@ -253,33 +379,33 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
if (w == 4) {
- uint8x8_t t0, t1, t2, t3, t4, t5, t6, d01, d23;
- int16x4_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, d0, d1, d2, d3;
-
+ uint8x8_t t0, t1, t2, t3, t4, t5, t6;
load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
- s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
- s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
- s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
- s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
- s4 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t4)));
- s5 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t5)));
- s6 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t6)));
+
+ int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+ int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+ int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+ int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+ int16x4_t s4 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t4)));
+ int16x4_t s5 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t5)));
+ int16x4_t s6 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t6)));
src += 7 * src_stride;
do {
load_u8_8x4(src, src_stride, &t0, &t1, &t2, &t3);
- s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
- s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
- s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
- s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
-
- d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter);
- d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter);
- d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter);
- d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter);
- d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
- d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+ int16x4_t s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+ int16x4_t s8 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+ int16x4_t s9 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+ int16x4_t s10 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+
+ int16x4_t d0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+ int16x4_t d1 = convolve8_4(s1, s2, s3, s4, s5, s6, s7, s8, filter);
+ int16x4_t d2 = convolve8_4(s2, s3, s4, s5, s6, s7, s8, s9, filter);
+ int16x4_t d3 = convolve8_4(s3, s4, s5, s6, s7, s8, s9, s10, filter);
+ uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
@@ -291,42 +417,40 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
s4 = s8;
s5 = s9;
s6 = s10;
+
src += 4 * src_stride;
dst += 4 * dst_stride;
h -= 4;
} while (h != 0);
} else {
- uint8x8_t t0, t1, t2, t3, t4, t5, t6, d0, d1, d2, d3;
- int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10;
- int height;
- const uint8_t *s;
- uint8_t *d;
-
do {
+ uint8x8_t t0, t1, t2, t3, t4, t5, t6;
load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
- s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
- s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
- s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
- s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
- s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
- s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
- s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
-
- height = h;
- s = src + 7 * src_stride;
- d = dst;
+
+ int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+ int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+ int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
+ int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
+
+ int height = h;
+ const uint8_t *s = src + 7 * src_stride;
+ uint8_t *d = dst;
do {
load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
- s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
- s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
- s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
- s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
- d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
- d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter);
- d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter);
- d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter);
+ int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
+ int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
+ int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
+ int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
+
+ uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+ uint8x8_t d1 = convolve8_8(s1, s2, s3, s4, s5, s6, s7, s8, filter);
+ uint8x8_t d2 = convolve8_8(s2, s3, s4, s5, s6, s7, s8, s9, filter);
+ uint8x8_t d3 = convolve8_8(s3, s4, s5, s6, s7, s8, s9, s10, filter);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -337,6 +461,7 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
s4 = s8;
s5 = s9;
s6 = s10;
+
s += 4 * src_stride;
d += 4 * dst_stride;
height -= 4;
diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index c82125ba17..120c479798 100644
--- a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -24,81 +24,72 @@
#include "aom_dsp/arm/transpose_neon.h"
#include "aom_ports/mem.h"
-DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
+// Filter values always sum to 128.
+#define FILTER_WEIGHT 128
+
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
};
-DECLARE_ALIGNED(16, static const uint8_t, dot_prod_tran_concat_tbl[32]) = {
- 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27,
- 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31
-};
-
-DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = {
- /* Shift left and insert new last column in transposed 4x4 block. */
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
+ // Shift left and insert new last column in transposed 4x4 block.
1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28,
- /* Shift left and insert two new columns in transposed 4x4 block. */
+ // Shift left and insert two new columns in transposed 4x4 block.
2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29,
- /* Shift left and insert three new columns in transposed 4x4 block. */
+ // Shift left and insert three new columns in transposed 4x4 block.
3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
};
-static INLINE int16x4_t convolve8_4_sdot(uint8x16_t samples,
- const int8x8_t filter,
- const int32x4_t correction,
- const uint8x16_t range_limit,
- const uint8x16x2_t permute_tbl) {
- int8x16_t clamped_samples, permuted_samples[2];
- int32x4_t sum;
-
- /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
- clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
-
- /* Permute samples ready for dot product. */
- /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
- permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
- /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
- permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
-
- /* Accumulate dot product into 'correction' to account for range clamp. */
- sum = vdotq_lane_s32(correction, permuted_samples[0], filter, 0);
- sum = vdotq_lane_s32(sum, permuted_samples[1], filter, 1);
-
- /* Further narrowing and packing is performed by the caller. */
+static INLINE int16x4_t convolve8_4_h(const uint8x16_t samples,
+ const int8x8_t filters,
+ const uint8x16x2_t permute_tbl) {
+ // Transform sample range to [-128, 127] for 8-bit signed dot product.
+ int8x16_t samples_128 =
+ vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128)));
+
+ // Permute samples ready for dot product.
+ // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
+ // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
+ int8x16_t perm_samples[2] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]),
+ vqtbl1q_s8(samples_128, permute_tbl.val[1]) };
+
+ // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
+ int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
+ int32x4_t sum = vdotq_lane_s32(acc, perm_samples[0], filters, 0);
+ sum = vdotq_lane_s32(sum, perm_samples[1], filters, 1);
+
+ // Further narrowing and packing is performed by the caller.
return vqmovn_s32(sum);
}
-static INLINE uint8x8_t convolve8_8_sdot(uint8x16_t samples,
- const int8x8_t filter,
- const int32x4_t correction,
- const uint8x16_t range_limit,
- const uint8x16x3_t permute_tbl) {
- int8x16_t clamped_samples, permuted_samples[3];
- int32x4_t sum0, sum1;
- int16x8_t sum;
-
- /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
- clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
-
- /* Permute samples ready for dot product. */
- /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
- permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
- /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
- permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
- /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
- permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
-
- /* Accumulate dot product into 'correction' to account for range clamp. */
- /* First 4 output values. */
- sum0 = vdotq_lane_s32(correction, permuted_samples[0], filter, 0);
- sum0 = vdotq_lane_s32(sum0, permuted_samples[1], filter, 1);
- /* Second 4 output values. */
- sum1 = vdotq_lane_s32(correction, permuted_samples[1], filter, 0);
- sum1 = vdotq_lane_s32(sum1, permuted_samples[2], filter, 1);
-
- /* Narrow and re-pack. */
- sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
+static INLINE uint8x8_t convolve8_8_h(const uint8x16_t samples,
+ const int8x8_t filters,
+ const uint8x16x3_t permute_tbl) {
+ // Transform sample range to [-128, 127] for 8-bit signed dot product.
+ int8x16_t samples_128 =
+ vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128)));
+
+ // Permute samples ready for dot product.
+ // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
+ // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
+ // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
+ int8x16_t perm_samples[3] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]),
+ vqtbl1q_s8(samples_128, permute_tbl.val[1]),
+ vqtbl1q_s8(samples_128, permute_tbl.val[2]) };
+
+ // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
+ int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
+ // First 4 output values.
+ int32x4_t sum0 = vdotq_lane_s32(acc, perm_samples[0], filters, 0);
+ sum0 = vdotq_lane_s32(sum0, perm_samples[1], filters, 1);
+ // Second 4 output values.
+ int32x4_t sum1 = vdotq_lane_s32(acc, perm_samples[1], filters, 0);
+ sum1 = vdotq_lane_s32(sum1, perm_samples[2], filters, 1);
+
+ // Narrow and re-pack.
+ int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
return vqrshrun_n_s16(sum, FILTER_BITS);
}
@@ -108,10 +99,6 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
const int16_t *filter_y, int y_step_q4,
int w, int h) {
const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x));
- const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_x), 128);
- const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
- const uint8x16_t range_limit = vdupq_n_u8(128);
- uint8x16_t s0, s1, s2, s3;
assert((intptr_t)dst % 4 == 0);
assert(dst_stride % 4 == 0);
@@ -123,19 +110,17 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
src -= ((SUBPEL_TAPS / 2) - 1);
if (w == 4) {
- const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
do {
- int16x4_t t0, t1, t2, t3;
- uint8x8_t d01, d23;
-
+ uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
- t0 = convolve8_4_sdot(s0, filter, correction, range_limit, perm_tbl);
- t1 = convolve8_4_sdot(s1, filter, correction, range_limit, perm_tbl);
- t2 = convolve8_4_sdot(s2, filter, correction, range_limit, perm_tbl);
- t3 = convolve8_4_sdot(s3, filter, correction, range_limit, perm_tbl);
- d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS);
- d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS);
+ int16x4_t d0 = convolve8_4_h(s0, filter, perm_tbl);
+ int16x4_t d1 = convolve8_4_h(s1, filter, perm_tbl);
+ int16x4_t d2 = convolve8_4_h(s2, filter, perm_tbl);
+ int16x4_t d3 = convolve8_4_h(s3, filter, perm_tbl);
+ uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
@@ -145,23 +130,20 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
h -= 4;
} while (h > 0);
} else {
- const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
- const uint8_t *s;
- uint8_t *d;
- int width;
- uint8x8_t d0, d1, d2, d3;
+ const uint8x16x3_t perm_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
do {
- width = w;
- s = src;
- d = dst;
+ int width = w;
+ const uint8_t *s = src;
+ uint8_t *d = dst;
do {
+ uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- d0 = convolve8_8_sdot(s0, filter, correction, range_limit, perm_tbl);
- d1 = convolve8_8_sdot(s1, filter, correction, range_limit, perm_tbl);
- d2 = convolve8_8_sdot(s2, filter, correction, range_limit, perm_tbl);
- d3 = convolve8_8_sdot(s3, filter, correction, range_limit, perm_tbl);
+ uint8x8_t d0 = convolve8_8_h(s0, filter, perm_tbl);
+ uint8x8_t d1 = convolve8_8_h(s1, filter, perm_tbl);
+ uint8x8_t d2 = convolve8_8_h(s2, filter, perm_tbl);
+ uint8x8_t d3 = convolve8_8_h(s3, filter, perm_tbl);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -177,83 +159,88 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
}
static INLINE void transpose_concat_4x4(int8x8_t a0, int8x8_t a1, int8x8_t a2,
- int8x8_t a3, int8x16_t *b,
- const uint8x16_t permute_tbl) {
- /* Transpose 8-bit elements and concatenate result rows as follows:
- * a0: 00, 01, 02, 03, XX, XX, XX, XX
- * a1: 10, 11, 12, 13, XX, XX, XX, XX
- * a2: 20, 21, 22, 23, XX, XX, XX, XX
- * a3: 30, 31, 32, 33, XX, XX, XX, XX
- *
- * b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
- *
- * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it
- * as an argument is preferable to loading it directly from memory as this
- * inline helper is called many times from the same parent function.
- */
-
- int8x16x2_t samples = { { vcombine_s8(a0, a1), vcombine_s8(a2, a3) } };
- *b = vqtbl2q_s8(samples, permute_tbl);
+ int8x8_t a3, int8x16_t *b) {
+ // Transpose 8-bit elements and concatenate result rows as follows:
+ // a0: 00, 01, 02, 03, XX, XX, XX, XX
+ // a1: 10, 11, 12, 13, XX, XX, XX, XX
+ // a2: 20, 21, 22, 23, XX, XX, XX, XX
+ // a3: 30, 31, 32, 33, XX, XX, XX, XX
+ //
+ // b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
+
+ int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0));
+ int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0));
+ int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0));
+ int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0));
+
+ int8x16_t a01 = vzipq_s8(a0q, a1q).val[0];
+ int8x16_t a23 = vzipq_s8(a2q, a3q).val[0];
+
+ int16x8_t a0123 =
+ vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23)).val[0];
+
+ *b = vreinterpretq_s8_s16(a0123);
}
static INLINE void transpose_concat_8x4(int8x8_t a0, int8x8_t a1, int8x8_t a2,
int8x8_t a3, int8x16_t *b0,
- int8x16_t *b1,
- const uint8x16x2_t permute_tbl) {
- /* Transpose 8-bit elements and concatenate result rows as follows:
- * a0: 00, 01, 02, 03, 04, 05, 06, 07
- * a1: 10, 11, 12, 13, 14, 15, 16, 17
- * a2: 20, 21, 22, 23, 24, 25, 26, 27
- * a3: 30, 31, 32, 33, 34, 35, 36, 37
- *
- * b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
- * b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37
- *
- * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it
- * as an argument is preferable to loading it directly from memory as this
- * inline helper is called many times from the same parent function.
- */
-
- int8x16x2_t samples = { { vcombine_s8(a0, a1), vcombine_s8(a2, a3) } };
- *b0 = vqtbl2q_s8(samples, permute_tbl.val[0]);
- *b1 = vqtbl2q_s8(samples, permute_tbl.val[1]);
+ int8x16_t *b1) {
+ // Transpose 8-bit elements and concatenate result rows as follows:
+ // a0: 00, 01, 02, 03, 04, 05, 06, 07
+ // a1: 10, 11, 12, 13, 14, 15, 16, 17
+ // a2: 20, 21, 22, 23, 24, 25, 26, 27
+ // a3: 30, 31, 32, 33, 34, 35, 36, 37
+ //
+ // b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
+ // b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37
+
+ int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0));
+ int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0));
+ int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0));
+ int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0));
+
+ int8x16_t a01 = vzipq_s8(a0q, a1q).val[0];
+ int8x16_t a23 = vzipq_s8(a2q, a3q).val[0];
+
+ int16x8x2_t a0123 =
+ vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23));
+
+ *b0 = vreinterpretq_s8_s16(a0123.val[0]);
+ *b1 = vreinterpretq_s8_s16(a0123.val[1]);
}
-static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo,
- const int8x16_t samples_hi,
- const int32x4_t correction,
- const int8x8_t filter) {
- /* Sample range-clamping and permutation are performed by the caller. */
- int32x4_t sum;
+static INLINE int16x4_t convolve8_4_v(const int8x16_t samples_lo,
+ const int8x16_t samples_hi,
+ const int8x8_t filters) {
+ // The sample range transform and permutation are performed by the caller.
- /* Accumulate dot product into 'correction' to account for range clamp. */
- sum = vdotq_lane_s32(correction, samples_lo, filter, 0);
- sum = vdotq_lane_s32(sum, samples_hi, filter, 1);
+ // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
+ int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
+ int32x4_t sum = vdotq_lane_s32(acc, samples_lo, filters, 0);
+ sum = vdotq_lane_s32(sum, samples_hi, filters, 1);
- /* Further narrowing and packing is performed by the caller. */
+ // Further narrowing and packing is performed by the caller.
return vqmovn_s32(sum);
}
-static INLINE uint8x8_t convolve8_8_sdot_partial(const int8x16_t samples0_lo,
- const int8x16_t samples0_hi,
- const int8x16_t samples1_lo,
- const int8x16_t samples1_hi,
- const int32x4_t correction,
- const int8x8_t filter) {
- /* Sample range-clamping and permutation are performed by the caller. */
- int32x4_t sum0, sum1;
- int16x8_t sum;
-
- /* Accumulate dot product into 'correction' to account for range clamp. */
- /* First 4 output values. */
- sum0 = vdotq_lane_s32(correction, samples0_lo, filter, 0);
- sum0 = vdotq_lane_s32(sum0, samples0_hi, filter, 1);
- /* Second 4 output values. */
- sum1 = vdotq_lane_s32(correction, samples1_lo, filter, 0);
- sum1 = vdotq_lane_s32(sum1, samples1_hi, filter, 1);
-
- /* Narrow and re-pack. */
- sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
+static INLINE uint8x8_t convolve8_8_v(const int8x16_t samples0_lo,
+ const int8x16_t samples0_hi,
+ const int8x16_t samples1_lo,
+ const int8x16_t samples1_hi,
+ const int8x8_t filters) {
+ // The sample range transform and permutation are performed by the caller.
+
+ // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
+ int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
+ // First 4 output values.
+ int32x4_t sum0 = vdotq_lane_s32(acc, samples0_lo, filters, 0);
+ sum0 = vdotq_lane_s32(sum0, samples0_hi, filters, 1);
+ // Second 4 output values.
+ int32x4_t sum1 = vdotq_lane_s32(acc, samples1_lo, filters, 0);
+ sum1 = vdotq_lane_s32(sum1, samples1_hi, filters, 1);
+
+ // Narrow and re-pack.
+ int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
return vqrshrun_n_s16(sum, FILTER_BITS);
}
@@ -263,10 +250,7 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
const int16_t *filter_y, int y_step_q4,
int w, int h) {
const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
- const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_y), 128);
- const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
- const uint8x8_t range_limit = vdup_n_u8(128);
- const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl);
+ const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
int8x16x2_t samples_LUT;
assert((intptr_t)dst % 4 == 0);
@@ -279,62 +263,58 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
if (w == 4) {
- const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl);
-
uint8x8_t t0, t1, t2, t3, t4, t5, t6;
load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
src += 7 * src_stride;
- /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
- int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit));
- int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit));
- int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit));
- int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit));
- int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit));
- int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit));
- int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit));
-
- /* This operation combines a conventional transpose and the sample permute
- * (see horizontal case) required before computing the dot product.
- */
+ // Clamp sample range to [-128, 127] for 8-bit signed dot product.
+ int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
+ int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
+ int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
+ int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
+ int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
+ int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
+ int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
+
+ // This operation combines a conventional transpose and the sample permute
+ // (see horizontal case) required before computing the dot product.
int8x16_t s0123, s1234, s2345, s3456;
- transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl);
- transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl);
- transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl);
- transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl);
+ transpose_concat_4x4(s0, s1, s2, s3, &s0123);
+ transpose_concat_4x4(s1, s2, s3, s4, &s1234);
+ transpose_concat_4x4(s2, s3, s4, s5, &s2345);
+ transpose_concat_4x4(s3, s4, s5, s6, &s3456);
do {
uint8x8_t t7, t8, t9, t10;
load_u8_8x4(src, src_stride, &t7, &t8, &t9, &t10);
- int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit));
- int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit));
- int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit));
- int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit));
+ int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
+ int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
+ int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
+ int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128)));
int8x16_t s4567, s5678, s6789, s78910;
- transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl);
+ transpose_concat_4x4(s7, s8, s9, s10, &s78910);
- /* Merge new data into block from previous iteration. */
+ // Merge new data into block from previous iteration.
samples_LUT.val[0] = s3456;
samples_LUT.val[1] = s78910;
s4567 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
- int16x4_t d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filter);
- int16x4_t d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filter);
- int16x4_t d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filter);
- int16x4_t d3 =
- convolve8_4_sdot_partial(s3456, s78910, correction, filter);
+ int16x4_t d0 = convolve8_4_v(s0123, s4567, filter);
+ int16x4_t d1 = convolve8_4_v(s1234, s5678, filter);
+ int16x4_t d2 = convolve8_4_v(s2345, s6789, filter);
+ int16x4_t d3 = convolve8_4_v(s3456, s78910, filter);
uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
- /* Prepare block for next iteration - re-using as much as possible. */
- /* Shuffle everything up four rows. */
+ // Prepare block for next iteration - re-using as much as possible.
+ // Shuffle everything up four rows.
s0123 = s4567;
s1234 = s5678;
s2345 = s6789;
@@ -345,8 +325,6 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
h -= 4;
} while (h != 0);
} else {
- const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl);
-
do {
int height = h;
const uint8_t *s = src;
@@ -356,44 +334,38 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
load_u8_8x7(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
s += 7 * src_stride;
- /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
- int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit));
- int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit));
- int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit));
- int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit));
- int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit));
- int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit));
- int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit));
-
- /* This operation combines a conventional transpose and the sample permute
- * (see horizontal case) required before computing the dot product.
- */
+ // Clamp sample range to [-128, 127] for 8-bit signed dot product.
+ int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
+ int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
+ int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
+ int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
+ int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
+ int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
+ int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
+
+ // This operation combines a conventional transpose and the sample permute
+ // (see horizontal case) required before computing the dot product.
int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
s3456_lo, s3456_hi;
- transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi,
- tran_concat_tbl);
- transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi,
- tran_concat_tbl);
- transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi,
- tran_concat_tbl);
- transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi,
- tran_concat_tbl);
+ transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);
+ transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi);
+ transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi);
+ transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi);
do {
uint8x8_t t7, t8, t9, t10;
load_u8_8x4(s, src_stride, &t7, &t8, &t9, &t10);
- int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit));
- int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit));
- int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit));
- int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit));
+ int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
+ int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
+ int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
+ int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128)));
int8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi,
s78910_lo, s78910_hi;
- transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi,
- tran_concat_tbl);
+ transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi);
- /* Merge new data into block from previous iteration. */
+ // Merge new data into block from previous iteration.
samples_LUT.val[0] = s3456_lo;
samples_LUT.val[1] = s78910_lo;
s4567_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
@@ -406,19 +378,19 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
- uint8x8_t d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi,
- s4567_hi, correction, filter);
- uint8x8_t d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi,
- s5678_hi, correction, filter);
- uint8x8_t d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi,
- s6789_hi, correction, filter);
- uint8x8_t d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi,
- s78910_hi, correction, filter);
+ uint8x8_t d0 =
+ convolve8_8_v(s0123_lo, s4567_lo, s0123_hi, s4567_hi, filter);
+ uint8x8_t d1 =
+ convolve8_8_v(s1234_lo, s5678_lo, s1234_hi, s5678_hi, filter);
+ uint8x8_t d2 =
+ convolve8_8_v(s2345_lo, s6789_lo, s2345_hi, s6789_hi, filter);
+ uint8x8_t d3 =
+ convolve8_8_v(s3456_lo, s78910_lo, s3456_hi, s78910_hi, filter);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
- /* Prepare block for next iteration - re-using as much as possible. */
- /* Shuffle everything up four rows. */
+ // Prepare block for next iteration - re-using as much as possible.
+ // Shuffle everything up four rows.
s0123_lo = s4567_lo;
s0123_hi = s4567_hi;
s1234_lo = s5678_lo;
diff --git a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index df6e4d2ab5..68e031461d 100644
--- a/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/third_party/aom/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -23,69 +23,60 @@
#include "aom_dsp/arm/transpose_neon.h"
#include "aom_ports/mem.h"
-DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
};
-DECLARE_ALIGNED(16, static const uint8_t, dot_prod_tran_concat_tbl[32]) = {
- 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27,
- 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31
-};
-
-DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = {
- /* Shift left and insert new last column in transposed 4x4 block. */
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
+ // Shift left and insert new last column in transposed 4x4 block.
1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28,
- /* Shift left and insert two new columns in transposed 4x4 block. */
+ // Shift left and insert two new columns in transposed 4x4 block.
2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29,
- /* Shift left and insert three new columns in transposed 4x4 block. */
+ // Shift left and insert three new columns in transposed 4x4 block.
3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
};
-static INLINE int16x4_t convolve8_4_usdot(const uint8x16_t samples,
- const int8x8_t filter,
- const uint8x16x2_t permute_tbl) {
- uint8x16_t permuted_samples[2];
- int32x4_t sum;
-
- /* Permute samples ready for dot product. */
- /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
- permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
- /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
- permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
+static INLINE int16x4_t convolve8_4_h(const uint8x16_t samples,
+ const int8x8_t filters,
+ const uint8x16x2_t permute_tbl) {
+ // Permute samples ready for dot product.
+ // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
+ // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
+ uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]) };
- sum = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filter, 0);
- sum = vusdotq_lane_s32(sum, permuted_samples[1], filter, 1);
+ int32x4_t sum =
+ vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
+ sum = vusdotq_lane_s32(sum, permuted_samples[1], filters, 1);
- /* Further narrowing and packing is performed by the caller. */
+ // Further narrowing and packing is performed by the caller.
return vqmovn_s32(sum);
}
-static INLINE uint8x8_t convolve8_8_usdot(const uint8x16_t samples,
- const int8x8_t filter,
- const uint8x16x3_t permute_tbl) {
- uint8x16_t permuted_samples[3];
- int32x4_t sum0, sum1;
- int16x8_t sum;
-
- /* Permute samples ready for dot product. */
- /* { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 } */
- permuted_samples[0] = vqtbl1q_u8(samples, permute_tbl.val[0]);
- /* { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 } */
- permuted_samples[1] = vqtbl1q_u8(samples, permute_tbl.val[1]);
- /* { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 } */
- permuted_samples[2] = vqtbl1q_u8(samples, permute_tbl.val[2]);
-
- /* First 4 output values. */
- sum0 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filter, 0);
- sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filter, 1);
- /* Second 4 output values. */
- sum1 = vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filter, 0);
- sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filter, 1);
-
- /* Narrow and re-pack. */
- sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
+static INLINE uint8x8_t convolve8_8_h(const uint8x16_t samples,
+ const int8x8_t filters,
+ const uint8x16x3_t permute_tbl) {
+ // Permute samples ready for dot product.
+ // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
+ // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
+ // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
+ uint8x16_t permuted_samples[3] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]),
+ vqtbl1q_u8(samples, permute_tbl.val[2]) };
+
+ // First 4 output values.
+ int32x4_t sum0 =
+ vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
+ sum0 = vusdotq_lane_s32(sum0, permuted_samples[1], filters, 1);
+ // Second 4 output values.
+ int32x4_t sum1 =
+ vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
+ sum1 = vusdotq_lane_s32(sum1, permuted_samples[2], filters, 1);
+
+ // Narrow and re-pack.
+ int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
return vqrshrun_n_s16(sum, FILTER_BITS);
}
@@ -95,7 +86,6 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
const int16_t *filter_y, int y_step_q4,
int w, int h) {
const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x));
- uint8x16_t s0, s1, s2, s3;
assert((intptr_t)dst % 4 == 0);
assert(dst_stride % 4 == 0);
@@ -107,19 +97,17 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
src -= ((SUBPEL_TAPS / 2) - 1);
if (w == 4) {
- const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl);
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
do {
- int16x4_t t0, t1, t2, t3;
- uint8x8_t d01, d23;
-
+ uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
- t0 = convolve8_4_usdot(s0, filter, perm_tbl);
- t1 = convolve8_4_usdot(s1, filter, perm_tbl);
- t2 = convolve8_4_usdot(s2, filter, perm_tbl);
- t3 = convolve8_4_usdot(s3, filter, perm_tbl);
- d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS);
- d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS);
+ int16x4_t d0 = convolve8_4_h(s0, filter, perm_tbl);
+ int16x4_t d1 = convolve8_4_h(s1, filter, perm_tbl);
+ int16x4_t d2 = convolve8_4_h(s2, filter, perm_tbl);
+ int16x4_t d3 = convolve8_4_h(s3, filter, perm_tbl);
+ uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+ uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
@@ -129,23 +117,20 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
h -= 4;
} while (h > 0);
} else {
- const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
- const uint8_t *s;
- uint8_t *d;
- int width;
- uint8x8_t d0, d1, d2, d3;
+ const uint8x16x3_t perm_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
do {
- width = w;
- s = src;
- d = dst;
+ int width = w;
+ const uint8_t *s = src;
+ uint8_t *d = dst;
do {
+ uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- d0 = convolve8_8_usdot(s0, filter, perm_tbl);
- d1 = convolve8_8_usdot(s1, filter, perm_tbl);
- d2 = convolve8_8_usdot(s2, filter, perm_tbl);
- d3 = convolve8_8_usdot(s3, filter, perm_tbl);
+ uint8x8_t d0 = convolve8_8_h(s0, filter, perm_tbl);
+ uint8x8_t d1 = convolve8_8_h(s1, filter, perm_tbl);
+ uint8x8_t d2 = convolve8_8_h(s2, filter, perm_tbl);
+ uint8x8_t d3 = convolve8_8_h(s3, filter, perm_tbl);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -162,79 +147,83 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
static INLINE void transpose_concat_4x4(uint8x8_t a0, uint8x8_t a1,
uint8x8_t a2, uint8x8_t a3,
- uint8x16_t *b,
- const uint8x16_t permute_tbl) {
- /* Transpose 8-bit elements and concatenate result rows as follows:
- * a0: 00, 01, 02, 03, XX, XX, XX, XX
- * a1: 10, 11, 12, 13, XX, XX, XX, XX
- * a2: 20, 21, 22, 23, XX, XX, XX, XX
- * a3: 30, 31, 32, 33, XX, XX, XX, XX
- *
- * b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
- *
- * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it
- * as an argument is preferable to loading it directly from memory as this
- * inline helper is called many times from the same parent function.
- */
-
- uint8x16x2_t samples = { { vcombine_u8(a0, a1), vcombine_u8(a2, a3) } };
- *b = vqtbl2q_u8(samples, permute_tbl);
+ uint8x16_t *b) {
+ // Transpose 8-bit elements and concatenate result rows as follows:
+ // a0: 00, 01, 02, 03, XX, XX, XX, XX
+ // a1: 10, 11, 12, 13, XX, XX, XX, XX
+ // a2: 20, 21, 22, 23, XX, XX, XX, XX
+ // a3: 30, 31, 32, 33, XX, XX, XX, XX
+ //
+ // b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
+
+ uint8x16_t a0q = vcombine_u8(a0, vdup_n_u8(0));
+ uint8x16_t a1q = vcombine_u8(a1, vdup_n_u8(0));
+ uint8x16_t a2q = vcombine_u8(a2, vdup_n_u8(0));
+ uint8x16_t a3q = vcombine_u8(a3, vdup_n_u8(0));
+
+ uint8x16_t a01 = vzipq_u8(a0q, a1q).val[0];
+ uint8x16_t a23 = vzipq_u8(a2q, a3q).val[0];
+
+ uint16x8_t a0123 =
+ vzipq_u16(vreinterpretq_u16_u8(a01), vreinterpretq_u16_u8(a23)).val[0];
+
+ *b = vreinterpretq_u8_u16(a0123);
}
static INLINE void transpose_concat_8x4(uint8x8_t a0, uint8x8_t a1,
uint8x8_t a2, uint8x8_t a3,
- uint8x16_t *b0, uint8x16_t *b1,
- const uint8x16x2_t permute_tbl) {
- /* Transpose 8-bit elements and concatenate result rows as follows:
- * a0: 00, 01, 02, 03, 04, 05, 06, 07
- * a1: 10, 11, 12, 13, 14, 15, 16, 17
- * a2: 20, 21, 22, 23, 24, 25, 26, 27
- * a3: 30, 31, 32, 33, 34, 35, 36, 37
- *
- * b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
- * b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37
- *
- * The 'permute_tbl' is always 'dot_prod_tran_concat_tbl' above. Passing it
- * as an argument is preferable to loading it directly from memory as this
- * inline helper is called many times from the same parent function.
- */
-
- uint8x16x2_t samples = { { vcombine_u8(a0, a1), vcombine_u8(a2, a3) } };
- *b0 = vqtbl2q_u8(samples, permute_tbl.val[0]);
- *b1 = vqtbl2q_u8(samples, permute_tbl.val[1]);
+ uint8x16_t *b0, uint8x16_t *b1) {
+ // Transpose 8-bit elements and concatenate result rows as follows:
+ // a0: 00, 01, 02, 03, 04, 05, 06, 07
+ // a1: 10, 11, 12, 13, 14, 15, 16, 17
+ // a2: 20, 21, 22, 23, 24, 25, 26, 27
+ // a3: 30, 31, 32, 33, 34, 35, 36, 37
+ //
+ // b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
+ // b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37
+
+ uint8x16_t a0q = vcombine_u8(a0, vdup_n_u8(0));
+ uint8x16_t a1q = vcombine_u8(a1, vdup_n_u8(0));
+ uint8x16_t a2q = vcombine_u8(a2, vdup_n_u8(0));
+ uint8x16_t a3q = vcombine_u8(a3, vdup_n_u8(0));
+
+ uint8x16_t a01 = vzipq_u8(a0q, a1q).val[0];
+ uint8x16_t a23 = vzipq_u8(a2q, a3q).val[0];
+
+ uint16x8x2_t a0123 =
+ vzipq_u16(vreinterpretq_u16_u8(a01), vreinterpretq_u16_u8(a23));
+
+ *b0 = vreinterpretq_u8_u16(a0123.val[0]);
+ *b1 = vreinterpretq_u8_u16(a0123.val[1]);
}
-static INLINE int16x4_t convolve8_4_usdot_partial(const uint8x16_t samples_lo,
- const uint8x16_t samples_hi,
- const int8x8_t filter) {
- /* Sample permutation is performed by the caller. */
- int32x4_t sum;
-
- sum = vusdotq_lane_s32(vdupq_n_s32(0), samples_lo, filter, 0);
- sum = vusdotq_lane_s32(sum, samples_hi, filter, 1);
+static INLINE int16x4_t convolve8_4_v(const uint8x16_t samples_lo,
+ const uint8x16_t samples_hi,
+ const int8x8_t filters) {
+ // Sample permutation is performed by the caller.
+ int32x4_t sum = vusdotq_lane_s32(vdupq_n_s32(0), samples_lo, filters, 0);
+ sum = vusdotq_lane_s32(sum, samples_hi, filters, 1);
- /* Further narrowing and packing is performed by the caller. */
+ // Further narrowing and packing is performed by the caller.
return vqmovn_s32(sum);
}
-static INLINE uint8x8_t convolve8_8_usdot_partial(const uint8x16_t samples0_lo,
- const uint8x16_t samples0_hi,
- const uint8x16_t samples1_lo,
- const uint8x16_t samples1_hi,
- const int8x8_t filter) {
- /* Sample permutation is performed by the caller. */
- int32x4_t sum0, sum1;
- int16x8_t sum;
-
- /* First 4 output values. */
- sum0 = vusdotq_lane_s32(vdupq_n_s32(0), samples0_lo, filter, 0);
- sum0 = vusdotq_lane_s32(sum0, samples0_hi, filter, 1);
- /* Second 4 output values. */
- sum1 = vusdotq_lane_s32(vdupq_n_s32(0), samples1_lo, filter, 0);
- sum1 = vusdotq_lane_s32(sum1, samples1_hi, filter, 1);
-
- /* Narrow and re-pack. */
- sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
+static INLINE uint8x8_t convolve8_8_v(const uint8x16_t samples0_lo,
+ const uint8x16_t samples0_hi,
+ const uint8x16_t samples1_lo,
+ const uint8x16_t samples1_hi,
+ const int8x8_t filters) {
+ // Sample permutation is performed by the caller.
+
+ // First 4 output values.
+ int32x4_t sum0 = vusdotq_lane_s32(vdupq_n_s32(0), samples0_lo, filters, 0);
+ sum0 = vusdotq_lane_s32(sum0, samples0_hi, filters, 1);
+ // Second 4 output values.
+ int32x4_t sum1 = vusdotq_lane_s32(vdupq_n_s32(0), samples1_lo, filters, 0);
+ sum1 = vusdotq_lane_s32(sum1, samples1_hi, filters, 1);
+
+ // Narrow and re-pack.
+ int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
return vqrshrun_n_s16(sum, FILTER_BITS);
}
@@ -244,7 +233,7 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
const int16_t *filter_y, int y_step_q4, int w,
int h) {
const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
- const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl);
+ const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
uint8x16x2_t samples_LUT;
assert((intptr_t)dst % 4 == 0);
@@ -257,47 +246,44 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
if (w == 4) {
- const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl);
-
uint8x8_t s0, s1, s2, s3, s4, s5, s6;
load_u8_8x7(src, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
src += 7 * src_stride;
- /* This operation combines a conventional transpose and the sample permute
- * (see horizontal case) required before computing the dot product.
- */
+ // This operation combines a conventional transpose and the sample permute
+ // (see horizontal case) required before computing the dot product.
uint8x16_t s0123, s1234, s2345, s3456;
- transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl);
- transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl);
- transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl);
- transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl);
+ transpose_concat_4x4(s0, s1, s2, s3, &s0123);
+ transpose_concat_4x4(s1, s2, s3, s4, &s1234);
+ transpose_concat_4x4(s2, s3, s4, s5, &s2345);
+ transpose_concat_4x4(s3, s4, s5, s6, &s3456);
do {
uint8x8_t s7, s8, s9, s10;
load_u8_8x4(src, src_stride, &s7, &s8, &s9, &s10);
uint8x16_t s4567, s5678, s6789, s78910;
- transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl);
+ transpose_concat_4x4(s7, s8, s9, s10, &s78910);
- /* Merge new data into block from previous iteration. */
+ // Merge new data into block from previous iteration.
samples_LUT.val[0] = s3456;
samples_LUT.val[1] = s78910;
s4567 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]);
s5678 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]);
s6789 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]);
- int16x4_t d0 = convolve8_4_usdot_partial(s0123, s4567, filter);
- int16x4_t d1 = convolve8_4_usdot_partial(s1234, s5678, filter);
- int16x4_t d2 = convolve8_4_usdot_partial(s2345, s6789, filter);
- int16x4_t d3 = convolve8_4_usdot_partial(s3456, s78910, filter);
+ int16x4_t d0 = convolve8_4_v(s0123, s4567, filter);
+ int16x4_t d1 = convolve8_4_v(s1234, s5678, filter);
+ int16x4_t d2 = convolve8_4_v(s2345, s6789, filter);
+ int16x4_t d3 = convolve8_4_v(s3456, s78910, filter);
uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
- /* Prepare block for next iteration - re-using as much as possible. */
- /* Shuffle everything up four rows. */
+ // Prepare block for next iteration - re-using as much as possible.
+ // Shuffle everything up four rows.
s0123 = s4567;
s1234 = s5678;
s2345 = s6789;
@@ -308,8 +294,6 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
h -= 4;
} while (h != 0);
} else {
- const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl);
-
do {
int height = h;
const uint8_t *s = src;
@@ -319,19 +303,14 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
load_u8_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
s += 7 * src_stride;
- /* This operation combines a conventional transpose and the sample permute
- * (see horizontal case) required before computing the dot product.
- */
+ // This operation combines a conventional transpose and the sample permute
+ // (see horizontal case) required before computing the dot product.
uint8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
s3456_lo, s3456_hi;
- transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi,
- tran_concat_tbl);
- transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi,
- tran_concat_tbl);
- transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi,
- tran_concat_tbl);
- transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi,
- tran_concat_tbl);
+ transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);
+ transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi);
+ transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi);
+ transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi);
do {
uint8x8_t s7, s8, s9, s10;
@@ -339,10 +318,9 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
uint8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi,
s78910_lo, s78910_hi;
- transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi,
- tran_concat_tbl);
+ transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi);
- /* Merge new data into block from previous iteration. */
+ // Merge new data into block from previous iteration.
samples_LUT.val[0] = s3456_lo;
samples_LUT.val[1] = s78910_lo;
s4567_lo = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[0]);
@@ -355,19 +333,19 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
s5678_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]);
s6789_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]);
- uint8x8_t d0 = convolve8_8_usdot_partial(s0123_lo, s4567_lo, s0123_hi,
- s4567_hi, filter);
- uint8x8_t d1 = convolve8_8_usdot_partial(s1234_lo, s5678_lo, s1234_hi,
- s5678_hi, filter);
- uint8x8_t d2 = convolve8_8_usdot_partial(s2345_lo, s6789_lo, s2345_hi,
- s6789_hi, filter);
- uint8x8_t d3 = convolve8_8_usdot_partial(s3456_lo, s78910_lo, s3456_hi,
- s78910_hi, filter);
+ uint8x8_t d0 =
+ convolve8_8_v(s0123_lo, s4567_lo, s0123_hi, s4567_hi, filter);
+ uint8x8_t d1 =
+ convolve8_8_v(s1234_lo, s5678_lo, s1234_hi, s5678_hi, filter);
+ uint8x8_t d2 =
+ convolve8_8_v(s2345_lo, s6789_lo, s2345_hi, s6789_hi, filter);
+ uint8x8_t d3 =
+ convolve8_8_v(s3456_lo, s78910_lo, s3456_hi, s78910_hi, filter);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
- /* Prepare block for next iteration - re-using as much as possible. */
- /* Shuffle everything up four rows. */
+ // Prepare block for next iteration - re-using as much as possible.
+ // Shuffle everything up four rows.
s0123_lo = s4567_lo;
s0123_hi = s4567_hi;
s1234_lo = s5678_lo;
diff --git a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c
index 62729133e3..5758d2887f 100644
--- a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c
+++ b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.c
@@ -16,36 +16,10 @@
#include "aom_dsp/arm/mem_neon.h"
#include "aom_dsp/arm/sum_neon.h"
+#include "aom_dsp/flow_estimation/arm/disflow_neon.h"
#include "config/aom_config.h"
#include "config/aom_dsp_rtcd.h"
-static INLINE void get_cubic_kernel_dbl(double x, double kernel[4]) {
- // Check that the fractional position is in range.
- //
- // Note: x is calculated from, e.g., `u_frac = u - floor(u)`.
- // Mathematically, this implies that 0 <= x < 1. However, in practice it is
- // possible to have x == 1 due to floating point rounding. This is fine,
- // and we still interpolate correctly if we allow x = 1.
- assert(0 <= x && x <= 1);
-
- double x2 = x * x;
- double x3 = x2 * x;
- kernel[0] = -0.5 * x + x2 - 0.5 * x3;
- kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3;
- kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3;
- kernel[3] = -0.5 * x2 + 0.5 * x3;
-}
-
-static INLINE void get_cubic_kernel_int(double x, int kernel[4]) {
- double kernel_dbl[4];
- get_cubic_kernel_dbl(x, kernel_dbl);
-
- kernel[0] = (int)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS));
- kernel[1] = (int)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS));
- kernel[2] = (int)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS));
- kernel[3] = (int)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS));
-}
-
// Compare two regions of width x height pixels, one rooted at position
// (x, y) in src and the other at (x + u, y + v) in ref.
// This function returns the sum of squared pixel differences between
@@ -157,82 +131,6 @@ static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref,
}
}
-static INLINE void sobel_filter_x(const uint8_t *src, int src_stride,
- int16_t *dst, int dst_stride) {
- int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
-
- // Horizontal filter, using kernel {1, 0, -1}.
- const uint8_t *src_start = src - 1 * src_stride - 1;
-
- for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) {
- uint8x16_t s = vld1q_u8(src_start + i * src_stride);
- uint8x8_t s0 = vget_low_u8(s);
- uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2));
-
- // Given that the kernel is {1, 0, -1} the convolution is a simple
- // subtraction.
- int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s0, s2));
-
- vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, diff);
- }
-
- // Vertical filter, using kernel {1, 2, 1}.
- // This kernel can be split into two 2-taps kernels of value {1, 1}.
- // That way we need only 3 add operations to perform the convolution, one of
- // which can be reused for the next line.
- int16x8_t s0 = vld1q_s16(tmp);
- int16x8_t s1 = vld1q_s16(tmp + DISFLOW_PATCH_SIZE);
- int16x8_t sum01 = vaddq_s16(s0, s1);
- for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
- int16x8_t s2 = vld1q_s16(tmp + (i + 2) * DISFLOW_PATCH_SIZE);
-
- int16x8_t sum12 = vaddq_s16(s1, s2);
- int16x8_t sum = vaddq_s16(sum01, sum12);
-
- vst1q_s16(dst + i * dst_stride, sum);
-
- sum01 = sum12;
- s1 = s2;
- }
-}
-
-static INLINE void sobel_filter_y(const uint8_t *src, int src_stride,
- int16_t *dst, int dst_stride) {
- int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
-
- // Horizontal filter, using kernel {1, 2, 1}.
- // This kernel can be split into two 2-taps kernels of value {1, 1}.
- // That way we need only 3 add operations to perform the convolution.
- const uint8_t *src_start = src - 1 * src_stride - 1;
-
- for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) {
- uint8x16_t s = vld1q_u8(src_start + i * src_stride);
- uint8x8_t s0 = vget_low_u8(s);
- uint8x8_t s1 = vget_low_u8(vextq_u8(s, s, 1));
- uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2));
-
- uint16x8_t sum01 = vaddl_u8(s0, s1);
- uint16x8_t sum12 = vaddl_u8(s1, s2);
- uint16x8_t sum = vaddq_u16(sum01, sum12);
-
- vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, vreinterpretq_s16_u16(sum));
- }
-
- // Vertical filter, using kernel {1, 0, -1}.
- // Load the whole block at once to avoid redundant loads during convolution.
- int16x8_t t[10];
- load_s16_8x10(tmp, DISFLOW_PATCH_SIZE, &t[0], &t[1], &t[2], &t[3], &t[4],
- &t[5], &t[6], &t[7], &t[8], &t[9]);
-
- for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
- // Given that the kernel is {1, 0, -1} the convolution is a simple
- // subtraction.
- int16x8_t diff = vsubq_s16(t[i], t[i + 2]);
-
- vst1q_s16(dst + i * dst_stride, diff);
- }
-}
-
// Computes the components of the system of equations used to solve for
// a flow vector.
//
diff --git a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h
new file mode 100644
index 0000000000..d991a13460
--- /dev/null
+++ b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_neon.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2024, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_AOM_DSP_FLOW_ESTIMATION_ARM_DISFLOW_NEON_H_
+#define AOM_AOM_DSP_FLOW_ESTIMATION_ARM_DISFLOW_NEON_H_
+
+#include "aom_dsp/flow_estimation/disflow.h"
+
+#include <arm_neon.h>
+#include <math.h>
+
+#include "aom_dsp/arm/mem_neon.h"
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+static INLINE void get_cubic_kernel_dbl(double x, double kernel[4]) {
+ // Check that the fractional position is in range.
+ //
+ // Note: x is calculated from, e.g., `u_frac = u - floor(u)`.
+ // Mathematically, this implies that 0 <= x < 1. However, in practice it is
+ // possible to have x == 1 due to floating point rounding. This is fine,
+ // and we still interpolate correctly if we allow x = 1.
+ assert(0 <= x && x <= 1);
+
+ double x2 = x * x;
+ double x3 = x2 * x;
+ kernel[0] = -0.5 * x + x2 - 0.5 * x3;
+ kernel[1] = 1.0 - 2.5 * x2 + 1.5 * x3;
+ kernel[2] = 0.5 * x + 2.0 * x2 - 1.5 * x3;
+ kernel[3] = -0.5 * x2 + 0.5 * x3;
+}
+
+static INLINE void get_cubic_kernel_int(double x, int kernel[4]) {
+ double kernel_dbl[4];
+ get_cubic_kernel_dbl(x, kernel_dbl);
+
+ kernel[0] = (int)rint(kernel_dbl[0] * (1 << DISFLOW_INTERP_BITS));
+ kernel[1] = (int)rint(kernel_dbl[1] * (1 << DISFLOW_INTERP_BITS));
+ kernel[2] = (int)rint(kernel_dbl[2] * (1 << DISFLOW_INTERP_BITS));
+ kernel[3] = (int)rint(kernel_dbl[3] * (1 << DISFLOW_INTERP_BITS));
+}
+
+static INLINE void sobel_filter_x(const uint8_t *src, int src_stride,
+ int16_t *dst, int dst_stride) {
+ int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
+
+ // Horizontal filter, using kernel {1, 0, -1}.
+ const uint8_t *src_start = src - 1 * src_stride - 1;
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) {
+ uint8x16_t s = vld1q_u8(src_start + i * src_stride);
+ uint8x8_t s0 = vget_low_u8(s);
+ uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2));
+
+ // Given that the kernel is {1, 0, -1} the convolution is a simple
+ // subtraction.
+ int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(s0, s2));
+
+ vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, diff);
+ }
+
+ // Vertical filter, using kernel {1, 2, 1}.
+ // This kernel can be split into two 2-taps kernels of value {1, 1}.
+ // That way we need only 3 add operations to perform the convolution, one of
+ // which can be reused for the next line.
+ int16x8_t s0 = vld1q_s16(tmp);
+ int16x8_t s1 = vld1q_s16(tmp + DISFLOW_PATCH_SIZE);
+ int16x8_t sum01 = vaddq_s16(s0, s1);
+ for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+ int16x8_t s2 = vld1q_s16(tmp + (i + 2) * DISFLOW_PATCH_SIZE);
+
+ int16x8_t sum12 = vaddq_s16(s1, s2);
+ int16x8_t sum = vaddq_s16(sum01, sum12);
+
+ vst1q_s16(dst + i * dst_stride, sum);
+
+ sum01 = sum12;
+ s1 = s2;
+ }
+}
+
+static INLINE void sobel_filter_y(const uint8_t *src, int src_stride,
+ int16_t *dst, int dst_stride) {
+ int16_t tmp[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 2)];
+
+ // Horizontal filter, using kernel {1, 2, 1}.
+ // This kernel can be split into two 2-taps kernels of value {1, 1}.
+ // That way we need only 3 add operations to perform the convolution.
+ const uint8_t *src_start = src - 1 * src_stride - 1;
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE + 2; i++) {
+ uint8x16_t s = vld1q_u8(src_start + i * src_stride);
+ uint8x8_t s0 = vget_low_u8(s);
+ uint8x8_t s1 = vget_low_u8(vextq_u8(s, s, 1));
+ uint8x8_t s2 = vget_low_u8(vextq_u8(s, s, 2));
+
+ uint16x8_t sum01 = vaddl_u8(s0, s1);
+ uint16x8_t sum12 = vaddl_u8(s1, s2);
+ uint16x8_t sum = vaddq_u16(sum01, sum12);
+
+ vst1q_s16(tmp + i * DISFLOW_PATCH_SIZE, vreinterpretq_s16_u16(sum));
+ }
+
+ // Vertical filter, using kernel {1, 0, -1}.
+ // Load the whole block at once to avoid redundant loads during convolution.
+ int16x8_t t[10];
+ load_s16_8x10(tmp, DISFLOW_PATCH_SIZE, &t[0], &t[1], &t[2], &t[3], &t[4],
+ &t[5], &t[6], &t[7], &t[8], &t[9]);
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+ // Given that the kernel is {1, 0, -1} the convolution is a simple
+ // subtraction.
+ int16x8_t diff = vsubq_s16(t[i], t[i + 2]);
+
+ vst1q_s16(dst + i * dst_stride, diff);
+ }
+}
+
+#endif // AOM_AOM_DSP_FLOW_ESTIMATION_ARM_DISFLOW_NEON_H_
diff --git a/third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c
new file mode 100644
index 0000000000..7b01e90d12
--- /dev/null
+++ b/third_party/aom/aom_dsp/flow_estimation/arm/disflow_sve.c
@@ -0,0 +1,268 @@
+/*
+ * Copyright (c) 2024, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#include "aom_dsp/flow_estimation/disflow.h"
+
+#include <arm_neon.h>
+#include <arm_sve.h>
+#include <math.h>
+
+#include "aom_dsp/arm/aom_neon_sve_bridge.h"
+#include "aom_dsp/arm/mem_neon.h"
+#include "aom_dsp/arm/sum_neon.h"
+#include "aom_dsp/flow_estimation/arm/disflow_neon.h"
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
+ 0, 2, 4, 6, 1, 3, 5, 7,
+};
+
+// Compare two regions of width x height pixels, one rooted at position
+// (x, y) in src and the other at (x + u, y + v) in ref.
+// This function returns the sum of squared pixel differences between
+// the two regions.
+static INLINE void compute_flow_error(const uint8_t *src, const uint8_t *ref,
+ int width, int height, int stride, int x,
+ int y, double u, double v, int16_t *dt) {
+ // Split offset into integer and fractional parts, and compute cubic
+ // interpolation kernels
+ const int u_int = (int)floor(u);
+ const int v_int = (int)floor(v);
+ const double u_frac = u - floor(u);
+ const double v_frac = v - floor(v);
+
+ int h_kernel[4];
+ int v_kernel[4];
+ get_cubic_kernel_int(u_frac, h_kernel);
+ get_cubic_kernel_int(v_frac, v_kernel);
+
+ int16_t tmp_[DISFLOW_PATCH_SIZE * (DISFLOW_PATCH_SIZE + 3)];
+
+ // Clamp coordinates so that all pixels we fetch will remain within the
+ // allocated border region, but allow them to go far enough out that
+ // the border pixels' values do not change.
+ // Since we are calculating an 8x8 block, the bottom-right pixel
+ // in the block has coordinates (x0 + 7, y0 + 7). Then, the cubic
+ // interpolation has 4 taps, meaning that the output of pixel
+ // (x_w, y_w) depends on the pixels in the range
+ // ([x_w - 1, x_w + 2], [y_w - 1, y_w + 2]).
+ //
+ // Thus the most extreme coordinates which will be fetched are
+ // (x0 - 1, y0 - 1) and (x0 + 9, y0 + 9).
+ const int x0 = clamp(x + u_int, -9, width);
+ const int y0 = clamp(y + v_int, -9, height);
+
+ // Horizontal convolution.
+ const uint8_t *ref_start = ref + (y0 - 1) * stride + (x0 - 1);
+ const int16x4_t h_kernel_s16 = vmovn_s32(vld1q_s32(h_kernel));
+ const int16x8_t h_filter = vcombine_s16(h_kernel_s16, vdup_n_s16(0));
+ const uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE + 3; ++i) {
+ svuint16_t r0 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 0);
+ svuint16_t r1 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 1);
+ svuint16_t r2 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 2);
+ svuint16_t r3 = svld1ub_u16(svptrue_b16(), ref_start + i * stride + 3);
+
+ int16x8_t s0 = vreinterpretq_s16_u16(svget_neonq_u16(r0));
+ int16x8_t s1 = vreinterpretq_s16_u16(svget_neonq_u16(r1));
+ int16x8_t s2 = vreinterpretq_s16_u16(svget_neonq_u16(r2));
+ int16x8_t s3 = vreinterpretq_s16_u16(svget_neonq_u16(r3));
+
+ int64x2_t sum04 = aom_svdot_lane_s16(vdupq_n_s64(0), s0, h_filter, 0);
+ int64x2_t sum15 = aom_svdot_lane_s16(vdupq_n_s64(0), s1, h_filter, 0);
+ int64x2_t sum26 = aom_svdot_lane_s16(vdupq_n_s64(0), s2, h_filter, 0);
+ int64x2_t sum37 = aom_svdot_lane_s16(vdupq_n_s64(0), s3, h_filter, 0);
+
+ int32x4_t res0 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
+ int32x4_t res1 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
+
+ // 6 is the maximum allowable number of extra bits which will avoid
+ // the intermediate values overflowing an int16_t. The most extreme
+ // intermediate value occurs when:
+ // * The input pixels are [0, 255, 255, 0]
+ // * u_frac = 0.5
+ // In this case, the un-scaled output is 255 * 1.125 = 286.875.
+ // As an integer with 6 fractional bits, that is 18360, which fits
+ // in an int16_t. But with 7 fractional bits it would be 36720,
+ // which is too large.
+ int16x8_t res = vcombine_s16(vrshrn_n_s32(res0, DISFLOW_INTERP_BITS - 6),
+ vrshrn_n_s32(res1, DISFLOW_INTERP_BITS - 6));
+
+ res = aom_tbl_s16(res, idx);
+
+ vst1q_s16(tmp_ + i * DISFLOW_PATCH_SIZE, res);
+ }
+
+ // Vertical convolution.
+ int16x4_t v_filter = vmovn_s32(vld1q_s32(v_kernel));
+ int16_t *tmp_start = tmp_ + DISFLOW_PATCH_SIZE;
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE; ++i) {
+ int16x8_t t0 = vld1q_s16(tmp_start + (i - 1) * DISFLOW_PATCH_SIZE);
+ int16x8_t t1 = vld1q_s16(tmp_start + i * DISFLOW_PATCH_SIZE);
+ int16x8_t t2 = vld1q_s16(tmp_start + (i + 1) * DISFLOW_PATCH_SIZE);
+ int16x8_t t3 = vld1q_s16(tmp_start + (i + 2) * DISFLOW_PATCH_SIZE);
+
+ int32x4_t sum_lo = vmull_lane_s16(vget_low_s16(t0), v_filter, 0);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t1), v_filter, 1);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t2), v_filter, 2);
+ sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(t3), v_filter, 3);
+
+ int32x4_t sum_hi = vmull_lane_s16(vget_high_s16(t0), v_filter, 0);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t1), v_filter, 1);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t2), v_filter, 2);
+ sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(t3), v_filter, 3);
+
+ uint8x8_t s = vld1_u8(src + (i + y) * stride + x);
+ int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, 3));
+
+ // This time, we have to round off the 6 extra bits which were kept
+ // earlier, but we also want to keep DISFLOW_DERIV_SCALE_LOG2 extra bits
+ // of precision to match the scale of the dx and dy arrays.
+ sum_lo = vrshrq_n_s32(sum_lo,
+ DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
+ sum_hi = vrshrq_n_s32(sum_hi,
+ DISFLOW_INTERP_BITS + 6 - DISFLOW_DERIV_SCALE_LOG2);
+ int32x4_t err_lo = vsubw_s16(sum_lo, vget_low_s16(s_s16));
+ int32x4_t err_hi = vsubw_s16(sum_hi, vget_high_s16(s_s16));
+ vst1q_s16(dt + i * DISFLOW_PATCH_SIZE,
+ vcombine_s16(vmovn_s32(err_lo), vmovn_s32(err_hi)));
+ }
+}
+
+// Computes the components of the system of equations used to solve for
+// a flow vector.
+//
+// The flow equations are a least-squares system, derived as follows:
+//
+// For each pixel in the patch, we calculate the current error `dt`,
+// and the x and y gradients `dx` and `dy` of the source patch.
+// This means that, to first order, the squared error for this pixel is
+//
+// (dt + u * dx + v * dy)^2
+//
+// where (u, v) are the incremental changes to the flow vector.
+//
+// We then want to find the values of u and v which minimize the sum
+// of the squared error across all pixels. Conveniently, this fits exactly
+// into the form of a least squares problem, with one equation
+//
+// u * dx + v * dy = -dt
+//
+// for each pixel.
+//
+// Summing across all pixels in a square window of size DISFLOW_PATCH_SIZE,
+// and absorbing the - sign elsewhere, this results in the least squares system
+//
+// M = |sum(dx * dx) sum(dx * dy)|
+// |sum(dx * dy) sum(dy * dy)|
+//
+// b = |sum(dx * dt)|
+// |sum(dy * dt)|
+static INLINE void compute_flow_matrix(const int16_t *dx, int dx_stride,
+ const int16_t *dy, int dy_stride,
+ double *M_inv) {
+ int64x2_t sum[3] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0) };
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+ int16x8_t x = vld1q_s16(dx + i * dx_stride);
+ int16x8_t y = vld1q_s16(dy + i * dy_stride);
+
+ sum[0] = aom_sdotq_s16(sum[0], x, x);
+ sum[1] = aom_sdotq_s16(sum[1], x, y);
+ sum[2] = aom_sdotq_s16(sum[2], y, y);
+ }
+
+ sum[0] = vpaddq_s64(sum[0], sum[1]);
+ sum[2] = vpaddq_s64(sum[1], sum[2]);
+ int32x4_t res = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
+
+ // Apply regularization
+ // We follow the standard regularization method of adding `k * I` before
+ // inverting. This ensures that the matrix will be invertible.
+ //
+ // Setting the regularization strength k to 1 seems to work well here, as
+ // typical values coming from the other equations are very large (1e5 to
+ // 1e6, with an upper limit of around 6e7, at the time of writing).
+ // It also preserves the property that all matrix values are whole numbers,
+ // which is convenient for integerized SIMD implementation.
+
+ double M0 = (double)vgetq_lane_s32(res, 0) + 1;
+ double M1 = (double)vgetq_lane_s32(res, 1);
+ double M2 = (double)vgetq_lane_s32(res, 2);
+ double M3 = (double)vgetq_lane_s32(res, 3) + 1;
+
+ // Invert matrix M.
+ double det = (M0 * M3) - (M1 * M2);
+ assert(det >= 1);
+ const double det_inv = 1 / det;
+
+ M_inv[0] = M3 * det_inv;
+ M_inv[1] = -M1 * det_inv;
+ M_inv[2] = -M2 * det_inv;
+ M_inv[3] = M0 * det_inv;
+}
+
+static INLINE void compute_flow_vector(const int16_t *dx, int dx_stride,
+ const int16_t *dy, int dy_stride,
+ const int16_t *dt, int dt_stride,
+ int *b) {
+ int64x2_t b_s64[2] = { vdupq_n_s64(0), vdupq_n_s64(0) };
+
+ for (int i = 0; i < DISFLOW_PATCH_SIZE; i++) {
+ int16x8_t dx16 = vld1q_s16(dx + i * dx_stride);
+ int16x8_t dy16 = vld1q_s16(dy + i * dy_stride);
+ int16x8_t dt16 = vld1q_s16(dt + i * dt_stride);
+
+ b_s64[0] = aom_sdotq_s16(b_s64[0], dx16, dt16);
+ b_s64[1] = aom_sdotq_s16(b_s64[1], dy16, dt16);
+ }
+
+ b_s64[0] = vpaddq_s64(b_s64[0], b_s64[1]);
+ vst1_s32(b, vmovn_s64(b_s64[0]));
+}
+
+void aom_compute_flow_at_point_sve(const uint8_t *src, const uint8_t *ref,
+ int x, int y, int width, int height,
+ int stride, double *u, double *v) {
+ double M_inv[4];
+ int b[2];
+ int16_t dt[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+ int16_t dx[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+ int16_t dy[DISFLOW_PATCH_SIZE * DISFLOW_PATCH_SIZE];
+
+ // Compute gradients within this patch
+ const uint8_t *src_patch = &src[y * stride + x];
+ sobel_filter_x(src_patch, stride, dx, DISFLOW_PATCH_SIZE);
+ sobel_filter_y(src_patch, stride, dy, DISFLOW_PATCH_SIZE);
+
+ compute_flow_matrix(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, M_inv);
+
+ for (int itr = 0; itr < DISFLOW_MAX_ITR; itr++) {
+ compute_flow_error(src, ref, width, height, stride, x, y, *u, *v, dt);
+ compute_flow_vector(dx, DISFLOW_PATCH_SIZE, dy, DISFLOW_PATCH_SIZE, dt,
+ DISFLOW_PATCH_SIZE, b);
+
+ // Solve flow equations to find a better estimate for the flow vector
+ // at this point
+ const double step_u = M_inv[0] * b[0] + M_inv[1] * b[1];
+ const double step_v = M_inv[2] * b[0] + M_inv[3] * b[1];
+ *u += fclamp(step_u * DISFLOW_STEP_SIZE, -2, 2);
+ *v += fclamp(step_v * DISFLOW_STEP_SIZE, -2, 2);
+
+ if (fabs(step_u) + fabs(step_v) < DISFLOW_STEP_SIZE_THRESOLD) {
+ // Stop iteration when we're close to convergence
+ break;
+ }
+ }
+}
diff --git a/third_party/aom/aom_dsp/pyramid.c b/third_party/aom/aom_dsp/pyramid.c
index 5de001dbd5..05ddbb2f5f 100644
--- a/third_party/aom/aom_dsp/pyramid.c
+++ b/third_party/aom/aom_dsp/pyramid.c
@@ -305,6 +305,7 @@ static INLINE int fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth,
// Fill in the remaining levels through progressive downsampling
for (int level = already_filled_levels; level < n_levels; ++level) {
+ bool mem_status = false;
PyramidLayer *prev_layer = &frame_pyr->layers[level - 1];
uint8_t *prev_buffer = prev_layer->buffer;
int prev_stride = prev_layer->stride;
@@ -315,6 +316,11 @@ static INLINE int fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth,
int this_height = this_layer->height;
int this_stride = this_layer->stride;
+ // The width and height of the previous layer that needs to be considered to
+ // derive the current layer frame.
+ const int input_layer_width = this_width << 1;
+ const int input_layer_height = this_height << 1;
+
// Compute the this pyramid level by downsampling the current level.
//
// We downsample by a factor of exactly 2, clipping the rightmost and
@@ -329,13 +335,30 @@ static INLINE int fill_pyramid(const YV12_BUFFER_CONFIG *frame, int bit_depth,
// 2) Up/downsampling by a factor of 2 can be implemented much more
// efficiently than up/downsampling by a generic ratio.
// TODO(rachelbarker): Use optimized downsample-by-2 function
- if (!av1_resize_plane(prev_buffer, this_height << 1, this_width << 1,
- prev_stride, this_buffer, this_height, this_width,
- this_stride)) {
- // If we can't allocate memory, we'll have to terminate early
+
+ // SIMD support has been added specifically for cases where the downsample
+ // factor is exactly 2. In such instances, horizontal and vertical resizing
+ // is performed utilizing the down2_symeven() function, which considers the
+ // even dimensions of the input layer.
+ if (should_resize_by_half(input_layer_height, input_layer_width,
+ this_height, this_width)) {
+ assert(input_layer_height % 2 == 0 && input_layer_width % 2 == 0 &&
+ "Input width or height cannot be odd.");
+ mem_status = av1_resize_plane_to_half(
+ prev_buffer, input_layer_height, input_layer_width, prev_stride,
+ this_buffer, this_height, this_width, this_stride);
+ } else {
+ mem_status = av1_resize_plane(prev_buffer, input_layer_height,
+ input_layer_width, prev_stride, this_buffer,
+ this_height, this_width, this_stride);
+ }
+
+ // Terminate early in cases of memory allocation failure.
+ if (!mem_status) {
frame_pyr->filled_levels = n_levels;
return -1;
}
+
fill_border(this_buffer, this_width, this_height, this_stride);
}
diff --git a/third_party/aom/aom_dsp/x86/synonyms.h b/third_party/aom/aom_dsp/x86/synonyms.h
index 74318de2e5..f9bc9ac733 100644
--- a/third_party/aom/aom_dsp/x86/synonyms.h
+++ b/third_party/aom/aom_dsp/x86/synonyms.h
@@ -46,7 +46,6 @@ static INLINE __m128i xx_loadu_128(const void *a) {
return _mm_loadu_si128((const __m128i *)a);
}
-
// _mm_loadu_si64 has been introduced in GCC 9, reimplement the function
// manually on older compilers.
#if !defined(__clang__) && __GNUC_MAJOR__ < 9