diff options
Diffstat (limited to 'third_party/aom/av1/encoder/arm/neon')
24 files changed, 14687 insertions, 0 deletions
diff --git a/third_party/aom/av1/encoder/arm/neon/av1_error_neon.c b/third_party/aom/av1/encoder/arm/neon/av1_error_neon.c new file mode 100644 index 0000000000..26d06b46fe --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/av1_error_neon.c @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2015 The WebM project authors. All Rights Reserved. + * Copyright (c) 2019, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "config/aom_config.h" + +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +int64_t av1_block_error_neon(const tran_low_t *coeff, const tran_low_t *dqcoeff, + intptr_t block_size, int64_t *ssz) { + uint64x2_t err_u64 = vdupq_n_u64(0); + int64x2_t ssz_s64 = vdupq_n_s64(0); + + assert(block_size >= 16); + assert((block_size % 16) == 0); + + do { + const int16x8_t c0 = load_tran_low_to_s16q(coeff); + const int16x8_t c1 = load_tran_low_to_s16q(coeff + 8); + const int16x8_t d0 = load_tran_low_to_s16q(dqcoeff); + const int16x8_t d1 = load_tran_low_to_s16q(dqcoeff + 8); + + const uint16x8_t diff0 = vreinterpretq_u16_s16(vabdq_s16(c0, d0)); + const uint16x8_t diff1 = vreinterpretq_u16_s16(vabdq_s16(c1, d1)); + + // By operating on unsigned integers we can store up to 4 squared diff in a + // 32-bit element before having to widen to 64 bits. + uint32x4_t err = vmull_u16(vget_low_u16(diff0), vget_low_u16(diff0)); + err = vmlal_u16(err, vget_high_u16(diff0), vget_high_u16(diff0)); + err = vmlal_u16(err, vget_low_u16(diff1), vget_low_u16(diff1)); + err = vmlal_u16(err, vget_high_u16(diff1), vget_high_u16(diff1)); + err_u64 = vpadalq_u32(err_u64, err); + + // We can't do the same here as we're operating on signed integers, so we + // can only accumulate 2 squares. + int32x4_t ssz0 = vmull_s16(vget_low_s16(c0), vget_low_s16(c0)); + ssz0 = vmlal_s16(ssz0, vget_high_s16(c0), vget_high_s16(c0)); + ssz_s64 = vpadalq_s32(ssz_s64, ssz0); + + int32x4_t ssz1 = vmull_s16(vget_low_s16(c1), vget_low_s16(c1)); + ssz1 = vmlal_s16(ssz1, vget_high_s16(c1), vget_high_s16(c1)); + ssz_s64 = vpadalq_s32(ssz_s64, ssz1); + + coeff += 16; + dqcoeff += 16; + block_size -= 16; + } while (block_size != 0); + + *ssz = horizontal_add_s64x2(ssz_s64); + return (int64_t)horizontal_add_u64x2(err_u64); +} + +int64_t av1_block_error_lp_neon(const int16_t *coeff, const int16_t *dqcoeff, + int block_size) { + uint64x2_t err_u64 = vdupq_n_u64(0); + + assert(block_size >= 16); + assert((block_size % 16) == 0); + + do { + const int16x8_t c0 = vld1q_s16(coeff); + const int16x8_t c1 = vld1q_s16(coeff + 8); + const int16x8_t d0 = vld1q_s16(dqcoeff); + const int16x8_t d1 = vld1q_s16(dqcoeff + 8); + + const uint16x8_t diff0 = vreinterpretq_u16_s16(vabdq_s16(c0, d0)); + const uint16x8_t diff1 = vreinterpretq_u16_s16(vabdq_s16(c1, d1)); + + // By operating on unsigned integers we can store up to 4 squared diff in a + // 32-bit element before having to widen to 64 bits. + uint32x4_t err = vmull_u16(vget_low_u16(diff0), vget_low_u16(diff0)); + err = vmlal_u16(err, vget_high_u16(diff0), vget_high_u16(diff0)); + err = vmlal_u16(err, vget_low_u16(diff1), vget_low_u16(diff1)); + err = vmlal_u16(err, vget_high_u16(diff1), vget_high_u16(diff1)); + err_u64 = vpadalq_u32(err_u64, err); + + coeff += 16; + dqcoeff += 16; + block_size -= 16; + } while (block_size != 0); + + return (int64_t)horizontal_add_u64x2(err_u64); +} diff --git a/third_party/aom/av1/encoder/arm/neon/av1_error_sve.c b/third_party/aom/av1/encoder/arm/neon/av1_error_sve.c new file mode 100644 index 0000000000..63aad0b785 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/av1_error_sve.c @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "config/aom_config.h" + +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/arm/dot_sve.h" +#include "aom_dsp/arm/mem_neon.h" + +int64_t av1_block_error_sve(const tran_low_t *coeff, const tran_low_t *dqcoeff, + intptr_t block_size, int64_t *ssz) { + int64x2_t error[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + int64x2_t sqcoeff[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + assert(block_size >= 16); + assert((block_size % 16) == 0); + + do { + const int16x8_t c0 = load_tran_low_to_s16q(coeff); + const int16x8_t c1 = load_tran_low_to_s16q(coeff + 8); + const int16x8_t d0 = load_tran_low_to_s16q(dqcoeff); + const int16x8_t d1 = load_tran_low_to_s16q(dqcoeff + 8); + + const int16x8_t diff0 = vsubq_s16(c0, d0); + const int16x8_t diff1 = vsubq_s16(c1, d1); + + error[0] = aom_sdotq_s16(error[0], diff0, diff0); + error[1] = aom_sdotq_s16(error[1], diff1, diff1); + sqcoeff[0] = aom_sdotq_s16(sqcoeff[0], c0, c0); + sqcoeff[1] = aom_sdotq_s16(sqcoeff[1], c1, c1); + + coeff += 16; + dqcoeff += 16; + block_size -= 16; + } while (block_size != 0); + + *ssz = vaddvq_s64(vaddq_s64(sqcoeff[0], sqcoeff[1])); + return vaddvq_s64(vaddq_s64(error[0], error[1])); +} + +int64_t av1_block_error_lp_sve(const int16_t *coeff, const int16_t *dqcoeff, + int block_size) { + if (block_size % 32 == 0) { + int64x2_t error[4] = { vdupq_n_s64(0), vdupq_n_s64(0), vdupq_n_s64(0), + vdupq_n_s64(0) }; + + do { + const int16x8_t c0 = vld1q_s16(coeff); + const int16x8_t c1 = vld1q_s16(coeff + 8); + const int16x8_t c2 = vld1q_s16(coeff + 16); + const int16x8_t c3 = vld1q_s16(coeff + 24); + const int16x8_t d0 = vld1q_s16(dqcoeff); + const int16x8_t d1 = vld1q_s16(dqcoeff + 8); + const int16x8_t d2 = vld1q_s16(dqcoeff + 16); + const int16x8_t d3 = vld1q_s16(dqcoeff + 24); + + const int16x8_t diff0 = vsubq_s16(c0, d0); + const int16x8_t diff1 = vsubq_s16(c1, d1); + const int16x8_t diff2 = vsubq_s16(c2, d2); + const int16x8_t diff3 = vsubq_s16(c3, d3); + + error[0] = aom_sdotq_s16(error[0], diff0, diff0); + error[1] = aom_sdotq_s16(error[1], diff1, diff1); + error[2] = aom_sdotq_s16(error[2], diff2, diff2); + error[3] = aom_sdotq_s16(error[3], diff3, diff3); + + coeff += 32; + dqcoeff += 32; + block_size -= 32; + } while (block_size != 0); + + error[0] = vaddq_s64(error[0], error[1]); + error[2] = vaddq_s64(error[2], error[3]); + error[0] = vaddq_s64(error[0], error[2]); + return vaddvq_s64(error[0]); + } + assert(block_size == 16); + + int64x2_t error[2] = { vdupq_n_s64(0), vdupq_n_s64(0) }; + + do { + const int16x8_t c0 = vld1q_s16(coeff); + const int16x8_t c1 = vld1q_s16(coeff + 8); + const int16x8_t d0 = vld1q_s16(dqcoeff); + const int16x8_t d1 = vld1q_s16(dqcoeff + 8); + + const int16x8_t diff0 = vsubq_s16(c0, d0); + const int16x8_t diff1 = vsubq_s16(c1, d1); + + error[0] = aom_sdotq_s16(error[0], diff0, diff0); + error[1] = aom_sdotq_s16(error[1], diff1, diff1); + + coeff += 16; + dqcoeff += 16; + block_size -= 16; + } while (block_size != 0); + + return vaddvq_s64(vaddq_s64(error[0], error[1])); +} diff --git a/third_party/aom/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c b/third_party/aom/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c new file mode 100644 index 0000000000..5148ee74a9 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/av1_fwd_txfm2d_neon.c @@ -0,0 +1,3090 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_dsp/txfm_common.h" +#include "aom_ports/mem.h" +#include "av1/common/av1_txfm.h" +#include "av1/encoder/av1_fwd_txfm1d_cfg.h" +#include "config/aom_config.h" +#include "config/av1_rtcd.h" +#include "shift_neon.h" +#include "txfm_neon.h" + +#define TXFM_COS_BIT_MAX 13 + +// A note on butterfly helper naming: +// +// butterfly_[input_ty]_[acc_ty]_[input_num]_[weight_num]_[weight_neg]_neon +// e.g. butterfly_s32_s32_x4_0231_neon +// | | | ^ Weights are applied as indices 0, 2, 3, 1 +// | | | (see more detail below) +// | | ^ (int32)x4 input/output parameters +// | ^ 32-bit accumulators internally +// ^ 32-bit input/output parameters +// +// Weights are stored as 4-tuples in Q2.13 format as (w0, 1-w0, -w0, w0-1) to +// avoid needing separate negation instructions. This is represented in the +// helper naming by referring to the lane index in the loaded tuple that each +// multiply is performed with: +// +// in0 in1 +// /---------- +// out0 | w0 w1 ==> out0 = in0 * w0 + in1 * w1 +// out1 | w2 w3 ==> out1 = in0 * w2 + in1 * w3 +// +// So for indices 0331 from the earlier example, we end up with: +// +// in0 in1 +// /------------------ +// out0 | (lane 0) (lane 2) ==> out0 = in0 * w0 + in1 * -w0 +// out1 | (lane 3) (lane 1) ==> out1 = in0 * (w0-1) + in1 * (1-w0) + +static AOM_FORCE_INLINE void butterfly_s32_s32_x4_0112_neon( + const int16x4_t w0101_s16, const int32x4_t in0, const int32x4_t in1, + int32x4_t *out0, int32x4_t *out1) { + int32x4_t w0101 = vmovl_s16(w0101_s16); + int32x4_t o0 = vmulq_lane_s32(in0, vget_low_s32(w0101), 0); + o0 = vmlaq_lane_s32(o0, in1, vget_low_s32(w0101), 1); + int32x4_t o1 = vmulq_lane_s32(in0, vget_low_s32(w0101), 1); + o1 = vmlaq_lane_s32(o1, in1, vget_high_s32(w0101), 0); + *out0 = vrshrq_n_s32(o0, TXFM_COS_BIT_MAX); + *out1 = vrshrq_n_s32(o1, TXFM_COS_BIT_MAX); +} + +static AOM_FORCE_INLINE void butterfly_s32_s32_x4_0332_neon( + const int16x4_t w0101_s16, const int32x4_t in0, const int32x4_t in1, + int32x4_t *out0, int32x4_t *out1) { + int32x4_t w0101 = vmovl_s16(w0101_s16); + int32x4_t o0 = vmulq_lane_s32(in0, vget_low_s32(w0101), 0); + o0 = vmlaq_lane_s32(o0, in1, vget_high_s32(w0101), 1); + int32x4_t o1 = vmulq_lane_s32(in0, vget_high_s32(w0101), 1); + o1 = vmlaq_lane_s32(o1, in1, vget_high_s32(w0101), 0); + *out0 = vrshrq_n_s32(o0, TXFM_COS_BIT_MAX); + *out1 = vrshrq_n_s32(o1, TXFM_COS_BIT_MAX); +} + +static AOM_FORCE_INLINE void butterfly_s32_s32_x4_1003_neon( + const int16x4_t w0101_s16, const int32x4_t in0, const int32x4_t in1, + int32x4_t *out0, int32x4_t *out1) { + int32x4_t w0101 = vmovl_s16(w0101_s16); + int32x4_t o0 = vmulq_lane_s32(in0, vget_low_s32(w0101), 1); + o0 = vmlaq_lane_s32(o0, in1, vget_low_s32(w0101), 0); + int32x4_t o1 = vmulq_lane_s32(in0, vget_low_s32(w0101), 0); + o1 = vmlaq_lane_s32(o1, in1, vget_high_s32(w0101), 1); + *out0 = vrshrq_n_s32(o0, TXFM_COS_BIT_MAX); + *out1 = vrshrq_n_s32(o1, TXFM_COS_BIT_MAX); +} + +static AOM_FORCE_INLINE void butterfly_s32_s32_x4_1223_neon( + const int16x4_t w0101_s16, const int32x4_t in0, const int32x4_t in1, + int32x4_t *out0, int32x4_t *out1) { + int32x4_t w0101 = vmovl_s16(w0101_s16); + int32x4_t o0 = vmulq_lane_s32(in0, vget_low_s32(w0101), 1); + o0 = vmlaq_lane_s32(o0, in1, vget_high_s32(w0101), 0); + int32x4_t o1 = vmulq_lane_s32(in0, vget_high_s32(w0101), 0); + o1 = vmlaq_lane_s32(o1, in1, vget_high_s32(w0101), 1); + *out0 = vrshrq_n_s32(o0, TXFM_COS_BIT_MAX); + *out1 = vrshrq_n_s32(o1, TXFM_COS_BIT_MAX); +} + +#define butterfly_s16_s32_x4_neon(wvec, lane0, lane1, lane2, lane3, in0, in1, \ + out0, out1) \ + do { \ + int32x4_t u0 = vmull_lane_s16(in0, wvec, lane0); \ + u0 = vmlal_lane_s16(u0, in1, wvec, lane1); \ + int32x4_t v0 = vmull_lane_s16(in0, wvec, lane2); \ + v0 = vmlal_lane_s16(v0, in1, wvec, lane3); \ + *out0 = vqrshrn_n_s32(u0, TXFM_COS_BIT_MAX); \ + *out1 = vqrshrn_n_s32(v0, TXFM_COS_BIT_MAX); \ + } while (0) + +static AOM_FORCE_INLINE void butterfly_s16_s32_x4_0112_neon( + const int16x4_t w0101, const int16x4_t in0, const int16x4_t in1, + int16x4_t *out0, int16x4_t *out1) { + butterfly_s16_s32_x4_neon(w0101, 0, 1, 1, 2, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void butterfly_s16_s32_x4_0332_neon( + const int16x4_t w0101, const int16x4_t in0, const int16x4_t in1, + int16x4_t *out0, int16x4_t *out1) { + butterfly_s16_s32_x4_neon(w0101, 0, 3, 3, 2, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void butterfly_s16_s32_x4_1003_neon( + const int16x4_t w0101, const int16x4_t in0, const int16x4_t in1, + int16x4_t *out0, int16x4_t *out1) { + butterfly_s16_s32_x4_neon(w0101, 1, 0, 0, 3, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void butterfly_s16_s32_x4_1223_neon( + const int16x4_t w0101, const int16x4_t in0, const int16x4_t in1, + int16x4_t *out0, int16x4_t *out1) { + butterfly_s16_s32_x4_neon(w0101, 1, 2, 2, 3, in0, in1, out0, out1); +} + +#define butterfly_s16_s32_x8_neon(wvec, lane0, lane1, lane2, lane3, in0, in1, \ + out0, out1) \ + do { \ + int32x4_t u0 = vmull_lane_s16(vget_low_s16(in0), wvec, lane0); \ + u0 = vmlal_lane_s16(u0, vget_low_s16(in1), wvec, lane1); \ + int32x4_t u1 = vmull_lane_s16(vget_high_s16(in0), wvec, lane0); \ + u1 = vmlal_lane_s16(u1, vget_high_s16(in1), wvec, lane1); \ + int32x4_t v0 = vmull_lane_s16(vget_low_s16(in0), wvec, lane2); \ + v0 = vmlal_lane_s16(v0, vget_low_s16(in1), wvec, lane3); \ + int32x4_t v1 = vmull_lane_s16(vget_high_s16(in0), wvec, lane2); \ + v1 = vmlal_lane_s16(v1, vget_high_s16(in1), wvec, lane3); \ + const int16x4_t c0 = vrshrn_n_s32(u0, TXFM_COS_BIT_MAX); \ + const int16x4_t c1 = vrshrn_n_s32(u1, TXFM_COS_BIT_MAX); \ + const int16x4_t d0 = vrshrn_n_s32(v0, TXFM_COS_BIT_MAX); \ + const int16x4_t d1 = vrshrn_n_s32(v1, TXFM_COS_BIT_MAX); \ + *out0 = vcombine_s16(c0, c1); \ + *out1 = vcombine_s16(d0, d1); \ + } while (0) + +static AOM_FORCE_INLINE void butterfly_s16_s32_x8_0112_neon( + const int16x4_t w0101, const int16x8_t in0, const int16x8_t in1, + int16x8_t *out0, int16x8_t *out1) { + butterfly_s16_s32_x8_neon(w0101, 0, 1, 1, 2, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void butterfly_s16_s32_x8_0332_neon( + const int16x4_t w0101, const int16x8_t in0, const int16x8_t in1, + int16x8_t *out0, int16x8_t *out1) { + butterfly_s16_s32_x8_neon(w0101, 0, 3, 3, 2, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void butterfly_s16_s32_x8_1003_neon( + const int16x4_t w0101, const int16x8_t in0, const int16x8_t in1, + int16x8_t *out0, int16x8_t *out1) { + butterfly_s16_s32_x8_neon(w0101, 1, 0, 0, 3, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void butterfly_s16_s32_x8_1223_neon( + const int16x4_t w0101, const int16x8_t in0, const int16x8_t in1, + int16x8_t *out0, int16x8_t *out1) { + butterfly_s16_s32_x8_neon(w0101, 1, 2, 2, 3, in0, in1, out0, out1); +} + +static AOM_FORCE_INLINE void flip_buf_4_neon(int16x4_t *in, int16x4_t *out, + int size) { + for (int i = 0; i < size; ++i) { + out[size - i - 1] = in[i]; + } +} + +static AOM_FORCE_INLINE void flip_buf_8_neon(int16x8_t *in, int16x8_t *out, + int size) { + for (int i = 0; i < size; ++i) { + out[size - i - 1] = in[i]; + } +} + +static AOM_FORCE_INLINE void store_buffer_interleaved_s32_x8( + int32_t *const out, const int32x4_t *const in1, const int32x4_t *const in2, + const int stride, const int out_size) { + for (int i = 0; i < out_size; ++i) { + vst1q_s32(out + stride * i, in1[i]); + vst1q_s32(out + stride * i + 4, in2[i]); + } +} + +static AOM_FORCE_INLINE void load_buffer_s16_x4(const int16_t *in, + const int stride, + int16x4_t *const out, + const int out_size) { + for (int i = 0; i < out_size; ++i) { + out[i] = vld1_s16(in); + in += stride; + } +} + +static AOM_FORCE_INLINE void load_buffer_s16_x8(const int16_t *in, int stride, + int16x8_t *out, int out_size) { + for (int i = 0; i < out_size; ++i) { + out[i] = vld1q_s16(in + i * stride); + } +} + +static AOM_FORCE_INLINE void store_buffer_s16_x4(const int16x4_t *const in, + int32_t *const out, + const int stride, + const int out_size) { + for (int i = 0; i < out_size; ++i) { + vst1q_s32(out + i * stride, vmovl_s16(in[i])); + } +} + +static AOM_FORCE_INLINE void store_buffer_s16_x8(const int16x8_t *const in, + int32_t *const out, + const int stride, + const int out_size) { + for (int i = 0; i < out_size; ++i) { + vst1q_s32(out + i * stride + 0, vmovl_s16(vget_low_s16(in[i]))); + vst1q_s32(out + i * stride + 4, vmovl_s16(vget_high_s16(in[i]))); + } +} + +// A note on naming: +// round_shift_[sqrt2]_s16_s32_4x1_neon(...) +// | | | ^ 1 => a single vector +// | | | n => an array of vectors +// | | | ^ input/output vector element count +// | | ^ output type +// | ^ input type +// ^ multiplicand and shift identifier + +static AOM_FORCE_INLINE int16x4_t +round_shift_sqrt2_s16_s16_4x1_neon(int16x4_t a) { + return vqrshrn_n_s32(vmull_n_s16(a, NewSqrt2), NewSqrt2Bits); +} + +static AOM_FORCE_INLINE int16x8_t +round_shift_sqrt2_s16_s16_8x1_neon(int16x8_t a) { + return vcombine_s16(round_shift_sqrt2_s16_s16_4x1_neon(vget_low_s16(a)), + round_shift_sqrt2_s16_s16_4x1_neon(vget_high_s16(a))); +} + +static AOM_FORCE_INLINE int16x4_t +round_shift_2sqrt2_s16_s16_4x1_neon(int16x4_t a) { + return vqrshrn_n_s32(vmull_n_s16(a, 2 * NewSqrt2), NewSqrt2Bits); +} + +static AOM_FORCE_INLINE int16x8_t +round_shift_2sqrt2_s16_s16_8x1_neon(int16x8_t a) { + return vcombine_s16(round_shift_2sqrt2_s16_s16_4x1_neon(vget_low_s16(a)), + round_shift_2sqrt2_s16_s16_4x1_neon(vget_high_s16(a))); +} + +static AOM_FORCE_INLINE int32x4_t +round_shift_sqrt2_s16_s32_4x1_neon(int16x4_t a) { + return vrshrq_n_s32(vmull_n_s16(a, NewSqrt2), NewSqrt2Bits); +} + +static AOM_FORCE_INLINE int32x4_t +round_shift_sqrt2_s32_s32_4x1_neon(int32x4_t a) { + return vrshrq_n_s32(vmulq_n_s32(a, NewSqrt2), NewSqrt2Bits); +} + +#define ROUND_SHIFT_SQRT_LOOP_HELPER(name, type0, type1, fn) \ + static AOM_FORCE_INLINE void name(const type0 *in, type1 *out, int size) { \ + for (int i = 0; i < size; ++i) { \ + out[i] = fn(in[i]); \ + } \ + } + +ROUND_SHIFT_SQRT_LOOP_HELPER(round_shift_sqrt2_s32_s32_4xn_neon, int32x4_t, + int32x4_t, round_shift_sqrt2_s32_s32_4x1_neon) +ROUND_SHIFT_SQRT_LOOP_HELPER(round_shift_sqrt2_s16_s16_4xn_neon, int16x4_t, + int16x4_t, round_shift_sqrt2_s16_s16_4x1_neon) +ROUND_SHIFT_SQRT_LOOP_HELPER(round_shift_sqrt2_s16_s16_8xn_neon, int16x8_t, + int16x8_t, round_shift_sqrt2_s16_s16_8x1_neon) +ROUND_SHIFT_SQRT_LOOP_HELPER(round_shift_2sqrt2_s16_s16_4xn_neon, int16x4_t, + int16x4_t, round_shift_2sqrt2_s16_s16_4x1_neon) +ROUND_SHIFT_SQRT_LOOP_HELPER(round_shift_2sqrt2_s16_s16_8xn_neon, int16x8_t, + int16x8_t, round_shift_2sqrt2_s16_s16_8x1_neon) + +static AOM_FORCE_INLINE void store_rect_buffer_s16_x4(const int16x4_t *const in, + int32_t *const out, + const int stride, + const int out_size) { + for (int i = 0; i < out_size; ++i) { + vst1q_s32(out + i * stride, round_shift_sqrt2_s16_s32_4x1_neon(in[i])); + } +} + +static AOM_FORCE_INLINE void store_rect_buffer_s16_x8(const int16x8_t *const in, + int32_t *const out, + const int stride, + const int out_size) { + for (int i = 0; i < out_size; ++i) { + vst1q_s32(out + i * stride + 0, + round_shift_sqrt2_s16_s32_4x1_neon(vget_low_s16(in[i]))); + vst1q_s32(out + i * stride + 4, + round_shift_sqrt2_s16_s32_4x1_neon(vget_high_s16(in[i]))); + } +} + +static AOM_FORCE_INLINE void fadst4x4_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + int32x4_t u[6], v[6]; + const int16x4_t sinpi = vld1_s16(sinpi_arr_q13(cos_bit)); + const int16x4_t u01 = vqadd_s16(input[0], input[1]); + + v[5] = vmull_lane_s16(input[2], sinpi, 2); + v[0] = vmull_lane_s16(input[1], sinpi, 1); + v[0] = vmlal_lane_s16(v[0], input[0], sinpi, 0); + v[1] = vmlal_lane_s16(v[5], input[3], sinpi, 3); + v[2] = vmull_lane_s16(u01, sinpi, 2); + v[3] = vmull_lane_s16(input[0], sinpi, 3); + v[3] = vmlsl_lane_s16(v[3], input[1], sinpi, 0); + v[4] = vmlsl_lane_s16(v[5], input[3], sinpi, 1); + + u[0] = vaddq_s32(v[0], v[1]); + u[1] = vmlsl_lane_s16(v[2], input[3], sinpi, 2); + u[2] = vsubq_s32(v[3], v[4]); + u[3] = vsubq_s32(u[2], u[0]); + u[3] = vmlaq_n_s32(u[3], v[5], 3); + + output[0] = vrshrn_n_s32(u[0], TXFM_COS_BIT_MAX); + output[1] = vrshrn_n_s32(u[1], TXFM_COS_BIT_MAX); + output[2] = vrshrn_n_s32(u[2], TXFM_COS_BIT_MAX); + output[3] = vrshrn_n_s32(u[3], TXFM_COS_BIT_MAX); +} + +static AOM_FORCE_INLINE void fadst4x8_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + + // stage 1-2 + int16x4_t x2[8]; + butterfly_s16_s32_x4_0332_neon(cospi32, input[4], input[3], &x2[2], &x2[3]); + butterfly_s16_s32_x4_0112_neon(cospi32, input[2], input[5], &x2[7], &x2[6]); + + // stage 3 + int16x4_t x3[8]; + x3[0] = vqadd_s16(input[0], x2[2]); + x3[1] = vqsub_s16(x2[3], input[7]); + x3[2] = vqsub_s16(input[0], x2[2]); + x3[3] = vqadd_s16(input[7], x2[3]); + x3[4] = vqsub_s16(x2[6], input[1]); + x3[5] = vqadd_s16(input[6], x2[7]); + x3[6] = vqadd_s16(input[1], x2[6]); + x3[7] = vqsub_s16(input[6], x2[7]); + + // stage 4 + int16x4_t x4[8]; + butterfly_s16_s32_x4_0112_neon(cospi16, x3[4], x3[5], &x4[4], &x4[5]); + butterfly_s16_s32_x4_0112_neon(cospi16, x3[7], x3[6], &x4[6], &x4[7]); + + // stage 5 + int16x4_t x5[8]; + x5[0] = vqadd_s16(x3[0], x4[4]); + x5[1] = vqadd_s16(x3[1], x4[5]); + x5[2] = vqadd_s16(x3[2], x4[6]); + x5[3] = vqsub_s16(x4[7], x3[3]); + x5[4] = vqsub_s16(x3[0], x4[4]); + x5[5] = vqsub_s16(x3[1], x4[5]); + x5[6] = vqsub_s16(x3[2], x4[6]); + x5[7] = vqadd_s16(x3[3], x4[7]); + + // stage 6-7 + butterfly_s16_s32_x4_0112_neon(cospi4, x5[0], x5[1], &output[7], &output[0]); + butterfly_s16_s32_x4_0112_neon(cospi20, x5[2], x5[3], &output[5], &output[2]); + butterfly_s16_s32_x4_1003_neon(cospi28, x5[4], x5[5], &output[3], &output[4]); + butterfly_s16_s32_x4_0112_neon(cospi12, x5[6], x5[7], &output[6], &output[1]); +} + +static AOM_FORCE_INLINE void fadst8x4_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + int32x4_t u_lo[4], u_hi[4]; + const int16x4_t sinpi = vld1_s16(sinpi_arr_q13(cos_bit)); + const int16x8_t u01 = vqaddq_s16(input[0], input[1]); + + u_lo[0] = vmull_lane_s16(vget_low_s16(input[1]), sinpi, 1); + u_hi[0] = vmull_lane_s16(vget_high_s16(input[1]), sinpi, 1); + + u_lo[0] = vmlal_lane_s16(u_lo[0], vget_low_s16(input[0]), sinpi, 0); + u_hi[0] = vmlal_lane_s16(u_hi[0], vget_high_s16(input[0]), sinpi, 0); + + u_lo[0] = vmlal_lane_s16(u_lo[0], vget_low_s16(input[3]), sinpi, 3); + u_hi[0] = vmlal_lane_s16(u_hi[0], vget_high_s16(input[3]), sinpi, 3); + + u_lo[0] = vmlal_lane_s16(u_lo[0], vget_low_s16(input[2]), sinpi, 2); + u_hi[0] = vmlal_lane_s16(u_hi[0], vget_high_s16(input[2]), sinpi, 2); + + u_lo[1] = vmull_lane_s16(vget_low_s16(u01), sinpi, 2); + u_hi[1] = vmull_lane_s16(vget_high_s16(u01), sinpi, 2); + + u_lo[2] = vmull_lane_s16(vget_low_s16(input[0]), sinpi, 3); + u_hi[2] = vmull_lane_s16(vget_high_s16(input[0]), sinpi, 3); + + u_lo[2] = vmlsl_lane_s16(u_lo[2], vget_low_s16(input[1]), sinpi, 0); + u_hi[2] = vmlsl_lane_s16(u_hi[2], vget_high_s16(input[1]), sinpi, 0); + + u_lo[2] = vmlal_lane_s16(u_lo[2], vget_low_s16(input[3]), sinpi, 1); + u_hi[2] = vmlal_lane_s16(u_hi[2], vget_high_s16(input[3]), sinpi, 1); + + u_lo[2] = vmlsl_lane_s16(u_lo[2], vget_low_s16(input[2]), sinpi, 2); + u_hi[2] = vmlsl_lane_s16(u_hi[2], vget_high_s16(input[2]), sinpi, 2); + + u_lo[1] = vmlsl_lane_s16(u_lo[1], vget_low_s16(input[3]), sinpi, 2); + u_hi[1] = vmlsl_lane_s16(u_hi[1], vget_high_s16(input[3]), sinpi, 2); + + u_lo[3] = vsubq_s32(u_lo[2], u_lo[0]); + u_hi[3] = vsubq_s32(u_hi[2], u_hi[0]); + + const int16x4_t sinpix3 = vmul_n_s16(sinpi, 3); + u_lo[3] = vmlal_lane_s16(u_lo[3], vget_low_s16(input[2]), sinpix3, 2); + u_hi[3] = vmlal_lane_s16(u_hi[3], vget_high_s16(input[2]), sinpix3, 2); + + output[0] = vcombine_s16(vrshrn_n_s32(u_lo[0], TXFM_COS_BIT_MAX), + vrshrn_n_s32(u_hi[0], TXFM_COS_BIT_MAX)); + output[1] = vcombine_s16(vrshrn_n_s32(u_lo[1], TXFM_COS_BIT_MAX), + vrshrn_n_s32(u_hi[1], TXFM_COS_BIT_MAX)); + output[2] = vcombine_s16(vrshrn_n_s32(u_lo[2], TXFM_COS_BIT_MAX), + vrshrn_n_s32(u_hi[2], TXFM_COS_BIT_MAX)); + output[3] = vcombine_s16(vrshrn_n_s32(u_lo[3], TXFM_COS_BIT_MAX), + vrshrn_n_s32(u_hi[3], TXFM_COS_BIT_MAX)); +} + +static AOM_FORCE_INLINE void fdct4x4_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + const int16x4_t cospi16 = vld1_s16(&cospi[4 * 1]); + + int16x4_t in12a = vadd_s16(input[1], input[2]); + int16x4_t in12s = vsub_s16(input[1], input[2]); + int16x4_t in03a = vadd_s16(input[0], input[3]); + int16x4_t in03s = vsub_s16(input[0], input[3]); + + int32x4_t u0ad1 = vmull_n_s16(in12a, cospi[4 * 0]); + int32x4_t u0ad2 = vmull_n_s16(in03a, cospi[4 * 0]); + + int32x4_t u[4]; + u[0] = vaddq_s32(u0ad1, u0ad2); + u[1] = vsubq_s32(u0ad2, u0ad1); + u[2] = vmull_lane_s16(in12s, cospi16, 1); + u[2] = vmlal_lane_s16(u[2], in03s, cospi16, 0); + u[3] = vmull_lane_s16(in03s, cospi16, 1); + u[3] = vmlsl_lane_s16(u[3], in12s, cospi16, 0); + + output[0] = vrshrn_n_s32(u[0], TXFM_COS_BIT_MAX); + output[1] = vrshrn_n_s32(u[2], TXFM_COS_BIT_MAX); + output[2] = vrshrn_n_s32(u[1], TXFM_COS_BIT_MAX); + output[3] = vrshrn_n_s32(u[3], TXFM_COS_BIT_MAX); +} + +// Butterfly pre-processing: +// e.g. n=4: +// out[0] = in[0] + in[3] +// out[1] = in[1] + in[2] +// out[2] = in[1] - in[2] +// out[3] = in[0] - in[3] + +static AOM_FORCE_INLINE void butterfly_dct_pre_s16_x4(const int16x4_t *input, + int16x4_t *output, + int n) { + for (int i = 0; i < n / 2; ++i) { + output[i] = vqadd_s16(input[i], input[n - i - 1]); + } + for (int i = 0; i < n / 2; ++i) { + output[n / 2 + i] = vqsub_s16(input[n / 2 - i - 1], input[n / 2 + i]); + } +} + +static AOM_FORCE_INLINE void butterfly_dct_pre_s16_x8(const int16x8_t *input, + int16x8_t *output, + int n) { + for (int i = 0; i < n / 2; ++i) { + output[i] = vqaddq_s16(input[i], input[n - i - 1]); + } + for (int i = 0; i < n / 2; ++i) { + output[n / 2 + i] = vqsubq_s16(input[n / 2 - i - 1], input[n / 2 + i]); + } +} + +static AOM_FORCE_INLINE void butterfly_dct_pre_s32_x4(const int32x4_t *input, + int32x4_t *output, + int n) { + for (int i = 0; i < n / 2; ++i) { + output[i] = vqaddq_s32(input[i], input[n - i - 1]); + } + for (int i = 0; i < n / 2; ++i) { + output[n / 2 + i] = vqsubq_s32(input[n / 2 - i - 1], input[n / 2 + i]); + } +} + +// Butterfly post-processing: +// e.g. n=8: +// out[0] = in0[0] + in1[3]; +// out[1] = in0[1] + in1[2]; +// out[2] = in0[1] - in1[2]; +// out[3] = in0[0] - in1[3]; +// out[4] = in0[7] - in1[4]; +// out[5] = in0[6] - in1[5]; +// out[6] = in0[6] + in1[5]; +// out[7] = in0[7] + in1[4]; + +static AOM_FORCE_INLINE void butterfly_dct_post_s16_x4(const int16x4_t *in0, + const int16x4_t *in1, + int16x4_t *output, + int n) { + for (int i = 0; i < n / 4; ++i) { + output[i] = vqadd_s16(in0[i], in1[n / 2 - i - 1]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 4 + i] = vqsub_s16(in0[n / 4 - i - 1], in1[n / 4 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 2 + i] = vqsub_s16(in0[n - i - 1], in1[n / 2 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[(3 * n) / 4 + i] = + vqadd_s16(in0[(3 * n) / 4 + i], in1[(3 * n) / 4 - i - 1]); + } +} + +static AOM_FORCE_INLINE void butterfly_dct_post_s16_x8(const int16x8_t *in0, + const int16x8_t *in1, + int16x8_t *output, + int n) { + for (int i = 0; i < n / 4; ++i) { + output[i] = vqaddq_s16(in0[i], in1[n / 2 - i - 1]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 4 + i] = vqsubq_s16(in0[n / 4 - i - 1], in1[n / 4 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 2 + i] = vqsubq_s16(in0[n - i - 1], in1[n / 2 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[(3 * n) / 4 + i] = + vqaddq_s16(in0[(3 * n) / 4 + i], in1[(3 * n) / 4 - i - 1]); + } +} + +static AOM_FORCE_INLINE void butterfly_dct_post_s32_x4(const int32x4_t *in0, + const int32x4_t *in1, + int32x4_t *output, + int n) { + for (int i = 0; i < n / 4; ++i) { + output[i] = vqaddq_s32(in0[i], in1[n / 2 - i - 1]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 4 + i] = vqsubq_s32(in0[n / 4 - i - 1], in1[n / 4 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 2 + i] = vqsubq_s32(in0[n - i - 1], in1[n / 2 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[(3 * n) / 4 + i] = + vqaddq_s32(in0[(3 * n) / 4 + i], in1[(3 * n) / 4 - i - 1]); + } +} + +static AOM_FORCE_INLINE void fdct8x4_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + + // stage 1 + int16x8_t x1[4]; + butterfly_dct_pre_s16_x8(input, x1, 4); + + // stage 2 + int16x8_t x2[4]; + butterfly_s16_s32_x8_0112_neon(cospi32, x1[0], x1[1], &x2[0], &x2[1]); + butterfly_s16_s32_x8_0112_neon(cospi16, x1[3], x1[2], &x2[2], &x2[3]); + + // stage 3 + output[0] = x2[0]; + output[1] = x2[2]; + output[2] = x2[1]; + output[3] = x2[3]; +} + +static AOM_FORCE_INLINE void fdct4x8_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + + // stage 1 + int16x4_t x1[8]; + butterfly_dct_pre_s16_x4(input, x1, 8); + + // stage 2 + int16x4_t x2[8]; + butterfly_dct_pre_s16_x4(x1, x2, 4); + butterfly_s16_s32_x4_0112_neon(cospi32, x1[6], x1[5], &x2[6], &x2[5]); + + // stage 3 + int16x4_t x3[8]; + butterfly_s16_s32_x4_0112_neon(cospi32, x2[0], x2[1], &output[0], &output[4]); + butterfly_s16_s32_x4_0112_neon(cospi16, x2[3], x2[2], &output[2], &output[6]); + butterfly_dct_post_s16_x4(x1 + 4, x2 + 4, x3 + 4, 4); + + // stage 4-5 + butterfly_s16_s32_x4_0112_neon(cospi8, x3[7], x3[4], &output[1], &output[7]); + butterfly_s16_s32_x4_1003_neon(cospi24, x3[6], x3[5], &output[5], &output[3]); +} + +static AOM_FORCE_INLINE void fdct8x8_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + + // stage 1 + int16x8_t x1[8]; + butterfly_dct_pre_s16_x8(input, x1, 8); + + // stage 2 + int16x8_t x2[8]; + butterfly_dct_pre_s16_x8(x1, x2, 4); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[6], x1[5], &x2[6], &x2[5]); + + // stage 3 + int16x8_t x3[8]; + butterfly_s16_s32_x8_0112_neon(cospi32, x2[0], x2[1], &output[0], &output[4]); + butterfly_s16_s32_x8_0112_neon(cospi16, x2[3], x2[2], &output[2], &output[6]); + butterfly_dct_post_s16_x8(x1 + 4, x2 + 4, x3 + 4, 4); + + // stage 4-5 + butterfly_s16_s32_x8_0112_neon(cospi8, x3[7], x3[4], &output[1], &output[7]); + butterfly_s16_s32_x8_1003_neon(cospi24, x3[6], x3[5], &output[5], &output[3]); +} + +static AOM_FORCE_INLINE void fdct4x16_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + + // stage 1 + int16x4_t x1[16]; + butterfly_dct_pre_s16_x4(input, x1, 16); + + // stage 2 + int16x4_t x2[16]; + butterfly_dct_pre_s16_x4(x1, x2, 8); + butterfly_s16_s32_x4_0112_neon(cospi32, x1[13], x1[10], &x2[13], &x2[10]); + butterfly_s16_s32_x4_0112_neon(cospi32, x1[12], x1[11], &x2[12], &x2[11]); + + // stage 3 + int16x4_t x3[16]; + butterfly_dct_pre_s16_x4(x2, x3, 4); + butterfly_s16_s32_x4_0112_neon(cospi32, x2[6], x2[5], &x3[6], &x3[5]); + butterfly_dct_post_s16_x4(x1 + 8, x2 + 8, x3 + 8, 8); + + // stage 4 + int16x4_t x4[16]; + butterfly_s16_s32_x4_0112_neon(cospi32, x3[0], x3[1], &output[0], &output[8]); + butterfly_s16_s32_x4_0112_neon(cospi16, x3[3], x3[2], &output[4], + &output[12]); + butterfly_dct_post_s16_x4(x2 + 4, x3 + 4, x4 + 4, 4); + butterfly_s16_s32_x4_0112_neon(cospi16, x3[14], x3[9], &x4[14], &x4[9]); + butterfly_s16_s32_x4_1223_neon(cospi16, x3[13], x3[10], &x4[13], &x4[10]); + + // stage 5 + int16x4_t x5[16]; + butterfly_s16_s32_x4_0112_neon(cospi8, x4[7], x4[4], &output[2], &output[14]); + butterfly_s16_s32_x4_1003_neon(cospi24, x4[6], x4[5], &output[10], + &output[6]); + butterfly_dct_post_s16_x4(x3 + 8, x4 + 8, x5 + 8, 4); + butterfly_dct_post_s16_x4(x3 + 12, x4 + 12, x5 + 12, 4); + + // stage 6-7 + butterfly_s16_s32_x4_0112_neon(cospi4, x5[15], x5[8], &output[1], + &output[15]); + butterfly_s16_s32_x4_1003_neon(cospi28, x5[14], x5[9], &output[9], + &output[7]); + butterfly_s16_s32_x4_0112_neon(cospi20, x5[13], x5[10], &output[5], + &output[11]); + butterfly_s16_s32_x4_1003_neon(cospi12, x5[12], x5[11], &output[13], + &output[3]); +} + +static AOM_FORCE_INLINE void fdct8x16_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + + // stage 1 + int16x8_t x1[16]; + butterfly_dct_pre_s16_x8(input, x1, 16); + + // stage 2 + int16x8_t x2[16]; + butterfly_dct_pre_s16_x8(x1, x2, 8); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[13], x1[10], &x2[13], &x2[10]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[12], x1[11], &x2[12], &x2[11]); + + // stage 3 + int16x8_t x3[16]; + butterfly_dct_pre_s16_x8(x2, x3, 4); + butterfly_s16_s32_x8_0112_neon(cospi32, x2[6], x2[5], &x3[6], &x3[5]); + butterfly_dct_post_s16_x8(x1 + 8, x2 + 8, x3 + 8, 8); + + // stage 4 + int16x8_t x4[16]; + butterfly_s16_s32_x8_0112_neon(cospi32, x3[0], x3[1], &output[0], &output[8]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[3], x3[2], &output[4], + &output[12]); + butterfly_dct_post_s16_x8(x2 + 4, x3 + 4, x4 + 4, 4); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[14], x3[9], &x4[14], &x4[9]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[13], x3[10], &x4[13], &x4[10]); + + // stage 5 + int16x8_t x5[16]; + butterfly_s16_s32_x8_0112_neon(cospi8, x4[7], x4[4], &output[2], &output[14]); + butterfly_s16_s32_x8_1003_neon(cospi24, x4[6], x4[5], &output[10], + &output[6]); + butterfly_dct_post_s16_x8(x3 + 8, x4 + 8, x5 + 8, 4); + butterfly_dct_post_s16_x8(x3 + 12, x4 + 12, x5 + 12, 4); + + // stage 6-7 + butterfly_s16_s32_x8_0112_neon(cospi4, x5[15], x5[8], &output[1], + &output[15]); + butterfly_s16_s32_x8_1003_neon(cospi28, x5[14], x5[9], &output[9], + &output[7]); + butterfly_s16_s32_x8_0112_neon(cospi20, x5[13], x5[10], &output[5], + &output[11]); + butterfly_s16_s32_x8_1003_neon(cospi12, x5[12], x5[11], &output[13], + &output[3]); +} + +static AOM_FORCE_INLINE void fdct8x32_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + const int16x8_t cospi2_6 = vld1q_s16(&cospi[4 * 8]); + const int16x8_t cospi10_14 = vld1q_s16(&cospi[4 * 10]); + const int16x8_t cospi18_22 = vld1q_s16(&cospi[4 * 12]); + const int16x8_t cospi26_30 = vld1q_s16(&cospi[4 * 14]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + const int16x4_t cospi2 = vget_low_s16(cospi2_6); + const int16x4_t cospi6 = vget_high_s16(cospi2_6); + const int16x4_t cospi10 = vget_low_s16(cospi10_14); + const int16x4_t cospi14 = vget_high_s16(cospi10_14); + const int16x4_t cospi18 = vget_low_s16(cospi18_22); + const int16x4_t cospi22 = vget_high_s16(cospi18_22); + const int16x4_t cospi26 = vget_low_s16(cospi26_30); + const int16x4_t cospi30 = vget_high_s16(cospi26_30); + + // stage 1 + int16x8_t x1[32]; + butterfly_dct_pre_s16_x8(input, x1, 32); + + // stage 2 + int16x8_t x2[32]; + butterfly_dct_pre_s16_x8(x1, x2, 16); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[27], x1[20], &x2[27], &x2[20]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[26], x1[21], &x2[26], &x2[21]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[25], x1[22], &x2[25], &x2[22]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[24], x1[23], &x2[24], &x2[23]); + + // stage 3 + int16x8_t x3[32]; + butterfly_dct_pre_s16_x8(x2, x3, 8); + butterfly_s16_s32_x8_0112_neon(cospi32, x2[13], x2[10], &x3[13], &x3[10]); + butterfly_s16_s32_x8_0112_neon(cospi32, x2[12], x2[11], &x3[12], &x3[11]); + butterfly_dct_post_s16_x8(x1 + 16, x2 + 16, x3 + 16, 16); + + // stage 4 + int16x8_t x4[32]; + butterfly_dct_pre_s16_x8(x3, x4, 4); + butterfly_s16_s32_x8_0112_neon(cospi32, x3[6], x3[5], &x4[6], &x4[5]); + butterfly_dct_post_s16_x8(x2 + 8, x3 + 8, x4 + 8, 8); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[29], x3[18], &x4[29], &x4[18]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[28], x3[19], &x4[28], &x4[19]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[27], x3[20], &x4[27], &x4[20]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[26], x3[21], &x4[26], &x4[21]); + + // stage 5 + int16x8_t x5[32]; + butterfly_s16_s32_x8_0112_neon(cospi32, x4[0], x4[1], &output[0], + &output[16]); + butterfly_s16_s32_x8_0112_neon(cospi16, x4[3], x4[2], &output[8], + &output[24]); + butterfly_dct_post_s16_x8(x3 + 4, x4 + 4, x5 + 4, 4); + butterfly_s16_s32_x8_0112_neon(cospi16, x4[14], x4[9], &x5[14], &x5[9]); + butterfly_s16_s32_x8_1223_neon(cospi16, x4[13], x4[10], &x5[13], &x5[10]); + butterfly_dct_post_s16_x8(x3 + 16, x4 + 16, x5 + 16, 8); + butterfly_dct_post_s16_x8(x3 + 24, x4 + 24, x5 + 24, 8); + + // stage 6 + int16x8_t x6[32]; + butterfly_s16_s32_x8_0112_neon(cospi8, x5[7], x5[4], &output[4], &output[28]); + butterfly_s16_s32_x8_1003_neon(cospi24, x5[6], x5[5], &output[20], + &output[12]); + butterfly_dct_post_s16_x8(x4 + 8, x5 + 8, x6 + 8, 4); + butterfly_dct_post_s16_x8(x4 + 12, x5 + 12, x6 + 12, 4); + butterfly_s16_s32_x8_0112_neon(cospi8, x5[30], x5[17], &x6[30], &x6[17]); + butterfly_s16_s32_x8_1223_neon(cospi8, x5[29], x5[18], &x6[29], &x6[18]); + butterfly_s16_s32_x8_1003_neon(cospi24, x5[26], x5[21], &x6[26], &x6[21]); + butterfly_s16_s32_x8_0332_neon(cospi24, x5[25], x5[22], &x6[25], &x6[22]); + + // stage 7 + int16x8_t x7[32]; + butterfly_s16_s32_x8_0112_neon(cospi4, x6[15], x6[8], &output[2], + &output[30]); + butterfly_s16_s32_x8_1003_neon(cospi28, x6[14], x6[9], &output[18], + &output[14]); + butterfly_s16_s32_x8_0112_neon(cospi20, x6[13], x6[10], &output[10], + &output[22]); + butterfly_s16_s32_x8_1003_neon(cospi12, x6[12], x6[11], &output[26], + &output[6]); + butterfly_dct_post_s16_x8(x5 + 16, x6 + 16, x7 + 16, 4); + butterfly_dct_post_s16_x8(x5 + 20, x6 + 20, x7 + 20, 4); + butterfly_dct_post_s16_x8(x5 + 24, x6 + 24, x7 + 24, 4); + butterfly_dct_post_s16_x8(x5 + 28, x6 + 28, x7 + 28, 4); + + butterfly_s16_s32_x8_0112_neon(cospi2, x7[31], x7[16], &output[1], + &output[31]); + butterfly_s16_s32_x8_1003_neon(cospi30, x7[30], x7[17], &output[17], + &output[15]); + butterfly_s16_s32_x8_0112_neon(cospi18, x7[29], x7[18], &output[9], + &output[23]); + butterfly_s16_s32_x8_1003_neon(cospi14, x7[28], x7[19], &output[25], + &output[7]); + butterfly_s16_s32_x8_0112_neon(cospi10, x7[27], x7[20], &output[5], + &output[27]); + butterfly_s16_s32_x8_1003_neon(cospi22, x7[26], x7[21], &output[21], + &output[11]); + butterfly_s16_s32_x8_0112_neon(cospi26, x7[25], x7[22], &output[13], + &output[19]); + butterfly_s16_s32_x8_1003_neon(cospi6, x7[24], x7[23], &output[29], + &output[3]); +} + +static AOM_FORCE_INLINE void fdct8x64_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + const int16x8_t cospi2_6 = vld1q_s16(&cospi[4 * 8]); + const int16x8_t cospi10_14 = vld1q_s16(&cospi[4 * 10]); + const int16x8_t cospi18_22 = vld1q_s16(&cospi[4 * 12]); + const int16x8_t cospi26_30 = vld1q_s16(&cospi[4 * 14]); + const int16x8_t cospi1_3 = vld1q_s16(&cospi[4 * 16]); + const int16x8_t cospi5_7 = vld1q_s16(&cospi[4 * 18]); + const int16x8_t cospi9_11 = vld1q_s16(&cospi[4 * 20]); + const int16x8_t cospi13_15 = vld1q_s16(&cospi[4 * 22]); + const int16x8_t cospi17_19 = vld1q_s16(&cospi[4 * 24]); + const int16x8_t cospi21_23 = vld1q_s16(&cospi[4 * 26]); + const int16x8_t cospi25_27 = vld1q_s16(&cospi[4 * 28]); + const int16x8_t cospi29_31 = vld1q_s16(&cospi[4 * 30]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + const int16x4_t cospi2 = vget_low_s16(cospi2_6); + const int16x4_t cospi6 = vget_high_s16(cospi2_6); + const int16x4_t cospi10 = vget_low_s16(cospi10_14); + const int16x4_t cospi14 = vget_high_s16(cospi10_14); + const int16x4_t cospi18 = vget_low_s16(cospi18_22); + const int16x4_t cospi22 = vget_high_s16(cospi18_22); + const int16x4_t cospi26 = vget_low_s16(cospi26_30); + const int16x4_t cospi30 = vget_high_s16(cospi26_30); + const int16x4_t cospi1 = vget_low_s16(cospi1_3); + const int16x4_t cospi3 = vget_high_s16(cospi1_3); + const int16x4_t cospi5 = vget_low_s16(cospi5_7); + const int16x4_t cospi7 = vget_high_s16(cospi5_7); + const int16x4_t cospi9 = vget_low_s16(cospi9_11); + const int16x4_t cospi11 = vget_high_s16(cospi9_11); + const int16x4_t cospi13 = vget_low_s16(cospi13_15); + const int16x4_t cospi15 = vget_high_s16(cospi13_15); + const int16x4_t cospi17 = vget_low_s16(cospi17_19); + const int16x4_t cospi19 = vget_high_s16(cospi17_19); + const int16x4_t cospi21 = vget_low_s16(cospi21_23); + const int16x4_t cospi23 = vget_high_s16(cospi21_23); + const int16x4_t cospi25 = vget_low_s16(cospi25_27); + const int16x4_t cospi27 = vget_high_s16(cospi25_27); + const int16x4_t cospi29 = vget_low_s16(cospi29_31); + const int16x4_t cospi31 = vget_high_s16(cospi29_31); + + // stage 1 + int16x8_t x1[64]; + butterfly_dct_pre_s16_x8(input, x1, 64); + + // stage 2 + int16x8_t x2[64]; + butterfly_dct_pre_s16_x8(x1, x2, 32); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[55], x1[40], &x2[55], &x2[40]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[54], x1[41], &x2[54], &x2[41]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[53], x1[42], &x2[53], &x2[42]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[52], x1[43], &x2[52], &x2[43]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[51], x1[44], &x2[51], &x2[44]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[50], x1[45], &x2[50], &x2[45]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[49], x1[46], &x2[49], &x2[46]); + butterfly_s16_s32_x8_0112_neon(cospi32, x1[48], x1[47], &x2[48], &x2[47]); + + // stage 3 + int16x8_t x3[64]; + butterfly_dct_pre_s16_x8(x2, x3, 16); + x3[16] = x2[16]; + x3[17] = x2[17]; + x3[18] = x2[18]; + x3[19] = x2[19]; + butterfly_s16_s32_x8_0112_neon(cospi32, x2[27], x2[20], &x3[27], &x3[20]); + butterfly_s16_s32_x8_0112_neon(cospi32, x2[26], x2[21], &x3[26], &x3[21]); + butterfly_s16_s32_x8_0112_neon(cospi32, x2[25], x2[22], &x3[25], &x3[22]); + butterfly_s16_s32_x8_0112_neon(cospi32, x2[24], x2[23], &x3[24], &x3[23]); + x3[28] = x2[28]; + x3[29] = x2[29]; + x3[30] = x2[30]; + x3[31] = x2[31]; + butterfly_dct_post_s16_x8(x1 + 32, x2 + 32, x3 + 32, 32); + + // stage 4 + int16x8_t x4[64]; + butterfly_dct_pre_s16_x8(x3, x4, 8); + butterfly_s16_s32_x8_0112_neon(cospi32, x3[13], x3[10], &x4[13], &x4[10]); + butterfly_s16_s32_x8_0112_neon(cospi32, x3[12], x3[11], &x4[12], &x4[11]); + butterfly_dct_post_s16_x8(x3 + 16, x3 + 16, x4 + 16, 16); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[59], x3[36], &x4[59], &x4[36]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[58], x3[37], &x4[58], &x4[37]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[57], x3[38], &x4[57], &x4[38]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[56], x3[39], &x4[56], &x4[39]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[55], x3[40], &x4[55], &x4[40]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[54], x3[41], &x4[54], &x4[41]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[53], x3[42], &x4[53], &x4[42]); + butterfly_s16_s32_x8_1223_neon(cospi16, x3[52], x3[43], &x4[52], &x4[43]); + + // stage 5 + int16x8_t x5[64]; + butterfly_dct_pre_s16_x8(x4, x5, 4); + butterfly_s16_s32_x8_0112_neon(cospi32, x4[6], x4[5], &x5[6], &x5[5]); + butterfly_dct_post_s16_x8(x3 + 8, x4 + 8, x5 + 8, 8); + butterfly_s16_s32_x8_0112_neon(cospi16, x4[29], x4[18], &x5[29], &x5[18]); + butterfly_s16_s32_x8_0112_neon(cospi16, x4[28], x4[19], &x5[28], &x5[19]); + butterfly_s16_s32_x8_1223_neon(cospi16, x4[27], x4[20], &x5[27], &x5[20]); + butterfly_s16_s32_x8_1223_neon(cospi16, x4[26], x4[21], &x5[26], &x5[21]); + butterfly_dct_post_s16_x8(x3 + 32, x4 + 32, x5 + 32, 16); + butterfly_dct_post_s16_x8(x3 + 48, x4 + 48, x5 + 48, 16); + + // stage 6 + int16x8_t x6[64]; + butterfly_s16_s32_x8_0112_neon(cospi32, x5[1], x5[0], &x6[0], &x6[1]); + butterfly_s16_s32_x8_0112_neon(cospi16, x5[3], x5[2], &x6[2], &x6[3]); + butterfly_dct_post_s16_x8(x4 + 4, x5 + 4, x6 + 4, 4); + butterfly_s16_s32_x8_0112_neon(cospi16, x5[14], x5[9], &x6[14], &x6[9]); + butterfly_s16_s32_x8_1223_neon(cospi16, x5[13], x5[10], &x6[13], &x6[10]); + butterfly_dct_post_s16_x8(x4 + 16, x5 + 16, x6 + 16, 8); + butterfly_dct_post_s16_x8(x4 + 24, x5 + 24, x6 + 24, 8); + butterfly_s16_s32_x8_0112_neon(cospi8, x5[61], x5[34], &x6[61], &x6[34]); + butterfly_s16_s32_x8_0112_neon(cospi8, x5[60], x5[35], &x6[60], &x6[35]); + butterfly_s16_s32_x8_1223_neon(cospi8, x5[59], x5[36], &x6[59], &x6[36]); + butterfly_s16_s32_x8_1223_neon(cospi8, x5[58], x5[37], &x6[58], &x6[37]); + butterfly_s16_s32_x8_1003_neon(cospi24, x5[53], x5[42], &x6[53], &x6[42]); + butterfly_s16_s32_x8_1003_neon(cospi24, x5[52], x5[43], &x6[52], &x6[43]); + butterfly_s16_s32_x8_0332_neon(cospi24, x5[51], x5[44], &x6[51], &x6[44]); + butterfly_s16_s32_x8_0332_neon(cospi24, x5[50], x5[45], &x6[50], &x6[45]); + + // stage 7 + int16x8_t x7[64]; + butterfly_s16_s32_x8_0112_neon(cospi8, x6[7], x6[4], &x7[4], &x7[7]); + butterfly_s16_s32_x8_1003_neon(cospi24, x6[6], x6[5], &x7[5], &x7[6]); + butterfly_dct_post_s16_x8(x5 + 8, x6 + 8, x7 + 8, 4); + butterfly_dct_post_s16_x8(x5 + 12, x6 + 12, x7 + 12, 4); + butterfly_s16_s32_x8_0112_neon(cospi8, x6[30], x6[17], &x7[30], &x7[17]); + butterfly_s16_s32_x8_1223_neon(cospi8, x6[29], x6[18], &x7[29], &x7[18]); + butterfly_s16_s32_x8_1003_neon(cospi24, x6[26], x6[21], &x7[26], &x7[21]); + butterfly_s16_s32_x8_0332_neon(cospi24, x6[25], x6[22], &x7[25], &x7[22]); + butterfly_dct_post_s16_x8(x5 + 32, x6 + 32, x7 + 32, 8); + butterfly_dct_post_s16_x8(x5 + 40, x6 + 40, x7 + 40, 8); + butterfly_dct_post_s16_x8(x5 + 48, x6 + 48, x7 + 48, 8); + butterfly_dct_post_s16_x8(x5 + 56, x6 + 56, x7 + 56, 8); + + // stage 8 + int16x8_t x8[64]; + butterfly_s16_s32_x8_0112_neon(cospi4, x7[15], x7[8], &x8[8], &x8[15]); + butterfly_s16_s32_x8_1003_neon(cospi28, x7[14], x7[9], &x8[9], &x8[14]); + butterfly_s16_s32_x8_0112_neon(cospi20, x7[13], x7[10], &x8[10], &x8[13]); + butterfly_s16_s32_x8_1003_neon(cospi12, x7[12], x7[11], &x8[11], &x8[12]); + butterfly_dct_post_s16_x8(x6 + 16, x7 + 16, x8 + 16, 4); + butterfly_dct_post_s16_x8(x6 + 20, x7 + 20, x8 + 20, 4); + butterfly_dct_post_s16_x8(x6 + 24, x7 + 24, x8 + 24, 4); + butterfly_dct_post_s16_x8(x6 + 28, x7 + 28, x8 + 28, 4); + butterfly_s16_s32_x8_0112_neon(cospi4, x7[62], x7[33], &x8[62], &x8[33]); + butterfly_s16_s32_x8_1223_neon(cospi4, x7[61], x7[34], &x8[61], &x8[34]); + butterfly_s16_s32_x8_1003_neon(cospi28, x7[58], x7[37], &x8[58], &x8[37]); + butterfly_s16_s32_x8_0332_neon(cospi28, x7[57], x7[38], &x8[57], &x8[38]); + butterfly_s16_s32_x8_0112_neon(cospi20, x7[54], x7[41], &x8[54], &x8[41]); + butterfly_s16_s32_x8_1223_neon(cospi20, x7[53], x7[42], &x8[53], &x8[42]); + butterfly_s16_s32_x8_1003_neon(cospi12, x7[50], x7[45], &x8[50], &x8[45]); + butterfly_s16_s32_x8_0332_neon(cospi12, x7[49], x7[46], &x8[49], &x8[46]); + + // stage 9 + int16x8_t x9[64]; + butterfly_s16_s32_x8_0112_neon(cospi2, x8[31], x8[16], &x9[16], &x9[31]); + butterfly_s16_s32_x8_1003_neon(cospi30, x8[30], x8[17], &x9[17], &x9[30]); + butterfly_s16_s32_x8_0112_neon(cospi18, x8[29], x8[18], &x9[18], &x9[29]); + butterfly_s16_s32_x8_1003_neon(cospi14, x8[28], x8[19], &x9[19], &x9[28]); + butterfly_s16_s32_x8_0112_neon(cospi10, x8[27], x8[20], &x9[20], &x9[27]); + butterfly_s16_s32_x8_1003_neon(cospi22, x8[26], x8[21], &x9[21], &x9[26]); + butterfly_s16_s32_x8_0112_neon(cospi26, x8[25], x8[22], &x9[22], &x9[25]); + butterfly_s16_s32_x8_1003_neon(cospi6, x8[24], x8[23], &x9[23], &x9[24]); + butterfly_dct_post_s16_x8(x7 + 32, x8 + 32, x9 + 32, 4); + butterfly_dct_post_s16_x8(x7 + 36, x8 + 36, x9 + 36, 4); + butterfly_dct_post_s16_x8(x7 + 40, x8 + 40, x9 + 40, 4); + butterfly_dct_post_s16_x8(x7 + 44, x8 + 44, x9 + 44, 4); + butterfly_dct_post_s16_x8(x7 + 48, x8 + 48, x9 + 48, 4); + butterfly_dct_post_s16_x8(x7 + 52, x8 + 52, x9 + 52, 4); + butterfly_dct_post_s16_x8(x7 + 56, x8 + 56, x9 + 56, 4); + butterfly_dct_post_s16_x8(x7 + 60, x8 + 60, x9 + 60, 4); + + // stage 10 + butterfly_s16_s32_x8_0112_neon(cospi1, x9[63], x9[32], &output[1], + &output[63]); + butterfly_s16_s32_x8_1003_neon(cospi31, x9[62], x9[33], &output[33], + &output[31]); + butterfly_s16_s32_x8_0112_neon(cospi17, x9[61], x9[34], &output[17], + &output[47]); + butterfly_s16_s32_x8_1003_neon(cospi15, x9[60], x9[35], &output[49], + &output[15]); + butterfly_s16_s32_x8_0112_neon(cospi9, x9[59], x9[36], &output[9], + &output[55]); + butterfly_s16_s32_x8_1003_neon(cospi23, x9[58], x9[37], &output[41], + &output[23]); + butterfly_s16_s32_x8_0112_neon(cospi25, x9[57], x9[38], &output[25], + &output[39]); + butterfly_s16_s32_x8_1003_neon(cospi7, x9[56], x9[39], &output[57], + &output[7]); + butterfly_s16_s32_x8_0112_neon(cospi5, x9[55], x9[40], &output[5], + &output[59]); + butterfly_s16_s32_x8_1003_neon(cospi27, x9[54], x9[41], &output[37], + &output[27]); + butterfly_s16_s32_x8_0112_neon(cospi21, x9[53], x9[42], &output[21], + &output[43]); + butterfly_s16_s32_x8_1003_neon(cospi11, x9[52], x9[43], &output[53], + &output[11]); + butterfly_s16_s32_x8_0112_neon(cospi13, x9[51], x9[44], &output[13], + &output[51]); + butterfly_s16_s32_x8_1003_neon(cospi19, x9[50], x9[45], &output[45], + &output[19]); + butterfly_s16_s32_x8_0112_neon(cospi29, x9[49], x9[46], &output[29], + &output[35]); + butterfly_s16_s32_x8_1003_neon(cospi3, x9[48], x9[47], &output[61], + &output[3]); + + // stage 11 + output[0] = x6[0]; + output[2] = x9[16]; + output[4] = x8[8]; + output[6] = x9[24]; + output[8] = x7[4]; + output[10] = x9[20]; + output[12] = x8[12]; + output[14] = x9[28]; + output[16] = x6[2]; + output[18] = x9[18]; + output[20] = x8[10]; + output[22] = x9[26]; + output[24] = x7[6]; + output[26] = x9[22]; + output[28] = x8[14]; + output[30] = x9[30]; + output[32] = x6[1]; + output[34] = x9[17]; + output[36] = x8[9]; + output[38] = x9[25]; + output[40] = x7[5]; + output[42] = x9[21]; + output[44] = x8[13]; + output[46] = x9[29]; + output[48] = x6[3]; + output[52] = x8[11]; + output[54] = x9[27]; + output[56] = x7[7]; + output[58] = x9[23]; + output[60] = x8[15]; + output[62] = x9[31]; +} + +static AOM_FORCE_INLINE void fadst8x8_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + + // stage 2 + int16x8_t x2[8]; + butterfly_s16_s32_x8_0332_neon(cospi32, input[4], input[3], &x2[2], &x2[3]); + butterfly_s16_s32_x8_0112_neon(cospi32, input[2], input[5], &x2[7], &x2[6]); + + // stage 3 + int16x8_t x3[8]; + x3[0] = vqaddq_s16(input[0], x2[2]); + x3[1] = vqsubq_s16(x2[3], input[7]); + x3[2] = vqsubq_s16(input[0], x2[2]); + x3[3] = vqaddq_s16(input[7], x2[3]); + x3[4] = vqsubq_s16(x2[6], input[1]); + x3[5] = vqaddq_s16(input[6], x2[7]); + x3[6] = vqaddq_s16(input[1], x2[6]); + x3[7] = vqsubq_s16(input[6], x2[7]); + + // stage 4 + butterfly_s16_s32_x8_0112_neon(cospi16, x3[4], x3[5], &x3[4], &x3[5]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[7], x3[6], &x3[6], &x3[7]); + + // stage 5 + int16x8_t x5[8]; + x5[0] = vqaddq_s16(x3[0], x3[4]); + x5[1] = vqaddq_s16(x3[1], x3[5]); + x5[2] = vqaddq_s16(x3[2], x3[6]); + x5[3] = vqsubq_s16(x3[7], x3[3]); + x5[4] = vqsubq_s16(x3[0], x3[4]); + x5[5] = vqsubq_s16(x3[1], x3[5]); + x5[6] = vqsubq_s16(x3[2], x3[6]); + x5[7] = vqaddq_s16(x3[3], x3[7]); + + // stage 6 + butterfly_s16_s32_x8_0112_neon(cospi4, x5[0], x5[1], &output[7], &output[0]); + butterfly_s16_s32_x8_0112_neon(cospi20, x5[2], x5[3], &output[5], &output[2]); + butterfly_s16_s32_x8_1003_neon(cospi28, x5[4], x5[5], &output[3], &output[4]); + butterfly_s16_s32_x8_0112_neon(cospi12, x5[6], x5[7], &output[6], &output[1]); +} + +static AOM_FORCE_INLINE void fadst4x16_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi2_6 = vld1q_s16(&cospi[4 * 8]); + const int16x8_t cospi10_14 = vld1q_s16(&cospi[4 * 10]); + const int16x8_t cospi18_22 = vld1q_s16(&cospi[4 * 12]); + const int16x8_t cospi26_30 = vld1q_s16(&cospi[4 * 14]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi2 = vget_low_s16(cospi2_6); + const int16x4_t cospi6 = vget_high_s16(cospi2_6); + const int16x4_t cospi10 = vget_low_s16(cospi10_14); + const int16x4_t cospi14 = vget_high_s16(cospi10_14); + const int16x4_t cospi18 = vget_low_s16(cospi18_22); + const int16x4_t cospi22 = vget_high_s16(cospi18_22); + const int16x4_t cospi26 = vget_low_s16(cospi26_30); + const int16x4_t cospi30 = vget_high_s16(cospi26_30); + + // stage 2 + int16x4_t x2[8]; + butterfly_s16_s32_x4_0332_neon(cospi32, input[8], input[7], &x2[0], &x2[1]); + butterfly_s16_s32_x4_0112_neon(cospi32, input[4], input[11], &x2[3], &x2[2]); + butterfly_s16_s32_x4_0112_neon(cospi32, input[6], input[9], &x2[5], &x2[4]); + butterfly_s16_s32_x4_0332_neon(cospi32, input[10], input[5], &x2[6], &x2[7]); + + // stage 3 + int16x4_t x3[16]; + x3[0] = vqadd_s16(input[0], x2[0]); + x3[1] = vqsub_s16(x2[1], input[15]); + x3[2] = vqsub_s16(input[0], x2[0]); + x3[3] = vqadd_s16(input[15], x2[1]); + x3[4] = vqsub_s16(x2[2], input[3]); + x3[5] = vqadd_s16(input[12], x2[3]); + x3[6] = vqadd_s16(input[3], x2[2]); + x3[7] = vqsub_s16(input[12], x2[3]); + x3[8] = vqsub_s16(x2[4], input[1]); + x3[9] = vqadd_s16(input[14], x2[5]); + x3[10] = vqadd_s16(input[1], x2[4]); + x3[11] = vqsub_s16(input[14], x2[5]); + x3[12] = vqadd_s16(input[2], x2[6]); + x3[13] = vqsub_s16(x2[7], input[13]); + x3[14] = vqsub_s16(input[2], x2[6]); + x3[15] = vqadd_s16(input[13], x2[7]); + + // stage 4 + butterfly_s16_s32_x4_0112_neon(cospi16, x3[4], x3[5], &x3[4], &x3[5]); + butterfly_s16_s32_x4_0112_neon(cospi16, x3[7], x3[6], &x3[6], &x3[7]); + butterfly_s16_s32_x4_0112_neon(cospi16, x3[12], x3[13], &x3[12], &x3[13]); + butterfly_s16_s32_x4_0332_neon(cospi16, x3[14], x3[15], &x3[15], &x3[14]); + + // stage 5 + int16x4_t x5[16]; + x5[0] = vqadd_s16(x3[0], x3[4]); + x5[1] = vqadd_s16(x3[1], x3[5]); + x5[2] = vqadd_s16(x3[2], x3[6]); + x5[3] = vqsub_s16(x3[7], x3[3]); + x5[4] = vqsub_s16(x3[0], x3[4]); + x5[5] = vqsub_s16(x3[1], x3[5]); + x5[6] = vqsub_s16(x3[2], x3[6]); + x5[7] = vqadd_s16(x3[3], x3[7]); + x5[8] = vqadd_s16(x3[8], x3[12]); + x5[9] = vqadd_s16(x3[9], x3[13]); + x5[10] = vqsub_s16(x3[14], x3[10]); + x5[11] = vqadd_s16(x3[11], x3[15]); + x5[12] = vqsub_s16(x3[8], x3[12]); + x5[13] = vqsub_s16(x3[9], x3[13]); + x5[14] = vqadd_s16(x3[10], x3[14]); + x5[15] = vqsub_s16(x3[11], x3[15]); + + // stage 6 + butterfly_s16_s32_x4_0112_neon(cospi8, x5[8], x5[9], &x5[8], &x5[9]); + butterfly_s16_s32_x4_1003_neon(cospi24, x5[10], x5[11], &x5[10], &x5[11]); + butterfly_s16_s32_x4_1003_neon(cospi8, x5[13], x5[12], &x5[13], &x5[12]); + butterfly_s16_s32_x4_1003_neon(cospi24, x5[15], x5[14], &x5[14], &x5[15]); + + // stage 7 + int16x4_t x7[16]; + x7[0] = vqadd_s16(x5[0], x5[8]); + x7[1] = vqadd_s16(x5[1], x5[9]); + x7[2] = vqadd_s16(x5[2], x5[10]); + x7[3] = vqadd_s16(x5[3], x5[11]); + x7[4] = vqadd_s16(x5[4], x5[12]); + x7[5] = vqadd_s16(x5[5], x5[13]); + x7[6] = vqadd_s16(x5[6], x5[14]); + x7[7] = vqsub_s16(x5[15], x5[7]); + x7[8] = vqsub_s16(x5[0], x5[8]); + x7[9] = vqsub_s16(x5[1], x5[9]); + x7[10] = vqsub_s16(x5[2], x5[10]); + x7[11] = vqsub_s16(x5[3], x5[11]); + x7[12] = vqsub_s16(x5[4], x5[12]); + x7[13] = vqsub_s16(x5[5], x5[13]); + x7[14] = vqsub_s16(x5[6], x5[14]); + x7[15] = vqadd_s16(x5[7], x5[15]); + + // stage 8 + butterfly_s16_s32_x4_0112_neon(cospi2, x7[0], x7[1], &output[15], &output[0]); + butterfly_s16_s32_x4_0112_neon(cospi10, x7[2], x7[3], &output[13], + &output[2]); + butterfly_s16_s32_x4_0112_neon(cospi18, x7[4], x7[5], &output[11], + &output[4]); + butterfly_s16_s32_x4_0112_neon(cospi26, x7[6], x7[7], &output[9], &output[6]); + butterfly_s16_s32_x4_1003_neon(cospi30, x7[8], x7[9], &output[7], &output[8]); + butterfly_s16_s32_x4_1003_neon(cospi22, x7[10], x7[11], &output[5], + &output[10]); + butterfly_s16_s32_x4_1003_neon(cospi14, x7[12], x7[13], &output[3], + &output[12]); + butterfly_s16_s32_x4_0112_neon(cospi6, x7[14], x7[15], &output[14], + &output[1]); +} + +static AOM_FORCE_INLINE void fadst8x16_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi2_6 = vld1q_s16(&cospi[4 * 8]); + const int16x8_t cospi10_14 = vld1q_s16(&cospi[4 * 10]); + const int16x8_t cospi18_22 = vld1q_s16(&cospi[4 * 12]); + const int16x8_t cospi26_30 = vld1q_s16(&cospi[4 * 14]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi2 = vget_low_s16(cospi2_6); + const int16x4_t cospi6 = vget_high_s16(cospi2_6); + const int16x4_t cospi10 = vget_low_s16(cospi10_14); + const int16x4_t cospi14 = vget_high_s16(cospi10_14); + const int16x4_t cospi18 = vget_low_s16(cospi18_22); + const int16x4_t cospi22 = vget_high_s16(cospi18_22); + const int16x4_t cospi26 = vget_low_s16(cospi26_30); + const int16x4_t cospi30 = vget_high_s16(cospi26_30); + + // stage 2 + int16x8_t x2[8]; + butterfly_s16_s32_x8_0332_neon(cospi32, input[8], input[7], &x2[0], &x2[1]); + butterfly_s16_s32_x8_0112_neon(cospi32, input[4], input[11], &x2[3], &x2[2]); + butterfly_s16_s32_x8_0112_neon(cospi32, input[6], input[9], &x2[5], &x2[4]); + butterfly_s16_s32_x8_0332_neon(cospi32, input[10], input[5], &x2[6], &x2[7]); + + // stage 3 + int16x8_t x3[16]; + x3[0] = vqaddq_s16(input[0], x2[0]); + x3[1] = vqsubq_s16(x2[1], input[15]); + x3[2] = vqsubq_s16(input[0], x2[0]); + x3[3] = vqaddq_s16(input[15], x2[1]); + x3[4] = vqsubq_s16(x2[2], input[3]); + x3[5] = vqaddq_s16(input[12], x2[3]); + x3[6] = vqaddq_s16(input[3], x2[2]); + x3[7] = vqsubq_s16(input[12], x2[3]); + x3[8] = vqsubq_s16(x2[4], input[1]); + x3[9] = vqaddq_s16(input[14], x2[5]); + x3[10] = vqaddq_s16(input[1], x2[4]); + x3[11] = vqsubq_s16(input[14], x2[5]); + x3[12] = vqaddq_s16(input[2], x2[6]); + x3[13] = vqsubq_s16(x2[7], input[13]); + x3[14] = vqsubq_s16(input[2], x2[6]); + x3[15] = vqaddq_s16(input[13], x2[7]); + + // stage 4 + butterfly_s16_s32_x8_0112_neon(cospi16, x3[4], x3[5], &x3[4], &x3[5]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[7], x3[6], &x3[6], &x3[7]); + butterfly_s16_s32_x8_0112_neon(cospi16, x3[12], x3[13], &x3[12], &x3[13]); + butterfly_s16_s32_x8_0332_neon(cospi16, x3[14], x3[15], &x3[15], &x3[14]); + + // stage 5 + int16x8_t x5[16]; + x5[0] = vqaddq_s16(x3[0], x3[4]); + x5[1] = vqaddq_s16(x3[1], x3[5]); + x5[2] = vqaddq_s16(x3[2], x3[6]); + x5[3] = vqsubq_s16(x3[7], x3[3]); + x5[4] = vqsubq_s16(x3[0], x3[4]); + x5[5] = vqsubq_s16(x3[1], x3[5]); + x5[6] = vqsubq_s16(x3[2], x3[6]); + x5[7] = vqaddq_s16(x3[3], x3[7]); + x5[8] = vqaddq_s16(x3[8], x3[12]); + x5[9] = vqaddq_s16(x3[9], x3[13]); + x5[10] = vqsubq_s16(x3[14], x3[10]); + x5[11] = vqaddq_s16(x3[11], x3[15]); + x5[12] = vqsubq_s16(x3[8], x3[12]); + x5[13] = vqsubq_s16(x3[9], x3[13]); + x5[14] = vqaddq_s16(x3[10], x3[14]); + x5[15] = vqsubq_s16(x3[11], x3[15]); + + // stage 6 + butterfly_s16_s32_x8_0112_neon(cospi8, x5[8], x5[9], &x5[8], &x5[9]); + butterfly_s16_s32_x8_1003_neon(cospi24, x5[10], x5[11], &x5[10], &x5[11]); + butterfly_s16_s32_x8_1003_neon(cospi8, x5[13], x5[12], &x5[13], &x5[12]); + butterfly_s16_s32_x8_1003_neon(cospi24, x5[15], x5[14], &x5[14], &x5[15]); + + // stage 7 + int16x8_t x7[16]; + x7[0] = vqaddq_s16(x5[0], x5[8]); + x7[1] = vqaddq_s16(x5[1], x5[9]); + x7[2] = vqaddq_s16(x5[2], x5[10]); + x7[3] = vqaddq_s16(x5[3], x5[11]); + x7[4] = vqaddq_s16(x5[4], x5[12]); + x7[5] = vqaddq_s16(x5[5], x5[13]); + x7[6] = vqaddq_s16(x5[6], x5[14]); + x7[7] = vqsubq_s16(x5[15], x5[7]); + x7[8] = vqsubq_s16(x5[0], x5[8]); + x7[9] = vqsubq_s16(x5[1], x5[9]); + x7[10] = vqsubq_s16(x5[2], x5[10]); + x7[11] = vqsubq_s16(x5[3], x5[11]); + x7[12] = vqsubq_s16(x5[4], x5[12]); + x7[13] = vqsubq_s16(x5[5], x5[13]); + x7[14] = vqsubq_s16(x5[6], x5[14]); + x7[15] = vqaddq_s16(x5[7], x5[15]); + + // stage 8 + butterfly_s16_s32_x8_0112_neon(cospi2, x7[0], x7[1], &output[15], &output[0]); + butterfly_s16_s32_x8_0112_neon(cospi10, x7[2], x7[3], &output[13], + &output[2]); + butterfly_s16_s32_x8_0112_neon(cospi18, x7[4], x7[5], &output[11], + &output[4]); + butterfly_s16_s32_x8_0112_neon(cospi26, x7[6], x7[7], &output[9], &output[6]); + butterfly_s16_s32_x8_1003_neon(cospi30, x7[8], x7[9], &output[7], &output[8]); + butterfly_s16_s32_x8_1003_neon(cospi22, x7[10], x7[11], &output[5], + &output[10]); + butterfly_s16_s32_x8_1003_neon(cospi14, x7[12], x7[13], &output[3], + &output[12]); + butterfly_s16_s32_x8_0112_neon(cospi6, x7[14], x7[15], &output[14], + &output[1]); +} + +static AOM_FORCE_INLINE void fidentity4x4_neon(const int16x4_t *const input, + int16x4_t *const output, + const int cos_bit) { + (void)cos_bit; + round_shift_sqrt2_s16_s16_4xn_neon(input, output, 4); +} + +static AOM_FORCE_INLINE void fidentity8x4_neon(const int16x8_t *const input, + int16x8_t *const output, + const int cos_bit) { + (void)cos_bit; + round_shift_sqrt2_s16_s16_8xn_neon(input, output, 4); +} + +static AOM_FORCE_INLINE void fidentity4x8_neon(const int16x4_t *input, + int16x4_t *output, int cos_bit) { + (void)cos_bit; + shift_left_1_s16_x4(input, output, 8); +} + +static AOM_FORCE_INLINE void fidentity8x8_neon(const int16x8_t *input, + int16x8_t *output, int cos_bit) { + (void)cos_bit; + shift_left_1_s16_x8(input, output, 8); +} + +static AOM_FORCE_INLINE void fidentity4x16_neon(const int16x4_t *input, + int16x4_t *output, + int cos_bit) { + (void)cos_bit; + round_shift_2sqrt2_s16_s16_4xn_neon(input, output, 16); +} + +static AOM_FORCE_INLINE void fidentity8x16_neon(const int16x8_t *input, + int16x8_t *output, + int cos_bit) { + (void)cos_bit; + round_shift_2sqrt2_s16_s16_8xn_neon(input, output, 16); +} + +static AOM_FORCE_INLINE void fidentity8x32_neon(const int16x8_t *input, + int16x8_t *output, + int cos_bit) { + (void)cos_bit; + shift_left_2_s16_x8(input, output, 32); +} + +#define TRANSFORM_COL(name, tw, n) \ + static void name##_col_neon(const int16_t *input, int16x##tw##_t *output, \ + int stride, int cos_bit) { \ + int16x##tw##_t buf0[n]; \ + load_buffer_s16_x##tw(input, stride, buf0, n); \ + shift_left_2_s16_x##tw(buf0, buf0, n); \ + name##_neon(buf0, output, cos_bit); \ + } + +TRANSFORM_COL(fadst4x4, 4, 4) +TRANSFORM_COL(fadst4x8, 4, 8) +TRANSFORM_COL(fadst4x16, 4, 16) +TRANSFORM_COL(fadst8x4, 8, 4) +TRANSFORM_COL(fadst8x8, 8, 8) +TRANSFORM_COL(fadst8x16, 8, 16) +TRANSFORM_COL(fdct4x4, 4, 4) +TRANSFORM_COL(fdct4x8, 4, 8) +TRANSFORM_COL(fdct4x16, 4, 16) +TRANSFORM_COL(fdct8x4, 8, 4) +TRANSFORM_COL(fdct8x8, 8, 8) +TRANSFORM_COL(fdct8x16, 8, 16) +TRANSFORM_COL(fdct8x32, 8, 32) +TRANSFORM_COL(fidentity4x4, 4, 4) +TRANSFORM_COL(fidentity4x8, 4, 8) +TRANSFORM_COL(fidentity4x16, 4, 16) +TRANSFORM_COL(fidentity8x4, 8, 4) +TRANSFORM_COL(fidentity8x8, 8, 8) +TRANSFORM_COL(fidentity8x16, 8, 16) +TRANSFORM_COL(fidentity8x32, 8, 32) + +#define TRANSFORM_ROW(name, tw, n) \ + static void name##_row_neon(const int16x##tw##_t *input, int32_t *output, \ + int stride, int cos_bit) { \ + int16x##tw##_t buf0[n]; \ + name##_neon(input, buf0, cos_bit); \ + store_buffer_s16_x##tw(buf0, output, stride, n); \ + } + +#define TRANSFORM_ROW_RECT(name, tw, n) \ + static void name##_row_rect_neon(const int16x##tw##_t *input, \ + int32_t *output, int stride, int cos_bit) { \ + int16x##tw##_t buf0[n]; \ + name##_neon(input, buf0, cos_bit); \ + store_rect_buffer_s16_x##tw(buf0, output, stride, n); \ + } + +TRANSFORM_ROW(fadst4x4, 4, 4) +TRANSFORM_ROW(fadst4x16, 4, 16) +TRANSFORM_ROW(fadst8x4, 8, 4) +TRANSFORM_ROW(fadst8x8, 8, 8) +TRANSFORM_ROW(fadst8x16, 8, 16) +TRANSFORM_ROW(fdct4x4, 4, 4) +TRANSFORM_ROW(fdct4x16, 4, 16) +TRANSFORM_ROW(fdct8x4, 8, 4) +TRANSFORM_ROW(fdct8x8, 8, 8) +TRANSFORM_ROW(fdct8x16, 8, 16) +TRANSFORM_ROW(fdct8x32, 8, 32) +TRANSFORM_ROW(fidentity4x4, 4, 4) +TRANSFORM_ROW(fidentity4x16, 4, 16) +TRANSFORM_ROW(fidentity8x4, 8, 4) +TRANSFORM_ROW(fidentity8x8, 8, 8) +TRANSFORM_ROW(fidentity8x16, 8, 16) +TRANSFORM_ROW(fidentity8x32, 8, 32) + +TRANSFORM_ROW_RECT(fadst4x8, 4, 8) +TRANSFORM_ROW_RECT(fadst8x4, 8, 4) +TRANSFORM_ROW_RECT(fadst8x8, 8, 8) +TRANSFORM_ROW_RECT(fadst8x16, 8, 16) +TRANSFORM_ROW_RECT(fdct4x8, 4, 8) +TRANSFORM_ROW_RECT(fdct8x4, 8, 4) +TRANSFORM_ROW_RECT(fdct8x8, 8, 8) +TRANSFORM_ROW_RECT(fdct8x16, 8, 16) +TRANSFORM_ROW_RECT(fdct8x32, 8, 32) +TRANSFORM_ROW_RECT(fidentity4x8, 4, 8) +TRANSFORM_ROW_RECT(fidentity8x4, 8, 4) +TRANSFORM_ROW_RECT(fidentity8x8, 8, 8) +TRANSFORM_ROW_RECT(fidentity8x16, 8, 16) +TRANSFORM_ROW_RECT(fidentity8x32, 8, 32) + +typedef void (*transform_1d_lbd_4_neon)(const int16x4_t *input, + int16x4_t *output, int cos_bit); +typedef void (*transform_1d_lbd_8_neon)(const int16x8_t *input, + int16x8_t *output, int cos_bit); + +typedef void (*col_transform_1d_lbd_4_neon)(const int16_t *input, + int16x4_t *output, int stride, + int cos_bit); +typedef void (*col_transform_1d_lbd_8_neon)(const int16_t *input, + int16x8_t *output, int stride, + int cos_bit); + +typedef void (*row_transform_1d_lbd_4_neon)(const int16x4_t *input, + int32_t *output, int stride, + int cos_bit); +typedef void (*row_transform_1d_lbd_8_neon)(const int16x8_t *input, + int32_t *output, int stride, + int cos_bit); + +static const col_transform_1d_lbd_4_neon col_txfm4x8_arr[TX_TYPES] = { + fdct4x8_col_neon, // DCT_DCT + fadst4x8_col_neon, // ADST_DCT + fdct4x8_col_neon, // DCT_ADST + fadst4x8_col_neon, // ADST_ADST + fadst4x8_col_neon, // FLIPADST_DCT + fdct4x8_col_neon, // DCT_FLIPADST + fadst4x8_col_neon, // FLIPADST_FLIPADST + fadst4x8_col_neon, // ADST_FLIPADST + fadst4x8_col_neon, // FLIPADST_ADST + fidentity4x8_col_neon, // IDTX + fdct4x8_col_neon, // V_DCT + fidentity4x8_col_neon, // H_DCT + fadst4x8_col_neon, // V_ADST + fidentity4x8_col_neon, // H_ADST + fadst4x8_col_neon, // V_FLIPADST + fidentity4x8_col_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_txfm8x4_arr[TX_TYPES] = { + fdct8x4_row_neon, // DCT_DCT + fdct8x4_row_neon, // ADST_DCT + fadst8x4_row_neon, // DCT_ADST + fadst8x4_row_neon, // ADST_ADST + fdct8x4_row_neon, // FLIPADST_DCT + fadst8x4_row_neon, // DCT_FLIPADST + fadst8x4_row_neon, // FLIPADST_FLIPADST + fadst8x4_row_neon, // ADST_FLIPADST + fadst8x4_row_neon, // FLIPADST_ADST + fidentity8x4_row_neon, // IDTX + fidentity8x4_row_neon, // V_DCT + fdct8x4_row_neon, // H_DCT + fidentity8x4_row_neon, // V_ADST + fadst8x4_row_neon, // H_ADST + fidentity8x4_row_neon, // V_FLIPADST + fadst8x4_row_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_rect_txfm8x4_arr[TX_TYPES] = { + fdct8x4_row_rect_neon, // DCT_DCT + fdct8x4_row_rect_neon, // ADST_DCT + fadst8x4_row_rect_neon, // DCT_ADST + fadst8x4_row_rect_neon, // ADST_ADST + fdct8x4_row_rect_neon, // FLIPADST_DCT + fadst8x4_row_rect_neon, // DCT_FLIPADST + fadst8x4_row_rect_neon, // FLIPADST_FLIPADST + fadst8x4_row_rect_neon, // ADST_FLIPADST + fadst8x4_row_rect_neon, // FLIPADST_ADST + fidentity8x4_row_rect_neon, // IDTX + fidentity8x4_row_rect_neon, // V_DCT + fdct8x4_row_rect_neon, // H_DCT + fidentity8x4_row_rect_neon, // V_ADST + fadst8x4_row_rect_neon, // H_ADST + fidentity8x4_row_rect_neon, // V_FLIPADST + fadst8x4_row_rect_neon // H_FLIPADST +}; + +static const col_transform_1d_lbd_8_neon col_txfm8x4_arr[TX_TYPES] = { + fdct8x4_col_neon, // DCT_DCT + fadst8x4_col_neon, // ADST_DCT + fdct8x4_col_neon, // DCT_ADST + fadst8x4_col_neon, // ADST_ADST + fadst8x4_col_neon, // FLIPADST_DCT + fdct8x4_col_neon, // DCT_FLIPADST + fadst8x4_col_neon, // FLIPADST_FLIPADST + fadst8x4_col_neon, // ADST_FLIPADST + fadst8x4_col_neon, // FLIPADST_ADST + fidentity8x4_col_neon, // IDTX + fdct8x4_col_neon, // V_DCT + fidentity8x4_col_neon, // H_DCT + fadst8x4_col_neon, // V_ADST + fidentity8x4_col_neon, // H_ADST + fadst8x4_col_neon, // V_FLIPADST + fidentity8x4_col_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_4_neon row_rect_txfm4x8_arr[TX_TYPES] = { + fdct4x8_row_rect_neon, // DCT_DCT + fdct4x8_row_rect_neon, // ADST_DCT + fadst4x8_row_rect_neon, // DCT_ADST + fadst4x8_row_rect_neon, // ADST_ADST + fdct4x8_row_rect_neon, // FLIPADST_DCT + fadst4x8_row_rect_neon, // DCT_FLIPADST + fadst4x8_row_rect_neon, // FLIPADST_FLIPADST + fadst4x8_row_rect_neon, // ADST_FLIPADST + fadst4x8_row_rect_neon, // FLIPADST_ADST + fidentity4x8_row_rect_neon, // IDTX + fidentity4x8_row_rect_neon, // V_DCT + fdct4x8_row_rect_neon, // H_DCT + fidentity4x8_row_rect_neon, // V_ADST + fadst4x8_row_rect_neon, // H_ADST + fidentity4x8_row_rect_neon, // V_FLIPADST + fadst4x8_row_rect_neon // H_FLIPADST +}; + +static const col_transform_1d_lbd_8_neon col_txfm8x8_arr[TX_TYPES] = { + fdct8x8_col_neon, // DCT_DCT + fadst8x8_col_neon, // ADST_DCT + fdct8x8_col_neon, // DCT_ADST + fadst8x8_col_neon, // ADST_ADST + fadst8x8_col_neon, // FLIPADST_DCT + fdct8x8_col_neon, // DCT_FLIPADST + fadst8x8_col_neon, // FLIPADST_FLIPADST + fadst8x8_col_neon, // ADST_FLIPADST + fadst8x8_col_neon, // FLIPADST_ADST + fidentity8x8_col_neon, // IDTX + fdct8x8_col_neon, // V_DCT + fidentity8x8_col_neon, // H_DCT + fadst8x8_col_neon, // V_ADST + fidentity8x8_col_neon, // H_ADST + fadst8x8_col_neon, // V_FLIPADST + fidentity8x8_col_neon, // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_txfm8x8_arr[TX_TYPES] = { + fdct8x8_row_neon, // DCT_DCT + fdct8x8_row_neon, // ADST_DCT + fadst8x8_row_neon, // DCT_ADST + fadst8x8_row_neon, // ADST_ADST + fdct8x8_row_neon, // FLIPADST_DCT + fadst8x8_row_neon, // DCT_FLIPADST + fadst8x8_row_neon, // FLIPADST_FLIPADST + fadst8x8_row_neon, // ADST_FLIPADST + fadst8x8_row_neon, // FLIPADST_ADST + fidentity8x8_row_neon, // IDTX + fidentity8x8_row_neon, // V_DCT + fdct8x8_row_neon, // H_DCT + fidentity8x8_row_neon, // V_ADST + fadst8x8_row_neon, // H_ADST + fidentity8x8_row_neon, // V_FLIPADST + fadst8x8_row_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_rect_txfm8x8_arr[TX_TYPES] = { + fdct8x8_row_rect_neon, // DCT_DCT + fdct8x8_row_rect_neon, // ADST_DCT + fadst8x8_row_rect_neon, // DCT_ADST + fadst8x8_row_rect_neon, // ADST_ADST + fdct8x8_row_rect_neon, // FLIPADST_DCT + fadst8x8_row_rect_neon, // DCT_FLIPADST + fadst8x8_row_rect_neon, // FLIPADST_FLIPADST + fadst8x8_row_rect_neon, // ADST_FLIPADST + fadst8x8_row_rect_neon, // FLIPADST_ADST + fidentity8x8_row_rect_neon, // IDTX + fidentity8x8_row_rect_neon, // V_DCT + fdct8x8_row_rect_neon, // H_DCT + fidentity8x8_row_rect_neon, // V_ADST + fadst8x8_row_rect_neon, // H_ADST + fidentity8x8_row_rect_neon, // V_FLIPADST + fadst8x8_row_rect_neon // H_FLIPADST +}; + +static const col_transform_1d_lbd_4_neon col_txfm4x16_arr[TX_TYPES] = { + fdct4x16_col_neon, // DCT_DCT + fadst4x16_col_neon, // ADST_DCT + fdct4x16_col_neon, // DCT_ADST + fadst4x16_col_neon, // ADST_ADST + fadst4x16_col_neon, // FLIPADST_DCT + fdct4x16_col_neon, // DCT_FLIPADST + fadst4x16_col_neon, // FLIPADST_FLIPADST + fadst4x16_col_neon, // ADST_FLIPADST + fadst4x16_col_neon, // FLIPADST_ADST + fidentity4x16_col_neon, // IDTX + fdct4x16_col_neon, // V_DCT + fidentity4x16_col_neon, // H_DCT + fadst4x16_col_neon, // V_ADST + fidentity4x16_col_neon, // H_ADST + fadst4x16_col_neon, // V_FLIPADST + fidentity4x16_col_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_4_neon row_txfm4x16_arr[TX_TYPES] = { + fdct4x16_row_neon, // DCT_DCT + fdct4x16_row_neon, // ADST_DCT + fadst4x16_row_neon, // DCT_ADST + fadst4x16_row_neon, // ADST_ADST + fdct4x16_row_neon, // FLIPADST_DCT + fadst4x16_row_neon, // DCT_FLIPADST + fadst4x16_row_neon, // FLIPADST_FLIPADST + fadst4x16_row_neon, // ADST_FLIPADST + fadst4x16_row_neon, // FLIPADST_ADST + fidentity4x16_row_neon, // IDTX + fidentity4x16_row_neon, // V_DCT + fdct4x16_row_neon, // H_DCT + fidentity4x16_row_neon, // V_ADST + fadst4x16_row_neon, // H_ADST + fidentity4x16_row_neon, // V_FLIPADST + fadst4x16_row_neon // H_FLIPADST +}; + +static const col_transform_1d_lbd_8_neon col_txfm8x16_arr[TX_TYPES] = { + fdct8x16_col_neon, // DCT_DCT + fadst8x16_col_neon, // ADST_DCT + fdct8x16_col_neon, // DCT_ADST + fadst8x16_col_neon, // ADST_ADST + fadst8x16_col_neon, // FLIPADST_DCT + fdct8x16_col_neon, // DCT_FLIPADST + fadst8x16_col_neon, // FLIPADST_FLIPADST + fadst8x16_col_neon, // ADST_FLIPADST + fadst8x16_col_neon, // FLIPADST_ADST + fidentity8x16_col_neon, // IDTX + fdct8x16_col_neon, // V_DCT + fidentity8x16_col_neon, // H_DCT + fadst8x16_col_neon, // V_ADST + fidentity8x16_col_neon, // H_ADST + fadst8x16_col_neon, // V_FLIPADST + fidentity8x16_col_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_txfm8x16_arr[TX_TYPES] = { + fdct8x16_row_neon, // DCT_DCT + fdct8x16_row_neon, // ADST_DCT + fadst8x16_row_neon, // DCT_ADST + fadst8x16_row_neon, // ADST_ADST + fdct8x16_row_neon, // FLIPADST_DCT + fadst8x16_row_neon, // DCT_FLIPADST + fadst8x16_row_neon, // FLIPADST_FLIPADST + fadst8x16_row_neon, // ADST_FLIPADST + fadst8x16_row_neon, // FLIPADST_ADST + fidentity8x16_row_neon, // IDTX + fidentity8x16_row_neon, // V_DCT + fdct8x16_row_neon, // H_DCT + fidentity8x16_row_neon, // V_ADST + fadst8x16_row_neon, // H_ADST + fidentity8x16_row_neon, // V_FLIPADST + fadst8x16_row_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_rect_txfm8x16_arr[TX_TYPES] = { + fdct8x16_row_rect_neon, // DCT_DCT + fdct8x16_row_rect_neon, // ADST_DCT + fadst8x16_row_rect_neon, // DCT_ADST + fadst8x16_row_rect_neon, // ADST_ADST + fdct8x16_row_rect_neon, // FLIPADST_DCT + fadst8x16_row_rect_neon, // DCT_FLIPADST + fadst8x16_row_rect_neon, // FLIPADST_FLIPADST + fadst8x16_row_rect_neon, // ADST_FLIPADST + fadst8x16_row_rect_neon, // FLIPADST_ADST + fidentity8x16_row_rect_neon, // IDTX + fidentity8x16_row_rect_neon, // V_DCT + fdct8x16_row_rect_neon, // H_DCT + fidentity8x16_row_rect_neon, // V_ADST + fadst8x16_row_rect_neon, // H_ADST + fidentity8x16_row_rect_neon, // V_FLIPADST + fadst8x16_row_rect_neon // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_txfm8x32_arr[TX_TYPES] = { + fdct8x32_row_neon, // DCT_DCT + NULL, // ADST_DCT + NULL, // DCT_ADST + NULL, // ADST_ADST + NULL, // FLIPADST_DCT + NULL, // DCT_FLIPADST + NULL, // FLIPADST_FLIPADST + NULL, // ADST_FLIPADST + NULL, // FLIPADST_ADST + fidentity8x32_row_neon, // IDTX + fidentity8x32_row_neon, // V_DCT + fdct8x32_row_neon, // H_DCT + NULL, // V_ADST + NULL, // H_ADST + NULL, // V_FLIPADST + NULL // H_FLIPADST +}; + +static const row_transform_1d_lbd_8_neon row_rect_txfm8x32_arr[TX_TYPES] = { + fdct8x32_row_rect_neon, // DCT_DCT + NULL, // ADST_DCT + NULL, // DCT_ADST + NULL, // ADST_ADST + NULL, // FLIPADST_DCT + NULL, // DCT_FLIPADST + NULL, // FLIPADST_FLIPADST + NULL, // ADST_FLIPADST + NULL, // FLIPADST_ADST + fidentity8x32_row_rect_neon, // IDTX + fidentity8x32_row_rect_neon, // V_DCT + fdct8x32_row_rect_neon, // H_DCT + NULL, // V_ADST + NULL, // H_ADST + NULL, // V_FLIPADST + NULL // H_FLIPADST +}; + +static const col_transform_1d_lbd_8_neon col_txfm8x32_arr[TX_TYPES] = { + fdct8x32_col_neon, // DCT_DCT + NULL, // ADST_DCT + NULL, // DCT_ADST + NULL, // ADST_ADST + NULL, // FLIPADST_DCT + NULL, // DCT_FLIPADST + NULL, // FLIPADST_FLIPADST + NULL, // ADST_FLIPADST + NULL, // FLIPADST_ADST + fidentity8x32_col_neon, // IDTX + fdct8x32_col_neon, // V_DCT + fidentity8x32_col_neon, // H_DCT + NULL, // V_ADST + NULL, // H_ADST + NULL, // V_FLIPADST + NULL // H_FLIPADST +}; + +static void lowbd_fwd_txfm2d_4x4_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 4); + + int16x4_t buf0[4], buf1[4]; + switch (tx_type) { + case DCT_DCT: + fdct4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fdct4x4_row_neon(buf1, output, 4, 13); + break; + case ADST_DCT: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fdct4x4_row_neon(buf1, output, 4, 13); + break; + case DCT_ADST: + fdct4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fadst4x4_row_neon(buf1, output, 4, 13); + break; + case ADST_ADST: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fadst4x4_row_neon(buf1, output, 4, 13); + break; + case FLIPADST_DCT: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fdct4x4_row_neon(buf1, output, 4, 13); + break; + case DCT_FLIPADST: + fdct4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + flip_buf_4_neon(buf1, buf0, 4); + fadst4x4_row_neon(buf0, output, 4, 13); + break; + case FLIPADST_FLIPADST: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + flip_buf_4_neon(buf1, buf0, 4); + fadst4x4_row_neon(buf0, output, 4, 13); + break; + case ADST_FLIPADST: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + flip_buf_4_neon(buf1, buf0, 4); + fadst4x4_row_neon(buf0, output, 4, 13); + break; + case FLIPADST_ADST: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fadst4x4_row_neon(buf1, output, 4, 13); + break; + case IDTX: + fidentity4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fidentity4x4_row_neon(buf1, output, 4, 13); + break; + case V_DCT: + fdct4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fidentity4x4_row_neon(buf1, output, 4, 13); + break; + case H_DCT: + fidentity4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fdct4x4_row_neon(buf1, output, 4, 13); + break; + case V_ADST: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fidentity4x4_row_neon(buf1, output, 4, 13); + break; + case H_ADST: + fidentity4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fadst4x4_row_neon(buf1, output, 4, 13); + break; + case V_FLIPADST: + fadst4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + fidentity4x4_row_neon(buf1, output, 4, 13); + break; + case H_FLIPADST: + fidentity4x4_col_neon(input, buf0, stride, 13); + transpose_arrays_s16_4x4(buf0, buf1); + flip_buf_4_neon(buf1, buf0, 4); + fadst4x4_row_neon(buf0, output, 4, 13); + break; + } +} + +static void lowbd_fwd_txfm2d_4x8_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x4_t buf0[8]; + int16x8_t buf1[8]; + const col_transform_1d_lbd_4_neon col_txfm = col_txfm4x8_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_rect_txfm8x4_arr[tx_type]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + col_txfm(input, buf0, stride, 13); + shift_right_1_round_s16_x4(buf0, buf0, 8); + transpose_arrays_s16_4x8(buf0, buf1); + + if (lr_flip) { + int16x8_t buf2[8]; + flip_buf_8_neon(buf1, buf2, 4); + row_txfm(buf2, output, 8, 13); + } else { + row_txfm(buf1, output, 8, 13); + } +} + +static void lowbd_fwd_txfm2d_4x16_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x4_t buf0[16]; + int16x8_t buf1[16]; + const col_transform_1d_lbd_4_neon col_txfm = col_txfm4x16_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_txfm8x4_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + col_txfm(input, buf0, stride, 13); + shift_right_1_round_s16_x4(buf0, buf0, 16); + transpose_arrays_s16_4x8(buf0, buf1); + transpose_arrays_s16_4x8(buf0 + 8, buf1 + 8); + + for (int i = 0; i < 2; i++) { + if (lr_flip) { + int16x8_t buf2[16]; + flip_buf_8_neon(buf1 + 8 * i, buf2, 4); + row_txfm(buf2, output + 8 * i, 16, 12); + } else { + int16x8_t *buf = buf1 + 8 * i; + row_txfm(buf, output + 8 * i, 16, 12); + } + } +} + +static void lowbd_fwd_txfm2d_8x4_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[8]; + int16x4_t buf1[8]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x4_arr[tx_type]; + const row_transform_1d_lbd_4_neon row_txfm = row_rect_txfm4x8_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 4); + col_txfm(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 4); + transpose_arrays_s16_8x4(buf0, buf1); + + if (lr_flip) { + int16x4_t buf2[8]; + flip_buf_4_neon(buf1, buf2, 8); + row_txfm(buf2, output, 4, 13); + } else { + row_txfm(buf1, output, 4, 13); + } +} + +static void lowbd_fwd_txfm2d_8x8_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + + int16x8_t buf0[8], buf1[8]; + + switch (tx_type) { + case DCT_DCT: + fdct8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fdct8x8_row_neon(buf1, output, 8, 13); + break; + case ADST_DCT: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fdct8x8_row_neon(buf1, output, 8, 13); + break; + case DCT_ADST: + fdct8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fadst8x8_row_neon(buf1, output, 8, 13); + break; + case ADST_ADST: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fadst8x8_row_neon(buf1, output, 8, 13); + break; + case FLIPADST_DCT: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fdct8x8_row_neon(buf1, output, 8, 13); + break; + case DCT_FLIPADST: + fdct8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + flip_buf_8_neon(buf1, buf0, 8); + fadst8x8_row_neon(buf0, output, 8, 13); + break; + case FLIPADST_FLIPADST: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + flip_buf_8_neon(buf1, buf0, 8); + fadst8x8_row_neon(buf0, output, 8, 13); + break; + case ADST_FLIPADST: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + flip_buf_8_neon(buf1, buf0, 8); + fadst8x8_row_neon(buf0, output, 8, 13); + break; + case FLIPADST_ADST: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fadst8x8_row_neon(buf1, output, 8, 13); + break; + case IDTX: + fidentity8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fidentity8x8_row_neon(buf1, output, 8, 13); + break; + case V_DCT: + fdct8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fidentity8x8_row_neon(buf1, output, 8, 13); + break; + case H_DCT: + fidentity8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fdct8x8_row_neon(buf1, output, 8, 13); + break; + case V_ADST: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fidentity8x8_row_neon(buf1, output, 8, 13); + break; + case H_ADST: + fidentity8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fadst8x8_row_neon(buf1, output, 8, 13); + break; + case V_FLIPADST: + fadst8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + fidentity8x8_row_neon(buf1, output, 8, 13); + break; + case H_FLIPADST: + fidentity8x8_col_neon(input, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1); + flip_buf_8_neon(buf1, buf0, 8); + fadst8x8_row_neon(buf0, output, 8, 13); + break; + } +} + +static void lowbd_fwd_txfm2d_8x16_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[16], buf1[16]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x16_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_rect_txfm8x8_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + col_txfm(input, buf0, stride, 13); + shift_right_2_round_s16_x8(buf0, buf0, 16); + transpose_arrays_s16_8x8(buf0, buf1); + transpose_arrays_s16_8x8(buf0 + 8, buf1 + 8); + + for (int i = 0; i < 2; i++) { + if (lr_flip) { + flip_buf_8_neon(buf1 + 8 * i, buf0, 8); + row_txfm(buf0, output + 8 * i, 16, 13); + } else { + int16x8_t *buf = buf1 + 8 * i; + row_txfm(buf, output + 8 * i, 16, 13); + } + } +} + +static void lowbd_fwd_txfm2d_8x32_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[32], buf1[32]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x32_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_txfm8x8_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 32); + col_txfm(input, buf0, stride, 12); + shift_right_2_round_s16_x8(buf0, buf0, 32); + transpose_arrays_s16_8x8(buf0, buf1); + transpose_arrays_s16_8x8(buf0 + 8, buf1 + 8); + transpose_arrays_s16_8x8(buf0 + 16, buf1 + 16); + transpose_arrays_s16_8x8(buf0 + 24, buf1 + 24); + + for (int i = 0; i < 4; i++) { + if (lr_flip) { + flip_buf_8_neon(buf1 + 8 * i, buf0, 8); + row_txfm(buf0, output + 8 * i, 32, 12); + } else { + int16x8_t *buf = buf1 + 8 * i; + row_txfm(buf, output + 8 * i, 32, 12); + } + } +} + +static void lowbd_fwd_txfm2d_16x4_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[16]; + int16x4_t buf1[16]; + int16x4_t buf2[16]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x4_arr[tx_type]; + const row_transform_1d_lbd_4_neon row_txfm = row_txfm4x16_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 4); + for (int i = 0; i < 2; i++) { + col_txfm(input + 8 * i, buf0, stride, 13); + shift_right_1_round_s16_x8(buf0, buf0, 4); + transpose_arrays_s16_8x4(buf0, buf1 + 8 * i); + } + + if (lr_flip) { + flip_buf_4_neon(buf1, buf2, 16); + row_txfm(buf2, output, 4, 13); + } else { + row_txfm(buf1, output, 4, 13); + } +} + +static void lowbd_fwd_txfm2d_16x8_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[16], buf1[16]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x8_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_rect_txfm8x16_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + for (int i = 0; i < 2; i++) { + col_txfm(input + 8 * i, buf0, stride, 13); + shift_right_2_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1 + 8 * i); + } + + if (lr_flip) { + flip_buf_8_neon(buf1, buf0, 16); + row_txfm(buf0, output, 8, 13); + } else { + row_txfm(buf1, output, 8, 13); + } +} + +static void lowbd_fwd_txfm2d_16x16_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[16], buf1[32]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x16_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_txfm8x16_arr[tx_type]; + int ud_flip, lr_flip; + + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + for (int i = 0; i < 2; i++) { + col_txfm(input + 8 * i, buf0, stride, 13); + shift_right_2_round_s16_x8(buf0, buf0, 16); + transpose_arrays_s16_8x8(buf0, buf1 + 0 * 16 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 8, buf1 + 1 * 16 + 8 * i); + } + + for (int i = 0; i < 2; i++) { + if (lr_flip) { + flip_buf_8_neon(buf1 + 16 * i, buf0, 16); + row_txfm(buf0, output + 8 * i, 16, 12); + } else { + int16x8_t *buf = buf1 + 16 * i; + row_txfm(buf, output + 8 * i, 16, 12); + } + } +} + +static void lowbd_fwd_txfm2d_16x32_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[32], buf1[64]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x32_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_rect_txfm8x16_arr[tx_type]; + + if (col_txfm == NULL || row_txfm == NULL) { + av1_fwd_txfm2d_16x32_c(input, output, stride, tx_type, bd); + return; + } + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 32); + for (int i = 0; i < 2; i++) { + col_txfm(input + 8 * i, buf0, stride, 12); + shift_right_4_round_s16_x8(buf0, buf0, 32); + transpose_arrays_s16_8x8(buf0 + 0 * 8, buf1 + 0 * 16 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 1 * 8, buf1 + 1 * 16 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 2 * 8, buf1 + 2 * 16 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 3 * 8, buf1 + 3 * 16 + 8 * i); + } + + for (int i = 0; i < 4; i++) { + if (lr_flip) { + flip_buf_8_neon(buf1 + 16 * i, buf0, 16); + row_txfm(buf0, output + 8 * i, 32, 13); + } else { + int16x8_t *buf = buf1 + 16 * i; + row_txfm(buf, output + 8 * i, 32, 13); + } + } +} + +static void lowbd_fwd_txfm2d_32x8_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[32], buf1[32]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x8_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_txfm8x32_arr[tx_type]; + + if (col_txfm == NULL || row_txfm == NULL) { + av1_fwd_txfm2d_32x16_c(input, output, stride, tx_type, bd); + return; + } + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + for (int i = 0; i < 4; i++) { + col_txfm(input + 8 * i, buf0, stride, 13); + shift_right_2_round_s16_x8(buf0, buf0, 8); + transpose_arrays_s16_8x8(buf0, buf1 + 0 * 32 + 8 * i); + } + + if (lr_flip) { + flip_buf_8_neon(buf1, buf0, 32); + row_txfm(buf0, output, 8, 12); + } else { + row_txfm(buf1, output, 8, 12); + } +} + +static void lowbd_fwd_txfm2d_32x16_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[32], buf1[64]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x16_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_rect_txfm8x32_arr[tx_type]; + + if (col_txfm == NULL || row_txfm == NULL) { + av1_fwd_txfm2d_32x16_c(input, output, stride, tx_type, bd); + return; + } + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + for (int i = 0; i < 4; i++) { + col_txfm(input + 8 * i, buf0, stride, 13); + shift_right_4_round_s16_x8(buf0, buf0, 16); + transpose_arrays_s16_8x8(buf0, buf1 + 0 * 32 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 8, buf1 + 1 * 32 + 8 * i); + } + + for (int i = 0; i < 2; i++) { + if (lr_flip) { + flip_buf_8_neon(buf1 + 32 * i, buf0, 32); + row_txfm(buf0, output + 8 * i, 16, 13); + } else { + int16x8_t *buf = buf1 + 32 * i; + row_txfm(buf, output + 8 * i, 16, 13); + } + } +} + +static void lowbd_fwd_txfm2d_32x32_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[32], buf1[128]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x32_arr[tx_type]; + const row_transform_1d_lbd_8_neon row_txfm = row_txfm8x32_arr[tx_type]; + + if (col_txfm == NULL || row_txfm == NULL) { + av1_fwd_txfm2d_32x32_c(input, output, stride, tx_type, bd); + return; + } + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 32); + for (int i = 0; i < 4; i++) { + col_txfm(input + 8 * i, buf0, stride, 12); + shift_right_4_round_s16_x8(buf0, buf0, 32); + transpose_arrays_s16_8x8(buf0 + 0 * 8, buf1 + 0 * 32 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 1 * 8, buf1 + 1 * 32 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 2 * 8, buf1 + 2 * 32 + 8 * i); + transpose_arrays_s16_8x8(buf0 + 3 * 8, buf1 + 3 * 32 + 8 * i); + } + + for (int i = 0; i < 4; i++) { + if (lr_flip) { + flip_buf_8_neon(buf1 + 32 * i, buf0, 32); + row_txfm(buf0, output + 8 * i, 32, 12); + } else { + int16x8_t *buf = buf1 + 32 * i; + row_txfm(buf, output + 8 * i, 32, 12); + } + } +} + +static void lowbd_fwd_txfm2d_64x16_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + assert(tx_type == DCT_DCT); + int16x8_t buf0[64], buf1[128]; + const transform_1d_lbd_8_neon col_txfm = fdct8x16_neon; + const transform_1d_lbd_8_neon row_txfm = fdct8x64_neon; + + for (int i = 0; i < 8; i++) { + load_buffer_s16_x8(input + 8 * i, stride, buf0, 16); + shift_left_2_s16_x8(buf0, buf0, 16); + col_txfm(buf0, buf0, 13); + shift_right_4_round_s16_x8(buf0, buf0, 16); + for (int j = 0; j < 2; ++j) { + transpose_arrays_s16_8x8(buf0 + j * 8, buf1 + j * 64 + 8 * i); + } + } + + for (int i = 0; i < 2; i++) { + int16x8_t *buf = buf1 + 64 * i; + row_txfm(buf, buf, 12); + store_buffer_s16_x8(buf, output + 8 * i, 16, 32); + } + // Zero out the bottom 16x32 area. + memset(output + 16 * 32, 0, 16 * 32 * sizeof(*output)); +} + +static void lowbd_fwd_txfm2d_16x64_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + assert(tx_type == DCT_DCT); + int16x8_t buf0[64], buf1[128]; + const transform_1d_lbd_8_neon col_txfm = fdct8x64_neon; + const transform_1d_lbd_8_neon row_txfm = fdct8x16_neon; + + for (int i = 0; i < 2; i++) { + load_buffer_s16_x8(input + 8 * i, stride, buf0, 64); + col_txfm(buf0, buf0, 13); + shift_right_2_round_s16_x8(buf0, buf0, 64); + for (int j = 0; j < 8; ++j) { + transpose_arrays_s16_8x8(buf0 + j * 8, buf1 + j * 16 + 8 * i); + } + } + + for (int i = 0; i < 4; i++) { + int16x8_t *buf = buf1 + 16 * i; + row_txfm(buf, buf, 12); + store_buffer_s16_x8(buf, output + 8 * i, 32, 16); + } +} + +static void fdct32_neon(const int32x4_t *input, int32x4_t *output, + int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + const int16x8_t cospi2_6 = vld1q_s16(&cospi[4 * 8]); + const int16x8_t cospi10_14 = vld1q_s16(&cospi[4 * 10]); + const int16x8_t cospi18_22 = vld1q_s16(&cospi[4 * 12]); + const int16x8_t cospi26_30 = vld1q_s16(&cospi[4 * 14]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + const int16x4_t cospi2 = vget_low_s16(cospi2_6); + const int16x4_t cospi6 = vget_high_s16(cospi2_6); + const int16x4_t cospi10 = vget_low_s16(cospi10_14); + const int16x4_t cospi14 = vget_high_s16(cospi10_14); + const int16x4_t cospi18 = vget_low_s16(cospi18_22); + const int16x4_t cospi22 = vget_high_s16(cospi18_22); + const int16x4_t cospi26 = vget_low_s16(cospi26_30); + const int16x4_t cospi30 = vget_high_s16(cospi26_30); + + int32x4_t buf0[32]; + int32x4_t buf1[32]; + + // stage 1 + butterfly_dct_pre_s32_x4(input, buf1, 32); + + // stage 2 + butterfly_dct_pre_s32_x4(buf1, buf0, 16); + buf0[16] = buf1[16]; + buf0[17] = buf1[17]; + buf0[18] = buf1[18]; + buf0[19] = buf1[19]; + butterfly_s32_s32_x4_0112_neon(cospi32, buf1[27], buf1[20], &buf0[27], + &buf0[20]); + butterfly_s32_s32_x4_0112_neon(cospi32, buf1[26], buf1[21], &buf0[26], + &buf0[21]); + butterfly_s32_s32_x4_0112_neon(cospi32, buf1[25], buf1[22], &buf0[25], + &buf0[22]); + butterfly_s32_s32_x4_0112_neon(cospi32, buf1[24], buf1[23], &buf0[24], + &buf0[23]); + buf0[28] = buf1[28]; + buf0[29] = buf1[29]; + buf0[30] = buf1[30]; + buf0[31] = buf1[31]; + + // stage 3 + butterfly_dct_pre_s32_x4(buf0, buf1, 8); + buf1[8] = buf0[8]; + buf1[9] = buf0[9]; + butterfly_s32_s32_x4_0112_neon(cospi32, buf0[13], buf0[10], &buf1[13], + &buf1[10]); + butterfly_s32_s32_x4_0112_neon(cospi32, buf0[12], buf0[11], &buf1[12], + &buf1[11]); + buf1[14] = buf0[14]; + buf1[15] = buf0[15]; + butterfly_dct_post_s32_x4(buf0 + 16, buf0 + 16, buf1 + 16, 16); + + // stage 4 + butterfly_dct_pre_s32_x4(buf1, buf0, 4); + buf0[4] = buf1[4]; + butterfly_s32_s32_x4_0112_neon(cospi32, buf1[6], buf1[5], &buf0[6], &buf0[5]); + buf0[7] = buf1[7]; + butterfly_dct_post_s32_x4(buf1 + 8, buf1 + 8, buf0 + 8, 8); + buf0[16] = buf1[16]; + buf0[17] = buf1[17]; + butterfly_s32_s32_x4_0112_neon(cospi16, buf1[29], buf1[18], &buf0[29], + &buf0[18]); + butterfly_s32_s32_x4_0112_neon(cospi16, buf1[28], buf1[19], &buf0[28], + &buf0[19]); + butterfly_s32_s32_x4_1223_neon(cospi16, buf1[27], buf1[20], &buf0[27], + &buf0[20]); + butterfly_s32_s32_x4_1223_neon(cospi16, buf1[26], buf1[21], &buf0[26], + &buf0[21]); + buf0[22] = buf1[22]; + buf0[23] = buf1[23]; + buf0[24] = buf1[24]; + buf0[25] = buf1[25]; + buf0[30] = buf1[30]; + buf0[31] = buf1[31]; + + // stage 5 + butterfly_s32_s32_x4_0112_neon(cospi32, buf0[0], buf0[1], &buf1[0], &buf1[1]); + butterfly_s32_s32_x4_0112_neon(cospi16, buf0[3], buf0[2], &buf1[2], &buf1[3]); + butterfly_dct_post_s32_x4(buf0 + 4, buf0 + 4, buf1 + 4, 4); + buf1[8] = buf0[8]; + butterfly_s32_s32_x4_0112_neon(cospi16, buf0[14], buf0[9], &buf1[14], + &buf1[9]); + butterfly_s32_s32_x4_1223_neon(cospi16, buf0[13], buf0[10], &buf1[13], + &buf1[10]); + buf1[11] = buf0[11]; + buf1[12] = buf0[12]; + buf1[15] = buf0[15]; + butterfly_dct_post_s32_x4(buf0 + 16, buf0 + 16, buf1 + 16, 8); + butterfly_dct_post_s32_x4(buf0 + 24, buf0 + 24, buf1 + 24, 8); + + // stage 6 + buf0[0] = buf1[0]; + buf0[1] = buf1[1]; + buf0[2] = buf1[2]; + buf0[3] = buf1[3]; + butterfly_s32_s32_x4_0112_neon(cospi8, buf1[7], buf1[4], &buf0[4], &buf0[7]); + butterfly_s32_s32_x4_1003_neon(cospi24, buf1[6], buf1[5], &buf0[5], &buf0[6]); + butterfly_dct_post_s32_x4(buf1 + 8, buf1 + 8, buf0 + 8, 4); + butterfly_dct_post_s32_x4(buf1 + 12, buf1 + 12, buf0 + 12, 4); + buf0[16] = buf1[16]; + butterfly_s32_s32_x4_0112_neon(cospi8, buf1[30], buf1[17], &buf0[30], + &buf0[17]); + butterfly_s32_s32_x4_1223_neon(cospi8, buf1[29], buf1[18], &buf0[29], + &buf0[18]); + buf0[19] = buf1[19]; + buf0[20] = buf1[20]; + butterfly_s32_s32_x4_1003_neon(cospi24, buf1[26], buf1[21], &buf0[26], + &buf0[21]); + butterfly_s32_s32_x4_0332_neon(cospi24, buf1[25], buf1[22], &buf0[25], + &buf0[22]); + buf0[23] = buf1[23]; + buf0[24] = buf1[24]; + buf0[27] = buf1[27]; + buf0[28] = buf1[28]; + buf0[31] = buf1[31]; + + // stage 7 + buf1[0] = buf0[0]; + buf1[1] = buf0[1]; + buf1[2] = buf0[2]; + buf1[3] = buf0[3]; + buf1[4] = buf0[4]; + buf1[5] = buf0[5]; + buf1[6] = buf0[6]; + buf1[7] = buf0[7]; + butterfly_s32_s32_x4_0112_neon(cospi4, buf0[15], buf0[8], &buf1[8], + &buf1[15]); + butterfly_s32_s32_x4_1003_neon(cospi28, buf0[14], buf0[9], &buf1[9], + &buf1[14]); + butterfly_s32_s32_x4_0112_neon(cospi20, buf0[13], buf0[10], &buf1[10], + &buf1[13]); + butterfly_s32_s32_x4_1003_neon(cospi12, buf0[12], buf0[11], &buf1[11], + &buf1[12]); + butterfly_dct_post_s32_x4(buf0 + 16, buf0 + 16, buf1 + 16, 4); + butterfly_dct_post_s32_x4(buf0 + 20, buf0 + 20, buf1 + 20, 4); + butterfly_dct_post_s32_x4(buf0 + 24, buf0 + 24, buf1 + 24, 4); + butterfly_dct_post_s32_x4(buf0 + 28, buf0 + 28, buf1 + 28, 4); + + // stage 8 + buf0[0] = buf1[0]; + buf0[1] = buf1[1]; + buf0[2] = buf1[2]; + buf0[3] = buf1[3]; + buf0[4] = buf1[4]; + buf0[5] = buf1[5]; + buf0[6] = buf1[6]; + buf0[7] = buf1[7]; + buf0[8] = buf1[8]; + buf0[9] = buf1[9]; + buf0[10] = buf1[10]; + buf0[11] = buf1[11]; + buf0[12] = buf1[12]; + buf0[13] = buf1[13]; + buf0[14] = buf1[14]; + buf0[15] = buf1[15]; + butterfly_s32_s32_x4_0112_neon(cospi2, buf1[31], buf1[16], &buf0[16], + &buf0[31]); + butterfly_s32_s32_x4_1003_neon(cospi30, buf1[30], buf1[17], &buf0[17], + &buf0[30]); + butterfly_s32_s32_x4_0112_neon(cospi18, buf1[29], buf1[18], &buf0[18], + &buf0[29]); + butterfly_s32_s32_x4_1003_neon(cospi14, buf1[28], buf1[19], &buf0[19], + &buf0[28]); + butterfly_s32_s32_x4_0112_neon(cospi10, buf1[27], buf1[20], &buf0[20], + &buf0[27]); + butterfly_s32_s32_x4_1003_neon(cospi22, buf1[26], buf1[21], &buf0[21], + &buf0[26]); + butterfly_s32_s32_x4_0112_neon(cospi26, buf1[25], buf1[22], &buf0[22], + &buf0[25]); + butterfly_s32_s32_x4_1003_neon(cospi6, buf1[24], buf1[23], &buf0[23], + &buf0[24]); + + // stage 9 + output[0] = buf0[0]; + output[1] = buf0[16]; + output[2] = buf0[8]; + output[3] = buf0[24]; + output[4] = buf0[4]; + output[5] = buf0[20]; + output[6] = buf0[12]; + output[7] = buf0[28]; + output[8] = buf0[2]; + output[9] = buf0[18]; + output[10] = buf0[10]; + output[11] = buf0[26]; + output[12] = buf0[6]; + output[13] = buf0[22]; + output[14] = buf0[14]; + output[15] = buf0[30]; + output[16] = buf0[1]; + output[17] = buf0[17]; + output[18] = buf0[9]; + output[19] = buf0[25]; + output[20] = buf0[5]; + output[21] = buf0[21]; + output[22] = buf0[13]; + output[23] = buf0[29]; + output[24] = buf0[3]; + output[25] = buf0[19]; + output[26] = buf0[11]; + output[27] = buf0[27]; + output[28] = buf0[7]; + output[29] = buf0[23]; + output[30] = buf0[15]; + output[31] = buf0[31]; +} + +static void fdct64_neon(const int32x4_t *input, int32x4_t *output, + int cos_bit) { + const int16_t *cospi = cospi_arr_q13(cos_bit); + + const int16x8_t cospi32_16 = vld1q_s16(&cospi[4 * 0]); + const int16x8_t cospi8_24 = vld1q_s16(&cospi[4 * 2]); + const int16x8_t cospi4_12 = vld1q_s16(&cospi[4 * 4]); + const int16x8_t cospi20_28 = vld1q_s16(&cospi[4 * 6]); + const int16x8_t cospi2_6 = vld1q_s16(&cospi[4 * 8]); + const int16x8_t cospi10_14 = vld1q_s16(&cospi[4 * 10]); + const int16x8_t cospi18_22 = vld1q_s16(&cospi[4 * 12]); + const int16x8_t cospi26_30 = vld1q_s16(&cospi[4 * 14]); + const int16x8_t cospi1_3 = vld1q_s16(&cospi[4 * 16]); + const int16x8_t cospi5_7 = vld1q_s16(&cospi[4 * 18]); + const int16x8_t cospi9_11 = vld1q_s16(&cospi[4 * 20]); + const int16x8_t cospi13_15 = vld1q_s16(&cospi[4 * 22]); + const int16x8_t cospi17_19 = vld1q_s16(&cospi[4 * 24]); + const int16x8_t cospi21_23 = vld1q_s16(&cospi[4 * 26]); + const int16x8_t cospi25_27 = vld1q_s16(&cospi[4 * 28]); + const int16x8_t cospi29_31 = vld1q_s16(&cospi[4 * 30]); + + const int16x4_t cospi32 = vget_low_s16(cospi32_16); + const int16x4_t cospi16 = vget_high_s16(cospi32_16); + const int16x4_t cospi8 = vget_low_s16(cospi8_24); + const int16x4_t cospi24 = vget_high_s16(cospi8_24); + const int16x4_t cospi4 = vget_low_s16(cospi4_12); + const int16x4_t cospi12 = vget_high_s16(cospi4_12); + const int16x4_t cospi20 = vget_low_s16(cospi20_28); + const int16x4_t cospi28 = vget_high_s16(cospi20_28); + const int16x4_t cospi2 = vget_low_s16(cospi2_6); + const int16x4_t cospi6 = vget_high_s16(cospi2_6); + const int16x4_t cospi10 = vget_low_s16(cospi10_14); + const int16x4_t cospi14 = vget_high_s16(cospi10_14); + const int16x4_t cospi18 = vget_low_s16(cospi18_22); + const int16x4_t cospi22 = vget_high_s16(cospi18_22); + const int16x4_t cospi26 = vget_low_s16(cospi26_30); + const int16x4_t cospi30 = vget_high_s16(cospi26_30); + const int16x4_t cospi1 = vget_low_s16(cospi1_3); + const int16x4_t cospi3 = vget_high_s16(cospi1_3); + const int16x4_t cospi5 = vget_low_s16(cospi5_7); + const int16x4_t cospi7 = vget_high_s16(cospi5_7); + const int16x4_t cospi9 = vget_low_s16(cospi9_11); + const int16x4_t cospi11 = vget_high_s16(cospi9_11); + const int16x4_t cospi13 = vget_low_s16(cospi13_15); + const int16x4_t cospi15 = vget_high_s16(cospi13_15); + const int16x4_t cospi17 = vget_low_s16(cospi17_19); + const int16x4_t cospi19 = vget_high_s16(cospi17_19); + const int16x4_t cospi21 = vget_low_s16(cospi21_23); + const int16x4_t cospi23 = vget_high_s16(cospi21_23); + const int16x4_t cospi25 = vget_low_s16(cospi25_27); + const int16x4_t cospi27 = vget_high_s16(cospi25_27); + const int16x4_t cospi29 = vget_low_s16(cospi29_31); + const int16x4_t cospi31 = vget_high_s16(cospi29_31); + + // stage 1 + int32x4_t x1[64]; + butterfly_dct_pre_s32_x4(input, x1, 64); + + // stage 2 + int32x4_t x2[64]; + butterfly_dct_pre_s32_x4(x1, x2, 32); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[55], x1[40], &x2[55], &x2[40]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[54], x1[41], &x2[54], &x2[41]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[53], x1[42], &x2[53], &x2[42]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[52], x1[43], &x2[52], &x2[43]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[51], x1[44], &x2[51], &x2[44]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[50], x1[45], &x2[50], &x2[45]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[49], x1[46], &x2[49], &x2[46]); + butterfly_s32_s32_x4_0112_neon(cospi32, x1[48], x1[47], &x2[48], &x2[47]); + + // stage 3 + int32x4_t x3[64]; + butterfly_dct_pre_s32_x4(x2, x3, 16); + butterfly_s32_s32_x4_0112_neon(cospi32, x2[27], x2[20], &x3[27], &x3[20]); + butterfly_s32_s32_x4_0112_neon(cospi32, x2[26], x2[21], &x3[26], &x3[21]); + butterfly_s32_s32_x4_0112_neon(cospi32, x2[25], x2[22], &x3[25], &x3[22]); + butterfly_s32_s32_x4_0112_neon(cospi32, x2[24], x2[23], &x3[24], &x3[23]); + butterfly_dct_post_s32_x4(x1 + 32, x2 + 32, x3 + 32, 32); + + // stage 4 + int32x4_t x4[64]; + butterfly_dct_pre_s32_x4(x3, x4, 8); + butterfly_s32_s32_x4_0112_neon(cospi32, x3[13], x3[10], &x4[13], &x4[10]); + butterfly_s32_s32_x4_0112_neon(cospi32, x3[12], x3[11], &x4[12], &x4[11]); + butterfly_dct_post_s32_x4(x2 + 16, x3 + 16, x4 + 16, 16); + butterfly_s32_s32_x4_0112_neon(cospi16, x3[59], x3[36], &x4[59], &x4[36]); + butterfly_s32_s32_x4_0112_neon(cospi16, x3[58], x3[37], &x4[58], &x4[37]); + butterfly_s32_s32_x4_0112_neon(cospi16, x3[57], x3[38], &x4[57], &x4[38]); + butterfly_s32_s32_x4_0112_neon(cospi16, x3[56], x3[39], &x4[56], &x4[39]); + butterfly_s32_s32_x4_1223_neon(cospi16, x3[55], x3[40], &x4[55], &x4[40]); + butterfly_s32_s32_x4_1223_neon(cospi16, x3[54], x3[41], &x4[54], &x4[41]); + butterfly_s32_s32_x4_1223_neon(cospi16, x3[53], x3[42], &x4[53], &x4[42]); + butterfly_s32_s32_x4_1223_neon(cospi16, x3[52], x3[43], &x4[52], &x4[43]); + + // stage 5 + int32x4_t x5[64]; + butterfly_dct_pre_s32_x4(x4, x5, 4); + butterfly_s32_s32_x4_0112_neon(cospi32, x4[6], x4[5], &x5[6], &x5[5]); + butterfly_dct_post_s32_x4(x3 + 8, x4 + 8, x5 + 8, 8); + butterfly_s32_s32_x4_0112_neon(cospi16, x4[29], x4[18], &x5[29], &x5[18]); + butterfly_s32_s32_x4_0112_neon(cospi16, x4[28], x4[19], &x5[28], &x5[19]); + butterfly_s32_s32_x4_1223_neon(cospi16, x4[27], x4[20], &x5[27], &x5[20]); + butterfly_s32_s32_x4_1223_neon(cospi16, x4[26], x4[21], &x5[26], &x5[21]); + butterfly_dct_post_s32_x4(x3 + 32, x4 + 32, x5 + 32, 16); + butterfly_dct_post_s32_x4(x3 + 48, x4 + 48, x5 + 48, 16); + + // stage 6 + int32x4_t x6[64]; + butterfly_s32_s32_x4_0112_neon(cospi32, x5[0], x5[1], &x6[0], &x6[1]); + butterfly_s32_s32_x4_0112_neon(cospi16, x5[3], x5[2], &x6[2], &x6[3]); + butterfly_dct_post_s32_x4(x4 + 4, x5 + 4, x6 + 4, 4); + butterfly_s32_s32_x4_0112_neon(cospi16, x5[14], x5[9], &x6[14], &x6[9]); + butterfly_s32_s32_x4_1223_neon(cospi16, x5[13], x5[10], &x6[13], &x6[10]); + butterfly_dct_post_s32_x4(x4 + 16, x5 + 16, x6 + 16, 8); + butterfly_dct_post_s32_x4(x4 + 24, x5 + 24, x6 + 24, 8); + butterfly_s32_s32_x4_0112_neon(cospi8, x5[61], x5[34], &x6[61], &x6[34]); + butterfly_s32_s32_x4_0112_neon(cospi8, x5[60], x5[35], &x6[60], &x6[35]); + butterfly_s32_s32_x4_1223_neon(cospi8, x5[59], x5[36], &x6[59], &x6[36]); + butterfly_s32_s32_x4_1223_neon(cospi8, x5[58], x5[37], &x6[58], &x6[37]); + butterfly_s32_s32_x4_1003_neon(cospi24, x5[53], x5[42], &x6[53], &x6[42]); + butterfly_s32_s32_x4_1003_neon(cospi24, x5[52], x5[43], &x6[52], &x6[43]); + butterfly_s32_s32_x4_0332_neon(cospi24, x5[51], x5[44], &x6[51], &x6[44]); + butterfly_s32_s32_x4_0332_neon(cospi24, x5[50], x5[45], &x6[50], &x6[45]); + + // stage 7 + int32x4_t x7[64]; + butterfly_s32_s32_x4_0112_neon(cospi8, x6[7], x6[4], &x7[4], &x7[7]); + butterfly_s32_s32_x4_1003_neon(cospi24, x6[6], x6[5], &x7[5], &x7[6]); + butterfly_dct_post_s32_x4(x5 + 8, x6 + 8, x7 + 8, 4); + butterfly_dct_post_s32_x4(x5 + 12, x6 + 12, x7 + 12, 4); + butterfly_s32_s32_x4_0112_neon(cospi8, x6[30], x6[17], &x7[30], &x7[17]); + butterfly_s32_s32_x4_1223_neon(cospi8, x6[29], x6[18], &x7[29], &x7[18]); + butterfly_s32_s32_x4_1003_neon(cospi24, x6[26], x6[21], &x7[26], &x7[21]); + butterfly_s32_s32_x4_0332_neon(cospi24, x6[25], x6[22], &x7[25], &x7[22]); + butterfly_dct_post_s32_x4(x5 + 32, x6 + 32, x7 + 32, 8); + butterfly_dct_post_s32_x4(x5 + 40, x6 + 40, x7 + 40, 8); + butterfly_dct_post_s32_x4(x5 + 48, x6 + 48, x7 + 48, 8); + butterfly_dct_post_s32_x4(x5 + 56, x6 + 56, x7 + 56, 8); + + // stage 8 + int32x4_t x8[64]; + butterfly_s32_s32_x4_0112_neon(cospi4, x7[15], x7[8], &x8[8], &x8[15]); + butterfly_s32_s32_x4_1003_neon(cospi28, x7[14], x7[9], &x8[9], &x8[14]); + butterfly_s32_s32_x4_0112_neon(cospi20, x7[13], x7[10], &x8[10], &x8[13]); + butterfly_s32_s32_x4_1003_neon(cospi12, x7[12], x7[11], &x8[11], &x8[12]); + butterfly_dct_post_s32_x4(x6 + 16, x7 + 16, x8 + 16, 4); + butterfly_dct_post_s32_x4(x6 + 20, x7 + 20, x8 + 20, 4); + butterfly_dct_post_s32_x4(x6 + 24, x7 + 24, x8 + 24, 4); + butterfly_dct_post_s32_x4(x6 + 28, x7 + 28, x8 + 28, 4); + butterfly_s32_s32_x4_0112_neon(cospi4, x7[62], x7[33], &x8[62], &x8[33]); + butterfly_s32_s32_x4_1223_neon(cospi4, x7[61], x7[34], &x8[61], &x8[34]); + butterfly_s32_s32_x4_1003_neon(cospi28, x7[58], x7[37], &x8[58], &x8[37]); + butterfly_s32_s32_x4_0332_neon(cospi28, x7[57], x7[38], &x8[57], &x8[38]); + butterfly_s32_s32_x4_0112_neon(cospi20, x7[54], x7[41], &x8[54], &x8[41]); + butterfly_s32_s32_x4_1223_neon(cospi20, x7[53], x7[42], &x8[53], &x8[42]); + butterfly_s32_s32_x4_1003_neon(cospi12, x7[50], x7[45], &x8[50], &x8[45]); + butterfly_s32_s32_x4_0332_neon(cospi12, x7[49], x7[46], &x8[49], &x8[46]); + + // stage 9 + int32x4_t x9[64]; + butterfly_s32_s32_x4_0112_neon(cospi2, x8[31], x8[16], &x9[16], &x9[31]); + butterfly_s32_s32_x4_1003_neon(cospi30, x8[30], x8[17], &x9[17], &x9[30]); + butterfly_s32_s32_x4_0112_neon(cospi18, x8[29], x8[18], &x9[18], &x9[29]); + butterfly_s32_s32_x4_1003_neon(cospi14, x8[28], x8[19], &x9[19], &x9[28]); + butterfly_s32_s32_x4_0112_neon(cospi10, x8[27], x8[20], &x9[20], &x9[27]); + butterfly_s32_s32_x4_1003_neon(cospi22, x8[26], x8[21], &x9[21], &x9[26]); + butterfly_s32_s32_x4_0112_neon(cospi26, x8[25], x8[22], &x9[22], &x9[25]); + butterfly_s32_s32_x4_1003_neon(cospi6, x8[24], x8[23], &x9[23], &x9[24]); + butterfly_dct_post_s32_x4(x7 + 32, x8 + 32, x9 + 32, 4); + butterfly_dct_post_s32_x4(x7 + 36, x8 + 36, x9 + 36, 4); + butterfly_dct_post_s32_x4(x7 + 40, x8 + 40, x9 + 40, 4); + butterfly_dct_post_s32_x4(x7 + 44, x8 + 44, x9 + 44, 4); + butterfly_dct_post_s32_x4(x7 + 48, x8 + 48, x9 + 48, 4); + butterfly_dct_post_s32_x4(x7 + 52, x8 + 52, x9 + 52, 4); + butterfly_dct_post_s32_x4(x7 + 56, x8 + 56, x9 + 56, 4); + butterfly_dct_post_s32_x4(x7 + 60, x8 + 60, x9 + 60, 4); + + // stage 10 + int32x4_t x10[64]; + butterfly_s32_s32_x4_0112_neon(cospi1, x9[63], x9[32], &x10[32], &x10[63]); + butterfly_s32_s32_x4_1003_neon(cospi31, x9[62], x9[33], &x10[33], &x10[62]); + butterfly_s32_s32_x4_0112_neon(cospi17, x9[61], x9[34], &x10[34], &x10[61]); + butterfly_s32_s32_x4_1003_neon(cospi15, x9[60], x9[35], &x10[35], &x10[60]); + butterfly_s32_s32_x4_0112_neon(cospi9, x9[59], x9[36], &x10[36], &x10[59]); + butterfly_s32_s32_x4_1003_neon(cospi23, x9[58], x9[37], &x10[37], &x10[58]); + butterfly_s32_s32_x4_0112_neon(cospi25, x9[57], x9[38], &x10[38], &x10[57]); + butterfly_s32_s32_x4_1003_neon(cospi7, x9[56], x9[39], &x10[39], &x10[56]); + butterfly_s32_s32_x4_0112_neon(cospi5, x9[55], x9[40], &x10[40], &x10[55]); + butterfly_s32_s32_x4_1003_neon(cospi27, x9[54], x9[41], &x10[41], &x10[54]); + butterfly_s32_s32_x4_0112_neon(cospi21, x9[53], x9[42], &x10[42], &x10[53]); + butterfly_s32_s32_x4_1003_neon(cospi11, x9[52], x9[43], &x10[43], &x10[52]); + butterfly_s32_s32_x4_0112_neon(cospi13, x9[51], x9[44], &x10[44], &x10[51]); + butterfly_s32_s32_x4_1003_neon(cospi19, x9[50], x9[45], &x10[45], &x10[50]); + butterfly_s32_s32_x4_0112_neon(cospi29, x9[49], x9[46], &x10[46], &x10[49]); + butterfly_s32_s32_x4_1003_neon(cospi3, x9[48], x9[47], &x10[47], &x10[48]); + + // stage 11, only store into the low 32 output indices. + output[0] = x6[0]; + output[1] = x10[32]; + output[2] = x9[16]; + output[3] = x10[48]; + output[4] = x8[8]; + output[5] = x10[40]; + output[6] = x9[24]; + output[7] = x10[56]; + output[8] = x7[4]; + output[9] = x10[36]; + output[10] = x9[20]; + output[11] = x10[52]; + output[12] = x8[12]; + output[13] = x10[44]; + output[14] = x9[28]; + output[15] = x10[60]; + output[16] = x6[2]; + output[17] = x10[34]; + output[18] = x9[18]; + output[19] = x10[50]; + output[20] = x8[10]; + output[21] = x10[42]; + output[22] = x9[26]; + output[23] = x10[58]; + output[24] = x7[6]; + output[25] = x10[38]; + output[26] = x9[22]; + output[27] = x10[54]; + output[28] = x8[14]; + output[29] = x10[46]; + output[30] = x9[30]; + output[31] = x10[62]; +} + +static void lowbd_fwd_txfm2d_64x64_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + assert(tx_type == DCT_DCT); + int16x8_t buf0[64], buf1[512]; + const transform_1d_lbd_8_neon col_txfm = fdct8x64_neon; + + for (int i = 0; i < 8; i++) { + load_buffer_s16_x8(input + 8 * i, stride, buf0, 64); + col_txfm(buf0, buf0, 13); + shift_right_2_round_s16_x8(buf0, buf0, 64); + for (int j = 0; j < 4; ++j) { + transpose_arrays_s16_8x8(buf0 + j * 8, buf1 + j * 64 + 8 * i); + } + } + for (int i = 0; i < 4; i++) { + int32x4_t bufA[64]; + int32x4_t bufB[64]; + int16x8_t *buf = buf1 + 64 * i; + for (int j = 0; j < 64; ++j) { + bufA[j] = vmovl_s16(vget_low_s16(buf[j])); + bufB[j] = vmovl_s16(vget_high_s16(buf[j])); + } + fdct64_neon(bufA, bufA, 10); + fdct64_neon(bufB, bufB, 10); + shift_right_2_round_s32_x4(bufA, bufA, 32); + shift_right_2_round_s32_x4(bufB, bufB, 32); + store_buffer_interleaved_s32_x8(output + i * 8, bufA, bufB, 32, 32); + } +} + +static void lowbd_fwd_txfm2d_64x32_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + int16x8_t buf0[64], buf1[256]; + const col_transform_1d_lbd_8_neon col_txfm = col_txfm8x32_arr[tx_type]; + + for (int i = 0; i < 8; i++) { + col_txfm(input + 8 * i, buf0, stride, 12); + shift_right_4_round_s16_x8(buf0, buf0, 32); + for (int j = 0; j < 4; ++j) { + transpose_arrays_s16_8x8(buf0 + j * 8, buf1 + j * 64 + 8 * i); + } + } + assert(tx_type == DCT_DCT); + for (int i = 0; i < 4; i++) { + int32x4_t bufA[64]; + int32x4_t bufB[64]; + int16x8_t *buf = buf1 + 64 * i; + for (int j = 0; j < 64; ++j) { + bufA[j] = vmovl_s16(vget_low_s16(buf[j])); + bufB[j] = vmovl_s16(vget_high_s16(buf[j])); + } + fdct64_neon(bufA, bufA, 11); + fdct64_neon(bufB, bufB, 11); + shift_right_2_round_s32_x4(bufA, bufA, 32); + shift_right_2_round_s32_x4(bufB, bufB, 32); + round_shift_sqrt2_s32_s32_4xn_neon(bufA, bufA, 32); + round_shift_sqrt2_s32_s32_4xn_neon(bufB, bufB, 32); + store_buffer_interleaved_s32_x8(output + i * 8, bufA, bufB, 32, 32); + } +} + +static void lowbd_fwd_txfm2d_32x64_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + assert(tx_type == DCT_DCT); + int16x8_t buf0[64], buf1[256]; + const transform_1d_lbd_8_neon col_txfm = fdct8x64_neon; + + for (int i = 0; i < 4; i++) { + load_buffer_s16_x8(input + 8 * i, stride, buf0, 64); + col_txfm(buf0, buf0, 13); + shift_right_2_round_s16_x8(buf0, buf0, 64); + for (int j = 0; j < 4; ++j) { + transpose_arrays_s16_8x8(buf0 + j * 8, buf1 + j * 32 + 8 * i); + } + } + + for (int i = 0; i < 4; i++) { + int32x4_t bufA[32]; + int32x4_t bufB[32]; + int16x8_t *buf = buf1 + 32 * i; + for (int j = 0; j < 32; ++j) { + bufA[j] = vmovl_s16(vget_low_s16(buf[j])); + bufB[j] = vmovl_s16(vget_high_s16(buf[j])); + } + fdct32_neon(bufA, bufA, 11); + fdct32_neon(bufB, bufB, 11); + shift_right_2_round_s32_x4(bufA, bufA, 32); + shift_right_2_round_s32_x4(bufB, bufB, 32); + round_shift_sqrt2_s32_s32_4xn_neon(bufA, bufA, 32); + round_shift_sqrt2_s32_s32_4xn_neon(bufB, bufB, 32); + store_buffer_interleaved_s32_x8(output + i * 8, bufA, bufB, 32, 32); + } +} + +static FwdTxfm2dFunc lowbd_fwd_txfm_func_ls[TX_SIZES_ALL] = { + lowbd_fwd_txfm2d_4x4_neon, // 4x4 transform + lowbd_fwd_txfm2d_8x8_neon, // 8x8 transform + lowbd_fwd_txfm2d_16x16_neon, // 16x16 transform + lowbd_fwd_txfm2d_32x32_neon, // 32x32 transform + lowbd_fwd_txfm2d_64x64_neon, // 64x64 transform + lowbd_fwd_txfm2d_4x8_neon, // 4x8 transform + lowbd_fwd_txfm2d_8x4_neon, // 8x4 transform + lowbd_fwd_txfm2d_8x16_neon, // 8x16 transform + lowbd_fwd_txfm2d_16x8_neon, // 16x8 transform + lowbd_fwd_txfm2d_16x32_neon, // 16x32 transform + lowbd_fwd_txfm2d_32x16_neon, // 32x16 transform + lowbd_fwd_txfm2d_32x64_neon, // 32x64 transform + lowbd_fwd_txfm2d_64x32_neon, // 64x32 transform + lowbd_fwd_txfm2d_4x16_neon, // 4x16 transform + lowbd_fwd_txfm2d_16x4_neon, // 16x4 transform + lowbd_fwd_txfm2d_8x32_neon, // 8x32 transform + lowbd_fwd_txfm2d_32x8_neon, // 32x8 transform + lowbd_fwd_txfm2d_16x64_neon, // 16x64 transform + lowbd_fwd_txfm2d_64x16_neon, // 64x16 transform +}; + +void av1_lowbd_fwd_txfm_neon(const int16_t *src_diff, tran_low_t *coeff, + int diff_stride, TxfmParam *txfm_param) { + FwdTxfm2dFunc fwd_txfm2d_func = lowbd_fwd_txfm_func_ls[txfm_param->tx_size]; + if (txfm_param->lossless && txfm_param->tx_size == TX_4X4) { + av1_lowbd_fwd_txfm_c(src_diff, coeff, diff_stride, txfm_param); + } else { + fwd_txfm2d_func(src_diff, coeff, diff_stride, txfm_param->tx_type, + txfm_param->bd); + } +} diff --git a/third_party/aom/av1/encoder/arm/neon/av1_highbd_quantize_neon.c b/third_party/aom/av1/encoder/arm/neon/av1_highbd_quantize_neon.c new file mode 100644 index 0000000000..11d3def16b --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/av1_highbd_quantize_neon.c @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include "config/aom_config.h" + +#include "aom_dsp/arm/mem_neon.h" + +#include "av1/common/quant_common.h" +#include "av1/encoder/av1_quantize.h" + +static INLINE uint16x4_t quantize_4(const tran_low_t *coeff_ptr, + tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, + int32x4_t v_quant_s32, + int32x4_t v_dequant_s32, + int32x4_t v_round_s32, int log_scale) { + const int32x4_t v_coeff = vld1q_s32(coeff_ptr); + const int32x4_t v_coeff_sign = + vreinterpretq_s32_u32(vcltq_s32(v_coeff, vdupq_n_s32(0))); + const int32x4_t v_log_scale = vdupq_n_s32(log_scale); + const int32x4_t v_abs_coeff = vabsq_s32(v_coeff); + // ((abs_coeff << (1 + log_scale)) >= dequant_ptr[rc01]) + const int32x4_t v_abs_coeff_scaled = + vshlq_s32(v_abs_coeff, vdupq_n_s32(1 + log_scale)); + const uint32x4_t v_mask = vcgeq_s32(v_abs_coeff_scaled, v_dequant_s32); + // const int64_t tmp = vmask ? (int64_t)abs_coeff + log_scaled_round : 0 + const int32x4_t v_tmp = vandq_s32(vaddq_s32(v_abs_coeff, v_round_s32), + vreinterpretq_s32_u32(v_mask)); + // const int abs_qcoeff = (int)((tmp * quant) >> (16 - log_scale)); + const int32x4_t v_abs_qcoeff = + vqdmulhq_s32(vshlq_s32(v_tmp, v_log_scale), v_quant_s32); + // qcoeff_ptr[rc] = (tran_low_t)((abs_qcoeff ^ coeff_sign) - coeff_sign); + const int32x4_t v_qcoeff = + vsubq_s32(veorq_s32(v_abs_qcoeff, v_coeff_sign), v_coeff_sign); + // vshlq_s32 will shift right if shift value is negative. + const int32x4_t v_abs_dqcoeff = + vshlq_s32(vmulq_s32(v_abs_qcoeff, v_dequant_s32), vnegq_s32(v_log_scale)); + // dqcoeff_ptr[rc] = (tran_low_t)((abs_dqcoeff ^ coeff_sign) - coeff_sign); + const int32x4_t v_dqcoeff = + vsubq_s32(veorq_s32(v_abs_dqcoeff, v_coeff_sign), v_coeff_sign); + + vst1q_s32(qcoeff_ptr, v_qcoeff); + vst1q_s32(dqcoeff_ptr, v_dqcoeff); + + // Used to find eob. + const uint32x4_t nz_qcoeff_mask = vcgtq_s32(v_abs_qcoeff, vdupq_n_s32(0)); + return vmovn_u32(nz_qcoeff_mask); +} + +static INLINE int16x8_t get_max_lane_eob(const int16_t *iscan, + int16x8_t v_eobmax, + uint16x8_t v_mask) { + const int16x8_t v_iscan = vld1q_s16(&iscan[0]); + const int16x8_t v_iscan_plus1 = vaddq_s16(v_iscan, vdupq_n_s16(1)); + const int16x8_t v_nz_iscan = vbslq_s16(v_mask, v_iscan_plus1, vdupq_n_s16(0)); + return vmaxq_s16(v_eobmax, v_nz_iscan); +} + +static INLINE uint16_t get_max_eob(int16x8_t v_eobmax) { +#if AOM_ARCH_AARCH64 + return (uint16_t)vmaxvq_s16(v_eobmax); +#else + const int16x4_t v_eobmax_3210 = + vmax_s16(vget_low_s16(v_eobmax), vget_high_s16(v_eobmax)); + const int64x1_t v_eobmax_xx32 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmax_3210), 32); + const int16x4_t v_eobmax_tmp = + vmax_s16(v_eobmax_3210, vreinterpret_s16_s64(v_eobmax_xx32)); + const int64x1_t v_eobmax_xxx3 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmax_tmp), 16); + const int16x4_t v_eobmax_final = + vmax_s16(v_eobmax_tmp, vreinterpret_s16_s64(v_eobmax_xxx3)); + return (uint16_t)vget_lane_s16(v_eobmax_final, 0); +#endif +} + +void av1_highbd_quantize_fp_neon( + const tran_low_t *coeff_ptr, intptr_t count, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, int log_scale) { + (void)scan; + (void)zbin_ptr; + (void)quant_shift_ptr; + + const int16x4_t v_quant = vld1_s16(quant_ptr); + const int16x4_t v_dequant = vld1_s16(dequant_ptr); + const int16x4_t v_zero = vdup_n_s16(0); + const uint16x4_t v_round_select = vcgt_s16(vdup_n_s16(log_scale), v_zero); + const int16x4_t v_round_no_scale = vld1_s16(round_ptr); + const int16x4_t v_round_log_scale = + vqrdmulh_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale))); + const int16x4_t v_round = + vbsl_s16(v_round_select, v_round_log_scale, v_round_no_scale); + int32x4_t v_round_s32 = vaddl_s16(v_round, v_zero); + int32x4_t v_quant_s32 = vshlq_n_s32(vaddl_s16(v_quant, v_zero), 15); + int32x4_t v_dequant_s32 = vaddl_s16(v_dequant, v_zero); + uint16x4_t v_mask_lo, v_mask_hi; + int16x8_t v_eobmax = vdupq_n_s16(-1); + + // DC and first 3 AC + v_mask_lo = quantize_4(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant_s32, + v_dequant_s32, v_round_s32, log_scale); + + // overwrite the DC constants with AC constants + v_round_s32 = vdupq_lane_s32(vget_low_s32(v_round_s32), 1); + v_quant_s32 = vdupq_lane_s32(vget_low_s32(v_quant_s32), 1); + v_dequant_s32 = vdupq_lane_s32(vget_low_s32(v_dequant_s32), 1); + + // 4 more AC + v_mask_hi = quantize_4(coeff_ptr + 4, qcoeff_ptr + 4, dqcoeff_ptr + 4, + v_quant_s32, v_dequant_s32, v_round_s32, log_scale); + + // Find the max lane eob for the first 8 coeffs. + v_eobmax = + get_max_lane_eob(iscan, v_eobmax, vcombine_u16(v_mask_lo, v_mask_hi)); + + count -= 8; + do { + coeff_ptr += 8; + qcoeff_ptr += 8; + dqcoeff_ptr += 8; + iscan += 8; + v_mask_lo = quantize_4(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant_s32, + v_dequant_s32, v_round_s32, log_scale); + v_mask_hi = quantize_4(coeff_ptr + 4, qcoeff_ptr + 4, dqcoeff_ptr + 4, + v_quant_s32, v_dequant_s32, v_round_s32, log_scale); + // Find the max lane eob for 8 coeffs. + v_eobmax = + get_max_lane_eob(iscan, v_eobmax, vcombine_u16(v_mask_lo, v_mask_hi)); + count -= 8; + } while (count); + + *eob_ptr = get_max_eob(v_eobmax); +} diff --git a/third_party/aom/av1/encoder/arm/neon/av1_k_means_neon.c b/third_party/aom/av1/encoder/arm/neon/av1_k_means_neon.c new file mode 100644 index 0000000000..d13cc65ae0 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/av1_k_means_neon.c @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <arm_neon.h> + +#include "aom_dsp/arm/sum_neon.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +static int32x4_t k_means_multiply_add_neon(const int16x8_t a) { + const int32x4_t l = vmull_s16(vget_low_s16(a), vget_low_s16(a)); + const int32x4_t h = vmull_s16(vget_high_s16(a), vget_high_s16(a)); +#if AOM_ARCH_AARCH64 + return vpaddq_s32(l, h); +#else + const int32x2_t dl = vpadd_s32(vget_low_s32(l), vget_high_s32(l)); + const int32x2_t dh = vpadd_s32(vget_low_s32(h), vget_high_s32(h)); + return vcombine_s32(dl, dh); +#endif +} + +void av1_calc_indices_dim1_neon(const int16_t *data, const int16_t *centroids, + uint8_t *indices, int64_t *total_dist, int n, + int k) { + int64x2_t sum = vdupq_n_s64(0); + int16x8_t cents[PALETTE_MAX_SIZE]; + for (int j = 0; j < k; ++j) { + cents[j] = vdupq_n_s16(centroids[j]); + } + + for (int i = 0; i < n; i += 8) { + const int16x8_t in = vld1q_s16(data); + uint16x8_t ind = vdupq_n_u16(0); + // Compute the distance to the first centroid. + int16x8_t dist_min = vabdq_s16(in, cents[0]); + + for (int j = 1; j < k; ++j) { + // Compute the distance to the centroid. + const int16x8_t dist = vabdq_s16(in, cents[j]); + // Compare to the minimal one. + const uint16x8_t cmp = vcgtq_s16(dist_min, dist); + dist_min = vminq_s16(dist_min, dist); + const uint16x8_t ind1 = vdupq_n_u16(j); + ind = vbslq_u16(cmp, ind1, ind); + } + if (total_dist) { + // Square, convert to 32 bit and add together. + const int32x4_t l = + vmull_s16(vget_low_s16(dist_min), vget_low_s16(dist_min)); + const int32x4_t sum32_tmp = + vmlal_s16(l, vget_high_s16(dist_min), vget_high_s16(dist_min)); + // Pairwise sum, convert to 64 bit and add to sum. + sum = vpadalq_s32(sum, sum32_tmp); + } + vst1_u8(indices, vmovn_u16(ind)); + indices += 8; + data += 8; + } + if (total_dist) { + *total_dist = horizontal_add_s64x2(sum); + } +} + +void av1_calc_indices_dim2_neon(const int16_t *data, const int16_t *centroids, + uint8_t *indices, int64_t *total_dist, int n, + int k) { + int64x2_t sum = vdupq_n_s64(0); + uint32x4_t ind[2]; + int16x8_t cents[PALETTE_MAX_SIZE]; + for (int j = 0; j < k; ++j) { + const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1]; + const int16_t cxcy[8] = { cx, cy, cx, cy, cx, cy, cx, cy }; + cents[j] = vld1q_s16(cxcy); + } + + for (int i = 0; i < n; i += 8) { + for (int l = 0; l < 2; ++l) { + const int16x8_t in = vld1q_s16(data); + ind[l] = vdupq_n_u32(0); + // Compute the distance to the first centroid. + int16x8_t d1 = vsubq_s16(in, cents[0]); + int32x4_t dist_min = k_means_multiply_add_neon(d1); + + for (int j = 1; j < k; ++j) { + // Compute the distance to the centroid. + d1 = vsubq_s16(in, cents[j]); + const int32x4_t dist = k_means_multiply_add_neon(d1); + // Compare to the minimal one. + const uint32x4_t cmp = vcgtq_s32(dist_min, dist); + dist_min = vminq_s32(dist_min, dist); + const uint32x4_t ind1 = vdupq_n_u32(j); + ind[l] = vbslq_u32(cmp, ind1, ind[l]); + } + if (total_dist) { + // Pairwise sum, convert to 64 bit and add to sum. + sum = vpadalq_s32(sum, dist_min); + } + data += 8; + } + // Cast to 8 bit and store. + vst1_u8(indices, + vmovn_u16(vcombine_u16(vmovn_u32(ind[0]), vmovn_u32(ind[1])))); + indices += 8; + } + if (total_dist) { + *total_dist = horizontal_add_s64x2(sum); + } +} diff --git a/third_party/aom/av1/encoder/arm/neon/av1_temporal_denoiser_neon.c b/third_party/aom/av1/encoder/arm/neon/av1_temporal_denoiser_neon.c new file mode 100644 index 0000000000..18cd0ce4c0 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/av1_temporal_denoiser_neon.c @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "aom/aom_integer.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_ports/mem.h" +#include "config/aom_config.h" +#include "config/av1_rtcd.h" + +#include "av1/common/reconinter.h" +#include "av1/encoder/context_tree.h" +#include "av1/encoder/av1_temporal_denoiser.h" + +// Compute the sum of all pixel differences of this MB. +static INLINE int horizontal_add_s8x16(const int8x16_t v_sum_diff_total) { +#if AOM_ARCH_AARCH64 + return vaddlvq_s8(v_sum_diff_total); +#else + const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff_total); + const int32x4_t fedc_ba98_7654_3210 = vpaddlq_s16(fe_dc_ba_98_76_54_32_10); + const int64x2_t fedcba98_76543210 = vpaddlq_s32(fedc_ba98_7654_3210); + const int64x1_t x = vqadd_s64(vget_high_s64(fedcba98_76543210), + vget_low_s64(fedcba98_76543210)); + const int sum_diff = vget_lane_s32(vreinterpret_s32_s64(x), 0); + return sum_diff; +#endif +} + +// Denoise a 16x1 vector. +static INLINE int8x16_t denoiser_16x1_neon( + const uint8_t *sig, const uint8_t *mc_running_avg_y, uint8_t *running_avg_y, + const uint8x16_t v_level1_threshold, const uint8x16_t v_level2_threshold, + const uint8x16_t v_level3_threshold, const uint8x16_t v_level1_adjustment, + const uint8x16_t v_delta_level_1_and_2, + const uint8x16_t v_delta_level_2_and_3, int8x16_t v_sum_diff_total) { + const uint8x16_t v_sig = vld1q_u8(sig); + const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y); + + /* Calculate absolute difference and sign masks. */ + const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y); + const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg_y); + const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg_y); + + /* Figure out which level that put us in. */ + const uint8x16_t v_level1_mask = vcleq_u8(v_level1_threshold, v_abs_diff); + const uint8x16_t v_level2_mask = vcleq_u8(v_level2_threshold, v_abs_diff); + const uint8x16_t v_level3_mask = vcleq_u8(v_level3_threshold, v_abs_diff); + + /* Calculate absolute adjustments for level 1, 2 and 3. */ + const uint8x16_t v_level2_adjustment = + vandq_u8(v_level2_mask, v_delta_level_1_and_2); + const uint8x16_t v_level3_adjustment = + vandq_u8(v_level3_mask, v_delta_level_2_and_3); + const uint8x16_t v_level1and2_adjustment = + vaddq_u8(v_level1_adjustment, v_level2_adjustment); + const uint8x16_t v_level1and2and3_adjustment = + vaddq_u8(v_level1and2_adjustment, v_level3_adjustment); + + /* Figure adjustment absolute value by selecting between the absolute + * difference if in level0 or the value for level 1, 2 and 3. + */ + const uint8x16_t v_abs_adjustment = + vbslq_u8(v_level1_mask, v_level1and2and3_adjustment, v_abs_diff); + + /* Calculate positive and negative adjustments. Apply them to the signal + * and accumulate them. Adjustments are less than eight and the maximum + * sum of them (7 * 16) can fit in a signed char. + */ + const uint8x16_t v_pos_adjustment = + vandq_u8(v_diff_pos_mask, v_abs_adjustment); + const uint8x16_t v_neg_adjustment = + vandq_u8(v_diff_neg_mask, v_abs_adjustment); + + uint8x16_t v_running_avg_y = vqaddq_u8(v_sig, v_pos_adjustment); + v_running_avg_y = vqsubq_u8(v_running_avg_y, v_neg_adjustment); + + /* Store results. */ + vst1q_u8(running_avg_y, v_running_avg_y); + + /* Sum all the accumulators to have the sum of all pixel differences + * for this macroblock. + */ + { + const int8x16_t v_sum_diff = + vqsubq_s8(vreinterpretq_s8_u8(v_pos_adjustment), + vreinterpretq_s8_u8(v_neg_adjustment)); + v_sum_diff_total = vaddq_s8(v_sum_diff_total, v_sum_diff); + } + return v_sum_diff_total; +} + +static INLINE int8x16_t denoiser_adjust_16x1_neon( + const uint8_t *sig, const uint8_t *mc_running_avg_y, uint8_t *running_avg_y, + const uint8x16_t k_delta, int8x16_t v_sum_diff_total) { + uint8x16_t v_running_avg_y = vld1q_u8(running_avg_y); + const uint8x16_t v_sig = vld1q_u8(sig); + const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y); + + /* Calculate absolute difference and sign masks. */ + const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y); + const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg_y); + const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg_y); + // Clamp absolute difference to delta to get the adjustment. + const uint8x16_t v_abs_adjustment = vminq_u8(v_abs_diff, (k_delta)); + + const uint8x16_t v_pos_adjustment = + vandq_u8(v_diff_pos_mask, v_abs_adjustment); + const uint8x16_t v_neg_adjustment = + vandq_u8(v_diff_neg_mask, v_abs_adjustment); + + v_running_avg_y = vqsubq_u8(v_running_avg_y, v_pos_adjustment); + v_running_avg_y = vqaddq_u8(v_running_avg_y, v_neg_adjustment); + + /* Store results. */ + vst1q_u8(running_avg_y, v_running_avg_y); + + { + const int8x16_t v_sum_diff = + vqsubq_s8(vreinterpretq_s8_u8(v_neg_adjustment), + vreinterpretq_s8_u8(v_pos_adjustment)); + v_sum_diff_total = vaddq_s8(v_sum_diff_total, v_sum_diff); + } + return v_sum_diff_total; +} + +// Denoise 8x8 and 8x16 blocks. +static int av1_denoiser_8xN_neon(const uint8_t *sig, int sig_stride, + const uint8_t *mc_running_avg_y, + int mc_avg_y_stride, uint8_t *running_avg_y, + int avg_y_stride, int increase_denoising, + BLOCK_SIZE bs, int motion_magnitude, + int width) { + int sum_diff_thresh, r, sum_diff = 0; + const int shift_inc = + (increase_denoising && motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD) + ? 1 + : 0; + uint8_t sig_buffer[8][16], mc_running_buffer[8][16], running_buffer[8][16]; + + const uint8x16_t v_level1_adjustment = vmovq_n_u8( + (motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD) ? 4 + shift_inc : 3); + const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1); + const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2); + const uint8x16_t v_level1_threshold = vdupq_n_u8(4 + shift_inc); + const uint8x16_t v_level2_threshold = vdupq_n_u8(8); + const uint8x16_t v_level3_threshold = vdupq_n_u8(16); + + const int b_height = block_size_high[bs] >> 1; + + int8x16_t v_sum_diff_total = vdupq_n_s8(0); + + for (r = 0; r < b_height; ++r) { + memcpy(sig_buffer[r], sig, width); + memcpy(sig_buffer[r] + width, sig + sig_stride, width); + memcpy(mc_running_buffer[r], mc_running_avg_y, width); + memcpy(mc_running_buffer[r] + width, mc_running_avg_y + mc_avg_y_stride, + width); + memcpy(running_buffer[r], running_avg_y, width); + memcpy(running_buffer[r] + width, running_avg_y + avg_y_stride, width); + v_sum_diff_total = denoiser_16x1_neon( + sig_buffer[r], mc_running_buffer[r], running_buffer[r], + v_level1_threshold, v_level2_threshold, v_level3_threshold, + v_level1_adjustment, v_delta_level_1_and_2, v_delta_level_2_and_3, + v_sum_diff_total); + { + const uint8x16_t v_running_buffer = vld1q_u8(running_buffer[r]); + const uint8x8_t v_running_buffer_high = vget_high_u8(v_running_buffer); + const uint8x8_t v_running_buffer_low = vget_low_u8(v_running_buffer); + vst1_u8(running_avg_y, v_running_buffer_low); + vst1_u8(running_avg_y + avg_y_stride, v_running_buffer_high); + } + // Update pointers for next iteration. + sig += (sig_stride << 1); + mc_running_avg_y += (mc_avg_y_stride << 1); + running_avg_y += (avg_y_stride << 1); + } + + { + sum_diff = horizontal_add_s8x16(v_sum_diff_total); + sum_diff_thresh = total_adj_strong_thresh(bs, increase_denoising); + if (abs(sum_diff) > sum_diff_thresh) { + // Before returning to copy the block (i.e., apply no denoising), + // check if we can still apply some (weaker) temporal filtering to + // this block, that would otherwise not be denoised at all. Simplest + // is to apply an additional adjustment to running_avg_y to bring it + // closer to sig. The adjustment is capped by a maximum delta, and + // chosen such that in most cases the resulting sum_diff will be + // within the acceptable range given by sum_diff_thresh. + + // The delta is set by the excess of absolute pixel diff over the + // threshold. + const int delta = + ((abs(sum_diff) - sum_diff_thresh) >> num_pels_log2_lookup[bs]) + 1; + // Only apply the adjustment for max delta up to 3. + if (delta < 4) { + const uint8x16_t k_delta = vmovq_n_u8(delta); + running_avg_y -= avg_y_stride * (b_height << 1); + for (r = 0; r < b_height; ++r) { + v_sum_diff_total = denoiser_adjust_16x1_neon( + sig_buffer[r], mc_running_buffer[r], running_buffer[r], k_delta, + v_sum_diff_total); + { + const uint8x16_t v_running_buffer = vld1q_u8(running_buffer[r]); + const uint8x8_t v_running_buffer_high = + vget_high_u8(v_running_buffer); + const uint8x8_t v_running_buffer_low = + vget_low_u8(v_running_buffer); + vst1_u8(running_avg_y, v_running_buffer_low); + vst1_u8(running_avg_y + avg_y_stride, v_running_buffer_high); + } + // Update pointers for next iteration. + running_avg_y += (avg_y_stride << 1); + } + sum_diff = horizontal_add_s8x16(v_sum_diff_total); + if (abs(sum_diff) > sum_diff_thresh) { + return COPY_BLOCK; + } + } else { + return COPY_BLOCK; + } + } + } + + return FILTER_BLOCK; +} + +// Denoise 16x16, to 128x128 blocks. +static int av1_denoiser_NxM_neon(const uint8_t *sig, int sig_stride, + const uint8_t *mc_running_avg_y, + int mc_avg_y_stride, uint8_t *running_avg_y, + int avg_y_stride, int increase_denoising, + BLOCK_SIZE bs, int motion_magnitude) { + const int shift_inc = + (increase_denoising && motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD) + ? 1 + : 0; + const uint8x16_t v_level1_adjustment = vmovq_n_u8( + (motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD) ? 4 + shift_inc : 3); + const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1); + const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2); + const uint8x16_t v_level1_threshold = vmovq_n_u8(4 + shift_inc); + const uint8x16_t v_level2_threshold = vdupq_n_u8(8); + const uint8x16_t v_level3_threshold = vdupq_n_u8(16); + + const int b_width = block_size_wide[bs]; + const int b_height = block_size_high[bs]; + const int b_width_shift4 = b_width >> 4; + + int8x16_t v_sum_diff_total[8][8]; + int r, c, sum_diff = 0; + + for (r = 0; r < 8; ++r) { + for (c = 0; c < b_width_shift4; ++c) { + v_sum_diff_total[c][r] = vdupq_n_s8(0); + } + } + + for (r = 0; r < b_height; ++r) { + for (c = 0; c < b_width_shift4; ++c) { + v_sum_diff_total[c][r >> 4] = denoiser_16x1_neon( + sig, mc_running_avg_y, running_avg_y, v_level1_threshold, + v_level2_threshold, v_level3_threshold, v_level1_adjustment, + v_delta_level_1_and_2, v_delta_level_2_and_3, + v_sum_diff_total[c][r >> 4]); + + // Update pointers for next iteration. + sig += 16; + mc_running_avg_y += 16; + running_avg_y += 16; + } + + if ((r & 0xf) == 0xf || (bs == BLOCK_16X8 && r == 7)) { + for (c = 0; c < b_width_shift4; ++c) { + sum_diff += horizontal_add_s8x16(v_sum_diff_total[c][r >> 4]); + } + } + + // Update pointers for next iteration. + sig = sig - b_width + sig_stride; + mc_running_avg_y = mc_running_avg_y - b_width + mc_avg_y_stride; + running_avg_y = running_avg_y - b_width + avg_y_stride; + } + + { + const int sum_diff_thresh = total_adj_strong_thresh(bs, increase_denoising); + if (abs(sum_diff) > sum_diff_thresh) { + const int delta = + ((abs(sum_diff) - sum_diff_thresh) >> num_pels_log2_lookup[bs]) + 1; + // Only apply the adjustment for max delta up to 3. + if (delta < 4) { + const uint8x16_t k_delta = vdupq_n_u8(delta); + sig -= sig_stride * b_height; + mc_running_avg_y -= mc_avg_y_stride * b_height; + running_avg_y -= avg_y_stride * b_height; + sum_diff = 0; + + for (r = 0; r < b_height; ++r) { + for (c = 0; c < b_width_shift4; ++c) { + v_sum_diff_total[c][r >> 4] = + denoiser_adjust_16x1_neon(sig, mc_running_avg_y, running_avg_y, + k_delta, v_sum_diff_total[c][r >> 4]); + + // Update pointers for next iteration. + sig += 16; + mc_running_avg_y += 16; + running_avg_y += 16; + } + if ((r & 0xf) == 0xf || (bs == BLOCK_16X8 && r == 7)) { + for (c = 0; c < b_width_shift4; ++c) { + sum_diff += horizontal_add_s8x16(v_sum_diff_total[c][r >> 4]); + } + } + + sig = sig - b_width + sig_stride; + mc_running_avg_y = mc_running_avg_y - b_width + mc_avg_y_stride; + running_avg_y = running_avg_y - b_width + avg_y_stride; + } + + if (abs(sum_diff) > sum_diff_thresh) { + return COPY_BLOCK; + } + } else { + return COPY_BLOCK; + } + } + } + return FILTER_BLOCK; +} + +int av1_denoiser_filter_neon(const uint8_t *sig, int sig_stride, + const uint8_t *mc_avg, int mc_avg_stride, + uint8_t *avg, int avg_stride, + int increase_denoising, BLOCK_SIZE bs, + int motion_magnitude) { + // Rank by frequency of the block type to have an early termination. + if (bs == BLOCK_16X16 || bs == BLOCK_32X32 || bs == BLOCK_64X64 || + bs == BLOCK_128X128 || bs == BLOCK_128X64 || bs == BLOCK_64X128 || + bs == BLOCK_16X32 || bs == BLOCK_16X8 || bs == BLOCK_32X16 || + bs == BLOCK_32X64 || bs == BLOCK_64X32) { + return av1_denoiser_NxM_neon(sig, sig_stride, mc_avg, mc_avg_stride, avg, + avg_stride, increase_denoising, bs, + motion_magnitude); + } else if (bs == BLOCK_8X8 || bs == BLOCK_8X16) { + return av1_denoiser_8xN_neon(sig, sig_stride, mc_avg, mc_avg_stride, avg, + avg_stride, increase_denoising, bs, + motion_magnitude, 8); + } + return COPY_BLOCK; +} diff --git a/third_party/aom/av1/encoder/arm/neon/cnn_neon.c b/third_party/aom/av1/encoder/arm/neon/cnn_neon.c new file mode 100644 index 0000000000..8e686260d0 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/cnn_neon.c @@ -0,0 +1,1144 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> +#include <math.h> +#include <stdbool.h> + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" + +#include "aom_dsp/aom_dsp_common.h" +#include "aom_dsp/arm/sum_neon.h" +#include "av1/common/av1_common_int.h" +#include "av1/encoder/cnn.h" +#include "av1/encoder/partition_cnn_weights.h" + +// The CNN weights used in av1_cnn_convolve_no_maxpool_padding_valid are +// declared (av1_intra_mode_cnn_partition_cnn_layer_[01234]_kernel) in +// partition_cnn_weights.h. However, to enable linear memory access, rearrange +// the weight tables here. +static const float weights_layer_1[] = { + 0.228403f, 0.031690f, -0.251710f, -0.046230f, 0.413294f, -0.236732f, + -0.038291f, 0.210766f, 0.427196f, -0.384319f, -0.439463f, 0.366015f, + 0.112263f, -0.144168f, -0.075017f, 0.119629f, 0.325200f, -0.678246f, + -0.370826f, -0.341362f, -0.503392f, 0.400884f, 0.465214f, -0.360847f, + 0.187100f, -0.190757f, -0.131906f, 0.121492f, -0.303556f, -0.007658f, + 0.380077f, -0.066394f, -0.016043f, -1.490730f, -0.120682f, 0.132062f, + 0.086185f, -0.042766f, -0.087069f, 0.029426f, 0.309583f, -0.029985f, + -0.297429f, -0.018139f, -0.688828f, 0.756607f, 0.706410f, -0.696826f, + -0.087793f, -0.023304f, -0.012332f, -0.018043f, -0.410268f, 0.352143f, + 0.391284f, -0.363178f, -0.295034f, 0.160246f, -0.149446f, 0.260145f, + -0.252249f, 0.190826f, 0.251206f, -0.270796f, -0.979219f, 0.884880f, + 0.962057f, -0.847601f, -0.011053f, 0.118765f, -0.028428f, -0.020138f, + 0.400274f, -0.382845f, -0.462766f, 0.390654f, 0.361223f, -0.320068f, + -0.372084f, 0.313196f, 0.241933f, -0.416614f, -0.008722f, -0.255078f, + 0.078730f, -0.381935f, -0.204577f, 0.159768f, 0.071853f, -0.126294f, + -0.036186f, -0.007900f, 0.380071f, -0.298882f, 0.387941f, -0.267350f, + -0.586802f, 0.477785f, -0.000013f, 0.197296f, -0.079154f, -0.005811f, + -0.044300f, -0.021192f, -0.020879f, -0.005265f, 0.082277f, -0.139132f, + -0.239237f, 0.440234f, -0.542342f, 0.378360f, -0.070974f, 0.272702f, + -0.278939f, -0.044948f, -0.134197f, -0.007172f, -0.353628f, -0.128091f, + 0.357458f, -0.037614f, -0.144983f, 0.220623f, -0.003394f, -0.070166f, + 0.200370f, -0.166037f, 0.224448f, -0.012990f, -0.098853f, 0.008613f, + -0.017669f, 0.070641f, 0.174530f, -0.119822f, -0.065096f, 0.118487f, + -0.024764f, -0.050466f, 0.066631f, -0.075896f, -0.062363f, 0.212604f, + -0.377322f, 0.306306f, -0.399733f, 0.238624f, 0.233571f, -0.344080f, + 0.462491f, -0.565210f, -0.035074f, -0.010459f, 0.084382f, 0.052294f, + 0.065714f, 0.013716f, 0.135036f, 0.000588f, 0.181079f, -0.566344f, + 0.395561f, -0.398509f, 0.450017f, -1.462710f, 1.138280f, -0.447774f, + 0.247936f, -0.417067f, 0.165997f, -0.458632f, -0.018527f, 0.308461f, + 0.541266f, 0.162257f, 0.601786f, -1.275840f, -0.373404f, -0.589747f, + 0.026539f, -0.219327f, 0.142972f, -0.018496f, 0.075204f, -0.775190f, + 0.237307f, -0.348252f, 0.117792f, -0.094332f, 0.363101f, -0.065025f, + 0.816662f, 0.590110f, 0.752202f, -0.308599f, 0.258337f, -0.842085f, + 0.695788f, -0.205615f, 0.093930f, -0.392536f, 0.463093f, -0.432456f, + 0.041660f, -0.827264f, 0.309128f, -0.354658f, 0.451957f, -1.406640f, + 0.773192f, -0.892943f, 0.134856f, -0.467808f, 0.306003f, -0.226560f, + 0.086865f, -0.104102f, 0.148098f, -0.082658f, 0.316655f, -1.028310f, + 0.741566f, -0.345326f, 0.052379f, -0.275613f, 0.191765f, -0.162391f, + 0.000976f, 0.093061f, 0.068649f, 0.033582f, 0.239727f, -0.647769f, + 0.218493f, -0.397120f, 0.268229f, -0.303424f, 0.185393f, -0.314189f, + 0.101728f, -0.163083f, -0.084989f, 0.136783f, -0.264346f, 0.465914f, + 0.220395f, -0.252968f, -0.326661f, 0.271483f, 0.374717f, -0.311570f, + -0.082119f, 0.020870f, 0.091975f, -0.030582f, -0.487148f, 0.198912f, + 0.024554f, -0.749363f, -0.102267f, 0.097787f, 0.141459f, -0.110706f, + 0.079467f, -0.082570f, -0.347567f, 0.341043f, -0.137871f, 0.112319f, + 0.064733f, -0.082869f, 0.269999f, -0.408184f, -0.183443f, 0.180608f, + 0.223345f, -0.357376f, -0.244593f, 0.355348f, -0.072701f, -0.034311f, + 0.096544f, 0.016407f, 0.417550f, -0.367772f, -0.484535f, 0.405977f, + 0.314243f, -0.099622f, -0.192218f, -0.012780f, 0.434551f, -0.399047f, + -0.531499f, 0.484513f, -0.691352f, 0.872823f, 1.207720f, -1.377490f, + 0.006872f, -0.041453f, 0.007845f, 0.007463f, 0.467299f, -0.476372f, + -0.452606f, 0.452357f, 0.447332f, -0.365632f, -0.332435f, 0.300284f, + -0.290504f, 0.255410f, 0.310921f, -0.293717f, -0.616299f, 0.594207f, + 0.461347f, -0.449439f, 0.278455f, 0.285085f, -1.201340f, -0.016463f, + 0.549095f, 0.610375f, -4.608530f, -1.727390f, 0.150404f, -0.012846f, + -0.481148f, -0.182257f, 0.918796f, 0.213872f, 1.050410f, 0.681526f, + -0.458777f, -0.710395f, -2.347200f, -0.277197f, 0.213294f, 0.337551f, + -0.177710f, -0.152136f, 0.167666f, 0.308403f, -1.248500f, -0.565367f, + 0.122054f, 0.087874f, -0.476556f, -0.083548f, -0.358734f, -0.073131f, + -0.146320f, -2.241960f, 0.697639f, 0.545581f, -1.889700f, -0.267725f, + 0.433045f, 0.298224f, -0.338508f, 0.250226f, 0.405675f, 0.447201f, + -1.184690f, -0.473447f, 0.307403f, 0.711236f, -3.191560f, -1.663980f, + 0.165201f, 0.101360f, -0.624451f, -0.173269f, 0.089795f, 0.227478f, + -0.136664f, 0.007907f, 0.131079f, 0.605374f, -2.991620f, -1.723790f, + 0.082428f, 0.006781f, -0.348732f, -0.019271f, -0.032040f, -0.067078f, + -0.437166f, -0.144472f, 0.069844f, 0.194625f, -0.162284f, -0.374656f, + 0.056472f, -0.236524f, -0.114241f, -0.029161f, -0.222078f, -0.053435f, + -0.313938f, -0.555472f, 1.037550f, 0.689968f, 0.575694f, 0.065826f, + -0.659979f, -0.881351f, -0.626417f, -0.953975f, -0.576106f, -0.258708f, + 0.263004f, -0.229847f, 0.463835f, 1.390960f, -2.614480f, -1.272910f, + 0.065780f, -0.058603f, 0.015612f, 0.104703f, 0.198028f, 0.262792f, + 0.253616f, -0.079126f, -0.587381f, -0.739021f, -0.822676f, -0.795512f, + 0.193644f, 0.234643f, -0.034407f, 0.421478f, -0.572610f, -0.290714f, + -0.257803f, -0.644835f, -0.536938f, -0.375899f, -0.651077f, -0.522576f, + 0.562564f, 0.834616f, 0.513893f, 0.649689f, 0.356530f, 0.400716f, + 0.300606f, 0.290505f, 0.584608f, 0.671574f, 0.564584f, 0.419870f, + 0.062061f, 0.018263f, 0.009831f, 0.084103f, -0.128281f, -0.018818f, + -0.187244f, 0.067210f, 0.437147f, 0.442029f, 0.444939f, 0.226661f, + 0.541609f, 0.444280f, 0.302795f, 0.633026f, -0.180374f, 0.265197f, + 0.210404f, -0.118916f, -0.294013f, -0.692627f, -0.402347f, -0.356287f, + 0.387578f, 0.385496f, 0.789542f, 0.690396f, -0.203542f, -0.688546f, + 0.045319f, -0.448747f, -0.157148f, 0.152581f, 0.022360f, 0.058358f, + 0.593007f, 1.131860f, 0.289006f, 1.015560f, 0.144942f, -0.411577f, + 0.264794f, -0.085791f, 0.156996f, 0.200340f, 0.169264f, 0.267615f, + -0.361015f, -0.601842f, -0.442217f, -0.781086f, 0.112938f, 0.385305f, + 0.482454f, 0.470268f, 1.193390f, 0.589642f, 0.127638f, -0.640946f, + 0.540310f, 0.741498f, 0.686937f, 0.435879f, 0.534523f, 0.693119f, + 0.817577f, 0.783109f, 0.021681f, -0.004973f, 0.201236f, -0.086311f, + 0.028628f, 0.227871f, 0.462751f, 0.126832f, -0.389997f, -0.553965f, + -0.343953f, -0.448517f, 0.053129f, -0.115083f, 0.018138f, -0.067131f, + -0.293468f, -0.220700f, 0.074348f, -0.273153f, 0.263637f, 0.122049f, + 0.153025f, 0.076292f, 0.142320f, 0.286734f, 0.100542f, 0.308660f, + -0.759591f, -0.750938f, -0.788799f, -0.853076f, -0.588019f, -0.990063f, + -0.692327f, -0.722904f, 0.084736f, 0.151068f, 0.159606f, 0.147715f, + 1.610180f, 1.950330f, 1.765670f, 2.265110f, 0.008262f, 0.185584f, + 0.039337f, 0.164721f, 0.479446f, 0.314083f, 0.043969f, 0.291320f, + 0.003400f, -0.551190f, 0.060158f, -0.147591f, 0.089117f, 0.042994f, + 0.042802f, 0.127392f, -0.066172f, 0.078370f, 0.051408f, 0.014004f, + 0.086726f, 0.133334f, -0.046733f, 0.155100f, -0.118223f, -0.100778f, + -0.225245f, -0.460397f, 0.892644f, 1.003770f, 0.405155f, 0.517477f, + 0.184585f, 0.279090f, -0.036477f, 0.198703f, 0.027139f, -0.055728f, + -0.022396f, -0.147319f, 2.275540f, 2.014990f, 2.296800f, 2.081730f, + -0.088713f, 0.105729f, -0.027871f, -0.095047f, 0.012429f, 0.014244f, + -0.014755f, -0.003017f, 1.332700f, 1.300040f, 1.464250f, 1.305030f, + 0.032568f, 0.118042f, 0.079632f, -0.089405f, 0.163905f, 0.146608f, + 0.026502f, 0.065307f, -0.056909f, -0.065052f, 0.069851f, -0.082958f, + 0.023419f, -0.026293f, 0.037616f, -0.048096f, -0.073701f, -0.208295f, + -0.782095f, 0.000523f, 0.374131f, 0.420946f, 0.466151f, 0.349651f, + -0.679275f, -0.745827f, -0.379918f, -0.900107f, 0.044070f, -0.347536f, + -1.224390f, 0.740113f, -0.779966f, 0.510920f, -0.968597f, -0.095630f, + 0.120805f, 0.676803f, -0.164827f, 0.172996f, -0.106720f, 0.197527f, + 0.337561f, 0.571094f, -0.279090f, -0.396697f, -0.253083f, -0.690170f, + -0.363291f, 0.516921f, 0.489391f, -0.920628f, 0.497572f, 0.483864f, + -0.125696f, -0.338123f, -0.041517f, -0.534630f, -0.388465f, -0.784554f, + 0.215227f, 0.055088f, 0.179638f, 0.086997f, 0.569313f, 0.572926f, + 0.137182f, -0.045485f, 0.118087f, 0.210383f, 0.212664f, 0.482443f, + 0.151921f, 0.307947f, -0.084656f, -0.386206f, 0.542277f, -0.207005f, + 0.073792f, -1.013240f, 0.303581f, 0.270527f, 0.265985f, 0.332702f, + 0.848609f, 0.686757f, 0.767212f, 0.316901f, -0.502460f, -0.567092f, + -0.484799f, -0.173350f, -0.426863f, 0.222375f, -0.200267f, -0.523758f, + 0.265180f, -0.175648f, -0.229754f, 0.148740f, 0.402515f, 0.028243f, + -0.366109f, 0.157232f, -0.131564f, 0.055136f, 0.211046f, -0.115542f, + 0.322379f, -0.137768f, -0.247832f, 0.070394f, 0.058530f, -0.295023f, + -0.196022f, -0.109097f, 0.261285f, -0.273585f, -0.240632f, 0.258326f, + -0.077364f, 0.071405f, -0.014766f, -0.008751f, -0.203622f, 0.177818f, + 0.116726f, -0.116735f, -0.723616f, -0.700154f, 0.145082f, -0.184949f, + -0.287076f, 0.150405f, 0.258075f, -0.157764f, -0.120909f, 0.105459f, + 0.113288f, -0.092963f, 0.328183f, -0.300115f, -0.361289f, 0.319792f, + -0.048875f, 0.135673f, 0.132539f, -0.162481f, 0.002109f, 0.065048f, + -0.135969f, 0.061558f, 1.510670f, -0.884925f, -0.827022f, 0.190311f, + -0.060088f, -0.033362f, 0.013354f, 0.002847f, 0.353479f, -0.462538f, + -0.319638f, 0.424484f, 0.199540f, -0.073843f, -0.140621f, 0.072133f, + -0.098662f, 0.070613f, 0.031150f, -0.021869f, -0.511253f, 0.503412f, + 0.565963f, -0.576146f, -1.081700f, 0.047670f, 0.266687f, 0.524804f, + -2.361150f, 0.147823f, 0.594717f, 0.956842f, -1.048220f, 0.127083f, + 0.079581f, 0.065419f, 0.176783f, 0.653953f, 0.260967f, 0.537892f, + -1.207580f, 0.245983f, -0.727067f, 0.071755f, -0.343025f, -0.173435f, + 0.215289f, 0.268578f, -1.158560f, 0.039263f, -0.132888f, 0.217132f, + -0.622195f, -0.071256f, 0.317333f, 0.157614f, -1.588250f, 0.316432f, + -0.736720f, -0.041698f, -1.959280f, 0.083451f, 0.570584f, 0.327620f, + -1.262200f, -0.026738f, 0.231198f, 0.326861f, -1.644200f, -0.143833f, + -0.079495f, 0.493026f, -2.488090f, -0.034046f, 0.165884f, 1.074260f, + -1.076980f, 0.248198f, -0.017987f, 0.421900f, -0.105860f, 0.076710f, + 0.002072f, 0.070264f, -1.734750f, 0.227145f, 0.209220f, 0.851459f, + -0.142369f, 0.066502f, 0.027816f, 0.044321f, -0.186591f, -0.100340f, + 0.115580f, 0.192252f, -0.892114f, 0.209531f, -0.308243f, 0.367968f, + -0.721770f, 0.220224f, -0.062744f, 0.133754f, 0.040416f, 0.190428f, + -0.035428f, 0.162974f, 0.116427f, 0.669393f, 0.278891f, 0.856676f, + 1.060390f, 0.936983f, 0.863355f, 0.990560f, -0.147111f, -0.217883f, + 0.355794f, -0.186530f, -0.275614f, -0.095719f, 0.167346f, 0.359078f, + -0.079223f, -0.581596f, -0.213134f, -0.431123f, -0.516443f, -0.388628f, + -0.643821f, -0.202345f, 0.426230f, 0.516923f, 0.548131f, 0.555973f, + 0.022286f, 0.361170f, 0.980065f, 0.648400f, -0.056813f, -0.100310f, + -0.439481f, -0.166454f, 0.412449f, 0.509400f, 0.316208f, 0.470293f, + -0.827838f, -1.078380f, -1.047040f, -1.074560f, 0.274555f, -0.316736f, + 0.128818f, 0.228566f, -0.520967f, -0.731674f, -0.687887f, -0.536388f, + -0.031187f, 0.041404f, 0.047821f, 0.064397f, 0.054230f, 0.105059f, + -0.178671f, 0.176847f, -0.394797f, -0.260255f, -0.333734f, -0.162345f, + -0.444650f, -0.928438f, -0.705840f, -0.833162f, 0.306737f, 0.429699f, + 0.417298f, 0.478469f, 0.420903f, 0.676871f, 0.429677f, 0.616921f, + -0.805199f, -0.643391f, -0.304100f, 0.797599f, -0.172157f, 0.429085f, + -0.750676f, 0.149227f, -0.207898f, -0.022534f, -0.341448f, -0.247976f, + 0.095325f, -0.561120f, 0.599694f, -0.025236f, 0.292346f, -0.312001f, + 0.517478f, 0.301457f, -0.106415f, 0.226263f, -0.184163f, -0.114419f, + -0.322702f, 0.172541f, 0.445573f, 0.157213f, 0.670704f, 0.102174f, + -0.234667f, -0.293311f, 0.769852f, 0.038028f, -0.036741f, -0.228060f, + -0.253335f, 0.424054f, -0.597980f, 0.221007f, -0.114741f, -0.411557f, + -0.592201f, 0.442684f, 0.115491f, -0.106896f, -0.028110f, 0.354751f, + -0.248375f, 0.242570f, -0.155856f, 0.280528f, -0.198742f, 0.588725f, + 0.371065f, 0.078197f, 0.114706f, -0.448021f, 0.065255f, 0.133741f, + -0.227522f, -0.047339f, -0.052849f, 0.309480f, 0.597185f, 0.209182f, + 0.226108f, -0.601036f, -0.431672f, -0.172601f, -0.000174f, 0.194292f, + -0.133937f, 0.130676f, 0.059372f, 0.091381f, 0.098751f, -0.150996f, + 0.170514f, -0.085494f, 0.336576f, 0.484004f, 0.033862f, 0.277473f, + -0.231482f, -0.328385f, -0.332739f, -0.626957f, 0.510167f, 0.575861f, + 0.421494f, 0.482540f, -0.636377f, -0.864661f, -0.694180f, -0.420014f, + -0.132781f, 0.017599f, 0.003538f, 0.486934f, 0.133878f, -0.094622f, + 0.016132f, 0.010117f, 0.156680f, -0.022201f, -0.014621f, 0.228445f, + 0.190826f, 0.171580f, 0.579923f, 0.245428f, 0.322713f, 0.480101f, + 0.406320f, 0.412229f, 0.002334f, -0.022349f, 0.074571f, -0.043828f, + 0.290453f, 0.451749f, 0.530376f, 0.271879f, 0.095144f, 0.169450f, + 0.049482f, 0.114605f, -0.635634f, -0.700768f, -0.558538f, -0.537625f, + 0.190255f, -0.308237f, -0.053703f, 0.212489f, 0.056520f, -0.040019f, + 0.089822f, -0.014155f, -0.376004f, -0.448752f, -0.526717f, -0.571440f, + 0.116482f, 0.162321f, 0.147895f, 0.280527f, 0.159037f, -0.095958f, + 0.007931f, -0.086630f, 0.285625f, 0.514914f, 0.208908f, 0.519251f, + 0.309368f, 0.379777f, 0.350565f, 0.487487f, -0.541494f, -0.421836f, + -0.390001f, -0.500696f, -0.905736f, -0.150439f, -0.942304f, -0.566771f, + 0.484233f, 0.767417f, 0.410477f, 0.670196f, 0.070210f, 0.488836f, + 0.372805f, 0.197631f, 0.337892f, 0.524423f, 0.777219f, -0.260955f, + -0.112981f, -0.060088f, -0.200250f, -0.195671f, 0.007584f, 0.252096f, + 0.235511f, 0.366612f, -0.304979f, -0.211068f, -0.420683f, -0.085370f, + 0.085762f, -0.097549f, -0.802509f, -0.468079f, -0.192787f, -0.069670f, + -0.235162f, -0.077772f, -0.441671f, -0.348479f, -0.431434f, -0.108256f, + -0.133779f, 0.017032f, 0.001964f, -0.120647f, -0.187663f, -0.194985f, + -0.231742f, -0.175288f, -0.162639f, 0.245110f, 0.049951f, 0.104229f, + -0.159634f, -0.076545f, -0.022496f, -0.036532f, -0.147028f, -0.034215f, + 0.028213f, -0.059669f, -0.078259f, 0.062993f, -0.124066f, -0.137362f, + -0.129977f, -0.010532f, -0.049090f, -0.189401f, 0.495471f, 0.615778f, + 0.451437f, 0.803526f, 0.523532f, 0.841339f, 0.699528f, 0.745129f, + 0.246264f, -0.198290f, -0.283620f, 0.189917f, -0.018306f, -0.419097f, + 0.280363f, -0.098085f, 0.138972f, -0.140867f, -0.117025f, 0.098585f, + 0.130979f, 0.268133f, -0.161731f, -0.176629f, -0.357677f, -0.126379f, + 0.553128f, -0.126821f, -0.001511f, -0.010081f, -0.031162f, 0.079203f, + -0.157731f, 0.072865f, 0.535830f, -0.529989f, -0.570075f, 0.295795f, + 0.595613f, -0.449278f, -0.669756f, 0.941452f, 0.356897f, -0.723720f, + -0.115203f, -0.134479f, 0.133048f, 0.109860f, -0.024250f, -0.049732f, + 0.020098f, 0.048356f, -0.048293f, 0.108754f, 0.062548f, -0.238315f, + 0.182700f, 0.312011f, -0.244377f, -0.118012f, 0.012276f, 0.006089f, + 0.098068f, -0.079280f, -0.423987f, -0.411931f, -0.027425f, 0.870280f, + 0.022825f, -0.024481f, -0.036320f, -0.111189f, 0.364539f, -0.244896f, + -0.373060f, 0.266345f, -0.141778f, 0.277549f, 0.059834f, -0.178242f, + -0.686222f, 0.594535f, 0.354546f, -0.272516f, 1.060730f, -1.059810f, + -0.948126f, 0.993267f, 0.116597f, -0.227574f, -0.436144f, -0.333309f, + -0.575746f, -0.828102f, 0.284561f, 0.351668f, -0.080164f, -0.762518f, + -0.511108f, -0.212855f, 0.293892f, -0.548664f, 0.072057f, 0.006748f, + 1.485110f, 0.124687f, 0.727211f, 1.557560f, -0.064383f, -0.022242f, + 0.002921f, -0.151505f, 0.270926f, 0.173632f, -0.640644f, 0.422410f, + -0.240699f, -0.361980f, -0.279864f, -0.055165f, -1.084140f, 0.231705f, + 0.366172f, -0.347698f, -0.097565f, -0.747227f, -0.243033f, 0.941545f, + -0.207460f, -0.353913f, 0.104303f, -0.403151f, 0.203177f, 0.335893f, + -0.229033f, 0.029096f, -0.409634f, -0.179599f, -0.442397f, 0.649114f, + 0.460774f, 0.170906f, -0.043857f, 0.402066f, -0.226896f, -0.199624f, + 0.016650f, 0.207894f, 0.056954f, 0.220329f, 0.374060f, 0.130361f, + -0.303960f, -0.078863f, 0.195410f, 0.729438f, 0.246818f, 0.287730f, + 0.484876f, 0.111488f, -0.168647f, -0.087878f, -0.070089f, -0.341329f, + -0.330280f, 0.259943f, -0.364205f, 0.256555f, -0.756804f, -0.086915f, + 0.777351f, 0.006136f, 0.110348f, 0.248743f, 0.209326f, -0.362741f, + -0.184416f, 0.422446f, 0.565193f, 0.310072f, -0.011212f, -0.765226f, + 0.039466f, 0.301288f, 0.172907f, -1.539450f, 0.606202f, 0.477469f, + 0.045894f, -0.222180f, -0.013192f, -0.064077f, -0.241551f, 0.192914f, + 0.028004f, -0.540538f, 0.437440f, 0.179087f, -0.753204f, -0.001374f, + 1.185930f, -0.151182f, 1.238580f, -1.389900f, 0.277954f, 0.422208f, + 0.041553f, -0.542284f, 0.139019f, -0.148580f, -0.130705f, 0.361830f, + 0.322953f, -0.092371f, 0.120180f, -0.355299f, -0.028057f, 0.128114f, + 0.250947f, -0.349926f, -0.684633f, 0.246175f, 0.186731f, -0.676313f, + 0.060535f, 0.333371f, -0.021172f, -0.421266f, -0.079650f, 0.031359f, + -0.303658f, -0.298286f, 0.119016f, 0.655585f, 0.200175f, -0.887182f, + -0.197539f, -0.318883f, -0.130250f, 0.522487f, -0.092616f, 0.405930f, + -0.281678f, 0.089728f, 0.081814f, -0.781745f, 0.348878f, 0.082274f, + -0.914136f, 1.098810f, 0.855321f, -1.078170f, -0.268018f, 0.246440f, + 0.238347f, -0.027228f, 0.074111f, -0.061197f, -0.063582f, 0.089462f, + -0.040347f, 0.117082f, 0.122772f, -0.162816f, -0.148668f, -0.342856f, + -0.495604f, -1.453630f, -0.045273f, -0.030463f, 0.043766f, 0.047978f, + 0.016910f, -0.009700f, 0.006288f, -0.042556f, 0.632896f, -0.845744f, + -0.516844f, 0.709439f, 0.486166f, -1.203050f, -0.978381f, 0.631876f, + 0.000705f, 0.123858f, -0.001187f, -0.172312f, -0.422668f, 0.241838f, + 0.437400f, -0.268186f, -0.513259f, 0.450209f, 0.542629f, -0.453810f, + -0.207119f, 0.072598f, 0.085066f, -0.018986f, -0.149512f, 0.149521f, + 0.182105f, -0.227200f, -0.363240f, 0.172670f, -0.502932f, 0.689256f, + 0.093760f, -0.090207f, -0.066803f, 0.056759f, -0.002243f, -0.050662f, + -0.059324f, 0.152943f, -0.701150f, 0.712540f, 0.660349f, -0.654970f, + 0.351772f, -0.303383f, -0.311177f, 0.247653f, 0.013035f, 0.034648f, + -0.137832f, 0.041197f, 0.410265f, 0.345129f, 0.653338f, 0.047050f, + 0.140399f, 0.018613f, -0.012431f, -0.113632f, -0.029928f, 0.051564f, + -0.031349f, 0.151944f, -0.160340f, 0.326798f, -0.458067f, 0.636235f, + 0.243184f, 0.514072f, 2.414450f, 1.421980f, -0.001474f, -0.141389f, + -0.104817f, -0.141882f, -0.026395f, 0.053014f, 0.143885f, -0.207774f, + -0.563846f, -0.242514f, -0.436574f, -0.456796f, -0.520646f, 0.282550f, + -0.684924f, 0.061105f, -0.315884f, -0.392624f, 0.009805f, -0.256597f, + -0.146732f, 0.331039f, 0.362342f, 0.270851f, 0.067679f, -0.071331f, + -0.222423f, 0.081286f, -0.208192f, -0.193816f, -0.008201f, -0.309340f, + 0.167556f, 0.106071f, 0.172254f, -0.163790f, -0.142205f, -0.043182f, + 0.096145f, 0.145037f, -0.066015f, -0.073194f, 0.132237f, -0.088522f, + -0.044292f, -0.487128f, 0.033389f, -0.573548f, 0.185449f, 0.273593f, + 0.147503f, 0.457049f, -0.021539f, 0.090786f, 0.009147f, 0.000899f, + 0.018088f, 0.115791f, -0.079165f, 0.139388f, +}; + +static const float weights_layer_2[] = { + 0.153048f, 0.112901f, 0.136781f, 0.154580f, 0.091610f, 0.045165f, + 0.088490f, 0.116991f, -0.463766f, -0.596567f, -0.567008f, -0.630565f, + 0.141874f, 0.095726f, 0.175427f, 0.145027f, -0.969824f, -1.018190f, + -1.073300f, -1.041130f, -0.070545f, -0.123600f, -0.114967f, -0.169453f, + -0.267458f, -0.147730f, -0.161419f, -0.164894f, -0.117508f, -0.204389f, + -0.122695f, -0.163107f, -0.003903f, -0.030470f, -0.037433f, -0.059568f, + 0.138243f, 0.091019f, 0.160372f, 0.141650f, -0.544565f, -0.620004f, + -0.504503f, -0.429979f, -0.099491f, -0.096384f, -0.155265f, -0.188536f, + 0.084923f, 0.038345f, 0.066706f, 0.122083f, 0.267087f, 0.184419f, + 0.261478f, 0.255746f, -0.245894f, -0.114980f, -0.193880f, -0.227785f, + 0.087536f, 0.095712f, 0.106105f, 0.099353f, -0.059473f, -0.173247f, + -0.202386f, -0.076010f, 0.125928f, 0.100793f, 0.119638f, 0.129623f, + 0.136593f, 0.102984f, 0.156550f, 0.140558f, 0.122524f, 0.051596f, + 0.084164f, 0.123630f, 0.072542f, 0.096063f, 0.083236f, 0.087630f, + 0.025900f, 0.023738f, 0.036385f, 0.053077f, -0.029501f, 0.010544f, + -0.010026f, -0.051268f, 0.086302f, 0.109909f, 0.101385f, 0.127513f, + -0.031869f, 0.005340f, -0.056267f, -0.032955f, 0.032748f, 0.023162f, + 0.092118f, -0.001780f, -0.123612f, -0.183433f, -0.202377f, -0.317516f, + 0.129052f, 0.208112f, 0.145582f, 0.175502f, 0.018476f, 0.036349f, + 0.072417f, 0.061194f, 0.086985f, 0.117086f, 0.072465f, 0.129068f, + 0.020182f, 0.052114f, 0.017878f, 0.010478f, -0.001381f, -0.034644f, + 0.025135f, -0.037748f, 0.004973f, 0.024778f, 0.041816f, 0.032111f, + 0.080268f, 0.124998f, 0.105719f, 0.177047f, -0.072114f, -0.011864f, + -0.076846f, -0.089840f, 0.069993f, 0.089362f, 0.088035f, 0.120621f, + 0.065916f, 0.100946f, -0.006784f, -0.007751f, 0.122039f, 0.126482f, + 0.078629f, 0.140299f, 0.074034f, 0.092464f, 0.089798f, 0.108968f, + 0.075729f, 0.057128f, 0.013570f, 0.021195f, 0.068901f, 0.054022f, + 0.029781f, 0.031404f, -0.209998f, -0.208731f, -0.198310f, -0.212454f, + -0.579168f, -0.490190f, -0.607567f, -0.520541f, 0.083863f, 0.056612f, + 0.030366f, 0.061790f, -0.004874f, -0.057203f, -0.060429f, -0.049145f, + 0.080086f, 0.138602f, 0.223796f, 0.133279f, -0.495954f, -0.612093f, + -0.545393f, -0.562310f, 0.070672f, 0.037702f, 0.139013f, 0.080192f, + -0.111387f, -0.048165f, 0.074359f, -0.042125f, 0.113633f, 0.106579f, + 0.042633f, 0.102734f, -0.068220f, 0.128423f, -0.181821f, -0.013260f, + -0.108563f, -0.138667f, -0.109304f, -0.131909f, -0.168667f, -0.126870f, + -0.132533f, -0.167096f, -0.184741f, -0.140890f, -0.125361f, -0.150632f, + 0.309013f, 0.364376f, 0.361102f, 0.271566f, 0.116552f, 0.091160f, + 0.096846f, 0.095954f, 0.046972f, 0.080489f, 0.028766f, -0.012223f, + 0.071379f, 0.041535f, -0.000668f, 0.033698f, -0.013493f, -0.027535f, + -0.025804f, -0.012267f, -0.097465f, -0.099232f, -0.208863f, -0.225201f, + -0.475608f, 0.077358f, -0.002872f, 0.163890f, -0.420298f, 0.072114f, + 0.121601f, -0.016727f, 0.573853f, -0.080196f, 0.193053f, 0.053012f, + -0.454179f, 0.058563f, 0.067265f, 0.141154f, 0.412541f, 0.086933f, + 0.030407f, -0.030413f, 0.478757f, -0.097731f, 0.277072f, -0.086393f, + 0.552604f, -0.334201f, 0.091765f, -0.270262f, -1.395060f, 0.271837f, + -0.005335f, 0.240499f, 0.175442f, -0.326329f, -0.019353f, -0.270338f, + -0.459273f, 0.096183f, 0.153046f, 0.135818f, 0.759028f, -0.177673f, + -0.099966f, 0.103363f, 0.697289f, -0.234184f, -0.048706f, -0.116099f, + -0.282575f, 0.025655f, -0.184759f, 0.040658f, -0.558267f, 0.214087f, + -0.095620f, 0.200522f, 0.278996f, 0.031959f, 0.122936f, -0.209196f, + -0.308217f, 0.092917f, 0.113269f, 0.136274f, -0.037046f, 0.017263f, + -0.194183f, 0.089133f, -0.161244f, 0.042799f, 0.030557f, 0.153545f, + -0.355048f, 0.070928f, -0.152852f, 0.102875f, -0.193649f, 0.007916f, + -0.062952f, 0.050602f, 0.073671f, 0.143045f, -5.978970f, -7.013850f, + 0.058713f, 0.076116f, 0.026445f, -0.056599f, -0.005966f, 0.032234f, + 0.006753f, -0.024528f, 0.120308f, 0.179939f, -6.624630f, -7.638680f, + 0.026359f, 0.020758f, 0.194274f, 0.051489f, -0.008491f, -0.028248f, + -0.061328f, -0.134423f, -0.103951f, -0.110877f, 0.042263f, 0.127016f, + 0.012473f, -0.008595f, 0.031357f, 0.087476f, -0.084022f, -0.015590f, + -0.313546f, 0.120072f, 0.123880f, 0.162148f, -6.596560f, -7.358830f, + 0.004797f, -0.003415f, 0.048455f, 0.026737f, -0.103702f, 0.034416f, + -0.003475f, -0.236827f, 0.005378f, 0.048413f, 0.054612f, -0.079359f, + 0.043707f, 0.001085f, 0.023380f, 0.007785f, 0.025938f, -0.052856f, + -0.033421f, 0.022643f, 0.034161f, 0.127681f, -5.019490f, -5.233580f, + -0.128630f, 0.087741f, -0.239834f, -0.377876f, 0.128082f, 0.142730f, + -0.086819f, -0.350927f, 0.089849f, 0.155776f, -6.155120f, -5.721720f, + 0.056110f, 0.008761f, 0.045579f, 0.016762f, -0.134076f, -0.101551f, + -0.096058f, -0.117146f, 0.003527f, -0.056942f, -0.005578f, 0.071287f, + 0.023776f, -0.028003f, -0.075390f, -0.191160f, -0.089672f, -0.104372f, + -0.104750f, -0.080813f, -0.249824f, -0.124479f, -0.243593f, -0.244284f, + -0.554911f, -0.549095f, -0.564693f, -0.475107f, -0.121771f, -0.143441f, + -0.171170f, -0.120920f, 0.109831f, 0.079708f, 0.327295f, 0.308907f, + -0.178785f, -0.428316f, -0.418882f, -0.366750f, -0.139296f, -0.129645f, + -0.081237f, -0.101533f, -0.006256f, -0.146756f, -0.322110f, -0.338865f, + -0.306085f, -0.319592f, -0.454803f, -0.363560f, -0.018557f, 0.006605f, + -0.131198f, -0.077708f, 0.138160f, 0.119611f, 0.271098f, 0.232168f, + 0.027812f, 0.035390f, -0.202503f, -0.091172f, -0.142020f, -0.159929f, + -0.106404f, -0.107433f, -0.381743f, -0.353222f, -0.484159f, -0.469926f, + -0.234659f, -0.315674f, -0.178327f, -0.213485f, -0.096207f, -0.190944f, + -0.118917f, -0.161288f, 0.015996f, 0.060737f, 0.051390f, 0.060876f, + 0.229289f, 0.282418f, 0.250945f, 0.197273f, 0.045131f, -0.008305f, + 0.072024f, 0.044547f, -0.050010f, 0.055504f, 0.001343f, -0.014445f, + 0.254909f, 0.309091f, 0.228249f, 0.274843f, 0.089778f, -0.046581f, + 0.072714f, 0.126814f, -0.048931f, -0.045743f, -0.151333f, -0.004490f, + 0.179966f, 0.058150f, -0.178622f, -0.088159f, -0.074416f, -0.005821f, + -0.011799f, -0.002225f, -0.069361f, -0.098937f, -0.081575f, -0.034796f, + 0.253792f, 0.301039f, 0.219163f, 0.256027f, 0.058007f, -0.041431f, + 0.040674f, 0.009019f, -0.099670f, -0.099077f, -0.039437f, 0.017946f, + 0.060717f, 0.045796f, 0.109664f, 0.032138f, -0.071094f, 0.023697f, + 0.011335f, -0.030465f, 0.068677f, 0.039345f, -0.045078f, 0.084037f, + 0.135517f, 0.190417f, 0.175578f, 0.155286f, -0.044505f, 0.010826f, + 0.006717f, -0.134715f, 0.068022f, 0.110095f, 0.079966f, 0.034481f, + 0.185804f, 0.188273f, 0.227283f, 0.135935f, 0.033447f, 0.031571f, + -0.014766f, -0.024565f, 0.021792f, 0.017675f, -0.001333f, -0.040069f, + -0.049384f, -0.045256f, -0.014013f, -0.000107f, -0.096928f, -0.111495f, + -0.051225f, -0.060449f, 0.071446f, 0.017294f, -0.004822f, 0.006932f, + 0.020884f, 0.089425f, 0.061097f, -0.038708f, -0.184029f, -0.089541f, + -0.158035f, -0.214607f, -0.377947f, -0.318586f, -0.336977f, -0.323908f, + 0.181612f, 0.140018f, 0.233524f, 0.193366f, -0.254507f, -0.271902f, + -0.197144f, -0.119539f, 0.042162f, 0.000320f, 0.014708f, -0.014228f, + -0.081119f, -0.089326f, 0.001763f, 0.081009f, -0.142618f, -0.160650f, + -0.214597f, -0.202143f, -0.053495f, -0.012819f, -0.071468f, -0.010883f, + 0.072570f, 0.071507f, 0.091045f, 0.083155f, -0.271237f, -0.289211f, + -0.272345f, -0.299411f, 0.031697f, -0.029795f, -0.030045f, -0.013604f, + -0.106843f, -0.045212f, -0.122459f, -0.096936f, 0.059793f, 0.006157f, + 0.028092f, 0.040589f, -0.014560f, -0.008975f, -0.051404f, -0.014309f, + -0.016883f, 0.018332f, 0.040114f, 0.050348f, 0.044921f, -0.002445f, + -0.112396f, 0.014395f, 0.115160f, 0.145350f, -0.166814f, -0.121449f, + 0.155573f, -0.099446f, -0.161661f, 0.187251f, 0.004711f, 0.024318f, + -0.060871f, -0.028311f, -0.098274f, 0.322030f, -0.069242f, -0.153173f, + -0.227428f, -0.293965f, 0.228491f, 0.111413f, -1.354720f, -0.344235f, + 0.866715f, 0.872344f, 0.078789f, -0.384865f, 0.162388f, 0.109018f, + -0.191549f, -0.002638f, 0.305053f, 0.087337f, 0.066506f, -0.055810f, + -0.010984f, -0.056160f, -0.114617f, -0.058478f, 0.022059f, -0.124368f, + -0.130989f, 0.369432f, -0.248898f, -0.003955f, -0.021578f, 0.115991f, + -0.114163f, -0.065232f, 0.339857f, -0.225997f, 0.006282f, -0.125395f, + 0.235082f, -0.347785f, 0.662321f, -0.529182f, 0.153297f, -0.001326f, + -0.026725f, -0.024677f, -0.088065f, -0.116127f, 0.080896f, 0.212542f, + 0.208421f, 0.032047f, -0.211395f, 0.074997f, 0.096659f, 0.096423f, + -0.078643f, 0.106556f, -0.123860f, 0.075609f, 0.066008f, -0.097275f, + -1.000020f, -0.780154f, -0.856922f, -0.964007f, 0.083135f, -0.018922f, + -0.266214f, -0.151480f, 0.051538f, 0.017802f, 0.066774f, -0.021341f, + -0.869494f, -0.935252f, -0.895836f, -0.853871f, -0.160490f, 0.085850f, + -0.029670f, -0.056675f, 0.159989f, 0.166872f, 0.129970f, 0.194377f, + 0.153294f, 0.199593f, 0.037692f, 0.103391f, 0.029335f, -0.085324f, + -0.079326f, -0.077216f, 0.501561f, 0.366168f, 0.330196f, 0.296432f, + -0.977282f, -0.844295f, -1.014870f, -1.098990f, -0.099858f, -0.129552f, + 0.090051f, -0.013378f, 0.081330f, 0.194911f, 0.286501f, 0.177363f, + -0.148250f, -0.111700f, -0.243081f, -0.102918f, 0.161069f, -0.012655f, + -0.071722f, -0.020329f, -0.077828f, -0.041716f, 0.109247f, 0.062229f, + -0.759722f, -0.742756f, -0.563713f, -0.631187f, 0.005911f, 0.268154f, + -0.263769f, 0.087149f, -0.163623f, -0.359600f, -0.464577f, -0.369352f, + -0.515784f, -0.475822f, -0.523485f, -0.649813f, -0.112419f, -0.029285f, + 0.021061f, -0.041515f, 0.149133f, -0.254428f, 0.115776f, -0.061892f, + 0.103675f, -0.283363f, 0.005005f, 0.022034f, -0.178454f, 0.035836f, + -0.113702f, -0.217823f, 0.209407f, -0.296257f, 0.187976f, -0.157370f, + -0.127190f, 0.251780f, 0.055633f, 0.294111f, -0.067773f, 0.467190f, + -0.192625f, -0.071084f, -0.445284f, 0.511090f, -0.319728f, 0.267971f, + 0.494929f, -0.586727f, 0.454543f, -0.520675f, -0.085900f, 0.325989f, + -0.131006f, -0.069501f, 0.199927f, -0.218919f, 0.170055f, -0.106538f, + 0.133312f, 0.127629f, -0.561625f, 0.595666f, -0.090927f, 0.363348f, + -0.249246f, 0.063068f, -0.016458f, -0.291045f, -0.040509f, 0.017866f, + 0.304871f, -0.459214f, 0.214390f, -0.238740f, -0.456541f, 0.545848f, + -0.218026f, 0.202475f, 0.128490f, -0.036417f, 0.173885f, -0.049385f, + 0.235514f, -0.132587f, -0.015066f, 0.164638f, 0.196873f, -0.125330f, + 0.216912f, -0.109398f, 0.121602f, -0.209374f, 0.164400f, -0.123049f, + 0.195520f, -0.212932f, -0.015180f, -0.005784f, 0.049726f, -5.822150f, + 0.124536f, 0.040689f, -0.018560f, -3.155020f, 0.014690f, 0.076202f, + -0.154008f, 1.070630f, -0.071606f, 0.051026f, 0.138285f, -5.836340f, + 0.162173f, 0.085890f, -0.186166f, 0.093221f, 0.019240f, -0.017053f, + -0.090144f, 0.236254f, -0.125344f, 0.056235f, -0.089813f, -0.252281f, + -0.127406f, -0.155088f, 0.009972f, -0.066449f, 0.044222f, 0.025943f, + -0.164921f, 0.165463f, -0.001132f, -0.038386f, 0.115194f, -5.757100f, + 0.163386f, 0.061226f, 0.024626f, 0.132750f, 0.107279f, -0.001622f, + -0.107860f, -0.356009f, -0.138935f, -0.145173f, -0.061198f, -0.646138f, + 0.034279f, 0.078187f, 0.108138f, -0.490444f, 0.074719f, 0.034984f, + -0.109303f, 0.741785f, -0.066939f, 0.015558f, 0.114229f, -4.001080f, + 0.130772f, 0.044675f, -0.165162f, -0.274810f, -0.042987f, -0.048579f, + 0.156603f, -1.288370f, 0.076198f, 0.035065f, 0.032043f, -5.002520f, + 0.086900f, -0.010886f, 0.030850f, -0.782259f, 0.056211f, -0.097759f, + 0.118988f, 0.106638f, 0.091419f, 0.079920f, 0.062325f, 0.097116f, + 0.126035f, 0.122530f, -0.278299f, -0.083314f, -0.300563f, -0.197946f, + 0.081664f, 0.089925f, 0.074754f, 0.074628f, 0.102338f, 0.088845f, + 0.105841f, 0.102381f, 0.003087f, 0.061599f, 0.098326f, 0.040119f, + -0.005298f, -0.028834f, 0.059938f, -0.013668f, -0.585882f, -0.631436f, + -0.742673f, -0.736666f, 0.025071f, 0.066851f, 0.075046f, 0.091360f, + 0.099045f, 0.098261f, 0.106413f, 0.099487f, -0.016742f, -0.097334f, + -0.086152f, -0.212444f, -0.028043f, -0.007362f, 0.003914f, -0.055864f, + 0.034756f, 0.081361f, 0.080183f, 0.061319f, 0.193396f, 0.173716f, + 0.207765f, 0.231701f, -0.074565f, -0.073257f, -0.086470f, -0.083114f, + 0.081489f, 0.078477f, 0.033452f, 0.058835f, -0.069665f, -0.031691f, + -0.111255f, -0.167754f, 0.184179f, 0.174673f, 0.160288f, 0.190893f, + 0.110930f, 0.103495f, 0.098408f, 0.102918f, 0.053764f, 0.089994f, + 0.140308f, 0.124867f, 0.074176f, 0.117460f, -0.160775f, -0.144132f, + -0.099373f, -0.035913f, 0.081237f, 0.062247f, -0.166421f, 0.062125f, + 0.276479f, 0.060955f, 0.066627f, 0.455347f, 0.219953f, 0.109912f, + 0.273931f, 0.233153f, 0.102236f, 0.447606f, -0.352243f, 0.499236f, + -0.931206f, 0.248595f, 0.254047f, 0.061542f, 0.268804f, 0.309517f, + -0.084414f, -0.245828f, -0.144882f, -0.296579f, -0.091628f, -0.142202f, + -0.541764f, -0.407470f, 0.053481f, 0.238955f, 0.150188f, -0.060598f, + 0.196118f, -0.215617f, -0.086238f, -0.263420f, 0.206877f, 0.241788f, + -0.122544f, -0.448790f, 0.286917f, 0.112063f, -0.268408f, -0.041770f, + 0.089161f, 0.355811f, -0.078245f, -0.148490f, -0.407301f, -1.296870f, + -0.633421f, 0.124253f, 0.275402f, 0.223048f, 0.077016f, 0.160766f, + 0.115374f, 0.061053f, -0.231872f, -0.515052f, -0.278331f, -0.235912f, + -0.416372f, -0.284106f, -0.055942f, 0.110698f, -0.428288f, -0.298137f, + -0.018101f, 0.102677f, -0.019639f, 0.013479f, 0.038549f, 0.048682f, + 0.128684f, 0.116416f, 0.044852f, 0.008133f, 0.061597f, 0.083582f, + 0.014953f, 0.063716f, -0.155318f, -0.061732f, 0.084855f, 0.129505f, + 0.068249f, 0.193775f, -0.088631f, -0.446398f, -0.075710f, -0.061327f, + 0.278715f, 0.540366f, 0.618715f, 0.538374f, -0.037843f, 0.062370f, + -0.033184f, 0.119901f, -0.008641f, -0.064789f, 0.087498f, 0.043486f, + 0.247085f, 0.419992f, 0.299935f, 0.234276f, 0.089283f, 0.070357f, + 0.068888f, 0.134311f, 0.109823f, 0.072431f, 0.081676f, 0.091366f, + -1.707980f, -2.213110f, -2.149930f, -1.556870f, 0.226598f, 0.191675f, + 0.192207f, 0.159566f, -0.070194f, -0.136070f, -0.015172f, -0.204272f, + -0.162191f, -0.043313f, -0.158007f, -0.227210f, 0.040398f, 0.043014f, + 0.039439f, -0.035439f, 0.245558f, 0.439691f, 0.219659f, 0.138210f, + -0.048129f, 0.004954f, -0.102860f, -0.185376f, 0.035548f, 0.006821f, + 0.079199f, 0.032901f, 0.039218f, 0.068113f, 0.023075f, -0.037582f, + 0.225181f, 0.164562f, 0.106718f, 0.032684f, 0.013402f, 0.018797f, + 0.076606f, 0.046512f, -0.070024f, 0.099921f, -0.051231f, 0.074167f, + 0.173313f, 0.220212f, 0.142665f, 0.069809f, -0.195130f, -0.007912f, + -0.006764f, -0.063687f, 0.306374f, 0.402035f, 0.273759f, 0.449469f, + 0.114597f, 0.210745f, 0.355326f, 0.271307f, -0.109943f, -0.171912f, + -0.070726f, -0.128932f, 0.138770f, 0.164971f, 0.308516f, 0.332536f, + 0.081537f, 0.096939f, 0.054136f, 0.052226f, 0.109489f, 0.010223f, + 0.168072f, -0.106279f, 0.525568f, 0.704816f, 0.588942f, 0.473398f, + 0.149497f, 0.120835f, 0.080049f, 0.151340f, -0.182038f, -0.191091f, + -0.196505f, -0.198309f, -0.801819f, -1.441620f, -1.107780f, -1.025650f, + 0.035750f, 0.018049f, -0.029033f, -0.067255f, 0.192049f, 0.009664f, + -0.043741f, 0.051557f, 0.082815f, 0.069547f, -0.073379f, 0.010584f, + 0.192128f, 0.208586f, 0.141904f, 0.100763f, 0.046183f, 0.044776f, + -0.033611f, -0.005812f, 0.012966f, 0.030301f, 0.100665f, 0.103641f, + -0.294776f, -0.361573f, -0.420156f, -0.388743f, 0.239287f, 0.191975f, + 0.089644f, 0.117591f, 0.069563f, 0.021480f, 0.100287f, 0.174159f, + -0.013571f, 0.090960f, 0.010232f, -0.034760f, -0.077205f, 0.060632f, + -0.145527f, -0.391110f, -0.143052f, -0.236448f, -0.103902f, -0.188463f, + 0.071311f, -0.080171f, 0.021987f, 0.041767f, -0.419487f, -0.515479f, + -0.205470f, -0.732132f, 0.150901f, 0.107202f, 0.156307f, 0.143672f, + 0.474682f, 0.178137f, 0.150063f, 0.414515f, 0.559891f, 0.697019f, + 0.541231f, 0.505310f, -0.478101f, -0.444267f, -0.586539f, -0.445996f, + -0.451873f, -0.530085f, -0.447980f, -0.364955f, 0.372435f, 0.318894f, + 0.351211f, 0.193961f, 0.212295f, 0.212842f, 0.220003f, 0.243743f, + -0.388628f, -0.789620f, -0.536618f, -0.430691f, 0.247004f, 0.266489f, + 0.261033f, 0.263692f, 0.050089f, 0.048958f, 0.065207f, 0.120180f, + -0.526230f, -0.481969f, -0.422411f, -0.272292f, 0.155593f, 0.229614f, + 0.139579f, 0.171805f, -0.251924f, -0.302067f, -0.126157f, -0.346650f, + -1.195450f, -1.281100f, -0.987911f, -1.478440f, 0.285667f, 0.284802f, + 0.301887f, 0.259556f, -0.194127f, -0.090440f, -0.257959f, -0.259572f, + -0.012273f, -0.049993f, -0.099431f, 0.012506f, 0.081526f, 0.166279f, + 0.042594f, 0.185121f, 0.148830f, 0.073161f, 0.201728f, 0.125747f, + -0.295065f, -0.187585f, -0.333066f, -0.312291f, 0.253458f, 0.321585f, + 0.178844f, 0.219944f, -0.763475f, -0.943374f, -0.816825f, -0.709901f, + -0.166132f, 0.129186f, 0.015405f, -0.065623f, -0.246006f, -0.340385f, + -0.118155f, -0.384905f, -0.233883f, -0.400666f, -0.228597f, -0.228428f, + -0.559083f, -0.377784f, -0.541458f, -0.542870f, 0.067400f, 0.122987f, + 0.180901f, 0.186004f, -0.482910f, -0.424823f, -0.477831f, -0.394719f, + 0.091558f, 0.049248f, 0.049370f, 0.160429f, 0.133641f, 0.096625f, + 0.104429f, 0.100782f, -0.238252f, -0.221459f, -0.196974f, -0.250393f, + -3.071750f, -2.418450f, -0.861410f, -1.051580f, 0.071263f, 0.118014f, + -0.028430f, -0.072073f, -0.074463f, 0.034168f, 0.044089f, -0.091109f, + -3.153840f, -2.945850f, -1.977360f, -1.498850f, -0.083429f, 0.131835f, + -0.063865f, -0.065785f, -0.069346f, -0.015520f, -0.119551f, 0.044881f, + -0.105280f, 0.127516f, 0.005255f, -0.142777f, 0.061055f, -0.117250f, + 0.020454f, 0.157879f, -0.213812f, -0.151783f, 0.028583f, 0.137759f, + -3.248250f, -3.005940f, -1.510540f, -1.475390f, 0.081874f, -0.171465f, + -0.135690f, -0.001989f, -0.227574f, -0.132799f, -0.359742f, -0.137197f, + 0.066324f, 0.039194f, -0.050857f, 0.095166f, 0.044475f, 0.011221f, + 0.054904f, 0.061414f, -0.039189f, 0.123751f, -0.017171f, -0.008494f, + -2.598220f, -2.832670f, -1.622030f, -1.201990f, 0.154313f, -0.021436f, + 0.042190f, 0.143947f, -0.090623f, 0.086853f, 0.143137f, 0.099821f, + -1.732820f, -1.429730f, -0.775125f, -0.648036f, 0.082176f, 0.079448f, + -0.040575f, 0.024511f, -0.064105f, -0.117122f, -0.190323f, -0.182589f, + -0.076430f, -0.095615f, -0.112513f, -0.101581f, 0.143037f, 0.148180f, + 0.430958f, 0.359225f, 0.001403f, -0.080541f, -0.295001f, -0.156706f, + 0.426623f, 0.475597f, 0.455210f, 0.454352f, 0.074365f, 0.099440f, + 0.066348f, -0.007078f, 0.008335f, -0.097116f, -0.133687f, -0.110535f, + 0.204145f, 0.281478f, 0.078886f, 0.112857f, -0.103620f, -0.068247f, + 0.191147f, 0.227593f, -0.011816f, -0.058755f, -0.149477f, -0.101828f, + 0.079878f, 0.304949f, 0.557555f, 0.305288f, -0.150955f, -0.118610f, + 0.052073f, 0.064707f, -0.121728f, -0.151132f, -0.193987f, -0.175046f, + 0.043655f, 0.105270f, -0.120715f, -0.040976f, 0.047776f, -0.004443f, + 0.149606f, 0.111240f, -0.047502f, -0.064146f, -0.151858f, -0.151872f, + -0.160207f, -0.113846f, -0.081585f, -0.006708f, -0.203760f, -0.068597f, + -0.179979f, -0.127779f, -0.062460f, -0.064513f, -0.121479f, -0.111122f, + -0.212384f, -0.229157f, -0.283428f, -0.184891f, +}; + +static const float weights_layer_3[] = { + -0.039388f, 0.033048f, -0.113003f, -0.011642f, 0.170478f, 0.145713f, + 0.040189f, -0.280129f, -0.049050f, -0.043788f, -0.157425f, 0.323829f, + -0.250725f, -0.166349f, 0.101650f, -0.049690f, 0.205606f, 0.281131f, + 0.623204f, 0.993452f, -0.015115f, -0.138995f, 0.009473f, 0.157673f, + -0.024687f, -0.067214f, 0.125566f, -0.317619f, 0.057002f, 0.031202f, + -0.018167f, 0.068542f, 0.011609f, -0.020233f, -0.000428f, -0.035956f, + -0.843274f, -0.800587f, -0.214917f, -0.221250f, 0.031255f, -0.077330f, + -0.074902f, -0.063979f, -0.055562f, 0.679495f, 0.146609f, 1.315330f, + -0.118399f, -0.034539f, -0.050377f, 0.172867f, -0.204607f, -0.034930f, + 0.176014f, 0.089747f, -0.003889f, 0.044980f, 0.002386f, -0.141723f, + -0.035828f, -0.204701f, 0.099813f, 0.123580f, 0.209851f, -0.110989f, + -0.043655f, -0.461118f, -0.139664f, 0.026855f, -0.081714f, 0.207623f, + 0.089942f, 0.253082f, 0.680568f, 0.811360f, -0.090528f, -0.116818f, + -0.432361f, -0.075588f, -0.269924f, -0.276810f, -0.289192f, -0.282570f, + 0.245566f, 0.267216f, 0.238622f, 0.286528f, -0.157605f, -0.200401f, + -0.138924f, -0.185006f, 0.215203f, 0.203316f, 0.209532f, 0.293135f, + 0.928046f, 0.733323f, -0.094120f, 0.036918f, -0.126643f, -0.083371f, + -0.147530f, -0.153195f, 0.097097f, 0.101852f, 0.109160f, 0.105129f, + -0.051869f, -0.064359f, -0.073469f, -0.059591f, 0.102431f, 0.109444f, + 0.113614f, 0.105617f, 0.383311f, 0.325783f, 0.393234f, 0.382508f, + 0.194720f, 0.189672f, 0.217477f, 0.177786f, 0.326461f, 0.114789f, + 0.317061f, 0.048291f, -0.061143f, -0.134641f, -0.067895f, -0.108446f, + 0.082592f, 0.029918f, -0.006580f, 0.015533f, -0.053583f, -0.055540f, + -0.063395f, -0.023157f, -0.064955f, -0.073981f, -0.115452f, -0.086626f, + -0.036616f, 0.008454f, 0.012029f, -0.008039f, -0.207395f, -0.216419f, + -0.205363f, -0.249099f, 0.343308f, 0.413215f, -0.009918f, -0.109978f, + -0.059711f, -0.045089f, -0.029130f, -0.038483f, -0.070323f, -0.099409f, + -0.008849f, -0.063527f, 0.175963f, 0.185335f, 0.149151f, 0.199997f, + -0.027516f, -0.039812f, -0.027760f, -0.047910f, -0.007337f, 0.071065f, + 0.086225f, 0.125539f, 0.151390f, 0.215488f, 0.203450f, 0.045380f, + 0.095761f, 0.107809f, 0.103918f, 0.122383f, 0.116287f, 0.135455f, + 0.115446f, 0.155673f, -0.044648f, -0.027455f, -0.015473f, -0.026657f, + 0.089852f, 0.077459f, 0.077631f, 0.082507f, -0.102761f, -0.054669f, + -0.132223f, -0.024768f, 0.111573f, 0.060467f, 0.107883f, 0.056621f, + 0.219357f, -0.161153f, 0.074379f, -0.118743f, -0.169931f, -0.153995f, + -0.220003f, -0.200186f, 0.032318f, -0.060687f, -0.087550f, -0.038022f, + 0.026633f, -0.005534f, 0.029532f, 0.027081f, 0.011926f, 0.058412f, + 0.010631f, 0.003068f, -0.014911f, 0.063070f, 0.065271f, 0.089550f, + 0.012885f, 0.005320f, -0.037494f, -0.019849f, -0.009624f, -0.059090f, + -0.021222f, -0.088033f, -0.055261f, -0.055113f, -0.047598f, -0.055478f, + -0.023648f, -0.046827f, -0.036572f, -0.057655f, 0.104194f, 0.179800f, + 0.175751f, 0.192851f, -0.016950f, -0.073650f, -0.028592f, -0.088219f, + 0.011130f, 0.061825f, 0.025643f, 0.034183f, 0.095548f, 0.001457f, + -0.132869f, 0.032981f, -0.140178f, -0.105343f, -0.161799f, -0.161983f, + 0.177746f, 0.132903f, 0.135627f, 0.152489f, -0.012532f, -0.068747f, + -0.085849f, -0.095434f, 0.087037f, 0.139497f, 0.111899f, 0.100189f, + -0.024649f, -0.092003f, 0.020783f, -0.115807f, 0.092039f, 0.093943f, + 0.109466f, 0.049639f, -0.133727f, 0.128430f, -0.050546f, 0.190632f, + 0.123733f, 0.082305f, 0.114878f, 0.122572f, 0.201618f, 0.137588f, + 0.065582f, 0.125161f, -0.095179f, -0.120719f, -0.127126f, -0.101961f, + -0.118120f, -0.104833f, -0.179632f, -0.131764f, -0.138096f, -0.147861f, + -0.131512f, -0.153905f, -0.201816f, -0.206641f, -0.196707f, -0.160013f, + -0.212605f, -0.093998f, -0.186258f, -0.076137f, -0.065340f, -0.006969f, + -0.071383f, -0.075005f, +}; + +static const float weights_layer_4[] = { + -0.016102f, -0.022836f, 0.624049f, 0.273485f, 0.222800f, -0.290175f, + -0.518415f, 0.413484f, -0.264495f, 0.498083f, -0.450145f, -0.106419f, + 0.095103f, -0.187451f, 0.145933f, -0.371542f, -0.088871f, 0.184017f, + -0.429625f, -0.110882f, 0.292781f, 0.289588f, 0.185127f, 0.326017f, + -0.432009f, -0.342663f, -0.312206f, 0.004004f, -1.114290f, 0.028497f, + -0.264944f, -0.419611f, 0.046336f, 0.138232f, -0.869528f, 0.425557f, + -0.954838f, -0.186830f, -0.464622f, -0.757107f, -0.432686f, -0.125978f, + -0.402633f, -0.172266f, -0.041749f, -0.822238f, -0.118486f, 0.238617f, + -0.198037f, 0.146347f, 0.405257f, 0.513303f, -0.078876f, -0.300385f, + -0.010293f, -0.183962f, 0.155738f, 0.186797f, -0.086814f, 0.000179f, + 0.123467f, 0.362523f, 0.068805f, 0.371834f, 0.038122f, -0.117867f, + -0.120445f, -0.422322f, -0.131402f, 0.285449f, 0.038957f, 0.008844f, + -0.020197f, 0.187723f, 0.190433f, 0.146532f, -0.091068f, -0.270865f, + -0.194231f, -0.226777f, 0.013548f, 0.248351f, 0.537685f, 0.056316f, + -0.171540f, -0.003865f, 0.406439f, 0.126507f, 0.192780f, 0.149335f, + -0.149602f, 0.255202f, -0.015426f, 0.032335f, -1.791330f, -0.894602f, + -0.196641f, -0.282846f, -0.391100f, -0.040969f, 0.049934f, 0.056348f, + -0.041426f, -0.075159f, -0.658335f, -0.827270f, -0.175029f, -0.427235f, + 0.311201f, 0.560413f, 0.363408f, 0.374580f, -0.433531f, -0.180580f, + 0.142142f, 0.194768f, -0.054118f, -0.376541f, -0.366185f, -0.308782f, + -0.273143f, -0.074097f, 0.009000f, -0.182198f, -0.015616f, -0.003882f, + -0.174340f, -0.354866f, 0.527972f, 0.348355f, 0.091381f, -0.419828f, + -0.530529f, 0.159899f, -0.511867f, -0.104237f, -0.286079f, -0.659039f, + -0.266596f, -0.256557f, -0.600437f, -0.446333f, -0.229629f, 0.024931f, + -0.143716f, -0.415754f, -0.003760f, -0.107195f, -0.666165f, -0.697312f, + -0.650255f, -0.703877f, 0.243402f, 0.426710f, 0.217210f, 0.260255f, + 0.027416f, 0.163147f, 0.132188f, 0.142374f, 0.558627f, 0.065717f, + 0.382781f, -1.192240f, 0.195492f, 0.028439f, 0.278252f, -0.491806f, + 0.497701f, -0.448835f, -0.245079f, -0.014336f, -0.174907f, -0.409633f, + 0.207548f, 0.433813f, 0.459889f, 0.431728f, 0.605050f, 0.485520f, + 0.218548f, 0.437307f, 0.027023f, -0.204251f, 0.012100f, 0.150677f, + -1.097980f, 0.086866f, -1.293130f, -0.372575f, -0.876264f, -0.021818f, + 0.322864f, -0.231043f, -0.271608f, 0.132782f, -0.314895f, 0.396800f, + 0.262788f, -0.317212f, -0.666308f, 0.830742f, 0.319409f, -0.564373f, + -0.178656f, 0.306993f, 0.265634f, -0.332480f, -0.491514f, -0.186745f, + -0.063044f, -0.009321f, 0.074944f, -0.372082f, -0.029479f, 0.081548f, + 0.028172f, -0.233148f, -0.337938f, -0.087695f, 0.596556f, 0.559530f, + 0.139332f, 0.107223f, -0.190915f, 0.137401f, -0.150625f, -0.225484f, + -0.191344f, -0.232535f, 0.126510f, 0.296323f, -0.547901f, -0.653080f, + 0.358514f, 0.726289f, -0.421725f, -0.243620f, 0.236206f, 0.390823f, + -0.076560f, -0.282329f, -0.012460f, -0.428484f, 0.349469f, 0.394629f, + 0.421537f, 0.219632f, -0.117550f, -0.087894f, 0.077155f, 0.016000f, + -0.289137f, -0.092937f, -0.014518f, -0.027111f, 0.210329f, -0.159678f, + 0.013288f, -0.039268f, 0.008112f, 0.003152f, 0.030084f, -0.039859f, + 0.322028f, -0.407797f, 0.447087f, -0.381562f, 0.529297f, -0.520298f, + 0.562865f, -0.616878f, 0.689389f, 0.754262f, 0.138475f, 0.750697f, + -0.760157f, -0.383740f, 0.074219f, 0.556257f, 0.087827f, -0.511826f, + -0.305507f, -0.638214f, 0.114833f, -0.444022f, 0.526612f, -0.604984f, + -0.100415f, 0.037824f, -0.106264f, 0.337615f, 0.070743f, 0.031129f, + 0.281954f, 0.176144f, -0.032833f, -0.073902f, -0.285492f, -0.803803f, + -0.015589f, 0.186077f, -0.033351f, 0.517269f, -1.878800f, -1.685210f, + -0.416581f, 0.158476f, -0.071929f, -0.624353f, -0.122069f, -0.075065f, + 0.311816f, 0.506305f, 0.383896f, 0.259450f, -0.308232f, -0.094221f, + -0.421885f, -0.293573f, +}; + +static const float weights_layer_5[] = { + 0.131894f, 0.078431f, 0.323121f, -0.230680f, -0.684740f, 0.020895f, + 0.364983f, 0.121656f, 0.132448f, -0.731198f, 0.071148f, 0.739642f, + 0.318437f, -0.033021f, -1.037080f, 0.135335f, 0.383582f, 0.287332f, + 0.054042f, -0.825482f, 0.418533f, 0.305606f, 0.041549f, 0.432422f, + -0.826878f, -0.593536f, 0.105657f, 0.125357f, 0.408567f, -0.293338f, + 0.233905f, -0.039609f, 0.547727f, -0.435806f, 0.036160f, 0.220275f, + -0.020337f, -0.619403f, -0.455858f, 0.681455f, 0.543846f, -0.495084f, + 0.251496f, -0.085686f, 0.091395f, -0.476696f, 0.453628f, -0.109663f, + 0.383493f, -0.456563f, -0.212935f, 0.020567f, -0.719564f, -0.377813f, + -0.737511f, 0.765965f, 0.624309f, -0.063679f, -0.055681f, -0.475969f, + -0.069902f, 0.725690f, 0.641094f, 0.439922f, -0.111544f, -0.309061f, + 0.280091f, 0.381416f, 0.481168f, 0.483543f, -0.901267f, -0.499230f, + 0.043449f, -0.372395f, 0.021216f, -0.002200f, -0.524089f, -0.071485f, + -0.273974f, -0.462654f, 0.042369f, -0.138679f, -0.330060f, 0.021886f, + -0.306075f, -0.011130f, -0.260224f, -0.288435f, -0.104039f, -0.183563f, + 0.118990f, -0.531160f, 0.339632f, -0.028374f, 0.159084f, -0.008824f, + -0.791388f, 0.245242f, 0.356510f, 0.469867f, -0.396949f, -0.476146f, + -0.168472f, 1.068400f, 0.474629f, -0.117554f, -0.142453f, -0.306604f, + 0.348525f, -0.111929f, -0.435384f, 0.019952f, -0.260185f, 0.373376f, + 0.109729f, -0.639168f, 0.033392f, -0.082573f, -0.196018f, 0.301637f, + -0.124210f, -0.202515f, -1.221920f, -0.253690f, -0.144864f, 0.287753f, + -0.161206f, -0.213246f, 0.373968f, 0.141397f, -0.248237f, 0.283090f, + -0.008977f, -0.172960f, -0.234146f, -0.720014f, -0.322451f, 0.181083f, + 0.310659f, -0.422646f, -0.719994f, -0.354339f, 0.352739f, 0.230923f, + 0.427013f, -0.660316f, 0.232140f, 0.685896f, 0.660208f, 0.225748f, + -0.918750f, -0.650790f, -0.674525f, -0.450305f, -0.152529f, 0.498480f, + 0.895092f, 0.688242f, 0.669057f, 0.612669f, 0.593484f, 0.318204f, + -0.169294f, 0.388789f, -0.529777f, -0.219706f, -0.044916f, 0.161697f, + -0.145288f, 0.196153f, -0.022212f, -0.434209f, -0.208115f, -0.117745f, + -0.279029f, -0.009506f, 0.137474f, 0.330148f, 0.439258f, 0.345879f, + -0.845131f, -0.215713f, 0.094463f, 0.638604f, 0.882254f, -0.964082f, + -0.383920f, 0.292645f, 0.266341f, 0.747473f, -0.645631f, -0.538896f, + -0.319764f, 0.521880f, 0.460091f, -0.470898f, -0.778283f, -0.061622f, + -0.142433f, 0.210520f, 0.804197f, 0.285840f, -0.138414f, -0.381846f, + -0.499991f, 0.223648f, 0.439025f, 0.321508f, -0.099560f, -0.622893f, + 0.750925f, 0.740994f, 0.140405f, 0.074631f, -0.270223f, -0.829049f, + -0.753355f, -0.258015f, 0.006285f, -0.730573f, -1.107390f, -0.538015f, + -1.005520f, -0.724115f, -0.440183f, -0.395239f, 0.508768f, 0.204620f, + -0.267331f, 0.001740f, -0.838709f, 0.659333f, 0.043739f, -0.024099f, + 0.262431f, 0.252433f, -0.265215f, 0.057289f, -0.428192f, -0.114350f, + -0.011475f, 0.463995f, 0.668833f, -0.604556f, -0.122780f, -0.441645f, + 0.145769f, 0.310450f, -1.003500f, 0.936069f, 0.516604f, -0.643386f, + -0.518571f, 0.306130f, 0.337387f, 0.583400f, -0.366025f, -0.560035f, + -0.262332f, 0.465242f, 0.964332f, -0.545410f, -0.637428f, -0.202695f, + 0.378931f, 0.834604f, 0.000970f, -0.553303f, -0.562879f, 0.221665f, + 0.395160f, 0.446281f, -0.184394f, -0.591780f, 0.170595f, 1.164390f, + 0.227068f, -0.150910f, -0.393690f, -0.131151f, 0.309956f, -0.413518f, + -0.768334f, -0.548975f, 0.245384f, -0.256904f, -0.514790f, -0.102616f, + -0.347625f, 0.420456f, 0.037804f, -0.283200f, -0.578815f, 0.319282f, + 0.674622f, -0.011791f, -0.339329f, 0.466705f, 0.563444f, 0.409660f, + 0.445784f, -0.899507f, -0.605116f, 0.622438f, 0.427385f, -0.062509f, + 0.666570f, 0.057105f, 0.357894f, -0.811016f, -0.421715f, -0.458397f, + 0.288955f, 0.005857f, 0.236331f, 0.107957f, 0.587276f, -0.375800f, + 0.323799f, -0.623363f, 0.254122f, -0.198478f, -0.098436f, -0.282531f, + 0.452453f, -0.163349f, -0.413382f, -0.448732f, -0.528770f, -0.457449f, + -0.619619f, -0.265919f, -0.042760f, 0.438730f, 0.501798f, -0.403851f, + 0.519564f, 0.817314f, 0.366203f, 0.492610f, 0.546929f, 0.853094f, + 0.289000f, 0.453941f, -0.076152f, 0.007226f, -0.183717f, -0.506252f, + -0.599989f, -0.576006f, 0.746488f, 0.631466f, -0.475599f, -0.334991f, + -0.879614f, 0.918957f, 0.473471f, -0.043781f, -0.688234f, -0.925875f, + -0.188081f, 0.050918f, 0.116855f, 0.221413f, -0.066680f, -0.674395f, + -0.481985f, 0.247368f, 0.271129f, 0.637979f, -1.006970f, -0.855441f, + 0.144874f, 0.507424f, 1.506960f, -0.338910f, 0.398203f, 0.738000f, + 0.263193f, -0.425908f, 0.358271f, -1.072900f, -0.816209f, -0.425519f, + 0.264373f, 0.694014f, 0.036333f, 0.635532f, 0.518856f, 0.047585f, + -0.854817f, -0.138202f, 0.006811f, -0.052020f, -0.468498f, 0.489080f, + -0.105778f, 0.357038f, -0.782875f, 0.649049f, -0.562652f, -0.544392f, + -0.328526f, -0.402121f, -0.263172f, -0.668459f, -0.526702f, -0.395829f, + 0.190986f, 0.307766f, -1.001830f, -0.293051f, 0.283334f, 0.572450f, + 0.906095f, -1.144300f, 0.180989f, 0.421092f, 0.684571f, 0.527276f, + -0.122287f, 0.575067f, 0.675221f, 0.755029f, 0.094957f, 0.481403f, + 0.825155f, 0.755035f, 0.641420f, 0.034497f, 0.518783f, 0.283800f, + 0.293733f, -0.074778f, -0.268720f, 0.798921f, 0.317714f, -0.236391f, + -0.375071f, -0.414600f, 0.223413f, -0.349044f, -0.191033f, -0.391779f, + -0.596894f, -0.378608f, -0.185920f, -0.822171f, -0.754962f, -0.167706f, + 0.755378f, 0.671847f, 0.969414f, 0.793048f, 1.078610f, -0.418963f, + 0.367648f, 0.217645f, 0.294232f, 0.113027f, 0.060312f, -0.327488f, + -0.305035f, -0.243600f, -0.020588f, -0.326324f, -0.417534f, -0.425868f, + -0.404614f, -0.346750f, -0.339145f, -0.348094f, -0.527290f, -0.617825f, + -0.258342f, -0.200753f, -0.249779f, -0.321039f, -0.023117f, -0.004167f, + -0.206788f, -0.612420f, -0.646428f, -0.548969f, -0.158875f, 0.213814f, + -0.084040f, -0.217365f, -0.511895f, -0.653285f, 0.440971f, 0.455591f, + -0.123900f, 0.134097f, -0.251241f, 0.682463f, 0.740614f, 0.991212f, + 0.565984f, 0.592690f, +}; + +static INLINE float32x4_t add_f32x4_x4(const float32x4_t a[4]) { + float32x4_t sum01 = vaddq_f32(a[0], a[1]); + float32x4_t sum23 = vaddq_f32(a[2], a[3]); + return vaddq_f32(sum01, sum23); +} + +static INLINE void av1_cnn_convolve_no_maxpool_padding_valid_2x2_large_neon( + const float **input, int in_width, int in_height, int in_stride, + const float *bias, const int skip_width, const int skip_height, + const int filter_width, const int filter_height, const int in_channels, + const int out_channels, float **output, int out_stride, int start_idx, + const float *weights) { + assert(filter_height == 2 && filter_width == 2); + assert(skip_width == 2 && skip_height == 2); + assert(in_width >= 16); + const int in_size = in_height * in_width; + + do { + const float32x4_t bias_v = vdupq_n_f32(bias[0]); + const float *weight_ptr0 = weights; + const float *in_ptr0 = *input; + float *out_ptr0 = *output; + int h = 0; + + do { + const float *in_ptr1 = in_ptr0; + float *out_ptr1 = out_ptr0; + int w = 0; + + do { + const float *weight_ptr1 = weight_ptr0; + const float *in_ptr2 = in_ptr1; + int k = 0; + float32x4_t sum0[4] = { bias_v, vdupq_n_f32(0), vdupq_n_f32(0), + vdupq_n_f32(0) }; + float32x4_t sum1[4] = { bias_v, vdupq_n_f32(0), vdupq_n_f32(0), + vdupq_n_f32(0) }; + + do { + const float32x4_t weights0 = vld1q_f32(weight_ptr1); + const float32x4_t weights1 = vld1q_f32(weight_ptr1 + 4); + const float32x2_t weights0_lo = vget_low_f32(weights0); + const float32x2_t weights0_hi = vget_high_f32(weights0); + const float32x2_t weights1_lo = vget_low_f32(weights1); + const float32x2_t weights1_hi = vget_high_f32(weights1); + + const float32x4x2_t in0_lo_0 = vld2q_f32(in_ptr2); + const float32x4x2_t in0_hi_0 = vld2q_f32(in_ptr2 + in_stride); + const float32x4x2_t in1_lo_0 = vld2q_f32(in_ptr2 + in_size); + const float32x4x2_t in1_hi_0 = + vld2q_f32(in_ptr2 + in_size + in_stride); + + sum0[0] = vmlaq_lane_f32(sum0[0], in0_lo_0.val[0], weights0_lo, 0); + sum0[0] = vmlaq_lane_f32(sum0[0], in0_lo_0.val[1], weights0_lo, 1); + + sum0[1] = vmlaq_lane_f32(sum0[1], in0_hi_0.val[0], weights0_hi, 0); + sum0[1] = vmlaq_lane_f32(sum0[1], in0_hi_0.val[1], weights0_hi, 1); + + sum0[2] = vmlaq_lane_f32(sum0[2], in1_lo_0.val[0], weights1_lo, 0); + sum0[2] = vmlaq_lane_f32(sum0[2], in1_lo_0.val[1], weights1_lo, 1); + + sum0[3] = vmlaq_lane_f32(sum0[3], in1_hi_0.val[0], weights1_hi, 0); + sum0[3] = vmlaq_lane_f32(sum0[3], in1_hi_0.val[1], weights1_hi, 1); + + const float32x4x2_t in0_lo_1 = vld2q_f32(in_ptr2 + 8); + const float32x4x2_t in0_hi_1 = vld2q_f32(in_ptr2 + in_stride + 8); + const float32x4x2_t in1_lo_1 = vld2q_f32(in_ptr2 + in_size + 8); + const float32x4x2_t in1_hi_1 = + vld2q_f32(in_ptr2 + in_size + in_stride + 8); + + sum1[0] = vmlaq_lane_f32(sum1[0], in0_lo_1.val[0], weights0_lo, 0); + sum1[0] = vmlaq_lane_f32(sum1[0], in0_lo_1.val[1], weights0_lo, 1); + + sum1[1] = vmlaq_lane_f32(sum1[1], in0_hi_1.val[0], weights0_hi, 0); + sum1[1] = vmlaq_lane_f32(sum1[1], in0_hi_1.val[1], weights0_hi, 1); + + sum1[2] = vmlaq_lane_f32(sum1[2], in1_lo_1.val[0], weights1_lo, 0); + sum1[2] = vmlaq_lane_f32(sum1[2], in1_lo_1.val[1], weights1_lo, 1); + + sum1[3] = vmlaq_lane_f32(sum1[3], in1_hi_1.val[0], weights1_hi, 0); + sum1[3] = vmlaq_lane_f32(sum1[3], in1_hi_1.val[1], weights1_hi, 1); + + weight_ptr1 += 8; + in_ptr2 += 2 * in_size; + k += 2; + } while (k < in_channels); + + vst1q_f32(out_ptr1, add_f32x4_x4(sum0)); + vst1q_f32(out_ptr1 + 4, add_f32x4_x4(sum1)); + + out_ptr1 += 8; + in_ptr1 += 8 * skip_width; + w += 8 * skip_width; + } while (w < in_width - filter_width + 1); + + out_ptr0 += out_stride; + in_ptr0 += skip_height * in_stride; + h += skip_height; + } while (h < in_height - filter_height + 1); + + ++bias; + ++output; + weights += in_channels * filter_height * filter_width; + } while (++start_idx < out_channels); +} + +static INLINE void av1_cnn_convolve_no_maxpool_padding_valid_2x2_neon( + const float **input, int in_width, int in_height, int in_stride, + const float *bias, const int skip_width, const int skip_height, + const int filter_width, const int filter_height, const int in_channels, + const int out_channels, float **output, int out_stride, int start_idx, + const float *weights) { + assert(filter_height == 2 && filter_width == 2); + assert(skip_width == 2 && skip_height == 2); + assert(in_width == 8); + const int in_size = in_height * in_width; + do { + const float32x4_t bias_v = vdupq_n_f32(*bias); + const float *weight_ptr0 = weights; + const float *in_ptr0 = *input; + float *out_ptr0 = *output; + int h = 0; + + do { + const float *in_ptr1 = in_ptr0; + float *out_ptr1 = out_ptr0; + int w = 0; + + do { + const float *weight_ptr1 = weight_ptr0; + const float *in_ptr2 = in_ptr1; + int k = 0; + float32x4_t sum[4] = { bias_v, vdupq_n_f32(0), vdupq_n_f32(0), + vdupq_n_f32(0) }; + + do { + const float32x4_t weights0 = vld1q_f32(weight_ptr1); + const float32x4_t weights1 = vld1q_f32(weight_ptr1 + 4); + const float32x2_t weights0_lo = vget_low_f32(weights0); + const float32x2_t weights0_hi = vget_high_f32(weights0); + const float32x2_t weights1_lo = vget_low_f32(weights1); + const float32x2_t weights1_hi = vget_high_f32(weights1); + + const float32x4x2_t in0_lo = vld2q_f32(in_ptr2); + const float32x4x2_t in0_hi = vld2q_f32(in_ptr2 + in_stride); + const float32x4x2_t in1_lo = vld2q_f32(in_ptr2 + in_size); + const float32x4x2_t in1_hi = vld2q_f32(in_ptr2 + in_size + in_stride); + + sum[0] = vmlaq_lane_f32(sum[0], in0_lo.val[0], weights0_lo, 0); + sum[0] = vmlaq_lane_f32(sum[0], in0_lo.val[1], weights0_lo, 1); + + sum[1] = vmlaq_lane_f32(sum[1], in0_hi.val[0], weights0_hi, 0); + sum[1] = vmlaq_lane_f32(sum[1], in0_hi.val[1], weights0_hi, 1); + + sum[2] = vmlaq_lane_f32(sum[2], in1_lo.val[0], weights1_lo, 0); + sum[2] = vmlaq_lane_f32(sum[2], in1_lo.val[1], weights1_lo, 1); + + sum[3] = vmlaq_lane_f32(sum[3], in1_hi.val[0], weights1_hi, 0); + sum[3] = vmlaq_lane_f32(sum[3], in1_hi.val[1], weights1_hi, 1); + + weight_ptr1 += 8; + in_ptr2 += 2 * in_size; + k += 2; + } while (k < in_channels); + + vst1q_f32(out_ptr1, add_f32x4_x4(sum)); + + out_ptr1 += 4; + in_ptr1 += 4 * skip_width; + w += 4 * skip_width; + } while (w < in_width - filter_width + 1); + + out_ptr0 += out_stride; + in_ptr0 += skip_height * in_stride; + h += skip_height; + } while (h < in_height - filter_height + 1); + + ++bias; + ++output; + weights += in_channels * filter_height * filter_width; + } while (++start_idx < out_channels); +} + +static INLINE void av1_cnn_convolve_no_maxpool_padding_valid_5x5_neon( + const float **input, int in_width, int in_height, int in_stride, + const float *bias, const int skip_width, const int skip_height, + const int filter_width, const int filter_height, const int in_channels, + const int out_channels, float **output, int out_stride, int start_idx, + const float *weights) { + assert(filter_height == 5 && filter_width == 5); + assert(skip_width == 4 && skip_height == 4); + assert(in_width >= 16); + assert(in_channels == 1); + (void)in_channels; + + do { + const float32x4_t bias_v = vdupq_n_f32(*bias); + const float *in_ptr0 = *input; + const float *weights_ptr0 = weights; + float *out_ptr0 = *output; + int h = 0; + + do { + const float *in_ptr1 = in_ptr0; + float *out_ptr1 = out_ptr0; + int w = 0; + + do { + float32x4_t sum[2] = { bias_v, vdupq_n_f32(0) }; + + const float32x4_t weight_0_3 = vld1q_f32(weights_ptr0); + const float32x4_t weight_4_7 = vld1q_f32(weights_ptr0 + 4); + const float32x4_t weight_8_11 = vld1q_f32(weights_ptr0 + 8); + const float32x4_t weight_12_15 = vld1q_f32(weights_ptr0 + 12); + const float32x4_t weight_16_19 = vld1q_f32(weights_ptr0 + 16); + const float32x4_t weight_20_23 = vld1q_f32(weights_ptr0 + 20); + + const float32x2_t weight_0_3_lo = vget_low_f32(weight_0_3); + const float32x2_t weight_0_3_hi = vget_high_f32(weight_0_3); + const float32x2_t weight_4_7_lo = vget_low_f32(weight_4_7); + const float32x2_t weight_4_7_hi = vget_high_f32(weight_4_7); + const float32x2_t weight_8_11_lo = vget_low_f32(weight_8_11); + const float32x2_t weight_8_11_hi = vget_high_f32(weight_8_11); + const float32x2_t weight_12_15_lo = vget_low_f32(weight_12_15); + const float32x2_t weight_12_15_hi = vget_high_f32(weight_12_15); + const float32x2_t weight_16_19_lo = vget_low_f32(weight_16_19); + const float32x2_t weight_16_19_hi = vget_high_f32(weight_16_19); + const float32x2_t weight_20_23_lo = vget_low_f32(weight_20_23); + const float32x2_t weight_20_23_hi = vget_high_f32(weight_20_23); + + const float32x4x4_t in0 = vld4q_f32(in_ptr1 + 0 * in_stride); + const float32x4x4_t in1 = vld4q_f32(in_ptr1 + 1 * in_stride); + const float32x4x4_t in2 = vld4q_f32(in_ptr1 + 2 * in_stride); + const float32x4x4_t in3 = vld4q_f32(in_ptr1 + 3 * in_stride); + const float32x4x4_t in4 = vld4q_f32(in_ptr1 + 4 * in_stride); + + const float32x4_t in0_4 = vextq_f32( + in0.val[0], vdupq_n_f32(*(in_ptr1 + 16 + 0 * in_stride)), 1); + const float32x4_t in1_4 = vextq_f32( + in1.val[0], vdupq_n_f32(*(in_ptr1 + 16 + 1 * in_stride)), 1); + const float32x4_t in2_4 = vextq_f32( + in2.val[0], vdupq_n_f32(*(in_ptr1 + 16 + 2 * in_stride)), 1); + const float32x4_t in3_4 = vextq_f32( + in3.val[0], vdupq_n_f32(*(in_ptr1 + 16 + 3 * in_stride)), 1); + const float32x4_t in4_4 = vextq_f32( + in4.val[0], vdupq_n_f32(*(in_ptr1 + 16 + 4 * in_stride)), 1); + + // Kernel row 0. + sum[0] = vmlaq_lane_f32(sum[0], in0.val[0], weight_0_3_lo, 0); + sum[1] = vmlaq_lane_f32(sum[1], in0.val[1], weight_0_3_lo, 1); + sum[0] = vmlaq_lane_f32(sum[0], in0.val[2], weight_0_3_hi, 0); + sum[1] = vmlaq_lane_f32(sum[1], in0.val[3], weight_0_3_hi, 1); + sum[0] = vmlaq_lane_f32(sum[0], in0_4, weight_4_7_lo, 0); + + // Kernel row 1. + sum[1] = vmlaq_lane_f32(sum[1], in1.val[0], weight_4_7_lo, 1); + sum[0] = vmlaq_lane_f32(sum[0], in1.val[1], weight_4_7_hi, 0); + sum[1] = vmlaq_lane_f32(sum[1], in1.val[2], weight_4_7_hi, 1); + sum[0] = vmlaq_lane_f32(sum[0], in1.val[3], weight_8_11_lo, 0); + sum[1] = vmlaq_lane_f32(sum[1], in1_4, weight_8_11_lo, 1); + + // Kernel row 2. + sum[0] = vmlaq_lane_f32(sum[0], in2.val[0], weight_8_11_hi, 0); + sum[1] = vmlaq_lane_f32(sum[1], in2.val[1], weight_8_11_hi, 1); + sum[0] = vmlaq_lane_f32(sum[0], in2.val[2], weight_12_15_lo, 0); + sum[1] = vmlaq_lane_f32(sum[1], in2.val[3], weight_12_15_lo, 1); + sum[0] = vmlaq_lane_f32(sum[0], in2_4, weight_12_15_hi, 0); + + // Kernel row 3. + sum[1] = vmlaq_lane_f32(sum[1], in3.val[0], weight_12_15_hi, 1); + sum[0] = vmlaq_lane_f32(sum[0], in3.val[1], weight_16_19_lo, 0); + sum[1] = vmlaq_lane_f32(sum[1], in3.val[2], weight_16_19_lo, 1); + sum[0] = vmlaq_lane_f32(sum[0], in3.val[3], weight_16_19_hi, 0); + sum[1] = vmlaq_lane_f32(sum[1], in3_4, weight_16_19_hi, 1); + + // Kernel row 4. + sum[0] = vmlaq_lane_f32(sum[0], in4.val[0], weight_20_23_lo, 0); + sum[1] = vmlaq_lane_f32(sum[1], in4.val[1], weight_20_23_lo, 1); + sum[0] = vmlaq_lane_f32(sum[0], in4.val[2], weight_20_23_hi, 0); + sum[1] = vmlaq_lane_f32(sum[1], in4.val[3], weight_20_23_hi, 1); + sum[0] = vmlaq_f32(sum[0], vdupq_n_f32(*(weights_ptr0 + 24)), in4_4); + + vst1q_f32(out_ptr1, vaddq_f32(sum[0], sum[1])); + + out_ptr1 += 4; + in_ptr1 += 4 * skip_width; + w += 4 * skip_width; + } while (w < in_width - filter_width + 1); + + out_ptr0 += out_stride; + in_ptr0 += skip_height * in_stride; + h += skip_height; + } while (h < in_height - filter_height + 1); + + ++output; + ++bias; + weights += 25; + } while (++start_idx < out_channels); +} + +// Neon variant of av1_cnn_convolve_no_maxpool_padding_valid_c(). +// As per the current encoder, av1_cnn_convolve function gets called for +// block size equal to 64x64. av1_cnn_convolve() uses layer config values +// set by av1_intra_mode_cnn_partition_cnn_config. The following are a few +// details related to each layer's config parameters. +// Layer_Number in_size out_size filter_wd filter_ht skip_wd skip_ht +// 0 64x64 16x16 5 5 4 4 +// 1 16x16 8x8 2 2 2 2 +// 2 8x8 4x4 2 2 2 2 +// 3 4x4 2x2 2 2 2 2 +// 4 2x2 1x1 2 2 2 2 +// Here, +// filter_wd = filter_width and filter_ht = filter_height, +// skip_wd = skip_width and skip_ht = skip_height. +void av1_cnn_convolve_no_maxpool_padding_valid_neon( + const float **input, int in_width, int in_height, int in_stride, + const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride, + int start_idx, int cstep, int channel_step) { + assert((layer_config->skip_height == 1 && layer_config->skip_width == 1) || + !layer_config->maxpool); + assert(layer_config->filter_height > 1 || layer_config->filter_width > 1); + assert(layer_config->pad == PADDING_VALID); + assert(channel_step == 1); + assert(cstep == layer_config->in_channels * layer_config->out_channels); + + if (layer_config->filter_width == 5 && layer_config->filter_height == 5 && + layer_config->skip_width == 4 && layer_config->skip_height == 4) { + av1_cnn_convolve_no_maxpool_padding_valid_5x5_neon( + input, in_width, in_height, in_stride, layer_config->bias, + layer_config->skip_width, layer_config->skip_height, + layer_config->filter_width, layer_config->filter_height, + layer_config->in_channels, layer_config->out_channels, output, + out_stride, start_idx, weights_layer_5); + } else if (layer_config->filter_width == 2 && + layer_config->filter_height == 2 && + layer_config->skip_width == 2 && layer_config->skip_height == 2) { + const float *weights = weights_layer_1; + if (layer_config->output_num == + av1_intra_mode_cnn_partition_cnn_config.layer_config[2].output_num) { + weights = weights_layer_2; + } else if ((layer_config->output_num == + av1_intra_mode_cnn_partition_cnn_config.layer_config[3] + .output_num)) { + weights = weights_layer_3; + } else if ((layer_config->output_num == + av1_intra_mode_cnn_partition_cnn_config.layer_config[4] + .output_num)) { + weights = weights_layer_4; + } + if (in_width >= 16) { + av1_cnn_convolve_no_maxpool_padding_valid_2x2_large_neon( + input, in_width, in_height, in_stride, layer_config->bias, + layer_config->skip_width, layer_config->skip_height, + layer_config->filter_width, layer_config->filter_height, + layer_config->in_channels, layer_config->out_channels, output, + out_stride, start_idx, weights); + } else if (in_width == 8) { + av1_cnn_convolve_no_maxpool_padding_valid_2x2_neon( + input, in_width, in_height, in_stride, layer_config->bias, + layer_config->skip_width, layer_config->skip_height, + layer_config->filter_width, layer_config->filter_height, + layer_config->in_channels, layer_config->out_channels, output, + out_stride, start_idx, weights); + } else { + av1_cnn_convolve_no_maxpool_padding_valid_c( + input, in_width, in_height, in_stride, layer_config, output, + out_stride, start_idx, cstep, channel_step); + } + } else { + av1_cnn_convolve_no_maxpool_padding_valid_c( + input, in_width, in_height, in_stride, layer_config, output, out_stride, + start_idx, cstep, channel_step); + } +} diff --git a/third_party/aom/av1/encoder/arm/neon/encodetxb_neon.c b/third_party/aom/av1/encoder/arm/neon/encodetxb_neon.c new file mode 100644 index 0000000000..582863a27c --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/encodetxb_neon.c @@ -0,0 +1,646 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> +#include <math.h> + +#include "config/aom_config.h" + +#include "aom_dsp/arm/mem_neon.h" +#include "av1/common/txb_common.h" +#include "av1/encoder/encodetxb.h" + +void av1_txb_init_levels_neon(const tran_low_t *const coeff, const int width, + const int height, uint8_t *const levels) { + const int stride = height + TX_PAD_HOR; + memset(levels - TX_PAD_TOP * stride, 0, + sizeof(*levels) * TX_PAD_TOP * stride); + memset(levels + stride * width, 0, + sizeof(*levels) * (TX_PAD_BOTTOM * stride + TX_PAD_END)); + + const int32x4_t zeros = vdupq_n_s32(0); + int i = 0; + uint8_t *ls = levels; + const tran_low_t *cf = coeff; + if (height == 4) { + do { + const int32x4_t coeffA = vld1q_s32(cf); + const int32x4_t coeffB = vld1q_s32(cf + height); + const int16x8_t coeffAB = + vcombine_s16(vqmovn_s32(coeffA), vqmovn_s32(coeffB)); + const int16x8_t absAB = vqabsq_s16(coeffAB); + const int8x8_t absABs = vqmovn_s16(absAB); +#if AOM_ARCH_AARCH64 + const int8x16_t absAB8 = + vcombine_s8(absABs, vreinterpret_s8_s32(vget_low_s32(zeros))); + const uint8x16_t lsAB = + vreinterpretq_u8_s32(vzip1q_s32(vreinterpretq_s32_s8(absAB8), zeros)); +#else + const int32x2x2_t absAB8 = + vzip_s32(vreinterpret_s32_s8(absABs), vget_low_s32(zeros)); + const uint8x16_t lsAB = + vreinterpretq_u8_s32(vcombine_s32(absAB8.val[0], absAB8.val[1])); +#endif + vst1q_u8(ls, lsAB); + ls += (stride << 1); + cf += (height << 1); + i += 2; + } while (i < width); + } else if (height == 8) { + do { + const int16x8_t coeffAB = load_tran_low_to_s16q(cf); + const int16x8_t absAB = vqabsq_s16(coeffAB); + const uint8x16_t absAB8 = vreinterpretq_u8_s8(vcombine_s8( + vqmovn_s16(absAB), vreinterpret_s8_s32(vget_low_s32(zeros)))); + vst1q_u8(ls, absAB8); + ls += stride; + cf += height; + i += 1; + } while (i < width); + } else { + do { + int j = 0; + do { + const int16x8_t coeffAB = load_tran_low_to_s16q(cf); + const int16x8_t coeffCD = load_tran_low_to_s16q(cf + 8); + const int16x8_t absAB = vqabsq_s16(coeffAB); + const int16x8_t absCD = vqabsq_s16(coeffCD); + const uint8x16_t absABCD = vreinterpretq_u8_s8( + vcombine_s8(vqmovn_s16(absAB), vqmovn_s16(absCD))); + vst1q_u8((ls + j), absABCD); + j += 16; + cf += 16; + } while (j < height); + *(int32_t *)(ls + height) = 0; + ls += stride; + i += 1; + } while (i < width); + } +} + +// get_4_nz_map_contexts_2d coefficients: +static const DECLARE_ALIGNED(16, uint8_t, c_4_po_2d[2][16]) = { + { 0, 1, 6, 6, 1, 6, 6, 21, 6, 6, 21, 21, 6, 21, 21, 21 }, + { 0, 16, 16, 16, 16, 16, 16, 16, 6, 6, 21, 21, 6, 21, 21, 21 } +}; + +// get_4_nz_map_contexts_hor coefficients: +/* clang-format off */ +#define SIG_COEF_CONTEXTS_2D_X4_051010 \ + (SIG_COEF_CONTEXTS_2D + ((SIG_COEF_CONTEXTS_2D + 5) << 8) + \ + ((SIG_COEF_CONTEXTS_2D + 10) << 16) + ((SIG_COEF_CONTEXTS_2D + 10) << 24)) +/* clang-format on */ + +// get_4_nz_map_contexts_ver coefficients: +static const DECLARE_ALIGNED(16, uint8_t, c_4_po_hor[16]) = { + SIG_COEF_CONTEXTS_2D + 0, SIG_COEF_CONTEXTS_2D + 0, + SIG_COEF_CONTEXTS_2D + 0, SIG_COEF_CONTEXTS_2D + 0, + SIG_COEF_CONTEXTS_2D + 5, SIG_COEF_CONTEXTS_2D + 5, + SIG_COEF_CONTEXTS_2D + 5, SIG_COEF_CONTEXTS_2D + 5, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10 +}; + +// get_8_coeff_contexts_2d coefficients: +// if (width == 8) +static const DECLARE_ALIGNED(16, uint8_t, c_8_po_2d_8[2][16]) = { + { 0, 1, 6, 6, 21, 21, 21, 21, 1, 6, 6, 21, 21, 21, 21, 21 }, + { 6, 6, 21, 21, 21, 21, 21, 21, 6, 21, 21, 21, 21, 21, 21, 21 } +}; +// if (width < 8) +static const DECLARE_ALIGNED(16, uint8_t, c_8_po_2d_l[2][16]) = { + { 0, 11, 6, 6, 21, 21, 21, 21, 11, 11, 6, 21, 21, 21, 21, 21 }, + { 11, 11, 21, 21, 21, 21, 21, 21, 11, 11, 21, 21, 21, 21, 21, 21 } +}; + +// if (width > 8) +static const DECLARE_ALIGNED(16, uint8_t, c_8_po_2d_g[2][16]) = { + { 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16 }, + { 6, 6, 21, 21, 21, 21, 21, 21, 6, 21, 21, 21, 21, 21, 21, 21 } +}; + +// get_4_nz_map_contexts_ver coefficients: +static const DECLARE_ALIGNED(16, uint8_t, c_8_po_ver[16]) = { + SIG_COEF_CONTEXTS_2D + 0, SIG_COEF_CONTEXTS_2D + 5, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 0, SIG_COEF_CONTEXTS_2D + 5, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10 +}; + +// get_16n_coeff_contexts_2d coefficients: +// real_width == real_height +static const DECLARE_ALIGNED(16, uint8_t, c_16_po_2d_e[4][16]) = { + { 0, 1, 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }, + { 1, 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }, + { 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }, + { 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 } +}; + +// real_width < real_height +static const DECLARE_ALIGNED(16, uint8_t, c_16_po_2d_g[3][16]) = { + { 0, 11, 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }, + { 11, 11, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }, + { 11, 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 } +}; + +// real_width > real_height +static const DECLARE_ALIGNED(16, uint8_t, c_16_po_2d_l[3][16]) = { + { 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16 }, + { 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }, + { 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 } +}; + +// get_16n_coeff_contexts_hor coefficients: +static const DECLARE_ALIGNED(16, uint8_t, c_16_po_ver[16]) = { + SIG_COEF_CONTEXTS_2D + 0, SIG_COEF_CONTEXTS_2D + 5, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10, + SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10 +}; + +// end of coefficients declaration area + +static INLINE uint8x16_t load_8bit_4x4_to_1_reg(const uint8_t *const src, + const int byte_stride) { +#if AOM_ARCH_AARCH64 + uint32x4_t v_data = vld1q_u32((uint32_t *)src); + v_data = vld1q_lane_u32((uint32_t *)(src + 1 * byte_stride), v_data, 1); + v_data = vld1q_lane_u32((uint32_t *)(src + 2 * byte_stride), v_data, 2); + v_data = vld1q_lane_u32((uint32_t *)(src + 3 * byte_stride), v_data, 3); + + return vreinterpretq_u8_u32(v_data); +#else + return load_unaligned_u8q(src, byte_stride); +#endif +} + +static INLINE uint8x16_t load_8bit_8x2_to_1_reg(const uint8_t *const src, + const int byte_stride) { +#if AOM_ARCH_AARCH64 + uint64x2_t v_data = vld1q_u64((uint64_t *)src); + v_data = vld1q_lane_u64((uint64_t *)(src + 1 * byte_stride), v_data, 1); + + return vreinterpretq_u8_u64(v_data); +#else + uint8x8_t v_data_low = vld1_u8(src); + uint8x8_t v_data_high = vld1_u8(src + byte_stride); + + return vcombine_u8(v_data_low, v_data_high); +#endif +} + +static INLINE uint8x16_t load_8bit_16x1_to_1_reg(const uint8_t *const src, + const int byte_stride) { + (void)byte_stride; + return vld1q_u8(src); +} + +static INLINE void load_levels_4x4x5(const uint8_t *const src, const int stride, + const ptrdiff_t *const offsets, + uint8x16_t *const level) { + level[0] = load_8bit_4x4_to_1_reg(&src[1], stride); + level[1] = load_8bit_4x4_to_1_reg(&src[stride], stride); + level[2] = load_8bit_4x4_to_1_reg(&src[offsets[0]], stride); + level[3] = load_8bit_4x4_to_1_reg(&src[offsets[1]], stride); + level[4] = load_8bit_4x4_to_1_reg(&src[offsets[2]], stride); +} + +static INLINE void load_levels_8x2x5(const uint8_t *const src, const int stride, + const ptrdiff_t *const offsets, + uint8x16_t *const level) { + level[0] = load_8bit_8x2_to_1_reg(&src[1], stride); + level[1] = load_8bit_8x2_to_1_reg(&src[stride], stride); + level[2] = load_8bit_8x2_to_1_reg(&src[offsets[0]], stride); + level[3] = load_8bit_8x2_to_1_reg(&src[offsets[1]], stride); + level[4] = load_8bit_8x2_to_1_reg(&src[offsets[2]], stride); +} + +static INLINE void load_levels_16x1x5(const uint8_t *const src, + const int stride, + const ptrdiff_t *const offsets, + uint8x16_t *const level) { + level[0] = load_8bit_16x1_to_1_reg(&src[1], stride); + level[1] = load_8bit_16x1_to_1_reg(&src[stride], stride); + level[2] = load_8bit_16x1_to_1_reg(&src[offsets[0]], stride); + level[3] = load_8bit_16x1_to_1_reg(&src[offsets[1]], stride); + level[4] = load_8bit_16x1_to_1_reg(&src[offsets[2]], stride); +} + +static INLINE uint8x16_t get_coeff_contexts_kernel(uint8x16_t *const level) { + const uint8x16_t const_3 = vdupq_n_u8(3); + const uint8x16_t const_4 = vdupq_n_u8(4); + uint8x16_t count; + + count = vminq_u8(level[0], const_3); + level[1] = vminq_u8(level[1], const_3); + level[2] = vminq_u8(level[2], const_3); + level[3] = vminq_u8(level[3], const_3); + level[4] = vminq_u8(level[4], const_3); + count = vaddq_u8(count, level[1]); + count = vaddq_u8(count, level[2]); + count = vaddq_u8(count, level[3]); + count = vaddq_u8(count, level[4]); + + count = vrshrq_n_u8(count, 1); + count = vminq_u8(count, const_4); + return count; +} + +static INLINE void get_4_nz_map_contexts_2d(const uint8_t *levels, + const int width, + const ptrdiff_t *const offsets, + uint8_t *const coeff_contexts) { + const int stride = 4 + TX_PAD_HOR; + const uint8x16_t pos_to_offset_large = vdupq_n_u8(21); + + uint8x16_t pos_to_offset = + (width == 4) ? vld1q_u8(c_4_po_2d[0]) : vld1q_u8(c_4_po_2d[1]); + + uint8x16_t count; + uint8x16_t level[5]; + uint8_t *cc = coeff_contexts; + + assert(!(width % 4)); + + int col = width; + do { + load_levels_4x4x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset); + vst1q_u8(cc, count); + pos_to_offset = pos_to_offset_large; + levels += 4 * stride; + cc += 16; + col -= 4; + } while (col); + + coeff_contexts[0] = 0; +} + +static INLINE void get_4_nz_map_contexts_ver(const uint8_t *levels, + const int width, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = 4 + TX_PAD_HOR; + + const uint8x16_t pos_to_offset = + vreinterpretq_u8_u32(vdupq_n_u32(SIG_COEF_CONTEXTS_2D_X4_051010)); + + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(width % 4)); + + int col = width; + do { + load_levels_4x4x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset); + vst1q_u8(coeff_contexts, count); + levels += 4 * stride; + coeff_contexts += 16; + col -= 4; + } while (col); +} + +static INLINE void get_4_nz_map_contexts_hor(const uint8_t *levels, + const int width, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = 4 + TX_PAD_HOR; + const uint8x16_t pos_to_offset_large = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10); + + uint8x16_t pos_to_offset = vld1q_u8(c_4_po_hor); + + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(width % 4)); + + int col = width; + do { + load_levels_4x4x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset); + vst1q_u8(coeff_contexts, count); + pos_to_offset = pos_to_offset_large; + levels += 4 * stride; + coeff_contexts += 16; + col -= 4; + } while (col); +} + +static INLINE void get_8_coeff_contexts_2d(const uint8_t *levels, + const int width, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = 8 + TX_PAD_HOR; + uint8_t *cc = coeff_contexts; + uint8x16_t count; + uint8x16_t level[5]; + uint8x16_t pos_to_offset[3]; + + assert(!(width % 2)); + + if (width == 8) { + pos_to_offset[0] = vld1q_u8(c_8_po_2d_8[0]); + pos_to_offset[1] = vld1q_u8(c_8_po_2d_8[1]); + } else if (width < 8) { + pos_to_offset[0] = vld1q_u8(c_8_po_2d_l[0]); + pos_to_offset[1] = vld1q_u8(c_8_po_2d_l[1]); + } else { + pos_to_offset[0] = vld1q_u8(c_8_po_2d_g[0]); + pos_to_offset[1] = vld1q_u8(c_8_po_2d_g[1]); + } + pos_to_offset[2] = vdupq_n_u8(21); + + int col = width; + do { + load_levels_8x2x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset[0]); + vst1q_u8(cc, count); + pos_to_offset[0] = pos_to_offset[1]; + pos_to_offset[1] = pos_to_offset[2]; + levels += 2 * stride; + cc += 16; + col -= 2; + } while (col); + + coeff_contexts[0] = 0; +} + +static INLINE void get_8_coeff_contexts_ver(const uint8_t *levels, + const int width, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = 8 + TX_PAD_HOR; + + const uint8x16_t pos_to_offset = vld1q_u8(c_8_po_ver); + + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(width % 2)); + + int col = width; + do { + load_levels_8x2x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset); + vst1q_u8(coeff_contexts, count); + levels += 2 * stride; + coeff_contexts += 16; + col -= 2; + } while (col); +} + +static INLINE void get_8_coeff_contexts_hor(const uint8_t *levels, + const int width, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = 8 + TX_PAD_HOR; + const uint8x16_t pos_to_offset_large = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10); + + uint8x16_t pos_to_offset = vcombine_u8(vdup_n_u8(SIG_COEF_CONTEXTS_2D + 0), + vdup_n_u8(SIG_COEF_CONTEXTS_2D + 5)); + + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(width % 2)); + + int col = width; + do { + load_levels_8x2x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset); + vst1q_u8(coeff_contexts, count); + pos_to_offset = pos_to_offset_large; + levels += 2 * stride; + coeff_contexts += 16; + col -= 2; + } while (col); +} + +static INLINE void get_16n_coeff_contexts_2d(const uint8_t *levels, + const int real_width, + const int real_height, + const int width, const int height, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = height + TX_PAD_HOR; + uint8_t *cc = coeff_contexts; + int col = width; + uint8x16_t pos_to_offset[5]; + uint8x16_t pos_to_offset_large[3]; + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(height % 16)); + + pos_to_offset_large[2] = vdupq_n_u8(21); + if (real_width == real_height) { + pos_to_offset[0] = vld1q_u8(c_16_po_2d_e[0]); + pos_to_offset[1] = vld1q_u8(c_16_po_2d_e[1]); + pos_to_offset[2] = vld1q_u8(c_16_po_2d_e[2]); + pos_to_offset[3] = vld1q_u8(c_16_po_2d_e[3]); + pos_to_offset[4] = pos_to_offset_large[0] = pos_to_offset_large[1] = + pos_to_offset_large[2]; + } else if (real_width < real_height) { + pos_to_offset[0] = vld1q_u8(c_16_po_2d_g[0]); + pos_to_offset[1] = vld1q_u8(c_16_po_2d_g[1]); + pos_to_offset[2] = pos_to_offset[3] = pos_to_offset[4] = + vld1q_u8(c_16_po_2d_g[2]); + pos_to_offset_large[0] = pos_to_offset_large[1] = pos_to_offset_large[2]; + } else { // real_width > real_height + pos_to_offset[0] = pos_to_offset[1] = vld1q_u8(c_16_po_2d_l[0]); + pos_to_offset[2] = vld1q_u8(c_16_po_2d_l[1]); + pos_to_offset[3] = vld1q_u8(c_16_po_2d_l[2]); + pos_to_offset[4] = pos_to_offset_large[2]; + pos_to_offset_large[0] = pos_to_offset_large[1] = vdupq_n_u8(16); + } + + do { + int h = height; + + do { + load_levels_16x1x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset[0]); + vst1q_u8(cc, count); + levels += 16; + cc += 16; + h -= 16; + pos_to_offset[0] = pos_to_offset_large[0]; + } while (h); + + pos_to_offset[0] = pos_to_offset[1]; + pos_to_offset[1] = pos_to_offset[2]; + pos_to_offset[2] = pos_to_offset[3]; + pos_to_offset[3] = pos_to_offset[4]; + pos_to_offset_large[0] = pos_to_offset_large[1]; + pos_to_offset_large[1] = pos_to_offset_large[2]; + levels += TX_PAD_HOR; + } while (--col); + + coeff_contexts[0] = 0; +} + +static INLINE void get_16n_coeff_contexts_ver(const uint8_t *levels, + const int width, const int height, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = height + TX_PAD_HOR; + + const uint8x16_t pos_to_offset_large = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10); + + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(height % 16)); + + int col = width; + do { + uint8x16_t pos_to_offset = vld1q_u8(c_16_po_ver); + + int h = height; + do { + load_levels_16x1x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset); + vst1q_u8(coeff_contexts, count); + pos_to_offset = pos_to_offset_large; + levels += 16; + coeff_contexts += 16; + h -= 16; + } while (h); + + levels += TX_PAD_HOR; + } while (--col); +} + +static INLINE void get_16n_coeff_contexts_hor(const uint8_t *levels, + const int width, const int height, + const ptrdiff_t *const offsets, + uint8_t *coeff_contexts) { + const int stride = height + TX_PAD_HOR; + + uint8x16_t pos_to_offset[3]; + uint8x16_t count; + uint8x16_t level[5]; + + assert(!(height % 16)); + + pos_to_offset[0] = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 0); + pos_to_offset[1] = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 5); + pos_to_offset[2] = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10); + + int col = width; + do { + int h = height; + do { + load_levels_16x1x5(levels, stride, offsets, level); + count = get_coeff_contexts_kernel(level); + count = vaddq_u8(count, pos_to_offset[0]); + vst1q_u8(coeff_contexts, count); + levels += 16; + coeff_contexts += 16; + h -= 16; + } while (h); + + pos_to_offset[0] = pos_to_offset[1]; + pos_to_offset[1] = pos_to_offset[2]; + levels += TX_PAD_HOR; + } while (--col); +} + +// Note: levels[] must be in the range [0, 127], inclusive. +void av1_get_nz_map_contexts_neon(const uint8_t *const levels, + const int16_t *const scan, const uint16_t eob, + const TX_SIZE tx_size, + const TX_CLASS tx_class, + int8_t *const coeff_contexts) { + const int last_idx = eob - 1; + if (!last_idx) { + coeff_contexts[0] = 0; + return; + } + + uint8_t *const coefficients = (uint8_t *const)coeff_contexts; + + const int real_width = tx_size_wide[tx_size]; + const int real_height = tx_size_high[tx_size]; + const int width = get_txb_wide(tx_size); + const int height = get_txb_high(tx_size); + const int stride = height + TX_PAD_HOR; + ptrdiff_t offsets[3]; + + /* coeff_contexts must be 16 byte aligned. */ + assert(!((intptr_t)coeff_contexts & 0xf)); + + if (tx_class == TX_CLASS_2D) { + offsets[0] = 0 * stride + 2; + offsets[1] = 1 * stride + 1; + offsets[2] = 2 * stride + 0; + + if (height == 4) { + get_4_nz_map_contexts_2d(levels, width, offsets, coefficients); + } else if (height == 8) { + get_8_coeff_contexts_2d(levels, width, offsets, coefficients); + } else { + get_16n_coeff_contexts_2d(levels, real_width, real_height, width, height, + offsets, coefficients); + } + } else if (tx_class == TX_CLASS_HORIZ) { + offsets[0] = 2 * stride; + offsets[1] = 3 * stride; + offsets[2] = 4 * stride; + if (height == 4) { + get_4_nz_map_contexts_hor(levels, width, offsets, coefficients); + } else if (height == 8) { + get_8_coeff_contexts_hor(levels, width, offsets, coefficients); + } else { + get_16n_coeff_contexts_hor(levels, width, height, offsets, coefficients); + } + } else { // TX_CLASS_VERT + offsets[0] = 2; + offsets[1] = 3; + offsets[2] = 4; + if (height == 4) { + get_4_nz_map_contexts_ver(levels, width, offsets, coefficients); + } else if (height == 8) { + get_8_coeff_contexts_ver(levels, width, offsets, coefficients); + } else { + get_16n_coeff_contexts_ver(levels, width, height, offsets, coefficients); + } + } + + const int bhl = get_txb_bhl(tx_size); + const int pos = scan[last_idx]; + if (last_idx <= (width << bhl) / 8) + coeff_contexts[pos] = 1; + else if (last_idx <= (width << bhl) / 4) + coeff_contexts[pos] = 2; + else + coeff_contexts[pos] = 3; +} diff --git a/third_party/aom/av1/encoder/arm/neon/highbd_fwd_txfm_neon.c b/third_party/aom/av1/encoder/arm/neon/highbd_fwd_txfm_neon.c new file mode 100644 index 0000000000..aa64a38902 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/highbd_fwd_txfm_neon.c @@ -0,0 +1,2619 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "aom_dsp/arm/transpose_neon.h" +#include "aom_dsp/txfm_common.h" +#include "aom_ports/mem.h" +#include "av1/common/av1_txfm.h" +#include "av1/encoder/av1_fwd_txfm1d_cfg.h" +#include "config/aom_config.h" +#include "config/av1_rtcd.h" +#include "shift_neon.h" +#include "txfm_neon.h" + +static AOM_FORCE_INLINE void transpose_arrays_s32_64x64(const int32x4_t *in, + int32x4_t *out) { + // This is not quite the same as the other transposes defined in + // transpose_neon.h: We only write the low 64x32 sub-matrix since the rest is + // unused by the following row transform. + for (int j = 0; j < 8; ++j) { + for (int i = 0; i < 16; ++i) { + transpose_arrays_s32_4x4(in + 64 * i + 4 * j, out + 64 * j + 4 * i); + } + } +} + +// A note on butterfly helper naming: +// +// butterfly_[weight_indices]_neon +// e.g. butterfly_0312_neon +// ^ Weights are applied as indices 0, 3, 2, 1 +// (see more detail below) +// +// Weight indices are treated as an index into the 4-tuple of the weight +// itself, plus related and negated constants: w=(w0, 1-w0, -w0, w0-1). +// This is then represented in the helper naming by referring to the lane index +// in the loaded tuple that each multiply is performed with: +// +// in0 in1 +// /------------ +// out0 | w[0] w[1] ==> out0 = in0 * w[0] + in1 * w[1] +// out1 | w[2] w[3] ==> out1 = in0 * w[2] + in1 * w[3] +// +// So for indices 0321 from the earlier example, we end up with: +// +// in0 in1 +// /------------------ +// out0 | (lane 0) (lane 3) ==> out0 = in0 * w0 + in1 * (w0-1) +// out1 | (lane 2) (lane 1) ==> out1 = in0 * -w0 + in1 * (1-w0) + +#define butterfly_half_neon(wvec, lane0, lane1, in0, in1, out, v_bit) \ + do { \ + int32x2x2_t wvecs = { { wvec, vneg_s32(wvec) } }; \ + int32x4_t x = vmulq_lane_s32(in0, wvecs.val[lane0 / 2], lane0 % 2); \ + x = vmlaq_lane_s32(x, in1, wvecs.val[lane1 / 2], lane1 % 2); \ + *out = vrshlq_s32(x, v_bit); \ + } while (false) + +static AOM_FORCE_INLINE void butterfly_0112_neon( + const int32_t *cospi, const int widx0, const int32x4_t n0, + const int32x4_t n1, int32x4_t *out0, int32x4_t *out1, + const int32x4_t v_bit) { + int32x2_t w01 = vld1_s32(cospi + 2 * widx0); + butterfly_half_neon(w01, 0, 1, n0, n1, out0, v_bit); + butterfly_half_neon(w01, 1, 2, n0, n1, out1, v_bit); +} + +static AOM_FORCE_INLINE void butterfly_2312_neon( + const int32_t *cospi, const int widx0, const int32x4_t n0, + const int32x4_t n1, int32x4_t *out0, int32x4_t *out1, + const int32x4_t v_bit) { + int32x2_t w01 = vld1_s32(cospi + 2 * widx0); + butterfly_half_neon(w01, 2, 3, n0, n1, out0, v_bit); + butterfly_half_neon(w01, 1, 2, n0, n1, out1, v_bit); +} + +static AOM_FORCE_INLINE void butterfly_0332_neon( + const int32_t *cospi, const int widx0, const int32x4_t n0, + const int32x4_t n1, int32x4_t *out0, int32x4_t *out1, + const int32x4_t v_bit) { + int32x2_t w01 = vld1_s32(cospi + 2 * widx0); + butterfly_half_neon(w01, 0, 3, n0, n1, out0, v_bit); + butterfly_half_neon(w01, 3, 2, n0, n1, out1, v_bit); +} + +static AOM_FORCE_INLINE void butterfly_0130_neon( + const int32_t *cospi, const int widx0, const int32x4_t n0, + const int32x4_t n1, int32x4_t *out0, int32x4_t *out1, + const int32x4_t v_bit) { + int32x2_t w01 = vld1_s32(cospi + 2 * widx0); + butterfly_half_neon(w01, 0, 1, n0, n1, out0, v_bit); + butterfly_half_neon(w01, 3, 0, n0, n1, out1, v_bit); +} + +static AOM_FORCE_INLINE void butterfly_cospi32_0002_neon( + const int32_t *cospi, const int32x4_t n0, const int32x4_t n1, + int32x4_t *out0, int32x4_t *out1, const int32x4_t v_bit) { + int32x2_t w01 = vld1_s32(cospi + 2 * 32); + butterfly_half_neon(w01, 0, 0, n0, n1, out0, v_bit); + butterfly_half_neon(w01, 0, 2, n0, n1, out1, v_bit); +} + +static AOM_FORCE_INLINE void butterfly_cospi32_0222_neon( + const int32_t *cospi, const int32x4_t n0, const int32x4_t n1, + int32x4_t *out0, int32x4_t *out1, const int32x4_t v_bit) { + int32x2_t w01 = vld1_s32(cospi + 2 * 32); + butterfly_half_neon(w01, 0, 2, n0, n1, out0, v_bit); + butterfly_half_neon(w01, 2, 2, n0, n1, out1, v_bit); +} + +static AOM_FORCE_INLINE void round_rect_array_s32_neon(const int32x4_t *input, + int32x4_t *output, + const int size) { + const int32x4_t sqrt2 = vdupq_n_s32(NewSqrt2); + int i = 0; + do { + const int32x4_t r1 = vmulq_s32(input[i], sqrt2); + output[i] = vrshrq_n_s32(r1, NewSqrt2Bits); + } while (++i < size); +} + +static AOM_FORCE_INLINE void round_shift2_rect_array_s32_neon( + const int32x4_t *input, int32x4_t *output, const int size) { + const int32x4_t sqrt2 = vdupq_n_s32(NewSqrt2); + int i = 0; + do { + const int32x4_t r0 = vrshrq_n_s32(input[i], 2); + const int32x4_t r1 = vmulq_s32(r0, sqrt2); + output[i] = vrshrq_n_s32(r1, NewSqrt2Bits); + } while (++i < size); +} + +#define LOAD_BUFFER_4XH(h) \ + static AOM_FORCE_INLINE void load_buffer_4x##h( \ + const int16_t *input, int32x4_t *in, int stride, int fliplr) { \ + if (fliplr) { \ + for (int i = 0; i < (h); ++i) { \ + int16x4_t a = vld1_s16(input + i * stride); \ + a = vrev64_s16(a); \ + in[i] = vshll_n_s16(a, 2); \ + } \ + } else { \ + for (int i = 0; i < (h); ++i) { \ + int16x4_t a = vld1_s16(input + i * stride); \ + in[i] = vshll_n_s16(a, 2); \ + } \ + } \ + } + +// AArch32 does not permit the argument to vshll_n_s16 to be zero, so need to +// avoid the expression even though the compiler can prove that the code path +// is never taken if `shift == 0`. +#define shift_left_long_s16(a, shift) \ + ((shift) == 0 ? vmovl_s16(a) : vshll_n_s16((a), (shift) == 0 ? 1 : (shift))) + +#define LOAD_BUFFER_WXH(w, h, shift) \ + static AOM_FORCE_INLINE void load_buffer_##w##x##h( \ + const int16_t *input, int32x4_t *in, int stride, int fliplr) { \ + assert(w >= 8); \ + if (fliplr) { \ + for (int i = 0; i < (h); ++i) { \ + for (int j = 0; j < (w) / 8; ++j) { \ + int16x8_t a = vld1q_s16(input + i * stride + j * 8); \ + a = vrev64q_s16(a); \ + int j2 = (w) / 8 - j - 1; \ + in[i + (h) * (2 * j2 + 0)] = \ + shift_left_long_s16(vget_high_s16(a), (shift)); \ + in[i + (h) * (2 * j2 + 1)] = \ + shift_left_long_s16(vget_low_s16(a), (shift)); \ + } \ + } \ + } else { \ + for (int i = 0; i < (h); ++i) { \ + for (int j = 0; j < (w) / 8; ++j) { \ + int16x8_t a = vld1q_s16(input + i * stride + j * 8); \ + in[i + (h) * (2 * j + 0)] = \ + shift_left_long_s16(vget_low_s16(a), (shift)); \ + in[i + (h) * (2 * j + 1)] = \ + shift_left_long_s16(vget_high_s16(a), (shift)); \ + } \ + } \ + } \ + } + +LOAD_BUFFER_4XH(4) +LOAD_BUFFER_4XH(8) +LOAD_BUFFER_4XH(16) +LOAD_BUFFER_4XH(32) +LOAD_BUFFER_WXH(8, 8, 2) +LOAD_BUFFER_WXH(16, 16, 2) +LOAD_BUFFER_WXH(32, 64, 0) +LOAD_BUFFER_WXH(64, 32, 2) +LOAD_BUFFER_WXH(64, 64, 0) + +#if !CONFIG_REALTIME_ONLY +LOAD_BUFFER_WXH(16, 64, 0) +LOAD_BUFFER_WXH(64, 16, 2) +#endif // !CONFIG_REALTIME_ONLY + +#define STORE_BUFFER_WXH(w, h) \ + static AOM_FORCE_INLINE void store_buffer_##w##x##h( \ + const int32x4_t *in, int32_t *out, int stride) { \ + for (int i = 0; i < (w); ++i) { \ + for (int j = 0; j < (h) / 4; ++j) { \ + vst1q_s32(&out[i * stride + j * 4], in[i + j * (w)]); \ + } \ + } \ + } + +STORE_BUFFER_WXH(4, 4) +STORE_BUFFER_WXH(8, 4) +STORE_BUFFER_WXH(8, 8) +STORE_BUFFER_WXH(16, 4) +STORE_BUFFER_WXH(16, 16) +STORE_BUFFER_WXH(32, 4) +STORE_BUFFER_WXH(32, 32) +STORE_BUFFER_WXH(64, 32) + +#if !CONFIG_REALTIME_ONLY +STORE_BUFFER_WXH(16, 32) +STORE_BUFFER_WXH(64, 16) +#endif // !CONFIG_REALTIME_ONLY + +static AOM_FORCE_INLINE void highbd_fdct4_x4_neon(const int32x4_t *in, + int32x4_t *out, int bit) { + const int32_t *const cospi = cospi_arr_s32(bit); + const int32x4_t cospi32 = vdupq_n_s32(cospi[2 * 32]); + const int32x2_t cospi16_48 = vld1_s32(&cospi[2 * 16]); + + const int32x4_t a0 = vaddq_s32(in[0], in[3]); + const int32x4_t a1 = vsubq_s32(in[0], in[3]); + const int32x4_t a2 = vaddq_s32(in[1], in[2]); + const int32x4_t a3 = vsubq_s32(in[1], in[2]); + + const int32x4_t b0 = vmulq_s32(a0, cospi32); + const int32x4_t b1 = vmulq_lane_s32(a1, cospi16_48, 1); + const int32x4_t b2 = vmulq_s32(a2, cospi32); + const int32x4_t b3 = vmulq_lane_s32(a3, cospi16_48, 1); + + const int32x4_t c0 = vaddq_s32(b0, b2); + const int32x4_t c1 = vsubq_s32(b0, b2); + const int32x4_t c2 = vmlaq_lane_s32(b3, a1, cospi16_48, 0); + const int32x4_t c3 = vmlsq_lane_s32(b1, a3, cospi16_48, 0); + + const int32x4_t v_bit = vdupq_n_s32(-bit); + const int32x4_t d0 = vrshlq_s32(c0, v_bit); + const int32x4_t d1 = vrshlq_s32(c1, v_bit); + const int32x4_t d2 = vrshlq_s32(c2, v_bit); + const int32x4_t d3 = vrshlq_s32(c3, v_bit); + + out[0] = d0; + out[1] = d2; + out[2] = d1; + out[3] = d3; +} + +static AOM_FORCE_INLINE void highbd_fadst4_x4_neon(const int32x4_t *in, + int32x4_t *out, int bit) { + const int32x4_t sinpi = vld1q_s32(sinpi_arr(bit) + 1); + + const int32x4_t a0 = vaddq_s32(in[0], in[1]); + const int32x4_t a1 = vmulq_lane_s32(in[0], vget_low_s32(sinpi), 0); + const int32x4_t a2 = vmulq_lane_s32(in[0], vget_high_s32(sinpi), 1); + const int32x4_t a3 = vmulq_lane_s32(in[2], vget_high_s32(sinpi), 0); + + const int32x4_t b0 = vmlaq_lane_s32(a1, in[1], vget_low_s32(sinpi), 1); + const int32x4_t b1 = vmlsq_lane_s32(a2, in[1], vget_low_s32(sinpi), 0); + const int32x4_t b2 = vsubq_s32(a0, in[3]); + + const int32x4_t c0 = vmlaq_lane_s32(b0, in[3], vget_high_s32(sinpi), 1); + const int32x4_t c1 = vmlaq_lane_s32(b1, in[3], vget_low_s32(sinpi), 1); + const int32x4_t c2 = vmulq_lane_s32(b2, vget_high_s32(sinpi), 0); + + const int32x4_t d0 = vaddq_s32(c0, a3); + const int32x4_t d1 = vsubq_s32(c1, a3); + const int32x4_t d2 = vsubq_s32(c1, c0); + + const int32x4_t e0 = vaddq_s32(d2, a3); + + const int32x4_t v_bit = vdupq_n_s32(-bit); + out[0] = vrshlq_s32(d0, v_bit); + out[1] = vrshlq_s32(c2, v_bit); + out[2] = vrshlq_s32(d1, v_bit); + out[3] = vrshlq_s32(e0, v_bit); +} + +static AOM_FORCE_INLINE void highbd_fidentity4_x4_neon(const int32x4_t *in, + int32x4_t *out, + int bit) { + (void)bit; + int32x4_t fact = vdupq_n_s32(NewSqrt2); + + for (int i = 0; i < 4; i++) { + const int32x4_t a_low = vmulq_s32(in[i], fact); + out[i] = vrshrq_n_s32(a_low, NewSqrt2Bits); + } +} + +void av1_fwd_txfm2d_4x4_neon(const int16_t *input, int32_t *coeff, + int input_stride, TX_TYPE tx_type, int bd) { + (void)bd; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &input_stride, 4); + + // Workspace for column/row-wise transforms. + int32x4_t buf[4]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case ADST_DCT: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case DCT_ADST: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case ADST_ADST: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case FLIPADST_DCT: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case DCT_FLIPADST: + load_buffer_4x4(input, buf, input_stride, 1); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case FLIPADST_FLIPADST: + load_buffer_4x4(input, buf, input_stride, 1); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case ADST_FLIPADST: + load_buffer_4x4(input, buf, input_stride, 1); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case FLIPADST_ADST: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case IDTX: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case V_DCT: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case H_DCT: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fdct4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case V_ADST: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case H_ADST: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_col[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case V_FLIPADST: + load_buffer_4x4(input, buf, input_stride, 0); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + case H_FLIPADST: + load_buffer_4x4(input, buf, input_stride, 1); + highbd_fidentity4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + transpose_arrays_s32_4x4(buf, buf); + highbd_fadst4_x4_neon(buf, buf, av1_fwd_cos_bit_row[0][0]); + store_buffer_4x4(buf, coeff, /*stride=*/4); + break; + default: assert(0); + } +} + +// Butterfly pre-processing: +// e.g. n=4: +// out[0] = in[0] + in[3] +// out[1] = in[1] + in[2] +// out[2] = in[1] - in[2] +// out[3] = in[0] - in[3] + +static AOM_FORCE_INLINE void butterfly_dct_pre(const int32x4_t *input, + int32x4_t *output, int n) { + for (int i = 0; i < n / 2; ++i) { + output[i] = vaddq_s32(input[i], input[n - i - 1]); + } + for (int i = 0; i < n / 2; ++i) { + output[n / 2 + i] = vsubq_s32(input[n / 2 - i - 1], input[n / 2 + i]); + } +} + +// Butterfly post-processing: +// e.g. n=8: +// out[0] = in0[0] + in1[3]; +// out[1] = in0[1] + in1[2]; +// out[2] = in0[1] - in1[2]; +// out[3] = in0[0] - in1[3]; +// out[4] = in0[7] - in1[4]; +// out[5] = in0[6] - in1[5]; +// out[6] = in0[6] + in1[5]; +// out[7] = in0[7] + in1[4]; + +static AOM_FORCE_INLINE void butterfly_dct_post(const int32x4_t *in0, + const int32x4_t *in1, + int32x4_t *output, int n) { + for (int i = 0; i < n / 4; ++i) { + output[i] = vaddq_s32(in0[i], in1[n / 2 - i - 1]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 4 + i] = vsubq_s32(in0[n / 4 - i - 1], in1[n / 4 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[n / 2 + i] = vsubq_s32(in0[n - i - 1], in1[n / 2 + i]); + } + for (int i = 0; i < n / 4; ++i) { + output[(3 * n) / 4 + i] = + vaddq_s32(in0[(3 * n) / 4 + i], in1[(3 * n) / 4 - i - 1]); + } +} + +static AOM_FORCE_INLINE void highbd_fdct8_x4_neon(const int32x4_t *in, + int32x4_t *out, int bit) { + const int32_t *const cospi = cospi_arr_s32(bit); + const int32x4_t v_bit = vdupq_n_s32(-bit); + + // stage 1 + int32x4_t a[8]; + butterfly_dct_pre(in, a, 8); + + // stage 2 + int32x4_t b[8]; + butterfly_dct_pre(a, b, 4); + butterfly_0130_neon(cospi, 32, a[5], a[6], &b[6], &b[5], v_bit); + + // stage 3 + int32x4_t c[8]; + butterfly_0130_neon(cospi, 32, b[1], b[0], &c[0], &c[1], v_bit); + butterfly_0112_neon(cospi, 16, b[3], b[2], &c[2], &c[3], v_bit); + butterfly_dct_post(a + 4, b + 4, c + 4, 4); + + // stage 4-5 + butterfly_0112_neon(cospi, 8, c[7], c[4], &out[1], &out[7], v_bit); + butterfly_0130_neon(cospi, 24, c[5], c[6], &out[5], &out[3], v_bit); + + out[0] = c[0]; + out[2] = c[2]; + out[4] = c[1]; + out[6] = c[3]; +} + +static AOM_FORCE_INLINE void highbd_fadst8_x4_neon(const int32x4_t *in, + int32x4_t *out, int bit) { + const int32_t *const cospi = cospi_arr_s32(bit); + const int32x4_t v_bit = vdupq_n_s32(-bit); + + int32x4_t u0, u1, u2, u3, u4, u5, u6, u7; + int32x4_t v0, v1, v2, v3, v4, v5, v6, v7; + + // stage 0-1 + u0 = in[0]; + u1 = in[7]; + u2 = in[3]; + u3 = in[4]; + u4 = in[1]; + u5 = in[6]; + u6 = in[2]; + u7 = in[5]; + + // stage 2 + v0 = u0; + v1 = u1; + butterfly_cospi32_0222_neon(cospi, u3, u2, &v2, &v3, v_bit); + v4 = u4; + v5 = u5; + butterfly_cospi32_0002_neon(cospi, u6, u7, &v7, &v6, v_bit); + + // stage 3 + u0 = vaddq_s32(v0, v2); + u1 = vsubq_s32(v3, v1); + u2 = vsubq_s32(v0, v2); + u3 = vaddq_s32(v1, v3); + u4 = vsubq_s32(v6, v4); + u5 = vaddq_s32(v5, v7); + u6 = vaddq_s32(v4, v6); + u7 = vsubq_s32(v5, v7); + + // stage 4 + v0 = u0; + v1 = u1; + v2 = u2; + v3 = u3; + + butterfly_0112_neon(cospi, 16, u4, u5, &v4, &v5, v_bit); + butterfly_0112_neon(cospi, 16, u7, u6, &v6, &v7, v_bit); + + // stage 5 + u0 = vaddq_s32(v0, v4); + u1 = vaddq_s32(v1, v5); + u2 = vaddq_s32(v2, v6); + u3 = vsubq_s32(v7, v3); + u4 = vsubq_s32(v0, v4); + u5 = vsubq_s32(v1, v5); + u6 = vsubq_s32(v2, v6); + u7 = vaddq_s32(v3, v7); + + // stage 6 + butterfly_0112_neon(cospi, 4, u0, u1, &v0, &v1, v_bit); + butterfly_0112_neon(cospi, 20, u2, u3, &v2, &v3, v_bit); + butterfly_0130_neon(cospi, 28, u5, u4, &v4, &v5, v_bit); + butterfly_0112_neon(cospi, 12, u6, u7, &v7, &v6, v_bit); + + // stage 7 + out[0] = v1; + out[1] = v6; + out[2] = v3; + out[3] = v4; + out[4] = v5; + out[5] = v2; + out[6] = v7; + out[7] = v0; +} + +static AOM_FORCE_INLINE void highbd_fidentity8_x4_neon(const int32x4_t *in, + int32x4_t *out, + int bit) { + (void)bit; + out[0] = vshlq_n_s32(in[0], 1); + out[1] = vshlq_n_s32(in[1], 1); + out[2] = vshlq_n_s32(in[2], 1); + out[3] = vshlq_n_s32(in[3], 1); + out[4] = vshlq_n_s32(in[4], 1); + out[5] = vshlq_n_s32(in[5], 1); + out[6] = vshlq_n_s32(in[6], 1); + out[7] = vshlq_n_s32(in[7], 1); +} + +static AOM_FORCE_INLINE void highbd_fdct8_xn_neon(const int32x4_t *in, + int32x4_t *out, int bit, + int howmany) { + const int stride = 8; + int i = 0; + do { + highbd_fdct8_x4_neon(in + i * stride, out + i * stride, bit); + } while (++i < howmany); +} + +static AOM_FORCE_INLINE void highbd_fadst8_xn_neon(const int32x4_t *in, + int32x4_t *out, int bit, + int howmany) { + const int stride = 8; + int i = 0; + do { + highbd_fadst8_x4_neon(in + i * stride, out + i * stride, bit); + } while (++i < howmany); +} + +static AOM_FORCE_INLINE void highbd_fidentity8_xn_neon(const int32x4_t *in, + int32x4_t *out, int bit, + int howmany) { + (void)bit; + const int stride = 8; + int i = 0; + do { + highbd_fidentity8_x4_neon(in + i * stride, out + i * stride, bit); + } while (++i < howmany); +} + +void av1_fwd_txfm2d_8x8_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + + // Workspaces for column/row-wise transforms. + int32x4_t buf0[16], buf1[16]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fdct8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fdct8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case ADST_DCT: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fdct8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case DCT_ADST: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fdct8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case ADST_ADST: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case FLIPADST_DCT: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fdct8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case DCT_FLIPADST: + load_buffer_8x8(input, buf0, stride, 1); + highbd_fdct8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case FLIPADST_FLIPADST: + load_buffer_8x8(input, buf0, stride, 1); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case ADST_FLIPADST: + load_buffer_8x8(input, buf0, stride, 1); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case FLIPADST_ADST: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case IDTX: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fidentity8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fidentity8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case V_DCT: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fdct8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fidentity8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case H_DCT: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fidentity8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fdct8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case V_ADST: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fidentity8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case H_ADST: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fidentity8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case V_FLIPADST: + load_buffer_8x8(input, buf0, stride, 0); + highbd_fadst8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fidentity8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + case H_FLIPADST: + load_buffer_8x8(input, buf0, stride, 1); + highbd_fidentity8_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[1][1], 2); + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_8x8(buf0, buf1); + highbd_fadst8_xn_neon(buf1, buf1, av1_fwd_cos_bit_col[1][1], 2); + store_buffer_8x8(buf1, coeff, /*stride=*/8); + break; + default: assert(0); + } +} + +static void highbd_fdct16_x4_neon(const int32x4_t *in, int32x4_t *out, + int bit) { + const int32_t *const cospi = cospi_arr_s32(bit); + const int32x4_t v_bit = vdupq_n_s32(-bit); + + int32x4_t u[16], v[16]; + + // stage 1 + butterfly_dct_pre(in, u, 16); + + // stage 2 + butterfly_dct_pre(u, v, 8); + v[8] = u[8]; + v[9] = u[9]; + butterfly_cospi32_0002_neon(cospi, u[13], u[10], &v[13], &v[10], v_bit); + butterfly_cospi32_0002_neon(cospi, u[12], u[11], &v[12], &v[11], v_bit); + v[14] = u[14]; + v[15] = u[15]; + + // stage 3 + butterfly_dct_pre(v, u, 4); + u[4] = v[4]; + butterfly_cospi32_0002_neon(cospi, v[6], v[5], &u[6], &u[5], v_bit); + u[7] = v[7]; + butterfly_dct_post(v + 8, v + 8, u + 8, 8); + + // stage 4 + butterfly_cospi32_0002_neon(cospi, u[0], u[1], &v[0], &v[1], v_bit); + butterfly_0112_neon(cospi, 16, u[3], u[2], &v[2], &v[3], v_bit); + butterfly_dct_post(u + 4, u + 4, v + 4, 4); + v[8] = u[8]; + butterfly_0112_neon(cospi, 16, u[14], u[9], &v[14], &v[9], v_bit); + butterfly_2312_neon(cospi, 16, u[13], u[10], &v[10], &v[13], v_bit); + v[11] = u[11]; + v[12] = u[12]; + v[15] = u[15]; + + // stage 5 + u[0] = v[0]; + u[1] = v[1]; + u[2] = v[2]; + u[3] = v[3]; + butterfly_0112_neon(cospi, 8, v[7], v[4], &u[4], &u[7], v_bit); + butterfly_0130_neon(cospi, 24, v[5], v[6], &u[5], &u[6], v_bit); + butterfly_dct_post(v + 8, v + 8, u + 8, 4); + butterfly_dct_post(v + 12, v + 12, u + 12, 4); + + // stage 6 + v[0] = u[0]; + v[1] = u[1]; + v[2] = u[2]; + v[3] = u[3]; + v[4] = u[4]; + v[5] = u[5]; + v[6] = u[6]; + v[7] = u[7]; + butterfly_0112_neon(cospi, 4, u[15], u[8], &v[8], &v[15], v_bit); + butterfly_0130_neon(cospi, 28, u[9], u[14], &v[9], &v[14], v_bit); + butterfly_0112_neon(cospi, 20, u[13], u[10], &v[10], &v[13], v_bit); + butterfly_0130_neon(cospi, 12, u[11], u[12], &v[11], &v[12], v_bit); + + out[0] = v[0]; + out[1] = v[8]; + out[2] = v[4]; + out[3] = v[12]; + out[4] = v[2]; + out[5] = v[10]; + out[6] = v[6]; + out[7] = v[14]; + out[8] = v[1]; + out[9] = v[9]; + out[10] = v[5]; + out[11] = v[13]; + out[12] = v[3]; + out[13] = v[11]; + out[14] = v[7]; + out[15] = v[15]; +} + +static void highbd_fadst16_x4_neon(const int32x4_t *in, int32x4_t *out, + int bit) { + const int32_t *const cospi = cospi_arr_s32(bit); + const int32x4_t v_bit = vdupq_n_s32(-bit); + + int32x4_t u[16], v[16]; + + // stage 0-1 + u[0] = in[0]; + u[1] = in[15]; + u[2] = in[7]; + u[3] = in[8]; + u[4] = in[3]; + u[5] = in[12]; + u[6] = in[4]; + u[7] = in[11]; + u[8] = in[1]; + u[9] = in[14]; + u[10] = in[6]; + u[11] = in[9]; + u[12] = in[2]; + u[13] = in[13]; + u[14] = in[5]; + u[15] = in[10]; + + // stage 2 + v[0] = u[0]; + v[1] = u[1]; + butterfly_cospi32_0222_neon(cospi, u[3], u[2], &v[2], &v[3], v_bit); + v[4] = u[4]; + v[5] = u[5]; + butterfly_cospi32_0002_neon(cospi, u[6], u[7], &v[7], &v[6], v_bit); + v[8] = u[8]; + v[9] = u[9]; + butterfly_cospi32_0002_neon(cospi, u[10], u[11], &v[11], &v[10], v_bit); + v[12] = u[12]; + v[13] = u[13]; + butterfly_cospi32_0222_neon(cospi, u[15], u[14], &v[14], &v[15], v_bit); + + // stage 3 + u[0] = vaddq_s32(v[0], v[2]); + u[1] = vsubq_s32(v[3], v[1]); + u[2] = vsubq_s32(v[0], v[2]); + u[3] = vaddq_s32(v[1], v[3]); + u[4] = vsubq_s32(v[6], v[4]); + u[5] = vaddq_s32(v[5], v[7]); + u[6] = vaddq_s32(v[4], v[6]); + u[7] = vsubq_s32(v[5], v[7]); + u[8] = vsubq_s32(v[10], v[8]); + u[9] = vaddq_s32(v[9], v[11]); + u[10] = vaddq_s32(v[8], v[10]); + u[11] = vsubq_s32(v[9], v[11]); + u[12] = vaddq_s32(v[12], v[14]); + u[13] = vsubq_s32(v[15], v[13]); + u[14] = vsubq_s32(v[12], v[14]); + u[15] = vaddq_s32(v[13], v[15]); + + // stage 4 + v[0] = u[0]; + v[1] = u[1]; + v[2] = u[2]; + v[3] = u[3]; + butterfly_0112_neon(cospi, 16, u[4], u[5], &v[4], &v[5], v_bit); + butterfly_0112_neon(cospi, 16, u[7], u[6], &v[6], &v[7], v_bit); + + v[8] = u[8]; + v[9] = u[9]; + v[10] = u[10]; + v[11] = u[11]; + + butterfly_0112_neon(cospi, 16, u[12], u[13], &v[12], &v[13], v_bit); + butterfly_0332_neon(cospi, 16, u[14], u[15], &v[15], &v[14], v_bit); + + // stage 5 + u[0] = vaddq_s32(v[0], v[4]); + u[1] = vaddq_s32(v[1], v[5]); + u[2] = vaddq_s32(v[2], v[6]); + u[3] = vsubq_s32(v[7], v[3]); + u[4] = vsubq_s32(v[0], v[4]); + u[5] = vsubq_s32(v[1], v[5]); + u[6] = vsubq_s32(v[2], v[6]); + u[7] = vaddq_s32(v[3], v[7]); + u[8] = vaddq_s32(v[8], v[12]); + u[9] = vaddq_s32(v[9], v[13]); + u[10] = vsubq_s32(v[14], v[10]); + u[11] = vaddq_s32(v[11], v[15]); + u[12] = vsubq_s32(v[8], v[12]); + u[13] = vsubq_s32(v[9], v[13]); + u[14] = vaddq_s32(v[10], v[14]); + u[15] = vsubq_s32(v[11], v[15]); + + // stage 6 + v[0] = u[0]; + v[1] = u[1]; + v[2] = u[2]; + v[3] = u[3]; + v[4] = u[4]; + v[5] = u[5]; + v[6] = u[6]; + v[7] = u[7]; + + butterfly_0112_neon(cospi, 8, u[8], u[9], &v[8], &v[9], v_bit); + butterfly_0130_neon(cospi, 8, u[12], u[13], &v[13], &v[12], v_bit); + butterfly_0130_neon(cospi, 24, u[11], u[10], &v[10], &v[11], v_bit); + butterfly_0130_neon(cospi, 24, u[14], u[15], &v[14], &v[15], v_bit); + + // stage 7 + u[0] = vaddq_s32(v[0], v[8]); + u[1] = vaddq_s32(v[1], v[9]); + u[2] = vaddq_s32(v[2], v[10]); + u[3] = vaddq_s32(v[3], v[11]); + u[4] = vaddq_s32(v[4], v[12]); + u[5] = vaddq_s32(v[5], v[13]); + u[6] = vaddq_s32(v[6], v[14]); + u[7] = vsubq_s32(v[15], v[7]); + u[8] = vsubq_s32(v[0], v[8]); + u[9] = vsubq_s32(v[1], v[9]); + u[10] = vsubq_s32(v[2], v[10]); + u[11] = vsubq_s32(v[3], v[11]); + u[12] = vsubq_s32(v[4], v[12]); + u[13] = vsubq_s32(v[5], v[13]); + u[14] = vsubq_s32(v[6], v[14]); + u[15] = vaddq_s32(v[7], v[15]); + + // stage 8 + butterfly_0112_neon(cospi, 2, u[0], u[1], &v[0], &v[1], v_bit); + butterfly_0112_neon(cospi, 10, u[2], u[3], &v[2], &v[3], v_bit); + butterfly_0112_neon(cospi, 18, u[4], u[5], &v[4], &v[5], v_bit); + butterfly_0112_neon(cospi, 26, u[6], u[7], &v[6], &v[7], v_bit); + butterfly_0130_neon(cospi, 30, u[9], u[8], &v[8], &v[9], v_bit); + butterfly_0130_neon(cospi, 22, u[11], u[10], &v[10], &v[11], v_bit); + butterfly_0130_neon(cospi, 14, u[13], u[12], &v[12], &v[13], v_bit); + butterfly_0112_neon(cospi, 6, u[14], u[15], &v[15], &v[14], v_bit); + + // stage 9 + out[0] = v[1]; + out[1] = v[14]; + out[2] = v[3]; + out[3] = v[12]; + out[4] = v[5]; + out[5] = v[10]; + out[6] = v[7]; + out[7] = v[8]; + out[8] = v[9]; + out[9] = v[6]; + out[10] = v[11]; + out[11] = v[4]; + out[12] = v[13]; + out[13] = v[2]; + out[14] = v[15]; + out[15] = v[0]; +} + +static void highbd_fidentity16_x4_neon(const int32x4_t *in, int32x4_t *out, + int bit) { + (void)bit; + const int32x4_t fact = vdupq_n_s32(2 * NewSqrt2); + const int32x4_t offset = vdupq_n_s32(1 << (NewSqrt2Bits - 1)); + + for (int i = 0; i < 16; i++) { + int32x4_t a = vmulq_s32(in[i], fact); + a = vaddq_s32(a, offset); + out[i] = vshrq_n_s32(a, NewSqrt2Bits); + } +} + +static void highbd_fdct16_xn_neon(const int32x4_t *in, int32x4_t *out, int bit, + const int howmany) { + const int stride = 16; + int i = 0; + do { + highbd_fdct16_x4_neon(in + i * stride, out + i * stride, bit); + } while (++i < howmany); +} + +static void highbd_fadst16_xn_neon(const int32x4_t *in, int32x4_t *out, int bit, + int howmany) { + const int stride = 16; + int i = 0; + do { + highbd_fadst16_x4_neon(in + i * stride, out + i * stride, bit); + } while (++i < howmany); +} + +static void highbd_fidentity16_xn_neon(const int32x4_t *in, int32x4_t *out, + int bit, int howmany) { + const int stride = 16; + int i = 0; + do { + highbd_fidentity16_x4_neon(in + i * stride, out + i * stride, bit); + } while (++i < howmany); +} + +void av1_fwd_txfm2d_16x16_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + + // Workspaces for column/row-wise transforms. + int32x4_t buf0[64], buf1[64]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fdct16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fdct16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case ADST_DCT: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fdct16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case DCT_ADST: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fdct16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case ADST_ADST: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case FLIPADST_DCT: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fdct16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case DCT_FLIPADST: + load_buffer_16x16(input, buf0, stride, 1); + highbd_fdct16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case FLIPADST_FLIPADST: + load_buffer_16x16(input, buf0, stride, 1); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case ADST_FLIPADST: + load_buffer_16x16(input, buf0, stride, 1); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case FLIPADST_ADST: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case IDTX: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fidentity16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fidentity16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case V_DCT: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fdct16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fidentity16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case H_DCT: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fidentity16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fdct16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case V_ADST: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fidentity16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case H_ADST: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fidentity16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case V_FLIPADST: + load_buffer_16x16(input, buf0, stride, 0); + highbd_fadst16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fidentity16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + case H_FLIPADST: + load_buffer_16x16(input, buf0, stride, 1); + highbd_fidentity16_xn_neon(buf0, buf0, av1_fwd_cos_bit_col[2][2], 4); + shift_right_2_round_s32_x4(buf0, buf0, 64); + transpose_arrays_s32_16x16(buf0, buf1); + highbd_fadst16_xn_neon(buf1, buf1, av1_fwd_cos_bit_row[2][2], 4); + store_buffer_16x16(buf1, coeff, /*stride=*/16); + break; + default: assert(0); + } +} + +typedef void (*fwd_transform_1d_col_neon)(const int16_t *in, int32x4_t *out, + int stride, int bit, int lr_flip); +typedef void (*fwd_transform_1d_col_many_neon)(const int16_t *in, + int32x4_t *out, int stride, + int bit, int lr_flip, + int howmany, int hm_stride); + +typedef void (*fwd_transform_1d_row_neon)(const int32x4_t *in, int32_t *out, + int bit, int stride); +typedef void (*fwd_transform_1d_row_many_neon)(const int32x4_t *in, + int32_t *out, int bit, + int howmany, int hm_stride, + int stride); + +// Construct component kernels that include the load_buffer and store_buffer +// stages to avoid the need to spill loaded data to the stack between these and +// the txfm kernel calls. +// The TRANSFORM_*_ONE cases are only ever called in situations where the +// howmany parameter would be one, so no need for the loop at all in these +// cases. + +#define TRANSFORM_COL_ONE(name, n) \ + static void highbd_##name##_col_neon(const int16_t *input, \ + int32x4_t *output, int stride, \ + int cos_bit, int lr_flip) { \ + int32x4_t buf0[n]; \ + load_buffer_4x##n(input, buf0, stride, lr_flip); \ + highbd_##name##_x4_neon(buf0, output, cos_bit); \ + } + +#define TRANSFORM_COL_MANY(name, n) \ + static void highbd_##name##_col_many_neon( \ + const int16_t *input, int32x4_t *output, int stride, int cos_bit, \ + int lr_flip, int howmany, int hm_stride) { \ + int i = 0; \ + do { \ + int32x4_t buf0[n]; \ + load_buffer_4x##n(input + 4 * i, buf0, stride, lr_flip); \ + highbd_##name##_x4_neon(buf0, output + i * hm_stride, cos_bit); \ + } while (++i < howmany); \ + } + +#define TRANSFORM_ROW_ONE(name, n) \ + static void highbd_##name##_row_neon( \ + const int32x4_t *input, int32_t *output, int cos_bit, int stride) { \ + int32x4_t buf0[n]; \ + highbd_##name##_x4_neon(input, buf0, cos_bit); \ + store_buffer_##n##x4(buf0, output, stride); \ + } + +#define TRANSFORM_ROW_RECT_ONE(name, n) \ + static void highbd_##name##_row_rect_neon( \ + const int32x4_t *input, int32_t *output, int cos_bit, int stride) { \ + int32x4_t buf0[n]; \ + highbd_##name##_x4_neon(input, buf0, cos_bit); \ + round_rect_array_s32_neon(buf0, buf0, (n)); \ + store_buffer_##n##x4(buf0, output, stride); \ + } + +#define TRANSFORM_ROW_MANY(name, n) \ + static void highbd_##name##_row_many_neon( \ + const int32x4_t *input, int32_t *output, int cos_bit, int howmany, \ + int hm_stride, int stride) { \ + int i = 0; \ + do { \ + int32x4_t buf0[n]; \ + highbd_##name##_x4_neon(input + hm_stride * i, buf0, cos_bit); \ + store_buffer_##n##x4(buf0, output + 4 * i, stride); \ + } while (++i < howmany); \ + } + +#define TRANSFORM_ROW_RECT_MANY(name, n) \ + static void highbd_##name##_row_rect_many_neon( \ + const int32x4_t *input, int32_t *output, int cos_bit, int howmany, \ + int hm_stride, int stride) { \ + int i = 0; \ + do { \ + int32x4_t buf0[n]; \ + highbd_##name##_x4_neon(input + hm_stride * i, buf0, cos_bit); \ + round_rect_array_s32_neon(buf0, buf0, (n)); \ + store_buffer_##n##x4(buf0, output + 4 * i, stride); \ + } while (++i < howmany); \ + } + +TRANSFORM_COL_ONE(fdct8, 8) +TRANSFORM_COL_ONE(fadst8, 8) +TRANSFORM_COL_ONE(fidentity8, 8) + +TRANSFORM_COL_MANY(fdct4, 4) +TRANSFORM_COL_MANY(fdct8, 8) +TRANSFORM_COL_MANY(fdct16, 16) +TRANSFORM_COL_MANY(fadst4, 4) +TRANSFORM_COL_MANY(fadst8, 8) +TRANSFORM_COL_MANY(fadst16, 16) +TRANSFORM_COL_MANY(fidentity4, 4) +TRANSFORM_COL_MANY(fidentity8, 8) +TRANSFORM_COL_MANY(fidentity16, 16) + +TRANSFORM_ROW_ONE(fdct16, 16) +TRANSFORM_ROW_ONE(fadst16, 16) +TRANSFORM_ROW_ONE(fidentity16, 16) + +TRANSFORM_ROW_RECT_ONE(fdct8, 8) +TRANSFORM_ROW_RECT_ONE(fadst8, 8) +TRANSFORM_ROW_RECT_ONE(fidentity8, 8) + +#if !CONFIG_REALTIME_ONLY +TRANSFORM_ROW_MANY(fdct4, 4) +TRANSFORM_ROW_MANY(fdct8, 8) +TRANSFORM_ROW_MANY(fadst4, 4) +TRANSFORM_ROW_MANY(fadst8, 8) +TRANSFORM_ROW_MANY(fidentity4, 4) +TRANSFORM_ROW_MANY(fidentity8, 8) +#endif + +TRANSFORM_ROW_RECT_MANY(fdct4, 4) +TRANSFORM_ROW_RECT_MANY(fdct8, 8) +TRANSFORM_ROW_RECT_MANY(fdct16, 16) +TRANSFORM_ROW_RECT_MANY(fadst4, 4) +TRANSFORM_ROW_RECT_MANY(fadst8, 8) +TRANSFORM_ROW_RECT_MANY(fadst16, 16) +TRANSFORM_ROW_RECT_MANY(fidentity4, 4) +TRANSFORM_ROW_RECT_MANY(fidentity8, 8) +TRANSFORM_ROW_RECT_MANY(fidentity16, 16) + +static const fwd_transform_1d_col_many_neon + col_highbd_txfm8_xn_arr[TX_TYPES] = { + highbd_fdct8_col_many_neon, // DCT_DCT + highbd_fadst8_col_many_neon, // ADST_DCT + highbd_fdct8_col_many_neon, // DCT_ADST + highbd_fadst8_col_many_neon, // ADST_ADST + highbd_fadst8_col_many_neon, // FLIPADST_DCT + highbd_fdct8_col_many_neon, // DCT_FLIPADST + highbd_fadst8_col_many_neon, // FLIPADST_FLIPADST + highbd_fadst8_col_many_neon, // ADST_FLIPADST + highbd_fadst8_col_many_neon, // FLIPADST_ADST + highbd_fidentity8_col_many_neon, // IDTX + highbd_fdct8_col_many_neon, // V_DCT + highbd_fidentity8_col_many_neon, // H_DCT + highbd_fadst8_col_many_neon, // V_ADST + highbd_fidentity8_col_many_neon, // H_ADST + highbd_fadst8_col_many_neon, // V_FLIPADST + highbd_fidentity8_col_many_neon // H_FLIPADST + }; + +static const fwd_transform_1d_col_neon col_highbd_txfm8_x4_arr[TX_TYPES] = { + highbd_fdct8_col_neon, // DCT_DCT + highbd_fadst8_col_neon, // ADST_DCT + highbd_fdct8_col_neon, // DCT_ADST + highbd_fadst8_col_neon, // ADST_ADST + highbd_fadst8_col_neon, // FLIPADST_DCT + highbd_fdct8_col_neon, // DCT_FLIPADST + highbd_fadst8_col_neon, // FLIPADST_FLIPADST + highbd_fadst8_col_neon, // ADST_FLIPADST + highbd_fadst8_col_neon, // FLIPADST_ADST + highbd_fidentity8_col_neon, // IDTX + highbd_fdct8_col_neon, // V_DCT + highbd_fidentity8_col_neon, // H_DCT + highbd_fadst8_col_neon, // V_ADST + highbd_fidentity8_col_neon, // H_ADST + highbd_fadst8_col_neon, // V_FLIPADST + highbd_fidentity8_col_neon // H_FLIPADST +}; + +static const fwd_transform_1d_col_many_neon + col_highbd_txfm16_xn_arr[TX_TYPES] = { + highbd_fdct16_col_many_neon, // DCT_DCT + highbd_fadst16_col_many_neon, // ADST_DCT + highbd_fdct16_col_many_neon, // DCT_ADST + highbd_fadst16_col_many_neon, // ADST_ADST + highbd_fadst16_col_many_neon, // FLIPADST_DCT + highbd_fdct16_col_many_neon, // DCT_FLIPADST + highbd_fadst16_col_many_neon, // FLIPADST_FLIPADST + highbd_fadst16_col_many_neon, // ADST_FLIPADST + highbd_fadst16_col_many_neon, // FLIPADST_ADST + highbd_fidentity16_col_many_neon, // IDTX + highbd_fdct16_col_many_neon, // V_DCT + highbd_fidentity16_col_many_neon, // H_DCT + highbd_fadst16_col_many_neon, // V_ADST + highbd_fidentity16_col_many_neon, // H_ADST + highbd_fadst16_col_many_neon, // V_FLIPADST + highbd_fidentity16_col_many_neon // H_FLIPADST + }; + +static const fwd_transform_1d_col_many_neon + col_highbd_txfm4_xn_arr[TX_TYPES] = { + highbd_fdct4_col_many_neon, // DCT_DCT + highbd_fadst4_col_many_neon, // ADST_DCT + highbd_fdct4_col_many_neon, // DCT_ADST + highbd_fadst4_col_many_neon, // ADST_ADST + highbd_fadst4_col_many_neon, // FLIPADST_DCT + highbd_fdct4_col_many_neon, // DCT_FLIPADST + highbd_fadst4_col_many_neon, // FLIPADST_FLIPADST + highbd_fadst4_col_many_neon, // ADST_FLIPADST + highbd_fadst4_col_many_neon, // FLIPADST_ADST + highbd_fidentity4_col_many_neon, // IDTX + highbd_fdct4_col_many_neon, // V_DCT + highbd_fidentity4_col_many_neon, // H_DCT + highbd_fadst4_col_many_neon, // V_ADST + highbd_fidentity4_col_many_neon, // H_ADST + highbd_fadst4_col_many_neon, // V_FLIPADST + highbd_fidentity4_col_many_neon // H_FLIPADST + }; + +static const fwd_transform_1d_row_neon row_highbd_txfm16_xn_arr[TX_TYPES] = { + highbd_fdct16_row_neon, // DCT_DCT + highbd_fdct16_row_neon, // ADST_DCT + highbd_fadst16_row_neon, // DCT_ADST + highbd_fadst16_row_neon, // ADST_ADST + highbd_fdct16_row_neon, // FLIPADST_DCT + highbd_fadst16_row_neon, // DCT_FLIPADST + highbd_fadst16_row_neon, // FLIPADST_FLIPADST + highbd_fadst16_row_neon, // ADST_FLIPADST + highbd_fadst16_row_neon, // FLIPADST_ADST + highbd_fidentity16_row_neon, // IDTX + highbd_fidentity16_row_neon, // V_DCT + highbd_fdct16_row_neon, // H_DCT + highbd_fidentity16_row_neon, // V_ADST + highbd_fadst16_row_neon, // H_ADST + highbd_fidentity16_row_neon, // V_FLIPADST + highbd_fadst16_row_neon // H_FLIPADST +}; + +static const fwd_transform_1d_row_many_neon + row_rect_highbd_txfm16_xn_arr[TX_TYPES] = { + highbd_fdct16_row_rect_many_neon, // DCT_DCT + highbd_fdct16_row_rect_many_neon, // ADST_DCT + highbd_fadst16_row_rect_many_neon, // DCT_ADST + highbd_fadst16_row_rect_many_neon, // ADST_ADST + highbd_fdct16_row_rect_many_neon, // FLIPADST_DCT + highbd_fadst16_row_rect_many_neon, // DCT_FLIPADST + highbd_fadst16_row_rect_many_neon, // FLIPADST_FLIPADST + highbd_fadst16_row_rect_many_neon, // ADST_FLIPADST + highbd_fadst16_row_rect_many_neon, // FLIPADST_ADST + highbd_fidentity16_row_rect_many_neon, // IDTX + highbd_fidentity16_row_rect_many_neon, // V_DCT + highbd_fdct16_row_rect_many_neon, // H_DCT + highbd_fidentity16_row_rect_many_neon, // V_ADST + highbd_fadst16_row_rect_many_neon, // H_ADST + highbd_fidentity16_row_rect_many_neon, // V_FLIPADST + highbd_fadst16_row_rect_many_neon // H_FLIPADST + }; + +#if !CONFIG_REALTIME_ONLY +static const fwd_transform_1d_row_many_neon + row_highbd_txfm8_xn_arr[TX_TYPES] = { + highbd_fdct8_row_many_neon, // DCT_DCT + highbd_fdct8_row_many_neon, // ADST_DCT + highbd_fadst8_row_many_neon, // DCT_ADST + highbd_fadst8_row_many_neon, // ADST_ADST + highbd_fdct8_row_many_neon, // FLIPADST_DCT + highbd_fadst8_row_many_neon, // DCT_FLIPADST + highbd_fadst8_row_many_neon, // FLIPADST_FLIPADST + highbd_fadst8_row_many_neon, // ADST_FLIPADST + highbd_fadst8_row_many_neon, // FLIPADST_ADST + highbd_fidentity8_row_many_neon, // IDTX + highbd_fidentity8_row_many_neon, // V_DCT + highbd_fdct8_row_many_neon, // H_DCT + highbd_fidentity8_row_many_neon, // V_ADST + highbd_fadst8_row_many_neon, // H_ADST + highbd_fidentity8_row_many_neon, // V_FLIPADST + highbd_fadst8_row_many_neon // H_FLIPADST + }; +#endif + +static const fwd_transform_1d_row_many_neon + row_rect_highbd_txfm8_xn_arr[TX_TYPES] = { + highbd_fdct8_row_rect_many_neon, // DCT_DCT + highbd_fdct8_row_rect_many_neon, // ADST_DCT + highbd_fadst8_row_rect_many_neon, // DCT_ADST + highbd_fadst8_row_rect_many_neon, // ADST_ADST + highbd_fdct8_row_rect_many_neon, // FLIPADST_DCT + highbd_fadst8_row_rect_many_neon, // DCT_FLIPADST + highbd_fadst8_row_rect_many_neon, // FLIPADST_FLIPADST + highbd_fadst8_row_rect_many_neon, // ADST_FLIPADST + highbd_fadst8_row_rect_many_neon, // FLIPADST_ADST + highbd_fidentity8_row_rect_many_neon, // IDTX + highbd_fidentity8_row_rect_many_neon, // V_DCT + highbd_fdct8_row_rect_many_neon, // H_DCT + highbd_fidentity8_row_rect_many_neon, // V_ADST + highbd_fadst8_row_rect_many_neon, // H_ADST + highbd_fidentity8_row_rect_many_neon, // V_FLIPADST + highbd_fadst8_row_rect_many_neon // H_FLIPADST + }; + +static const fwd_transform_1d_row_neon row_highbd_txfm8_x4_arr[TX_TYPES] = { + highbd_fdct8_row_rect_neon, // DCT_DCT + highbd_fdct8_row_rect_neon, // ADST_DCT + highbd_fadst8_row_rect_neon, // DCT_ADST + highbd_fadst8_row_rect_neon, // ADST_ADST + highbd_fdct8_row_rect_neon, // FLIPADST_DCT + highbd_fadst8_row_rect_neon, // DCT_FLIPADST + highbd_fadst8_row_rect_neon, // FLIPADST_FLIPADST + highbd_fadst8_row_rect_neon, // ADST_FLIPADST + highbd_fadst8_row_rect_neon, // FLIPADST_ADST + highbd_fidentity8_row_rect_neon, // IDTX + highbd_fidentity8_row_rect_neon, // V_DCT + highbd_fdct8_row_rect_neon, // H_DCT + highbd_fidentity8_row_rect_neon, // V_ADST + highbd_fadst8_row_rect_neon, // H_ADST + highbd_fidentity8_row_rect_neon, // V_FLIPADST + highbd_fadst8_row_rect_neon // H_FLIPADST +}; + +#if !CONFIG_REALTIME_ONLY +static const fwd_transform_1d_row_many_neon + row_highbd_txfm4_xn_arr[TX_TYPES] = { + highbd_fdct4_row_many_neon, // DCT_DCT + highbd_fdct4_row_many_neon, // ADST_DCT + highbd_fadst4_row_many_neon, // DCT_ADST + highbd_fadst4_row_many_neon, // ADST_ADST + highbd_fdct4_row_many_neon, // FLIPADST_DCT + highbd_fadst4_row_many_neon, // DCT_FLIPADST + highbd_fadst4_row_many_neon, // FLIPADST_FLIPADST + highbd_fadst4_row_many_neon, // ADST_FLIPADST + highbd_fadst4_row_many_neon, // FLIPADST_ADST + highbd_fidentity4_row_many_neon, // IDTX + highbd_fidentity4_row_many_neon, // V_DCT + highbd_fdct4_row_many_neon, // H_DCT + highbd_fidentity4_row_many_neon, // V_ADST + highbd_fadst4_row_many_neon, // H_ADST + highbd_fidentity4_row_many_neon, // V_FLIPADST + highbd_fadst4_row_many_neon // H_FLIPADST + }; +#endif + +static const fwd_transform_1d_row_many_neon + row_rect_highbd_txfm4_xn_arr[TX_TYPES] = { + highbd_fdct4_row_rect_many_neon, // DCT_DCT + highbd_fdct4_row_rect_many_neon, // ADST_DCT + highbd_fadst4_row_rect_many_neon, // DCT_ADST + highbd_fadst4_row_rect_many_neon, // ADST_ADST + highbd_fdct4_row_rect_many_neon, // FLIPADST_DCT + highbd_fadst4_row_rect_many_neon, // DCT_FLIPADST + highbd_fadst4_row_rect_many_neon, // FLIPADST_FLIPADST + highbd_fadst4_row_rect_many_neon, // ADST_FLIPADST + highbd_fadst4_row_rect_many_neon, // FLIPADST_ADST + highbd_fidentity4_row_rect_many_neon, // IDTX + highbd_fidentity4_row_rect_many_neon, // V_DCT + highbd_fdct4_row_rect_many_neon, // H_DCT + highbd_fidentity4_row_rect_many_neon, // V_ADST + highbd_fadst4_row_rect_many_neon, // H_ADST + highbd_fidentity4_row_rect_many_neon, // V_FLIPADST + highbd_fadst4_row_rect_many_neon // H_FLIPADST + }; + +static void highbd_fdct32_x4_neon(const int32x4_t *input, int32x4_t *output, + int cos_bit) { + const int32_t *const cospi = cospi_arr_s32(cos_bit); + const int32x4_t v_cos_bit = vdupq_n_s32(-cos_bit); + + // Workspaces for intermediate transform steps. + int32x4_t buf0[32]; + int32x4_t buf1[32]; + + // stage 1 + butterfly_dct_pre(input, buf1, 32); + + // stage 2 + butterfly_dct_pre(buf1, buf0, 16); + buf0[16] = buf1[16]; + buf0[17] = buf1[17]; + buf0[18] = buf1[18]; + buf0[19] = buf1[19]; + butterfly_0112_neon(cospi, 32, buf1[27], buf1[20], &buf0[27], &buf0[20], + v_cos_bit); + butterfly_0112_neon(cospi, 32, buf1[26], buf1[21], &buf0[26], &buf0[21], + v_cos_bit); + butterfly_0112_neon(cospi, 32, buf1[25], buf1[22], &buf0[25], &buf0[22], + v_cos_bit); + butterfly_0112_neon(cospi, 32, buf1[24], buf1[23], &buf0[24], &buf0[23], + v_cos_bit); + buf0[28] = buf1[28]; + buf0[29] = buf1[29]; + buf0[30] = buf1[30]; + buf0[31] = buf1[31]; + + // stage 3 + butterfly_dct_pre(buf0, buf1, 8); + buf1[8] = buf0[8]; + buf1[9] = buf0[9]; + butterfly_0112_neon(cospi, 32, buf0[13], buf0[10], &buf1[13], &buf1[10], + v_cos_bit); + butterfly_0112_neon(cospi, 32, buf0[12], buf0[11], &buf1[12], &buf1[11], + v_cos_bit); + buf1[14] = buf0[14]; + buf1[15] = buf0[15]; + butterfly_dct_post(buf0 + 16, buf0 + 16, buf1 + 16, 16); + + // stage 4 + butterfly_dct_pre(buf1, buf0, 4); + buf0[4] = buf1[4]; + butterfly_0112_neon(cospi, 32, buf1[6], buf1[5], &buf0[6], &buf0[5], + v_cos_bit); + buf0[7] = buf1[7]; + butterfly_dct_post(buf1 + 8, buf1 + 8, buf0 + 8, 8); + buf0[16] = buf1[16]; + buf0[17] = buf1[17]; + butterfly_0112_neon(cospi, 16, buf1[29], buf1[18], &buf0[29], &buf0[18], + v_cos_bit); + butterfly_0112_neon(cospi, 16, buf1[28], buf1[19], &buf0[28], &buf0[19], + v_cos_bit); + butterfly_2312_neon(cospi, 16, buf1[27], buf1[20], &buf0[20], &buf0[27], + v_cos_bit); + butterfly_2312_neon(cospi, 16, buf1[26], buf1[21], &buf0[21], &buf0[26], + v_cos_bit); + buf0[22] = buf1[22]; + buf0[23] = buf1[23]; + buf0[24] = buf1[24]; + buf0[25] = buf1[25]; + buf0[30] = buf1[30]; + buf0[31] = buf1[31]; + + // stage 5 + butterfly_0112_neon(cospi, 32, buf0[0], buf0[1], &buf1[0], &buf1[1], + v_cos_bit); + butterfly_0112_neon(cospi, 16, buf0[3], buf0[2], &buf1[2], &buf1[3], + v_cos_bit); + butterfly_dct_post(buf0 + 4, buf0 + 4, buf1 + 4, 4); + buf1[8] = buf0[8]; + butterfly_0112_neon(cospi, 16, buf0[14], buf0[9], &buf1[14], &buf1[9], + v_cos_bit); + butterfly_2312_neon(cospi, 16, buf0[13], buf0[10], &buf1[10], &buf1[13], + v_cos_bit); + buf1[11] = buf0[11]; + buf1[12] = buf0[12]; + buf1[15] = buf0[15]; + butterfly_dct_post(buf0 + 16, buf0 + 16, buf1 + 16, 8); + butterfly_dct_post(buf0 + 24, buf0 + 24, buf1 + 24, 8); + + // stage 6 + buf0[0] = buf1[0]; + buf0[1] = buf1[1]; + buf0[2] = buf1[2]; + buf0[3] = buf1[3]; + + butterfly_0112_neon(cospi, 8, buf1[7], buf1[4], &buf0[4], &buf0[7], + v_cos_bit); + butterfly_0112_neon(cospi, 8, buf1[30], buf1[17], &buf0[30], &buf0[17], + v_cos_bit); + butterfly_2312_neon(cospi, 8, buf1[29], buf1[18], &buf0[18], &buf0[29], + v_cos_bit); + butterfly_dct_post(buf1 + 8, buf1 + 8, buf0 + 8, 4); + butterfly_dct_post(buf1 + 12, buf1 + 12, buf0 + 12, 4); + buf0[16] = buf1[16]; + buf0[19] = buf1[19]; + buf0[20] = buf1[20]; + + butterfly_0130_neon(cospi, 24, buf1[5], buf1[6], &buf0[5], &buf0[6], + v_cos_bit); + butterfly_0130_neon(cospi, 24, buf1[21], buf1[26], &buf0[26], &buf0[21], + v_cos_bit); + butterfly_0332_neon(cospi, 24, buf1[25], buf1[22], &buf0[25], &buf0[22], + v_cos_bit); + + buf0[23] = buf1[23]; + buf0[24] = buf1[24]; + buf0[27] = buf1[27]; + buf0[28] = buf1[28]; + buf0[31] = buf1[31]; + + // stage 7 + buf1[0] = buf0[0]; + buf1[1] = buf0[1]; + buf1[2] = buf0[2]; + buf1[3] = buf0[3]; + buf1[4] = buf0[4]; + buf1[5] = buf0[5]; + buf1[6] = buf0[6]; + buf1[7] = buf0[7]; + butterfly_0112_neon(cospi, 4, buf0[15], buf0[8], &buf1[8], &buf1[15], + v_cos_bit); + butterfly_0130_neon(cospi, 28, buf0[9], buf0[14], &buf1[9], &buf1[14], + v_cos_bit); + butterfly_0112_neon(cospi, 20, buf0[13], buf0[10], &buf1[10], &buf1[13], + v_cos_bit); + butterfly_0130_neon(cospi, 12, buf0[11], buf0[12], &buf1[11], &buf1[12], + v_cos_bit); + butterfly_dct_post(buf0 + 16, buf0 + 16, buf1 + 16, 4); + butterfly_dct_post(buf0 + 20, buf0 + 20, buf1 + 20, 4); + butterfly_dct_post(buf0 + 24, buf0 + 24, buf1 + 24, 4); + butterfly_dct_post(buf0 + 28, buf0 + 28, buf1 + 28, 4); + + // stage 8 + buf0[0] = buf1[0]; + buf0[1] = buf1[1]; + buf0[2] = buf1[2]; + buf0[3] = buf1[3]; + buf0[4] = buf1[4]; + buf0[5] = buf1[5]; + buf0[6] = buf1[6]; + buf0[7] = buf1[7]; + buf0[8] = buf1[8]; + buf0[9] = buf1[9]; + buf0[10] = buf1[10]; + buf0[11] = buf1[11]; + buf0[12] = buf1[12]; + buf0[13] = buf1[13]; + buf0[14] = buf1[14]; + buf0[15] = buf1[15]; + butterfly_0112_neon(cospi, 2, buf1[31], buf1[16], &buf0[16], &buf0[31], + v_cos_bit); + butterfly_0130_neon(cospi, 30, buf1[17], buf1[30], &buf0[17], &buf0[30], + v_cos_bit); + butterfly_0112_neon(cospi, 18, buf1[29], buf1[18], &buf0[18], &buf0[29], + v_cos_bit); + butterfly_0130_neon(cospi, 14, buf1[19], buf1[28], &buf0[19], &buf0[28], + v_cos_bit); + butterfly_0112_neon(cospi, 10, buf1[27], buf1[20], &buf0[20], &buf0[27], + v_cos_bit); + butterfly_0130_neon(cospi, 22, buf1[21], buf1[26], &buf0[21], &buf0[26], + v_cos_bit); + butterfly_0112_neon(cospi, 26, buf1[25], buf1[22], &buf0[22], &buf0[25], + v_cos_bit); + butterfly_0130_neon(cospi, 6, buf1[23], buf1[24], &buf0[23], &buf0[24], + v_cos_bit); + + // stage 9 + output[0] = buf0[0]; + output[1] = buf0[16]; + output[2] = buf0[8]; + output[3] = buf0[24]; + output[4] = buf0[4]; + output[5] = buf0[20]; + output[6] = buf0[12]; + output[7] = buf0[28]; + output[8] = buf0[2]; + output[9] = buf0[18]; + output[10] = buf0[10]; + output[11] = buf0[26]; + output[12] = buf0[6]; + output[13] = buf0[22]; + output[14] = buf0[14]; + output[15] = buf0[30]; + output[16] = buf0[1]; + output[17] = buf0[17]; + output[18] = buf0[9]; + output[19] = buf0[25]; + output[20] = buf0[5]; + output[21] = buf0[21]; + output[22] = buf0[13]; + output[23] = buf0[29]; + output[24] = buf0[3]; + output[25] = buf0[19]; + output[26] = buf0[11]; + output[27] = buf0[27]; + output[28] = buf0[7]; + output[29] = buf0[23]; + output[30] = buf0[15]; + output[31] = buf0[31]; +} + +static void highbd_fdct64_x4_neon(const int32x4_t *input, int32x4_t *output, + int8_t cos_bit) { + const int32_t *const cospi = cospi_arr_s32(cos_bit); + const int32x4_t v_cos_bit = vdupq_n_s32(-cos_bit); + + // stage 1 + int32x4_t x1[64]; + butterfly_dct_pre(input, x1, 64); + + // stage 2 + int32x4_t x2[64]; + butterfly_dct_pre(x1, x2, 32); + x2[32] = x1[32]; + x2[33] = x1[33]; + x2[34] = x1[34]; + x2[35] = x1[35]; + x2[36] = x1[36]; + x2[37] = x1[37]; + x2[38] = x1[38]; + x2[39] = x1[39]; + butterfly_0112_neon(cospi, 32, x1[55], x1[40], &x2[55], &x2[40], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[54], x1[41], &x2[54], &x2[41], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[53], x1[42], &x2[53], &x2[42], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[52], x1[43], &x2[52], &x2[43], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[51], x1[44], &x2[51], &x2[44], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[50], x1[45], &x2[50], &x2[45], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[49], x1[46], &x2[49], &x2[46], v_cos_bit); + butterfly_0112_neon(cospi, 32, x1[48], x1[47], &x2[48], &x2[47], v_cos_bit); + x2[56] = x1[56]; + x2[57] = x1[57]; + x2[58] = x1[58]; + x2[59] = x1[59]; + x2[60] = x1[60]; + x2[61] = x1[61]; + x2[62] = x1[62]; + x2[63] = x1[63]; + + // stage 3 + int32x4_t x3[64]; + butterfly_dct_pre(x2, x3, 16); + x3[16] = x2[16]; + x3[17] = x2[17]; + x3[18] = x2[18]; + x3[19] = x2[19]; + butterfly_0112_neon(cospi, 32, x2[27], x2[20], &x3[27], &x3[20], v_cos_bit); + butterfly_0112_neon(cospi, 32, x2[26], x2[21], &x3[26], &x3[21], v_cos_bit); + butterfly_0112_neon(cospi, 32, x2[25], x2[22], &x3[25], &x3[22], v_cos_bit); + butterfly_0112_neon(cospi, 32, x2[24], x2[23], &x3[24], &x3[23], v_cos_bit); + x3[28] = x2[28]; + x3[29] = x2[29]; + x3[30] = x2[30]; + x3[31] = x2[31]; + butterfly_dct_post(x2 + 32, x2 + 32, x3 + 32, 32); + + // stage 4 + int32x4_t x4[64]; + butterfly_dct_pre(x3, x4, 8); + x4[8] = x3[8]; + x4[9] = x3[9]; + butterfly_0112_neon(cospi, 32, x3[13], x3[10], &x4[13], &x4[10], v_cos_bit); + butterfly_0112_neon(cospi, 32, x3[12], x3[11], &x4[12], &x4[11], v_cos_bit); + x4[14] = x3[14]; + x4[15] = x3[15]; + butterfly_dct_post(x3 + 16, x3 + 16, x4 + 16, 16); + x4[32] = x3[32]; + x4[33] = x3[33]; + x4[34] = x3[34]; + x4[35] = x3[35]; + butterfly_0112_neon(cospi, 16, x3[59], x3[36], &x4[59], &x4[36], v_cos_bit); + butterfly_0112_neon(cospi, 16, x3[58], x3[37], &x4[58], &x4[37], v_cos_bit); + butterfly_0112_neon(cospi, 16, x3[57], x3[38], &x4[57], &x4[38], v_cos_bit); + butterfly_0112_neon(cospi, 16, x3[56], x3[39], &x4[56], &x4[39], v_cos_bit); + butterfly_2312_neon(cospi, 16, x3[55], x3[40], &x4[40], &x4[55], v_cos_bit); + butterfly_2312_neon(cospi, 16, x3[54], x3[41], &x4[41], &x4[54], v_cos_bit); + butterfly_2312_neon(cospi, 16, x3[53], x3[42], &x4[42], &x4[53], v_cos_bit); + butterfly_2312_neon(cospi, 16, x3[52], x3[43], &x4[43], &x4[52], v_cos_bit); + x4[44] = x3[44]; + x4[45] = x3[45]; + x4[46] = x3[46]; + x4[47] = x3[47]; + x4[48] = x3[48]; + x4[49] = x3[49]; + x4[50] = x3[50]; + x4[51] = x3[51]; + x4[60] = x3[60]; + x4[61] = x3[61]; + x4[62] = x3[62]; + x4[63] = x3[63]; + + // stage 5 + int32x4_t x5[64]; + butterfly_dct_pre(x4, x5, 4); + x5[4] = x4[4]; + butterfly_0112_neon(cospi, 32, x4[6], x4[5], &x5[6], &x5[5], v_cos_bit); + x5[7] = x4[7]; + butterfly_dct_post(x4 + 8, x4 + 8, x5 + 8, 8); + x5[16] = x4[16]; + x5[17] = x4[17]; + butterfly_0112_neon(cospi, 16, x4[29], x4[18], &x5[29], &x5[18], v_cos_bit); + butterfly_0112_neon(cospi, 16, x4[28], x4[19], &x5[28], &x5[19], v_cos_bit); + butterfly_2312_neon(cospi, 16, x4[27], x4[20], &x5[20], &x5[27], v_cos_bit); + butterfly_2312_neon(cospi, 16, x4[26], x4[21], &x5[21], &x5[26], v_cos_bit); + x5[22] = x4[22]; + x5[23] = x4[23]; + x5[24] = x4[24]; + x5[25] = x4[25]; + x5[30] = x4[30]; + x5[31] = x4[31]; + butterfly_dct_post(x4 + 32, x4 + 32, x5 + 32, 16); + butterfly_dct_post(x4 + 48, x4 + 48, x5 + 48, 16); + + // stage 6 + int32x4_t x6[64]; + butterfly_0112_neon(cospi, 32, x5[0], x5[1], &x6[0], &x6[1], v_cos_bit); + butterfly_0112_neon(cospi, 16, x5[3], x5[2], &x6[2], &x6[3], v_cos_bit); + butterfly_dct_post(x5 + 4, x5 + 4, x6 + 4, 4); + x6[8] = x5[8]; + butterfly_0112_neon(cospi, 16, x5[14], x5[9], &x6[14], &x6[9], v_cos_bit); + butterfly_2312_neon(cospi, 16, x5[13], x5[10], &x6[10], &x6[13], v_cos_bit); + x6[11] = x5[11]; + x6[12] = x5[12]; + x6[15] = x5[15]; + butterfly_dct_post(x5 + 16, x5 + 16, x6 + 16, 8); + butterfly_dct_post(x5 + 24, x5 + 24, x6 + 24, 8); + x6[32] = x5[32]; + x6[33] = x5[33]; + butterfly_0112_neon(cospi, 8, x5[61], x5[34], &x6[61], &x6[34], v_cos_bit); + butterfly_0112_neon(cospi, 8, x5[60], x5[35], &x6[60], &x6[35], v_cos_bit); + butterfly_2312_neon(cospi, 8, x5[59], x5[36], &x6[36], &x6[59], v_cos_bit); + butterfly_2312_neon(cospi, 8, x5[58], x5[37], &x6[37], &x6[58], v_cos_bit); + x6[38] = x5[38]; + x6[39] = x5[39]; + x6[40] = x5[40]; + x6[41] = x5[41]; + butterfly_0130_neon(cospi, 24, x5[42], x5[53], &x6[53], &x6[42], v_cos_bit); + butterfly_0130_neon(cospi, 24, x5[43], x5[52], &x6[52], &x6[43], v_cos_bit); + butterfly_0332_neon(cospi, 24, x5[51], x5[44], &x6[51], &x6[44], v_cos_bit); + butterfly_0332_neon(cospi, 24, x5[50], x5[45], &x6[50], &x6[45], v_cos_bit); + x6[46] = x5[46]; + x6[47] = x5[47]; + x6[48] = x5[48]; + x6[49] = x5[49]; + x6[54] = x5[54]; + x6[55] = x5[55]; + x6[56] = x5[56]; + x6[57] = x5[57]; + x6[62] = x5[62]; + x6[63] = x5[63]; + + // stage 7 + int32x4_t x7[64]; + x7[0] = x6[0]; + x7[1] = x6[1]; + x7[2] = x6[2]; + x7[3] = x6[3]; + butterfly_0112_neon(cospi, 8, x6[7], x6[4], &x7[4], &x7[7], v_cos_bit); + butterfly_0130_neon(cospi, 24, x6[5], x6[6], &x7[5], &x7[6], v_cos_bit); + butterfly_dct_post(x6 + 8, x6 + 8, x7 + 8, 4); + butterfly_dct_post(x6 + 12, x6 + 12, x7 + 12, 4); + x7[16] = x6[16]; + butterfly_0112_neon(cospi, 8, x6[30], x6[17], &x7[30], &x7[17], v_cos_bit); + butterfly_2312_neon(cospi, 8, x6[29], x6[18], &x7[18], &x7[29], v_cos_bit); + x7[19] = x6[19]; + x7[20] = x6[20]; + butterfly_0130_neon(cospi, 24, x6[21], x6[26], &x7[26], &x7[21], v_cos_bit); + butterfly_0332_neon(cospi, 24, x6[25], x6[22], &x7[25], &x7[22], v_cos_bit); + x7[23] = x6[23]; + x7[24] = x6[24]; + x7[27] = x6[27]; + x7[28] = x6[28]; + x7[31] = x6[31]; + butterfly_dct_post(x6 + 32, x6 + 32, x7 + 32, 8); + butterfly_dct_post(x6 + 40, x6 + 40, x7 + 40, 8); + butterfly_dct_post(x6 + 48, x6 + 48, x7 + 48, 8); + butterfly_dct_post(x6 + 56, x6 + 56, x7 + 56, 8); + + // stage 8 + int32x4_t x8[64]; + x8[0] = x7[0]; + x8[1] = x7[1]; + x8[2] = x7[2]; + x8[3] = x7[3]; + x8[4] = x7[4]; + x8[5] = x7[5]; + x8[6] = x7[6]; + x8[7] = x7[7]; + + butterfly_0112_neon(cospi, 4, x7[15], x7[8], &x8[8], &x8[15], v_cos_bit); + butterfly_0130_neon(cospi, 28, x7[9], x7[14], &x8[9], &x8[14], v_cos_bit); + butterfly_0112_neon(cospi, 20, x7[13], x7[10], &x8[10], &x8[13], v_cos_bit); + butterfly_0130_neon(cospi, 12, x7[11], x7[12], &x8[11], &x8[12], v_cos_bit); + butterfly_dct_post(x7 + 16, x7 + 16, x8 + 16, 4); + butterfly_dct_post(x7 + 20, x7 + 20, x8 + 20, 4); + butterfly_dct_post(x7 + 24, x7 + 24, x8 + 24, 4); + butterfly_dct_post(x7 + 28, x7 + 28, x8 + 28, 4); + x8[32] = x7[32]; + butterfly_0112_neon(cospi, 4, x7[62], x7[33], &x8[62], &x8[33], v_cos_bit); + butterfly_2312_neon(cospi, 4, x7[61], x7[34], &x8[34], &x8[61], v_cos_bit); + x8[35] = x7[35]; + x8[36] = x7[36]; + butterfly_0130_neon(cospi, 28, x7[37], x7[58], &x8[58], &x8[37], v_cos_bit); + butterfly_0332_neon(cospi, 28, x7[57], x7[38], &x8[57], &x8[38], v_cos_bit); + x8[39] = x7[39]; + x8[40] = x7[40]; + butterfly_0112_neon(cospi, 20, x7[54], x7[41], &x8[54], &x8[41], v_cos_bit); + butterfly_2312_neon(cospi, 20, x7[53], x7[42], &x8[42], &x8[53], v_cos_bit); + x8[43] = x7[43]; + x8[44] = x7[44]; + butterfly_0130_neon(cospi, 12, x7[45], x7[50], &x8[50], &x8[45], v_cos_bit); + butterfly_0332_neon(cospi, 12, x7[49], x7[46], &x8[49], &x8[46], v_cos_bit); + x8[47] = x7[47]; + x8[48] = x7[48]; + x8[51] = x7[51]; + x8[52] = x7[52]; + x8[55] = x7[55]; + x8[56] = x7[56]; + x8[59] = x7[59]; + x8[60] = x7[60]; + x8[63] = x7[63]; + + // stage 9 + int32x4_t x9[64]; + x9[0] = x8[0]; + x9[1] = x8[1]; + x9[2] = x8[2]; + x9[3] = x8[3]; + x9[4] = x8[4]; + x9[5] = x8[5]; + x9[6] = x8[6]; + x9[7] = x8[7]; + x9[8] = x8[8]; + x9[9] = x8[9]; + x9[10] = x8[10]; + x9[11] = x8[11]; + x9[12] = x8[12]; + x9[13] = x8[13]; + x9[14] = x8[14]; + x9[15] = x8[15]; + butterfly_0112_neon(cospi, 2, x8[31], x8[16], &x9[16], &x9[31], v_cos_bit); + butterfly_0130_neon(cospi, 30, x8[17], x8[30], &x9[17], &x9[30], v_cos_bit); + butterfly_0112_neon(cospi, 18, x8[29], x8[18], &x9[18], &x9[29], v_cos_bit); + butterfly_0130_neon(cospi, 14, x8[19], x8[28], &x9[19], &x9[28], v_cos_bit); + butterfly_0112_neon(cospi, 10, x8[27], x8[20], &x9[20], &x9[27], v_cos_bit); + butterfly_0130_neon(cospi, 22, x8[21], x8[26], &x9[21], &x9[26], v_cos_bit); + butterfly_0112_neon(cospi, 26, x8[25], x8[22], &x9[22], &x9[25], v_cos_bit); + butterfly_0130_neon(cospi, 6, x8[23], x8[24], &x9[23], &x9[24], v_cos_bit); + butterfly_dct_post(x8 + 32, x8 + 32, x9 + 32, 4); + butterfly_dct_post(x8 + 36, x8 + 36, x9 + 36, 4); + butterfly_dct_post(x8 + 40, x8 + 40, x9 + 40, 4); + butterfly_dct_post(x8 + 44, x8 + 44, x9 + 44, 4); + butterfly_dct_post(x8 + 48, x8 + 48, x9 + 48, 4); + butterfly_dct_post(x8 + 52, x8 + 52, x9 + 52, 4); + butterfly_dct_post(x8 + 56, x8 + 56, x9 + 56, 4); + butterfly_dct_post(x8 + 60, x8 + 60, x9 + 60, 4); + + // stage 10 + int32x4_t x10[64]; + x10[0] = x9[0]; + x10[1] = x9[1]; + x10[2] = x9[2]; + x10[3] = x9[3]; + x10[4] = x9[4]; + x10[5] = x9[5]; + x10[6] = x9[6]; + x10[7] = x9[7]; + x10[8] = x9[8]; + x10[9] = x9[9]; + x10[10] = x9[10]; + x10[11] = x9[11]; + x10[12] = x9[12]; + x10[13] = x9[13]; + x10[14] = x9[14]; + x10[15] = x9[15]; + x10[16] = x9[16]; + x10[17] = x9[17]; + x10[18] = x9[18]; + x10[19] = x9[19]; + x10[20] = x9[20]; + x10[21] = x9[21]; + x10[22] = x9[22]; + x10[23] = x9[23]; + x10[24] = x9[24]; + x10[25] = x9[25]; + x10[26] = x9[26]; + x10[27] = x9[27]; + x10[28] = x9[28]; + x10[29] = x9[29]; + x10[30] = x9[30]; + x10[31] = x9[31]; + butterfly_0112_neon(cospi, 1, x9[63], x9[32], &x10[32], &x10[63], v_cos_bit); + butterfly_0130_neon(cospi, 31, x9[33], x9[62], &x10[33], &x10[62], v_cos_bit); + butterfly_0112_neon(cospi, 17, x9[61], x9[34], &x10[34], &x10[61], v_cos_bit); + butterfly_0130_neon(cospi, 15, x9[35], x9[60], &x10[35], &x10[60], v_cos_bit); + butterfly_0112_neon(cospi, 9, x9[59], x9[36], &x10[36], &x10[59], v_cos_bit); + butterfly_0130_neon(cospi, 23, x9[37], x9[58], &x10[37], &x10[58], v_cos_bit); + butterfly_0112_neon(cospi, 25, x9[57], x9[38], &x10[38], &x10[57], v_cos_bit); + butterfly_0130_neon(cospi, 7, x9[39], x9[56], &x10[39], &x10[56], v_cos_bit); + butterfly_0112_neon(cospi, 5, x9[55], x9[40], &x10[40], &x10[55], v_cos_bit); + butterfly_0130_neon(cospi, 27, x9[41], x9[54], &x10[41], &x10[54], v_cos_bit); + butterfly_0112_neon(cospi, 21, x9[53], x9[42], &x10[42], &x10[53], v_cos_bit); + butterfly_0130_neon(cospi, 11, x9[43], x9[52], &x10[43], &x10[52], v_cos_bit); + butterfly_0112_neon(cospi, 13, x9[51], x9[44], &x10[44], &x10[51], v_cos_bit); + butterfly_0130_neon(cospi, 19, x9[45], x9[50], &x10[45], &x10[50], v_cos_bit); + butterfly_0112_neon(cospi, 29, x9[49], x9[46], &x10[46], &x10[49], v_cos_bit); + butterfly_0130_neon(cospi, 3, x9[47], x9[48], &x10[47], &x10[48], v_cos_bit); + + // stage 11 + output[0] = x10[0]; + output[1] = x10[32]; + output[2] = x10[16]; + output[3] = x10[48]; + output[4] = x10[8]; + output[5] = x10[40]; + output[6] = x10[24]; + output[7] = x10[56]; + output[8] = x10[4]; + output[9] = x10[36]; + output[10] = x10[20]; + output[11] = x10[52]; + output[12] = x10[12]; + output[13] = x10[44]; + output[14] = x10[28]; + output[15] = x10[60]; + output[16] = x10[2]; + output[17] = x10[34]; + output[18] = x10[18]; + output[19] = x10[50]; + output[20] = x10[10]; + output[21] = x10[42]; + output[22] = x10[26]; + output[23] = x10[58]; + output[24] = x10[6]; + output[25] = x10[38]; + output[26] = x10[22]; + output[27] = x10[54]; + output[28] = x10[14]; + output[29] = x10[46]; + output[30] = x10[30]; + output[31] = x10[62]; + output[32] = x10[1]; + output[33] = x10[33]; + output[34] = x10[17]; + output[35] = x10[49]; + output[36] = x10[9]; + output[37] = x10[41]; + output[38] = x10[25]; + output[39] = x10[57]; + output[40] = x10[5]; + output[41] = x10[37]; + output[42] = x10[21]; + output[43] = x10[53]; + output[44] = x10[13]; + output[45] = x10[45]; + output[46] = x10[29]; + output[47] = x10[61]; + output[48] = x10[3]; + output[49] = x10[35]; + output[50] = x10[19]; + output[51] = x10[51]; + output[52] = x10[11]; + output[53] = x10[43]; + output[54] = x10[27]; + output[55] = x10[59]; + output[56] = x10[7]; + output[57] = x10[39]; + output[58] = x10[23]; + output[59] = x10[55]; + output[60] = x10[15]; + output[61] = x10[47]; + output[62] = x10[31]; + output[63] = x10[63]; +} + +static void highbd_fidentity32_x4_neon(const int32x4_t *input, + int32x4_t *output, int cos_bit) { + (void)cos_bit; + for (int i = 0; i < 32; i++) { + output[i] = vshlq_n_s32(input[i], 2); + } +} + +TRANSFORM_COL_MANY(fdct32, 32) +TRANSFORM_COL_MANY(fidentity32, 32) + +static const fwd_transform_1d_col_many_neon + col_highbd_txfm32_x4_arr[TX_TYPES] = { + highbd_fdct32_col_many_neon, // DCT_DCT + NULL, // ADST_DCT + NULL, // DCT_ADST + NULL, // ADST_ADST + NULL, // FLIPADST_DCT + NULL, // DCT_FLIPADST + NULL, // FLIPADST_FLIPADST + NULL, // ADST_FLIPADST + NULL, // FLIPADST_ADST + highbd_fidentity32_col_many_neon, // IDTX + NULL, // V_DCT + NULL, // H_DCT + NULL, // V_ADST + NULL, // H_ADST + NULL, // V_FLIPADST + NULL // H_FLIPADST + }; + +TRANSFORM_ROW_MANY(fdct32, 32) +TRANSFORM_ROW_MANY(fidentity32, 32) + +static const fwd_transform_1d_row_many_neon + row_highbd_txfm32_x4_arr[TX_TYPES] = { + highbd_fdct32_row_many_neon, // DCT_DCT + NULL, // ADST_DCT + NULL, // DCT_ADST + NULL, // ADST_ADST + NULL, // FLIPADST_DCT + NULL, // DCT_FLIPADST + NULL, // FLIPADST_FLIPADST + NULL, // ADST_FLIPADST + NULL, // FLIPADST_ADST + highbd_fidentity32_row_many_neon, // IDTX + NULL, // V_DCT + NULL, // H_DCT + NULL, // V_ADST + NULL, // H_ADST + NULL, // V_FLIPADST + NULL // H_FLIPADST + }; + +TRANSFORM_ROW_RECT_MANY(fdct32, 32) +TRANSFORM_ROW_RECT_MANY(fidentity32, 32) + +static const fwd_transform_1d_row_many_neon + row_rect_highbd_txfm32_x4_arr[TX_TYPES] = { + highbd_fdct32_row_rect_many_neon, // DCT_DCT + NULL, // ADST_DCT + NULL, // DCT_ADST + NULL, // ADST_ADST + NULL, // FLIPADST_DCT + NULL, // DCT_FLIPADST + NULL, // FLIPADST_FLIPADST + NULL, // ADST_FLIPADST + NULL, // FLIPADST_ADST + highbd_fidentity32_row_rect_many_neon, // IDTX + NULL, // V_DCT + NULL, // H_DCT + NULL, // V_ADST + NULL, // H_ADST + NULL, // V_FLIPADST + NULL // H_FLIPADST + }; + +void av1_fwd_txfm2d_16x8_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm8_xn_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_rect_highbd_txfm16_xn_arr[tx_type]; + int bit = av1_fwd_cos_bit_col[2][1]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + + // Column-wise transform. + int32x4_t buf0[32]; + if (lr_flip) { + col_txfm(input, buf0 + 3 * 8, stride, bit, /*lr_flip=*/1, /*howmany=*/4, + /*hm_stride=*/-8); + } else { + col_txfm(input, buf0, stride, bit, /*lr_flip=*/0, /*howmany=*/4, + /*hm_stride=*/8); + } + shift_right_2_round_s32_x4(buf0, buf0, 32); + + int32x4_t buf1[32]; + transpose_arrays_s32_16x8(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bit, /*howmany=*/2, /*hm_stride=*/16, /*stride=*/8); +} + +void av1_fwd_txfm2d_8x16_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm16_xn_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_rect_highbd_txfm8_xn_arr[tx_type]; + int bit = av1_fwd_cos_bit_col[1][2]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + + // Column-wise transform. + int32x4_t buf0[32]; + if (lr_flip) { + col_txfm(input, buf0 + 16, stride, bit, /*lr_flip=*/1, /*howmany=*/2, + /*hm_stride=*/-16); + } else { + col_txfm(input, buf0, stride, bit, /*lr_flip=*/0, /*howmany=*/2, + /*hm_stride=*/16); + } + shift_right_2_round_s32_x4(buf0, buf0, 32); + + int32x4_t buf1[32]; + transpose_arrays_s32_8x16(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bit, /*howmany=*/4, /*hm_stride=*/8, /*stride=*/16); +} + +#if !CONFIG_REALTIME_ONLY +void av1_fwd_txfm2d_4x16_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + int bitcol = av1_fwd_cos_bit_col[0][2]; + int bitrow = av1_fwd_cos_bit_row[0][2]; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm16_xn_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_highbd_txfm4_xn_arr[tx_type]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + + // Column-wise transform. + int32x4_t buf0[16]; + if (lr_flip) { + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/1, /*howmany=*/1, + /*hm_stride=*/0); + } else { + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/1, + /*hm_stride=*/0); + } + shift_right_1_round_s32_x4(buf0, buf0, 16); + + int32x4_t buf1[16]; + transpose_arrays_s32_4x16(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*howmany=*/4, /*hm_stride=*/4, /*stride=*/16); +} +#endif + +void av1_fwd_txfm2d_16x4_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + int bitcol = av1_fwd_cos_bit_col[2][0]; + int bitrow = av1_fwd_cos_bit_row[2][0]; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm4_xn_arr[tx_type]; + const fwd_transform_1d_row_neon row_txfm = row_highbd_txfm16_xn_arr[tx_type]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 4); + + // Column-wise transform. + int32x4_t buf0[16]; + if (lr_flip) { + col_txfm(input, buf0 + 3 * 4, stride, bitcol, /*lr_flip=*/1, /*howmany=*/4, + /*hm_stride=*/-4); + } else { + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/4, + /*hm_stride=*/4); + } + + shift_right_1_round_s32_x4(buf0, buf0, 16); + transpose_arrays_s32_4x16(buf0, buf0); + + // Row-wise transform. + row_txfm(buf0, coeff, bitrow, /*stride=*/4); +} + +void av1_fwd_txfm2d_16x32_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm32_x4_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_rect_highbd_txfm16_xn_arr[tx_type]; + int bitcol = av1_fwd_cos_bit_col[2][3]; + int bitrow = av1_fwd_cos_bit_row[2][3]; + + // Column-wise transform. + int32x4_t buf0[128]; + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/4, + /*hm_stride=*/32); + shift_right_4_round_s32_x4(buf0, buf0, 128); + + int32x4_t buf1[128]; + transpose_arrays_s32_16x32(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*howmany=*/8, /*hm_stride=*/16, /*stride=*/32); +} + +void av1_fwd_txfm2d_32x64_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + int bitcol = av1_fwd_cos_bit_col[3][4]; + int bitrow = av1_fwd_cos_bit_row[3][4]; + + // Column-wise transform. + int32x4_t buf0[512]; + load_buffer_32x64(input, buf0, stride, 0); + for (int i = 0; i < 8; i++) { + highbd_fdct64_x4_neon(buf0 + i * 64, buf0 + i * 64, bitcol); + } + shift_right_2_round_s32_x4(buf0, buf0, 512); + + int32x4_t buf1[512]; + transpose_arrays_s32_32x64(buf0, buf1); + + // Row-wise transform. + for (int i = 0; i < 16; i++) { + highbd_fdct32_x4_neon(buf1 + i * 32, buf1 + i * 32, bitrow); + } + round_shift2_rect_array_s32_neon(buf1, buf1, 512); + store_buffer_32x32(buf1, coeff, /*stride=*/32); +} + +void av1_fwd_txfm2d_64x32_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + int bitcol = av1_fwd_cos_bit_col[4][3]; + int bitrow = av1_fwd_cos_bit_row[4][3]; + + // Column-wise transform. + int32x4_t buf0[512]; + load_buffer_64x32(input, buf0, stride, 0); + for (int i = 0; i < 16; i++) { + highbd_fdct32_x4_neon(buf0 + i * 32, buf0 + i * 32, bitcol); + } + shift_right_4_round_s32_x4(buf0, buf0, 512); + + int32x4_t buf1[512]; + transpose_arrays_s32_64x32(buf0, buf1); + + // Row-wise transform. + for (int i = 0; i < 8; i++) { + highbd_fdct64_x4_neon(buf1 + i * 64, buf1 + i * 64, bitrow); + } + round_shift2_rect_array_s32_neon(buf1, buf1, 512); + store_buffer_64x32(buf1, coeff, /*stride=*/32); +} + +void av1_fwd_txfm2d_32x16_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm16_xn_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_rect_highbd_txfm32_x4_arr[tx_type]; + int bitcol = av1_fwd_cos_bit_col[3][2]; + int bitrow = av1_fwd_cos_bit_row[3][2]; + + // Column-wise transform. + int32x4_t buf0[128]; + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/8, + /*hm_stride=*/16); + shift_right_4_round_s32_x4(buf0, buf0, 128); + + int32x4_t buf1[128]; + transpose_arrays_s32_32x16(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*howmany=*/4, /*hm_stride=*/32, /*stride=*/16); +} + +#if !CONFIG_REALTIME_ONLY +void av1_fwd_txfm2d_8x32_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm32_x4_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_highbd_txfm8_xn_arr[tx_type]; + int bitcol = av1_fwd_cos_bit_col[1][3]; + int bitrow = av1_fwd_cos_bit_row[1][3]; + + // Column-wise transform. + int32x4_t buf0[64]; + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/2, + /*hm_stride=*/32); + shift_right_2_round_s32_x4(buf0, buf0, 64); + + int32x4_t buf1[64]; + transpose_arrays_s32_8x32(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*howmany=*/8, /*hm_stride=*/8, /*stride=*/32); +} + +void av1_fwd_txfm2d_32x8_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm8_xn_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_highbd_txfm32_x4_arr[tx_type]; + int bitcol = av1_fwd_cos_bit_col[3][1]; + int bitrow = av1_fwd_cos_bit_row[3][1]; + + // Column-wise transform. + int32x4_t buf0[64]; + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/8, + /*hm_stride=*/8); + shift_right_2_round_s32_x4(buf0, buf0, 64); + + int32x4_t buf1[64]; + transpose_arrays_s32_32x8(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*howmany=*/2, /*hm_stride=*/32, /*stride=*/8); +} +#endif + +void av1_fwd_txfm2d_4x8_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + int bitcol = av1_fwd_cos_bit_col[0][1]; + int bitrow = av1_fwd_cos_bit_row[0][1]; + const fwd_transform_1d_col_neon col_txfm = col_highbd_txfm8_x4_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_rect_highbd_txfm4_xn_arr[tx_type]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 8); + + // Column-wise transform. + int32x4_t buf0[8]; + col_txfm(input, buf0, stride, bitcol, lr_flip); + shift_right_1_round_s32_x4(buf0, buf0, 8); + + int32x4_t buf1[8]; + transpose_arrays_s32_4x8(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*howmany=*/2, /*hm_stride=*/4, /*stride=*/8); +} + +void av1_fwd_txfm2d_8x4_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const int bitcol = av1_fwd_cos_bit_col[1][0]; + const int bitrow = av1_fwd_cos_bit_row[1][0]; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm4_xn_arr[tx_type]; + const fwd_transform_1d_row_neon row_txfm = row_highbd_txfm8_x4_arr[tx_type]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 4); + + // Column-wise transform. + int32x4_t buf0[8]; + if (lr_flip) { + col_txfm(input, buf0 + 4, stride, bitcol, /*lr_flip=*/1, /*howmany=*/2, + /*hm_stride=*/-4); + } else { + col_txfm(input, buf0, stride, bitcol, /*lr_flip=*/0, /*howmany=*/2, + /*hm_stride=*/4); + } + + shift_right_1_round_s32_x4(buf0, buf0, 8); + + int32x4_t buf1[8]; + transpose_arrays_s32_8x4(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, coeff, bitrow, /*stride=*/4); +} + +#if !CONFIG_REALTIME_ONLY +void av1_fwd_txfm2d_16x64_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const int bitcol = av1_fwd_cos_bit_col[2][4]; + const int bitrow = av1_fwd_cos_bit_row[2][4]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 64); + + // Column-wise transform. + int32x4_t buf0[256]; + load_buffer_16x64(input, buf0, stride, lr_flip); + for (int i = 0; i < 4; i++) { + highbd_fdct64_x4_neon(buf0 + i * 64, buf0 + i * 64, bitcol); + } + shift_right_2_round_s32_x4(buf0, buf0, 256); + + int32x4_t buf1[256]; + transpose_arrays_s32_16x64(buf0, buf1); + + // Row-wise transform. + highbd_fdct16_xn_neon(buf1, buf1, bitrow, 8); + store_buffer_16x32(buf1, coeff, /*stride=*/32); +} + +void av1_fwd_txfm2d_64x16_neon(const int16_t *input, int32_t *coeff, int stride, + TX_TYPE tx_type, int bd) { + (void)bd; + const int bitcol = av1_fwd_cos_bit_col[4][2]; + const int bitrow = av1_fwd_cos_bit_row[4][2]; + + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + ud_adjust_input_and_stride(ud_flip, &input, &stride, 16); + + // Column-wise transform. + int32x4_t buf0[256]; + load_buffer_64x16(input, buf0, stride, lr_flip); + highbd_fdct16_xn_neon(buf0, buf0, bitcol, 16); + shift_right_4_round_s32_x4(buf0, buf0, 256); + + int32x4_t buf1[256]; + transpose_arrays_s32_64x16(buf0, buf1); + + // Row-wise transform. + for (int i = 0; i < 4; i++) { + highbd_fdct64_x4_neon(buf1 + i * 64, buf1 + i * 64, bitrow); + } + store_buffer_64x16(buf1, coeff, /*stride=*/16); + memset(coeff + 16 * 32, 0, 16 * 32 * sizeof(*coeff)); +} +#endif + +void av1_fwd_txfm2d_32x32_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + const fwd_transform_1d_col_many_neon col_txfm = + col_highbd_txfm32_x4_arr[tx_type]; + const fwd_transform_1d_row_many_neon row_txfm = + row_highbd_txfm32_x4_arr[tx_type]; + + // Column-wise transform. + int32x4_t buf0[256]; + col_txfm(input, buf0, stride, /*cos_bit=*/12, /*lr_flip=*/0, /*howmany=*/8, + /*hm_stride=*/32); + shift_right_4_round_s32_x4(buf0, buf0, 256); + + int32x4_t buf1[256]; + transpose_arrays_s32_32x32(buf0, buf1); + + // Row-wise transform. + row_txfm(buf1, output, /*cos_bit=*/12, /*howmany=*/8, /*hm_stride=*/32, + /*stride=*/32); +} + +void av1_fwd_txfm2d_64x64_neon(const int16_t *input, int32_t *output, + int stride, TX_TYPE tx_type, int bd) { + (void)bd; + (void)tx_type; + + // Column-wise transform. + int32x4_t buf0[1024]; + load_buffer_64x64(input, buf0, stride, 0); + for (int col = 0; col < 16; col++) { + highbd_fdct64_x4_neon(buf0 + col * 64, buf0 + col * 64, 13); + } + shift_right_2_round_s32_x4(buf0, buf0, 1024); + + int32x4_t buf1[1024]; + transpose_arrays_s32_64x64(buf0, buf1); + + // Row-wise transform. + for (int col = 0; col < 8; col++) { + highbd_fdct64_x4_neon(buf1 + col * 64, buf1 + col * 64, 10); + } + shift_right_2_round_s32_x4(buf1, buf1, 512); + store_buffer_64x32(buf1, output, /*stride=*/32); +} diff --git a/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c b/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c new file mode 100644 index 0000000000..47b5f5cfb7 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/highbd_pickrst_neon.c @@ -0,0 +1,1207 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> +#include <stdint.h> + +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "av1/encoder/arm/neon/pickrst_neon.h" +#include "av1/encoder/pickrst.h" + +static INLINE void highbd_calc_proj_params_r0_r1_neon( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) { + assert(width % 8 == 0); + const int size = width * height; + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8); + + int64x2_t h00_lo = vdupq_n_s64(0); + int64x2_t h00_hi = vdupq_n_s64(0); + int64x2_t h11_lo = vdupq_n_s64(0); + int64x2_t h11_hi = vdupq_n_s64(0); + int64x2_t h01_lo = vdupq_n_s64(0); + int64x2_t h01_hi = vdupq_n_s64(0); + int64x2_t c0_lo = vdupq_n_s64(0); + int64x2_t c0_hi = vdupq_n_s64(0); + int64x2_t c1_lo = vdupq_n_s64(0); + int64x2_t c1_hi = vdupq_n_s64(0); + + do { + const uint16_t *src_ptr = src; + const uint16_t *dat_ptr = dat; + int32_t *flt0_ptr = flt0; + int32_t *flt1_ptr = flt1; + int w = width; + + do { + uint16x8_t s = vld1q_u16(src_ptr); + uint16x8_t d = vld1q_u16(dat_ptr); + int32x4_t f0_lo = vld1q_s32(flt0_ptr); + int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4); + int32x4_t f1_lo = vld1q_s32(flt1_ptr); + int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4); + + int32x4_t u_lo = + vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS)); + int32x4_t u_hi = vreinterpretq_s32_u32( + vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS)); + int32x4_t s_lo = + vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS)); + int32x4_t s_hi = vreinterpretq_s32_u32( + vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS)); + s_lo = vsubq_s32(s_lo, u_lo); + s_hi = vsubq_s32(s_hi, u_hi); + + f0_lo = vsubq_s32(f0_lo, u_lo); + f0_hi = vsubq_s32(f0_hi, u_hi); + f1_lo = vsubq_s32(f1_lo, u_lo); + f1_hi = vsubq_s32(f1_hi, u_hi); + + h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo)); + h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo)); + h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi)); + h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi)); + + h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo)); + h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo)); + h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi)); + h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi)); + + h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo)); + h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo)); + h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi)); + h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi)); + + c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo)); + c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo)); + c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi)); + c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi)); + + c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo)); + c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo)); + c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi)); + c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi)); + + src_ptr += 8; + dat_ptr += 8; + flt0_ptr += 8; + flt1_ptr += 8; + w -= 8; + } while (w != 0); + + src += src_stride; + dat += dat_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + } while (--height != 0); + + H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size; + H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size; + H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size; + H[1][0] = H[0][1]; + C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size; + C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size; +} + +static INLINE void highbd_calc_proj_params_r0_neon( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int64_t H[2][2], int64_t C[2]) { + assert(width % 8 == 0); + const int size = width * height; + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8); + + int64x2_t h00_lo = vdupq_n_s64(0); + int64x2_t h00_hi = vdupq_n_s64(0); + int64x2_t c0_lo = vdupq_n_s64(0); + int64x2_t c0_hi = vdupq_n_s64(0); + + do { + const uint16_t *src_ptr = src; + const uint16_t *dat_ptr = dat; + int32_t *flt0_ptr = flt0; + int w = width; + + do { + uint16x8_t s = vld1q_u16(src_ptr); + uint16x8_t d = vld1q_u16(dat_ptr); + int32x4_t f0_lo = vld1q_s32(flt0_ptr); + int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4); + + int32x4_t u_lo = + vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS)); + int32x4_t u_hi = vreinterpretq_s32_u32( + vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS)); + int32x4_t s_lo = + vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS)); + int32x4_t s_hi = vreinterpretq_s32_u32( + vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS)); + s_lo = vsubq_s32(s_lo, u_lo); + s_hi = vsubq_s32(s_hi, u_hi); + + f0_lo = vsubq_s32(f0_lo, u_lo); + f0_hi = vsubq_s32(f0_hi, u_hi); + + h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo)); + h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo)); + h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi)); + h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi)); + + c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo)); + c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo)); + c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi)); + c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi)); + + src_ptr += 8; + dat_ptr += 8; + flt0_ptr += 8; + w -= 8; + } while (w != 0); + + src += src_stride; + dat += dat_stride; + flt0 += flt0_stride; + } while (--height != 0); + + H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size; + C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size; +} + +static INLINE void highbd_calc_proj_params_r1_neon( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt1, int flt1_stride, + int64_t H[2][2], int64_t C[2]) { + assert(width % 8 == 0); + const int size = width * height; + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8); + + int64x2_t h11_lo = vdupq_n_s64(0); + int64x2_t h11_hi = vdupq_n_s64(0); + int64x2_t c1_lo = vdupq_n_s64(0); + int64x2_t c1_hi = vdupq_n_s64(0); + + do { + const uint16_t *src_ptr = src; + const uint16_t *dat_ptr = dat; + int32_t *flt1_ptr = flt1; + int w = width; + + do { + uint16x8_t s = vld1q_u16(src_ptr); + uint16x8_t d = vld1q_u16(dat_ptr); + int32x4_t f1_lo = vld1q_s32(flt1_ptr); + int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4); + + int32x4_t u_lo = + vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(d), SGRPROJ_RST_BITS)); + int32x4_t u_hi = vreinterpretq_s32_u32( + vshll_n_u16(vget_high_u16(d), SGRPROJ_RST_BITS)); + int32x4_t s_lo = + vreinterpretq_s32_u32(vshll_n_u16(vget_low_u16(s), SGRPROJ_RST_BITS)); + int32x4_t s_hi = vreinterpretq_s32_u32( + vshll_n_u16(vget_high_u16(s), SGRPROJ_RST_BITS)); + s_lo = vsubq_s32(s_lo, u_lo); + s_hi = vsubq_s32(s_hi, u_hi); + + f1_lo = vsubq_s32(f1_lo, u_lo); + f1_hi = vsubq_s32(f1_hi, u_hi); + + h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo)); + h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo)); + h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi)); + h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi)); + + c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo)); + c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo)); + c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi)); + c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi)); + + src_ptr += 8; + dat_ptr += 8; + flt1_ptr += 8; + w -= 8; + } while (w != 0); + + src += src_stride; + dat += dat_stride; + flt1 += flt1_stride; + } while (--height != 0); + + H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size; + C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size; +} + +// The function calls 3 subfunctions for the following cases : +// 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements +// of C and H need to be computed. +// 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are +// non-zero and need to be computed. +// 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are +// non-zero and need to be computed. +void av1_calc_proj_params_high_bd_neon(const uint8_t *src8, int width, + int height, int src_stride, + const uint8_t *dat8, int dat_stride, + int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, + int64_t H[2][2], int64_t C[2], + const sgr_params_type *params) { + if ((params->r[0] > 0) && (params->r[1] > 0)) { + highbd_calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8, + dat_stride, flt0, flt0_stride, flt1, + flt1_stride, H, C); + } else if (params->r[0] > 0) { + highbd_calc_proj_params_r0_neon(src8, width, height, src_stride, dat8, + dat_stride, flt0, flt0_stride, H, C); + } else if (params->r[1] > 0) { + highbd_calc_proj_params_r1_neon(src8, width, height, src_stride, dat8, + dat_stride, flt1, flt1_stride, H, C); + } +} + +static INLINE int16x8_t tbl2q(int16x8_t a, int16x8_t b, uint8x16_t idx) { +#if AOM_ARCH_AARCH64 + uint8x16x2_t table = { { vreinterpretq_u8_s16(a), vreinterpretq_u8_s16(b) } }; + return vreinterpretq_s16_u8(vqtbl2q_u8(table, idx)); +#else + uint8x8x4_t table = { { vreinterpret_u8_s16(vget_low_s16(a)), + vreinterpret_u8_s16(vget_high_s16(a)), + vreinterpret_u8_s16(vget_low_s16(b)), + vreinterpret_u8_s16(vget_high_s16(b)) } }; + return vreinterpretq_s16_u8(vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)), + vtbl4_u8(table, vget_high_u8(idx)))); +#endif +} + +static INLINE int16x8_t tbl3q(int16x8_t a, int16x8_t b, int16x8_t c, + uint8x16_t idx) { +#if AOM_ARCH_AARCH64 + uint8x16x3_t table = { { vreinterpretq_u8_s16(a), vreinterpretq_u8_s16(b), + vreinterpretq_u8_s16(c) } }; + return vreinterpretq_s16_u8(vqtbl3q_u8(table, idx)); +#else + // This is a specific implementation working only for compute stats with + // wiener_win == 5. + uint8x8x3_t table_lo = { { vreinterpret_u8_s16(vget_low_s16(a)), + vreinterpret_u8_s16(vget_high_s16(a)), + vreinterpret_u8_s16(vget_low_s16(b)) } }; + uint8x8x3_t table_hi = { { vreinterpret_u8_s16(vget_low_s16(b)), + vreinterpret_u8_s16(vget_high_s16(b)), + vreinterpret_u8_s16(vget_low_s16(c)) } }; + return vreinterpretq_s16_u8(vcombine_u8( + vtbl3_u8(table_lo, vget_low_u8(idx)), + vtbl3_u8(table_hi, vsub_u8(vget_high_u8(idx), vdup_n_u8(16))))); +#endif +} + +static INLINE int64_t div_shift_s64(int64_t x, int power) { + return (x < 0 ? x + (1ll << power) - 1 : x) >> power; +} + +// The M matrix is accumulated in a bitdepth-dependent number of steps to +// speed up the computation. This function computes the final M from the +// accumulated (src_s64) and the residual parts (src_s32). It also transposes +// the result as the output needs to be column-major. +static INLINE void acc_transpose_M(int64_t *dst, const int64_t *src_s64, + const int32_t *src_s32, const int wiener_win, + int shift) { + for (int i = 0; i < wiener_win; ++i) { + for (int j = 0; j < wiener_win; ++j) { + int tr_idx = j * wiener_win + i; + *dst++ = div_shift_s64(src_s64[tr_idx] + src_s32[tr_idx], shift); + } + } +} + +// The resulting H is a column-major matrix accumulated from the transposed +// (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single +// vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This +// function transforms back to the originally expected format (double +// transpose). The H matrix is accumulated in a bitdepth-dependent number of +// steps to speed up the computation. This function computes the final H from +// the accumulated (src_s64) and the residual parts (src_s32). The computed H is +// only an upper triangle matrix, this function also fills the lower triangle of +// the resulting matrix. +static INLINE void update_H(int64_t *dst, const int64_t *src_s64, + const int32_t *src_s32, const int wiener_win, + int stride, int shift) { + // For a simplified theoretical 3x3 case where `wiener_win` is 3 and + // `wiener_win2` is 9, the M matrix is 3x3: + // 0, 3, 6 + // 1, 4, 7 + // 2, 5, 8 + // + // This is viewed as a vector to compute H (9x9) by vector outer product: + // 0, 3, 6, 1, 4, 7, 2, 5, 8 + // + // Double transpose and upper triangle remapping for 3x3 -> 9x9 case: + // 0, 3, 6, 1, 4, 7, 2, 5, 8, + // 3, 30, 33, 12, 31, 34, 21, 32, 35, + // 6, 33, 60, 15, 42, 61, 24, 51, 62, + // 1, 12, 15, 10, 13, 16, 11, 14, 17, + // 4, 31, 42, 13, 40, 43, 22, 41, 44, + // 7, 34, 61, 16, 43, 70, 25, 52, 71, + // 2, 21, 24, 11, 22, 25, 20, 23, 26, + // 5, 32, 51, 14, 41, 52, 23, 50, 53, + // 8, 35, 62, 17, 44, 71, 26, 53, 80, + const int wiener_win2 = wiener_win * wiener_win; + + // Loop through the indices according to the remapping above, along the + // columns: + // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ..., + // wiener_win - 1, wiener_win - 1 + wiener_win, ... + // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8. + for (int i = 0; i < wiener_win; ++i) { + for (int j = i; j < wiener_win2; j += wiener_win) { + // These two inner loops are the same as the two outer loops, but running + // along rows instead of columns. For the 3x3 case `l` will be: + // 0, 3, 6, 1, 4, 7, 2, 5, 8. + for (int k = 0; k < wiener_win; ++k) { + for (int l = k; l < wiener_win2; l += wiener_win) { + // The nominal double transpose indexing would be: + // int idx = stride * j + l; + // However we need the upper-right triangle, it is easy with some + // min/max operations. + int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l); + + // Resulting matrix is filled by combining the 64-bit and the residual + // 32-bit matrices together with scaling. + *dst++ = div_shift_s64(src_s64[tr_idx] + src_s32[tr_idx], shift); + } + } + } + } +} + +// Load 7x7 matrix into 7 128-bit vectors from consecutive rows, the last load +// address is offset to prevent out-of-bounds access. +static INLINE void load_and_pack_s16_8x7(int16x8_t dst[7], const int16_t *src, + ptrdiff_t stride) { + dst[0] = vld1q_s16(src); + src += stride; + dst[1] = vld1q_s16(src); + src += stride; + dst[2] = vld1q_s16(src); + src += stride; + dst[3] = vld1q_s16(src); + src += stride; + dst[4] = vld1q_s16(src); + src += stride; + dst[5] = vld1q_s16(src); + src += stride; + dst[6] = vld1q_s16(src - 1); +} + +static INLINE void highbd_compute_stats_win7_neon( + const uint16_t *dgd, const uint16_t *src, int avg, int width, int height, + int dgd_stride, int src_stride, int64_t *M, int64_t *H, + aom_bit_depth_t bit_depth) { + // Matrix names are capitalized to help readability. + DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]); + DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]); + + memset(M_s32, 0, sizeof(M_s32)); + memset(M_s64, 0, sizeof(M_s64)); + memset(H_s32, 0, sizeof(H_s32)); + memset(H_s64, 0, sizeof(H_s64)); + + // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7 + // matrices. + // clang-format off + DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7_highbd[192]) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21, + 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 10, 11, 12, 13, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, + 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + }; + // clang-format on + + const uint8x16_t lut0 = vld1q_u8(shuffle_stats7_highbd + 0); + const uint8x16_t lut1 = vld1q_u8(shuffle_stats7_highbd + 16); + const uint8x16_t lut2 = vld1q_u8(shuffle_stats7_highbd + 32); + const uint8x16_t lut3 = vld1q_u8(shuffle_stats7_highbd + 48); + const uint8x16_t lut4 = vld1q_u8(shuffle_stats7_highbd + 64); + const uint8x16_t lut5 = vld1q_u8(shuffle_stats7_highbd + 80); + const uint8x16_t lut6 = vld1q_u8(shuffle_stats7_highbd + 96); + const uint8x16_t lut7 = vld1q_u8(shuffle_stats7_highbd + 112); + const uint8x16_t lut8 = vld1q_u8(shuffle_stats7_highbd + 128); + const uint8x16_t lut9 = vld1q_u8(shuffle_stats7_highbd + 144); + const uint8x16_t lut10 = vld1q_u8(shuffle_stats7_highbd + 160); + const uint8x16_t lut11 = vld1q_u8(shuffle_stats7_highbd + 176); + + // We can accumulate up to 65536/4096/256 8/10/12-bit multiplication results + // in 32-bit. We are processing 2 pixels at a time, so the accumulator max can + // be as high as 32768/2048/128 for the compute stats. + const int acc_cnt_max = (1 << (32 - 2 * bit_depth)) >> 1; + int acc_cnt = acc_cnt_max; + const int src_next = src_stride - width; + const int dgd_next = dgd_stride - width; + const int16x8_t avg_s16 = vdupq_n_s16(avg); + + do { + int j = width; + while (j >= 2) { + // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the + // middle 6x7 elements being shared. + int16x8_t dgd_rows[7]; + load_and_pack_s16_8x7(dgd_rows, (const int16_t *)dgd, dgd_stride); + + const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 6; + dgd += 2; + + dgd_rows[0] = vsubq_s16(dgd_rows[0], avg_s16); + dgd_rows[1] = vsubq_s16(dgd_rows[1], avg_s16); + dgd_rows[2] = vsubq_s16(dgd_rows[2], avg_s16); + dgd_rows[3] = vsubq_s16(dgd_rows[3], avg_s16); + dgd_rows[4] = vsubq_s16(dgd_rows[4], avg_s16); + dgd_rows[5] = vsubq_s16(dgd_rows[5], avg_s16); + dgd_rows[6] = vsubq_s16(dgd_rows[6], avg_s16); + + // Re-arrange the combined 8x7 matrix to have the 2 whole 7x7 matrices (1 + // for each of the 2 pixels) separated into distinct int16x8_t[6] arrays. + // These arrays contain 48 elements of the 49 (7x7). Compute `dgd - avg` + // for both buffers. Each DGD_AVG buffer contains 49 consecutive elements. + int16x8_t dgd_avg0[6]; + int16x8_t dgd_avg1[6]; + + dgd_avg0[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut0); + dgd_avg1[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut6); + dgd_avg0[1] = tbl2q(dgd_rows[1], dgd_rows[2], lut1); + dgd_avg1[1] = tbl2q(dgd_rows[1], dgd_rows[2], lut7); + dgd_avg0[2] = tbl2q(dgd_rows[2], dgd_rows[3], lut2); + dgd_avg1[2] = tbl2q(dgd_rows[2], dgd_rows[3], lut8); + dgd_avg0[3] = tbl2q(dgd_rows[3], dgd_rows[4], lut3); + dgd_avg1[3] = tbl2q(dgd_rows[3], dgd_rows[4], lut9); + dgd_avg0[4] = tbl2q(dgd_rows[4], dgd_rows[5], lut4); + dgd_avg1[4] = tbl2q(dgd_rows[4], dgd_rows[5], lut10); + dgd_avg0[5] = tbl2q(dgd_rows[5], dgd_rows[6], lut5); + dgd_avg1[5] = tbl2q(dgd_rows[5], dgd_rows[6], lut11); + + vst1q_s16(DGD_AVG0, dgd_avg0[0]); + vst1q_s16(DGD_AVG1, dgd_avg1[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]); + vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]); + vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]); + vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]); + vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]); + vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]); + vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]); + + // The remaining last (49th) elements of `dgd - avg`. + DGD_AVG0[48] = dgd_ptr[6] - avg; + DGD_AVG1[48] = dgd_ptr[7] - avg; + + // Accumulate into row-major variant of matrix M (cross-correlation) for 2 + // output pixels at a time. M is of size 7 * 7. It needs to be filled such + // that multiplying one element from src with each element of a row of the + // wiener window will fill one column of M. However this is not very + // convenient in terms of memory access, as it means we do contiguous + // loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int src_avg1 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1); + update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0], + dgd_avg1[0]); + update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1], + dgd_avg1[1]); + update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2], + dgd_avg1[2]); + update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3], + dgd_avg1[3]); + update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4], + dgd_avg1[4]); + update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5], + dgd_avg1[5]); + + // Last (49th) element of M_s32 can be computed as scalar more efficiently + // for 2 output pixels. + M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1; + + // Start accumulating into row-major version of matrix H + // (auto-covariance), it expects the DGD_AVG[01] matrices to also be + // row-major. H is of size 49 * 49. It is filled by multiplying every pair + // of elements of the wiener window together (vector outer product). Since + // it is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work with column-major matrices, + // so we accumulate into a row-major matrix H_s32. At the end of the + // algorithm a double transpose transformation will convert H_s32 back to + // the expected output layout. + update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1); + + // The last element of the triangle of H_s32 matrix can be computed as a + // scalar more efficiently. + H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += + DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48]; + + // Accumulate into 64-bit after a bit depth dependent number of iterations + // to prevent overflow. + if (--acc_cnt == 0) { + acc_cnt = acc_cnt_max; + + accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2); + + // The widening accumulation is only needed for the upper triangle part + // of the matrix. + int64_t *lh = H_s64; + int32_t *lh32 = H_s32; + for (int k = 0; k < WIENER_WIN2; ++k) { + // The widening accumulation is only run for the relevant parts + // (upper-right triangle) in a row 4-element aligned. + int k4 = k / 4 * 4; + accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4); + + // Last element of the row is computed separately. + lh[48] += lh32[48]; + lh32[48] = 0; + + lh += WIENER_WIN2_ALIGN2; + lh32 += WIENER_WIN2_ALIGN2; + } + } + + j -= 2; + } + + // Computations for odd pixel in the row. + if (width & 1) { + // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the + // middle 6x7 elements being shared. + int16x8_t dgd_rows[7]; + load_and_pack_s16_8x7(dgd_rows, (const int16_t *)dgd, dgd_stride); + + const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 6; + ++dgd; + + // Re-arrange the combined 8x7 matrix to have a whole 7x7 matrix tightly + // packed into a int16x8_t[6] array. This array contains 48 elements of + // the 49 (7x7). Compute `dgd - avg` for the whole buffer. The DGD_AVG + // buffer contains 49 consecutive elements. + int16x8_t dgd_avg0[6]; + + dgd_avg0[0] = vsubq_s16(tbl2q(dgd_rows[0], dgd_rows[1], lut0), avg_s16); + dgd_avg0[1] = vsubq_s16(tbl2q(dgd_rows[1], dgd_rows[2], lut1), avg_s16); + dgd_avg0[2] = vsubq_s16(tbl2q(dgd_rows[2], dgd_rows[3], lut2), avg_s16); + dgd_avg0[3] = vsubq_s16(tbl2q(dgd_rows[3], dgd_rows[4], lut3), avg_s16); + dgd_avg0[4] = vsubq_s16(tbl2q(dgd_rows[4], dgd_rows[5], lut4), avg_s16); + dgd_avg0[5] = vsubq_s16(tbl2q(dgd_rows[5], dgd_rows[6], lut5), avg_s16); + + vst1q_s16(DGD_AVG0, dgd_avg0[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]); + vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]); + vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]); + + // The remaining last (49th) element of `dgd - avg`. + DGD_AVG0[48] = dgd_ptr[6] - avg; + + // Accumulate into row-major order variant of matrix M (cross-correlation) + // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled + // such that multiplying one element from src with each element of a row + // of the wiener window will fill one column of M. However this is not + // very convenient in terms of memory access, as it means we do + // contiguous loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]); + update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]); + update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]); + update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]); + update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]); + update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]); + + // Last (49th) element of M_s32 can be computed as scalar more efficiently + // for 1 output pixel. + M_s32[48] += DGD_AVG0[48] * src_avg0; + + // Start accumulating into row-major order version of matrix H + // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major. + // H is of size 49 * 49. It is filled by multiplying every pair of + // elements of the wiener window together (vector outer product). Since it + // is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work column-major matrices, so we + // accumulate into a row-major matrix H_s32. At the end of the algorithm a + // double transpose transformation will convert H_s32 back to the expected + // output layout. + update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48); + + // The last element of the triangle of H_s32 matrix can be computed as + // scalar more efficiently. + H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48]; + } + + src += src_next; + dgd += dgd_next; + } while (--height != 0); + + int bit_depth_shift = bit_depth - AOM_BITS_8; + + acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, bit_depth_shift); + + update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, bit_depth_shift); +} + +// Load 5x5 matrix into 5 128-bit vectors from consecutive rows, the last load +// address is offset to prevent out-of-bounds access. +static INLINE void load_and_pack_s16_6x5(int16x8_t dst[5], const int16_t *src, + ptrdiff_t stride) { + dst[0] = vld1q_s16(src); + src += stride; + dst[1] = vld1q_s16(src); + src += stride; + dst[2] = vld1q_s16(src); + src += stride; + dst[3] = vld1q_s16(src); + src += stride; + dst[4] = vld1q_s16(src - 3); +} + +static void highbd_compute_stats_win5_neon(const uint16_t *dgd, + const uint16_t *src, int avg, + int width, int height, + int dgd_stride, int src_stride, + int64_t *M, int64_t *H, + aom_bit_depth_t bit_depth) { + // Matrix names are capitalized to help readability. + DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, + H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]); + DECLARE_ALIGNED(64, int64_t, + H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]); + + memset(M_s32, 0, sizeof(M_s32)); + memset(M_s64, 0, sizeof(M_s64)); + memset(H_s32, 0, sizeof(H_s32)); + memset(H_s64, 0, sizeof(H_s64)); + + // Look-up tables to create 8x3 matrix with consecutive elements from 5x5 + // matrix. + DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5_highbd[96]) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17, 18, 19, 20, 21, + 6, 7, 8, 9, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 32, 33, + 2, 3, 4, 5, 6, 7, 8, 9, 22, 23, 24, 25, 26, 27, 28, 29, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 34, 35, + 4, 5, 6, 7, 8, 9, 10, 11, 24, 25, 26, 27, 28, 29, 30, 31, + }; + + const uint8x16_t lut0 = vld1q_u8(shuffle_stats5_highbd + 0); + const uint8x16_t lut1 = vld1q_u8(shuffle_stats5_highbd + 16); + const uint8x16_t lut2 = vld1q_u8(shuffle_stats5_highbd + 32); + const uint8x16_t lut3 = vld1q_u8(shuffle_stats5_highbd + 48); + const uint8x16_t lut4 = vld1q_u8(shuffle_stats5_highbd + 64); + const uint8x16_t lut5 = vld1q_u8(shuffle_stats5_highbd + 80); + + // We can accumulate up to 65536/4096/256 8/10/12-bit multiplication results + // in 32-bit. We are processing 2 pixels at a time, so the accumulator max can + // be as high as 32768/2048/128 for the compute stats. + const int acc_cnt_max = (1 << (32 - 2 * bit_depth)) >> 1; + int acc_cnt = acc_cnt_max; + const int src_next = src_stride - width; + const int dgd_next = dgd_stride - width; + const int16x8_t avg_s16 = vdupq_n_s16(avg); + + do { + int j = width; + while (j >= 2) { + // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the + // middle 4x5 elements being shared. + int16x8_t dgd_rows[5]; + load_and_pack_s16_6x5(dgd_rows, (const int16_t *)dgd, dgd_stride); + + const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 4; + dgd += 2; + + dgd_rows[0] = vsubq_s16(dgd_rows[0], avg_s16); + dgd_rows[1] = vsubq_s16(dgd_rows[1], avg_s16); + dgd_rows[2] = vsubq_s16(dgd_rows[2], avg_s16); + dgd_rows[3] = vsubq_s16(dgd_rows[3], avg_s16); + dgd_rows[4] = vsubq_s16(dgd_rows[4], avg_s16); + + // Re-arrange the combined 6x5 matrix to have the 2 whole 5x5 matrices (1 + // for each of the 2 pixels) separated into distinct int16x8_t[3] arrays. + // These arrays contain 24 elements of the 25 (5x5). Compute `dgd - avg` + // for both buffers. Each DGD_AVG buffer contains 25 consecutive elements. + int16x8_t dgd_avg0[3]; + int16x8_t dgd_avg1[3]; + + dgd_avg0[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut0); + dgd_avg1[0] = tbl2q(dgd_rows[0], dgd_rows[1], lut3); + dgd_avg0[1] = tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut1); + dgd_avg1[1] = tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut4); + dgd_avg0[2] = tbl2q(dgd_rows[3], dgd_rows[4], lut2); + dgd_avg1[2] = tbl2q(dgd_rows[3], dgd_rows[4], lut5); + + vst1q_s16(DGD_AVG0, dgd_avg0[0]); + vst1q_s16(DGD_AVG1, dgd_avg1[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]); + + // The remaining last (25th) elements of `dgd - avg`. + DGD_AVG0[24] = dgd_ptr[4] - avg; + DGD_AVG1[24] = dgd_ptr[5] - avg; + + // Accumulate into row-major variant of matrix M (cross-correlation) for 2 + // output pixels at a time. M is of size 5 * 5. It needs to be filled such + // that multiplying one element from src with each element of a row of the + // wiener window will fill one column of M. However this is not very + // convenient in terms of memory access, as it means we do contiguous + // loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int src_avg1 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1); + update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0], + dgd_avg1[0]); + update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1], + dgd_avg1[1]); + update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2], + dgd_avg1[2]); + + // Last (25th) element of M_s32 can be computed as scalar more efficiently + // for 2 output pixels. + M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1; + + // Start accumulating into row-major version of matrix H + // (auto-covariance), it expects the DGD_AVG[01] matrices to also be + // row-major. H is of size 25 * 25. It is filled by multiplying every pair + // of elements of the wiener window together (vector outer product). Since + // it is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work with column-major matrices, + // so we accumulate into a row-major matrix H_s32. At the end of the + // algorithm a double transpose transformation will convert H_s32 back to + // the expected output layout. + update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1); + + // The last element of the triangle of H_s32 matrix can be computed as a + // scalar more efficiently. + H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] += + DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24]; + + // Accumulate into 64-bit after a bit depth dependent number of iterations + // to prevent overflow. + if (--acc_cnt == 0) { + acc_cnt = acc_cnt_max; + + accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2); + + // The widening accumulation is only needed for the upper triangle part + // of the matrix. + int64_t *lh = H_s64; + int32_t *lh32 = H_s32; + for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) { + // The widening accumulation is only run for the relevant parts + // (upper-right triangle) in a row 4-element aligned. + int k4 = k / 4 * 4; + accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4); + + // Last element of the row is computed separately. + lh[24] += lh32[24]; + lh32[24] = 0; + + lh += WIENER_WIN2_REDUCED_ALIGN2; + lh32 += WIENER_WIN2_REDUCED_ALIGN2; + } + } + + j -= 2; + } + + // Computations for odd pixel in the row. + if (width & 1) { + // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the + // middle 4x5 elements being shared. + int16x8_t dgd_rows[5]; + load_and_pack_s16_6x5(dgd_rows, (const int16_t *)dgd, dgd_stride); + + const int16_t *dgd_ptr = (const int16_t *)dgd + dgd_stride * 4; + ++dgd; + + // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5 + // matrix tightly packed into a int16x8_t[3] array. This array contains + // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer. + // The DGD_AVG buffer contains 25 consecutive elements. + int16x8_t dgd_avg0[3]; + + dgd_avg0[0] = vsubq_s16(tbl2q(dgd_rows[0], dgd_rows[1], lut0), avg_s16); + dgd_avg0[1] = vsubq_s16( + tbl3q(dgd_rows[1], dgd_rows[2], dgd_rows[3], lut1), avg_s16); + dgd_avg0[2] = vsubq_s16(tbl2q(dgd_rows[3], dgd_rows[4], lut2), avg_s16); + + vst1q_s16(DGD_AVG0, dgd_avg0[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + + // The remaining last (25th) element of `dgd - avg`. + DGD_AVG0[24] = dgd_ptr[4] - avg; + DGD_AVG1[24] = dgd_ptr[5] - avg; + + // Accumulate into row-major order variant of matrix M (cross-correlation) + // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled + // such that multiplying one element from src with each element of a row + // of the wiener window will fill one column of M. However this is not + // very convenient in terms of memory access, as it means we do + // contiguous loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]); + update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]); + update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]); + + // Last (25th) element of M_s32 can be computed as scalar more efficiently + // for 1 output pixel. + M_s32[24] += DGD_AVG0[24] * src_avg0; + + // Start accumulating into row-major order version of matrix H + // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major. + // H is of size 25 * 25. It is filled by multiplying every pair of + // elements of the wiener window together (vector outer product). Since it + // is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work with column-major matrices, + // so we accumulate into a row-major matrix H_s32. At the end of the + // algorithm a double transpose transformation will convert H_s32 back to + // the expected output layout. + update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24); + + // The last element of the triangle of H_s32 matrix can be computed as a + // scalar more efficiently. + H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] += + DGD_AVG0[24] * DGD_AVG0[24]; + } + + src += src_next; + dgd += dgd_next; + } while (--height != 0); + + int bit_depth_shift = bit_depth - AOM_BITS_8; + + acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, bit_depth_shift); + + update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2, + bit_depth_shift); +} + +static uint16_t highbd_find_average_neon(const uint16_t *src, int src_stride, + int width, int height) { + assert(width > 0); + assert(height > 0); + + uint64x2_t sum_u64 = vdupq_n_u64(0); + uint64_t sum = 0; + + int h = height; + do { + uint32x4_t sum_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; + + int w = width; + const uint16_t *row = src; + while (w >= 32) { + uint16x8_t s0 = vld1q_u16(row + 0); + uint16x8_t s1 = vld1q_u16(row + 8); + uint16x8_t s2 = vld1q_u16(row + 16); + uint16x8_t s3 = vld1q_u16(row + 24); + + s0 = vaddq_u16(s0, s1); + s2 = vaddq_u16(s2, s3); + sum_u32[0] = vpadalq_u16(sum_u32[0], s0); + sum_u32[1] = vpadalq_u16(sum_u32[1], s2); + + row += 32; + w -= 32; + } + + if (w >= 16) { + uint16x8_t s0 = vld1q_u16(row + 0); + uint16x8_t s1 = vld1q_u16(row + 8); + + s0 = vaddq_u16(s0, s1); + sum_u32[0] = vpadalq_u16(sum_u32[0], s0); + + row += 16; + w -= 16; + } + + if (w >= 8) { + uint16x8_t s0 = vld1q_u16(row); + sum_u32[1] = vpadalq_u16(sum_u32[1], s0); + + row += 8; + w -= 8; + } + + if (w >= 4) { + uint16x8_t s0 = vcombine_u16(vld1_u16(row), vdup_n_u16(0)); + sum_u32[0] = vpadalq_u16(sum_u32[0], s0); + + row += 4; + w -= 4; + } + + while (w-- > 0) { + sum += *row++; + } + + sum_u64 = vpadalq_u32(sum_u64, vaddq_u32(sum_u32[0], sum_u32[1])); + + src += src_stride; + } while (--h != 0); + + return (uint16_t)((horizontal_add_u64x2(sum_u64) + sum) / (height * width)); +} + +void av1_compute_stats_highbd_neon(int wiener_win, const uint8_t *dgd8, + const uint8_t *src8, int h_start, int h_end, + int v_start, int v_end, int dgd_stride, + int src_stride, int64_t *M, int64_t *H, + aom_bit_depth_t bit_depth) { + assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_REDUCED); + + const int wiener_halfwin = wiener_win >> 1; + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8); + const int height = v_end - v_start; + const int width = h_end - h_start; + + const uint16_t *dgd_start = dgd + h_start + v_start * dgd_stride; + const uint16_t *src_start = src + h_start + v_start * src_stride; + + // The wiener window will slide along the dgd frame, centered on each pixel. + // For the top left pixel and all the pixels on the side of the frame this + // means half of the window will be outside of the frame. As such the actual + // buffer that we need to subtract the avg from will be 2 * wiener_halfwin + // wider and 2 * wiener_halfwin higher than the original dgd buffer. + const int vert_offset = v_start - wiener_halfwin; + const int horiz_offset = h_start - wiener_halfwin; + const uint16_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride; + + uint16_t avg = highbd_find_average_neon(dgd_start, dgd_stride, width, height); + + if (wiener_win == WIENER_WIN) { + highbd_compute_stats_win7_neon(dgd_win, src_start, avg, width, height, + dgd_stride, src_stride, M, H, bit_depth); + } else { + highbd_compute_stats_win5_neon(dgd_win, src_start, avg, width, height, + dgd_stride, src_stride, M, H, bit_depth); + } +} + +int64_t av1_highbd_pixel_proj_error_neon( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) { + const uint16_t *src = CONVERT_TO_SHORTPTR(src8); + const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8); + int64_t sse = 0; + int64x2_t sse_s64 = vdupq_n_s64(0); + + if (params->r[0] > 0 && params->r[1] > 0) { + int32x2_t xq_v = vld1_s32(xq); + int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), 4); + + do { + int j = 0; + int32x4_t sse_s32 = vdupq_n_s32(0); + + do { + const uint16x8_t d = vld1q_u16(&dat[j]); + const uint16x8_t s = vld1q_u16(&src[j]); + int32x4_t flt0_0 = vld1q_s32(&flt0[j]); + int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]); + int32x4_t flt1_0 = vld1q_s32(&flt1[j]); + int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]); + + int32x4_t d_s32_lo = vreinterpretq_s32_u32( + vmull_lane_u16(vget_low_u16(d), vreinterpret_u16_s32(xq_sum_v), 0)); + int32x4_t d_s32_hi = vreinterpretq_s32_u32(vmull_lane_u16( + vget_high_u16(d), vreinterpret_u16_s32(xq_sum_v), 0)); + + int32x4_t v0 = vsubq_s32( + vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), + d_s32_lo); + int32x4_t v1 = vsubq_s32( + vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), + d_s32_hi); + + v0 = vmlaq_lane_s32(v0, flt0_0, xq_v, 0); + v1 = vmlaq_lane_s32(v1, flt0_1, xq_v, 0); + v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1); + v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1); + + int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + + int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), + vreinterpretq_s16_u16(vsubq_u16(d, s))); + int16x4_t e_lo = vget_low_s16(e); + int16x4_t e_hi = vget_high_s16(e); + + sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo); + sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi); + + j += 8; + } while (j <= width - 8); + + for (int k = j; k < width; ++k) { + int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1); + v += xq[0] * (flt0[k]) + xq[1] * (flt1[k]); + v -= (xq[1] + xq[0]) * (int32_t)(dat[k] << 4); + int32_t e = + (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k]; + sse += ((int64_t)e * e); + } + + sse_s64 = vpadalq_s32(sse_s64, sse_s32); + + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + } while (--height != 0); + } else if (params->r[0] > 0 || params->r[1] > 0) { + int xq_active = (params->r[0] > 0) ? xq[0] : xq[1]; + int32_t *flt = (params->r[0] > 0) ? flt0 : flt1; + int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride; + int32x4_t xq_v = vdupq_n_s32(xq_active); + + do { + int j = 0; + int32x4_t sse_s32 = vdupq_n_s32(0); + do { + const uint16x8_t d0 = vld1q_u16(&dat[j]); + const uint16x8_t s0 = vld1q_u16(&src[j]); + int32x4_t flt0_0 = vld1q_s32(&flt[j]); + int32x4_t flt0_1 = vld1q_s32(&flt[j + 4]); + + uint16x8_t d_u16 = vshlq_n_u16(d0, 4); + int32x4_t sub0 = vreinterpretq_s32_u32( + vsubw_u16(vreinterpretq_u32_s32(flt0_0), vget_low_u16(d_u16))); + int32x4_t sub1 = vreinterpretq_s32_u32( + vsubw_u16(vreinterpretq_u32_s32(flt0_1), vget_high_u16(d_u16))); + + int32x4_t v0 = vmlaq_s32( + vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub0, + xq_v); + int32x4_t v1 = vmlaq_s32( + vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)), sub1, + xq_v); + + int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + + int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), + vreinterpretq_s16_u16(vsubq_u16(d0, s0))); + int16x4_t e_lo = vget_low_s16(e); + int16x4_t e_hi = vget_high_s16(e); + + sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo); + sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi); + + j += 8; + } while (j <= width - 8); + + for (int k = j; k < width; ++k) { + int32_t v = 1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1); + v += xq_active * (int32_t)((uint32_t)flt[j] - (uint16_t)(dat[k] << 4)); + const int32_t e = + (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k]; + sse += ((int64_t)e * e); + } + + sse_s64 = vpadalq_s32(sse_s64, sse_s32); + + dat += dat_stride; + flt += flt_stride; + src += src_stride; + } while (--height != 0); + } else { + do { + int j = 0; + + do { + const uint16x8_t d = vld1q_u16(&dat[j]); + const uint16x8_t s = vld1q_u16(&src[j]); + + uint16x8_t diff = vabdq_u16(d, s); + uint16x4_t diff_lo = vget_low_u16(diff); + uint16x4_t diff_hi = vget_high_u16(diff); + + uint32x4_t sqr_lo = vmull_u16(diff_lo, diff_lo); + uint32x4_t sqr_hi = vmull_u16(diff_hi, diff_hi); + + sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_lo)); + sse_s64 = vpadalq_s32(sse_s64, vreinterpretq_s32_u32(sqr_hi)); + + j += 8; + } while (j <= width - 8); + + for (int k = j; k < width; ++k) { + int32_t e = dat[k] - src[k]; + sse += e * e; + } + + dat += dat_stride; + src += src_stride; + } while (--height != 0); + } + + sse += horizontal_add_s64x2(sse_s64); + return sse; +} diff --git a/third_party/aom/av1/encoder/arm/neon/highbd_rdopt_neon.c b/third_party/aom/av1/encoder/arm/neon/highbd_rdopt_neon.c new file mode 100644 index 0000000000..4bf7ae6ce4 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/highbd_rdopt_neon.c @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <assert.h> +#include <arm_neon.h> + +#include "config/av1_rtcd.h" +#include "aom_dsp/arm/sum_neon.h" + +int64_t av1_highbd_block_error_neon(const tran_low_t *coeff, + const tran_low_t *dqcoeff, + intptr_t block_size, int64_t *ssz, int bd) { + uint64x2_t err_u64 = vdupq_n_u64(0); + int64x2_t ssz_s64 = vdupq_n_s64(0); + + const int shift = 2 * (bd - 8); + const int rounding = shift > 0 ? 1 << (shift - 1) : 0; + + assert(block_size >= 16); + assert((block_size % 16) == 0); + + do { + const int32x4_t c = vld1q_s32(coeff); + const int32x4_t d = vld1q_s32(dqcoeff); + + const uint32x4_t diff = vreinterpretq_u32_s32(vabdq_s32(c, d)); + + err_u64 = vmlal_u32(err_u64, vget_low_u32(diff), vget_low_u32(diff)); + err_u64 = vmlal_u32(err_u64, vget_high_u32(diff), vget_high_u32(diff)); + + ssz_s64 = vmlal_s32(ssz_s64, vget_low_s32(c), vget_low_s32(c)); + ssz_s64 = vmlal_s32(ssz_s64, vget_high_s32(c), vget_high_s32(c)); + + coeff += 4; + dqcoeff += 4; + block_size -= 4; + } while (block_size != 0); + + *ssz = (horizontal_add_s64x2(ssz_s64) + rounding) >> shift; + return ((int64_t)horizontal_add_u64x2(err_u64) + rounding) >> shift; +} diff --git a/third_party/aom/av1/encoder/arm/neon/highbd_temporal_filter_neon.c b/third_party/aom/av1/encoder/arm/neon/highbd_temporal_filter_neon.c new file mode 100644 index 0000000000..88e176f56c --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/highbd_temporal_filter_neon.c @@ -0,0 +1,562 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" +#include "av1/encoder/encoder.h" +#include "av1/encoder/temporal_filter.h" +#include "aom_dsp/mathutils.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +static INLINE void get_squared_error( + const uint16_t *frame1, const uint32_t stride1, const uint16_t *frame2, + const uint32_t stride2, const uint32_t block_width, + const uint32_t block_height, uint32_t *frame_sse, + const unsigned int dst_stride) { + uint32_t *dst = frame_sse; + + uint32_t i = 0; + do { + uint32_t j = 0; + do { + uint16x8_t s = vld1q_u16(frame1 + i * stride1 + j); + uint16x8_t r = vld1q_u16(frame2 + i * stride2 + j); + + uint16x8_t abs_diff = vabdq_u16(s, r); + uint32x4_t sse_lo = + vmull_u16(vget_low_u16(abs_diff), vget_low_u16(abs_diff)); + uint32x4_t sse_hi = + vmull_u16(vget_high_u16(abs_diff), vget_high_u16(abs_diff)); + + vst1q_u32(dst + j, sse_lo); + vst1q_u32(dst + j + 4, sse_hi); + + j += 8; + } while (j < block_width); + + dst += dst_stride; + i++; + } while (i < block_height); +} + +static uint32_t sum_kernel5x5_mask_single(const uint32x4_t vsrc[5][2], + const uint32x4_t mask_single) { + uint32x4_t vsums = vmulq_u32(vsrc[0][0], mask_single); + vsums = vmlaq_u32(vsums, vsrc[1][0], mask_single); + vsums = vmlaq_u32(vsums, vsrc[2][0], mask_single); + vsums = vmlaq_u32(vsums, vsrc[3][0], mask_single); + vsums = vmlaq_u32(vsums, vsrc[4][0], mask_single); + return horizontal_add_u32x4(vsums); +} + +static uint32x4_t sum_kernel5x5_mask_double(const uint32x4_t vsrc[5][2], + const uint32x4_t mask1, + const uint32x4_t mask2) { + uint32x4_t vsums = vmulq_u32(vsrc[0][0], mask1); + vsums = vmlaq_u32(vsums, vsrc[1][0], mask1); + vsums = vmlaq_u32(vsums, vsrc[2][0], mask1); + vsums = vmlaq_u32(vsums, vsrc[3][0], mask1); + vsums = vmlaq_u32(vsums, vsrc[4][0], mask1); + vsums = vmlaq_u32(vsums, vsrc[0][1], mask2); + vsums = vmlaq_u32(vsums, vsrc[1][1], mask2); + vsums = vmlaq_u32(vsums, vsrc[2][1], mask2); + vsums = vmlaq_u32(vsums, vsrc[3][1], mask2); + vsums = vmlaq_u32(vsums, vsrc[4][1], mask2); + return vsums; +} + +static void highbd_apply_temporal_filter( + const uint16_t *frame, const unsigned int stride, + const uint32_t block_width, const uint32_t block_height, + const int *subblock_mses, unsigned int *accumulator, uint16_t *count, + const uint32_t *frame_sse, const uint32_t frame_sse_stride, + const uint32_t *luma_sse_sum, const double inv_num_ref_pixels, + const double decay_factor, const double inv_factor, + const double weight_factor, const double *d_factor, int tf_wgt_calc_lvl, + int bd) { + assert(((block_width == 16) || (block_width == 32)) && + ((block_height == 16) || (block_height == 32))); + + uint32_t acc_5x5_neon[BH][BW] = { 0 }; + const int half_window = TF_WINDOW_LENGTH >> 1; + + uint32x4_t vsrc[5][2] = { 0 }; + const uint32x4_t k0000 = vdupq_n_u32(0); + const uint32x4_t k1111 = vdupq_n_u32(1); + const uint32_t k3110_u32[4] = { 0, 1, 1, 3 }; + const uint32_t k2111_u32[4] = { 1, 1, 1, 2 }; + const uint32_t k1112_u32[4] = { 2, 1, 1, 1 }; + const uint32_t k0113_u32[4] = { 3, 1, 1, 0 }; + const uint32x4_t k3110 = vld1q_u32(k3110_u32); + const uint32x4_t k2111 = vld1q_u32(k2111_u32); + const uint32x4_t k1112 = vld1q_u32(k1112_u32); + const uint32x4_t k0113 = vld1q_u32(k0113_u32); + + uint32x4_t vmask1[4], vmask2[4]; + vmask1[0] = k1111; + vmask2[0] = vextq_u32(k1111, k0000, 3); + vmask1[1] = vextq_u32(k0000, k1111, 3); + vmask2[1] = vextq_u32(k1111, k0000, 2); + vmask1[2] = vextq_u32(k0000, k1111, 2); + vmask2[2] = vextq_u32(k1111, k0000, 1); + vmask1[3] = vextq_u32(k0000, k1111, 1); + vmask2[3] = k1111; + + uint32_t row = 0; + do { + uint32_t col = 0; + const uint32_t *src = frame_sse + row * frame_sse_stride; + if (row == 0) { + vsrc[2][0] = vld1q_u32(src); + vsrc[3][0] = vld1q_u32(src + frame_sse_stride); + vsrc[4][0] = vld1q_u32(src + 2 * frame_sse_stride); + + // First 2 rows of the 5x5 matrix are padded from the 1st. + vsrc[0][0] = vsrc[2][0]; + vsrc[1][0] = vsrc[2][0]; + } else if (row == 1) { + vsrc[1][0] = vld1q_u32(src - frame_sse_stride); + vsrc[2][0] = vld1q_u32(src); + vsrc[3][0] = vld1q_u32(src + frame_sse_stride); + vsrc[4][0] = vld1q_u32(src + 2 * frame_sse_stride); + + // First row of the 5x5 matrix are padded from the 1st. + vsrc[0][0] = vsrc[1][0]; + } else if (row == block_height - 2) { + vsrc[0][0] = vld1q_u32(src - 2 * frame_sse_stride); + vsrc[1][0] = vld1q_u32(src - frame_sse_stride); + vsrc[2][0] = vld1q_u32(src); + vsrc[3][0] = vld1q_u32(src + frame_sse_stride); + + // Last row of the 5x5 matrix are padded from the one before. + vsrc[4][0] = vsrc[3][0]; + } else if (row == block_height - 1) { + vsrc[0][0] = vld1q_u32(src - 2 * frame_sse_stride); + vsrc[1][0] = vld1q_u32(src - frame_sse_stride); + vsrc[2][0] = vld1q_u32(src); + + // Last 2 rows of the 5x5 matrix are padded from the 3rd. + vsrc[3][0] = vsrc[2][0]; + vsrc[4][0] = vsrc[2][0]; + } else { + vsrc[0][0] = vld1q_u32(src - 2 * frame_sse_stride); + vsrc[1][0] = vld1q_u32(src - frame_sse_stride); + vsrc[2][0] = vld1q_u32(src); + vsrc[3][0] = vld1q_u32(src + frame_sse_stride); + vsrc[4][0] = vld1q_u32(src + 2 * frame_sse_stride); + } + + acc_5x5_neon[row][0] = sum_kernel5x5_mask_single(vsrc, k0113); + acc_5x5_neon[row][1] = sum_kernel5x5_mask_single(vsrc, k1112); + + col += 4; + src += 4; + // Traverse 4 columns at a time + do { + if (row == 0) { + vsrc[2][1] = vld1q_u32(src); + vsrc[3][1] = vld1q_u32(src + frame_sse_stride); + vsrc[4][1] = vld1q_u32(src + 2 * frame_sse_stride); + + // First 2 rows of the 5x5 matrix are padded from the 1st. + vsrc[0][1] = vsrc[2][1]; + vsrc[1][1] = vsrc[2][1]; + } else if (row == 1) { + vsrc[1][1] = vld1q_u32(src - frame_sse_stride); + vsrc[2][1] = vld1q_u32(src); + vsrc[3][1] = vld1q_u32(src + frame_sse_stride); + vsrc[4][1] = vld1q_u32(src + 2 * frame_sse_stride); + + // First row of the 5x5 matrix are padded from the 1st. + vsrc[0][1] = vsrc[1][1]; + } else if (row == block_height - 2) { + vsrc[0][1] = vld1q_u32(src - 2 * frame_sse_stride); + vsrc[1][1] = vld1q_u32(src - frame_sse_stride); + vsrc[2][1] = vld1q_u32(src); + vsrc[3][1] = vld1q_u32(src + frame_sse_stride); + + // Last row of the 5x5 matrix are padded from the one before. + vsrc[4][1] = vsrc[3][1]; + } else if (row == block_height - 1) { + vsrc[0][1] = vld1q_u32(src - 2 * frame_sse_stride); + vsrc[1][1] = vld1q_u32(src - frame_sse_stride); + vsrc[2][1] = vld1q_u32(src); + + // Last 2 rows of the 5x5 matrix are padded from the 3rd. + vsrc[3][1] = vsrc[2][1]; + vsrc[4][1] = vsrc[2][1]; + } else { + vsrc[0][1] = vld1q_u32(src - 2 * frame_sse_stride); + vsrc[1][1] = vld1q_u32(src - frame_sse_stride); + vsrc[2][1] = vld1q_u32(src); + vsrc[3][1] = vld1q_u32(src + frame_sse_stride); + vsrc[4][1] = vld1q_u32(src + 2 * frame_sse_stride); + } + + uint32x4_t sums[4]; + sums[0] = sum_kernel5x5_mask_double(vsrc, vmask1[0], vmask2[0]); + sums[1] = sum_kernel5x5_mask_double(vsrc, vmask1[1], vmask2[1]); + sums[2] = sum_kernel5x5_mask_double(vsrc, vmask1[2], vmask2[2]); + sums[3] = sum_kernel5x5_mask_double(vsrc, vmask1[3], vmask2[3]); + vst1q_u32(&acc_5x5_neon[row][col - half_window], + horizontal_add_4d_u32x4(sums)); + + vsrc[0][0] = vsrc[0][1]; + vsrc[1][0] = vsrc[1][1]; + vsrc[2][0] = vsrc[2][1]; + vsrc[3][0] = vsrc[3][1]; + vsrc[4][0] = vsrc[4][1]; + + src += 4; + col += 4; + } while (col <= block_width - 4); + + acc_5x5_neon[row][col - half_window] = + sum_kernel5x5_mask_single(vsrc, k2111); + acc_5x5_neon[row][col - half_window + 1] = + sum_kernel5x5_mask_single(vsrc, k3110); + + row++; + } while (row < block_height); + + // Perform filtering. + if (tf_wgt_calc_lvl == 0) { + for (unsigned int i = 0, k = 0; i < block_height; i++) { + for (unsigned int j = 0; j < block_width; j++, k++) { + const int pixel_value = frame[i * stride + j]; + // Scale down the difference for high bit depth input. + const uint32_t diff_sse = + (acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j]) >> ((bd - 8) * 2); + + const double window_error = diff_sse * inv_num_ref_pixels; + const int subblock_idx = + (i >= block_height / 2) * 2 + (j >= block_width / 2); + const double block_error = (double)subblock_mses[subblock_idx]; + const double combined_error = + weight_factor * window_error + block_error * inv_factor; + // Compute filter weight. + double scaled_error = + combined_error * d_factor[subblock_idx] * decay_factor; + scaled_error = AOMMIN(scaled_error, 7); + const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE); + accumulator[k] += weight * pixel_value; + count[k] += weight; + } + } + } else { + for (unsigned int i = 0, k = 0; i < block_height; i++) { + for (unsigned int j = 0; j < block_width; j++, k++) { + const int pixel_value = frame[i * stride + j]; + // Scale down the difference for high bit depth input. + const uint32_t diff_sse = + (acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j]) >> ((bd - 8) * 2); + + const double window_error = diff_sse * inv_num_ref_pixels; + const int subblock_idx = + (i >= block_height / 2) * 2 + (j >= block_width / 2); + const double block_error = (double)subblock_mses[subblock_idx]; + const double combined_error = + weight_factor * window_error + block_error * inv_factor; + // Compute filter weight. + double scaled_error = + combined_error * d_factor[subblock_idx] * decay_factor; + scaled_error = AOMMIN(scaled_error, 7); + const float fweight = + approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE; + const int weight = iroundpf(fweight); + accumulator[k] += weight * pixel_value; + count[k] += weight; + } + } + } +} + +void av1_highbd_apply_temporal_filter_neon( + const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd, + const BLOCK_SIZE block_size, const int mb_row, const int mb_col, + const int num_planes, const double *noise_levels, const MV *subblock_mvs, + const int *subblock_mses, const int q_factor, const int filter_strength, + int tf_wgt_calc_lvl, const uint8_t *pred8, uint32_t *accum, + uint16_t *count) { + const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH; + assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with Neon!"); + assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE); + (void)is_high_bitdepth; + assert(is_high_bitdepth); + + // Block information. + const int mb_height = block_size_high[block_size]; + const int mb_width = block_size_wide[block_size]; + // Frame information. + const int frame_height = frame_to_filter->y_crop_height; + const int frame_width = frame_to_filter->y_crop_width; + const int min_frame_size = AOMMIN(frame_height, frame_width); + // Variables to simplify combined error calculation. + const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) * + TF_SEARCH_ERROR_NORM_WEIGHT); + const double weight_factor = + (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor; + // Adjust filtering based on q. + // Larger q -> stronger filtering -> larger weight. + // Smaller q -> weaker filtering -> smaller weight. + double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2); + q_decay = CLIP(q_decay, 1e-5, 1); + if (q_factor >= TF_QINDEX_CUTOFF) { + // Max q_factor is 255, therefore the upper bound of q_decay is 8. + // We do not need a clip here. + q_decay = 0.5 * pow((double)q_factor / 64, 2); + } + // Smaller strength -> smaller filtering weight. + double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2); + s_decay = CLIP(s_decay, 1e-5, 1); + double d_factor[4] = { 0 }; + uint32_t frame_sse[BW * BH] = { 0 }; + uint32_t luma_sse_sum[BW * BH] = { 0 }; + uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); + + for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) { + // Larger motion vector -> smaller filtering weight. + const MV mv = subblock_mvs[subblock_idx]; + const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2)); + double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD; + distance_threshold = AOMMAX(distance_threshold, 1); + d_factor[subblock_idx] = distance / distance_threshold; + d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1); + } + + // Handle planes in sequence. + int plane_offset = 0; + for (int plane = 0; plane < num_planes; ++plane) { + const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y; + const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x; + const uint32_t frame_stride = + frame_to_filter->strides[plane == AOM_PLANE_Y ? 0 : 1]; + const uint32_t frame_sse_stride = plane_w; + const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w; + + const uint16_t *ref = + CONVERT_TO_SHORTPTR(frame_to_filter->buffers[plane]) + frame_offset; + const int ss_x_shift = + mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x; + const int ss_y_shift = + mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y; + const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH + + ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0); + const double inv_num_ref_pixels = 1.0 / num_ref_pixels; + // Larger noise -> larger filtering weight. + const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0); + // Decay factors for non-local mean approach. + const double decay_factor = 1 / (n_decay * q_decay * s_decay); + + // Filter U-plane and V-plane using Y-plane. This is because motion + // search is only done on Y-plane, so the information from Y-plane + // will be more accurate. The luma sse sum is reused in both chroma + // planes. + if (plane == AOM_PLANE_U) { + for (unsigned int i = 0; i < plane_h; i++) { + for (unsigned int j = 0; j < plane_w; j++) { + for (int ii = 0; ii < (1 << ss_y_shift); ++ii) { + for (int jj = 0; jj < (1 << ss_x_shift); ++jj) { + const int yy = (i << ss_y_shift) + ii; // Y-coord on Y-plane. + const int xx = (j << ss_x_shift) + jj; // X-coord on Y-plane. + const int ww = frame_sse_stride + << ss_x_shift; // Width of Y-plane. + luma_sse_sum[i * BW + j] += frame_sse[yy * ww + xx]; + } + } + } + } + } + get_squared_error(ref, frame_stride, pred + plane_offset, plane_w, plane_w, + plane_h, frame_sse, frame_sse_stride); + + highbd_apply_temporal_filter( + pred + plane_offset, plane_w, plane_w, plane_h, subblock_mses, + accum + plane_offset, count + plane_offset, frame_sse, frame_sse_stride, + luma_sse_sum, inv_num_ref_pixels, decay_factor, inv_factor, + weight_factor, d_factor, tf_wgt_calc_lvl, mbd->bd); + + plane_offset += plane_h * plane_w; + } +} + +double av1_highbd_estimate_noise_from_single_plane_neon(const uint16_t *src, + int height, int width, + int stride, + int bitdepth, + int edge_thresh) { + uint16x8_t thresh = vdupq_n_u16(edge_thresh); + uint64x2_t acc = vdupq_n_u64(0); + // Count is in theory positive as it counts the number of times we're under + // the threshold, but it will be counted negatively in order to make best use + // of the vclt instruction, which sets every bit of a lane to 1 when the + // condition is true. + int32x4_t count = vdupq_n_s32(0); + int final_count = 0; + uint64_t final_acc = 0; + const uint16_t *src_start = src + stride + 1; + int h = 1; + + do { + int w = 1; + const uint16_t *src_ptr = src_start; + + while (w <= (width - 1) - 8) { + uint16x8_t mat[3][3]; + mat[0][0] = vld1q_u16(src_ptr - stride - 1); + mat[0][1] = vld1q_u16(src_ptr - stride); + mat[0][2] = vld1q_u16(src_ptr - stride + 1); + mat[1][0] = vld1q_u16(src_ptr - 1); + mat[1][1] = vld1q_u16(src_ptr); + mat[1][2] = vld1q_u16(src_ptr + 1); + mat[2][0] = vld1q_u16(src_ptr + stride - 1); + mat[2][1] = vld1q_u16(src_ptr + stride); + mat[2][2] = vld1q_u16(src_ptr + stride + 1); + + // Compute Sobel gradients. + uint16x8_t gxa = vaddq_u16(mat[0][0], mat[2][0]); + uint16x8_t gxb = vaddq_u16(mat[0][2], mat[2][2]); + gxa = vaddq_u16(gxa, vaddq_u16(mat[1][0], mat[1][0])); + gxb = vaddq_u16(gxb, vaddq_u16(mat[1][2], mat[1][2])); + + uint16x8_t gya = vaddq_u16(mat[0][0], mat[0][2]); + uint16x8_t gyb = vaddq_u16(mat[2][0], mat[2][2]); + gya = vaddq_u16(gya, vaddq_u16(mat[0][1], mat[0][1])); + gyb = vaddq_u16(gyb, vaddq_u16(mat[2][1], mat[2][1])); + + uint16x8_t ga = vabaq_u16(vabdq_u16(gxa, gxb), gya, gyb); + ga = vrshlq_u16(ga, vdupq_n_s16(8 - bitdepth)); + + // Check which vector elements are under the threshold. The Laplacian is + // then unconditionnally computed and we accumulate zeros if we're not + // under the threshold. This is much faster than using an if statement. + uint16x8_t thresh_u16 = vcltq_u16(ga, thresh); + + uint16x8_t center = vshlq_n_u16(mat[1][1], 2); + + uint16x8_t adj0 = vaddq_u16(mat[0][1], mat[2][1]); + uint16x8_t adj1 = vaddq_u16(mat[1][0], mat[1][2]); + uint16x8_t adj = vaddq_u16(adj0, adj1); + adj = vaddq_u16(adj, adj); + + uint16x8_t diag0 = vaddq_u16(mat[0][0], mat[0][2]); + uint16x8_t diag1 = vaddq_u16(mat[2][0], mat[2][2]); + uint16x8_t diag = vaddq_u16(diag0, diag1); + + uint16x8_t v = vabdq_u16(vaddq_u16(center, diag), adj); + v = vandq_u16(vrshlq_u16(v, vdupq_n_s16(8 - bitdepth)), thresh_u16); + uint32x4_t v_u32 = vpaddlq_u16(v); + + acc = vpadalq_u32(acc, v_u32); + // Add -1 for each lane where the gradient is under the threshold. + count = vpadalq_s16(count, vreinterpretq_s16_u16(thresh_u16)); + + w += 8; + src_ptr += 8; + } + + if (w <= (width - 1) - 4) { + uint16x4_t mat[3][3]; + mat[0][0] = vld1_u16(src_ptr - stride - 1); + mat[0][1] = vld1_u16(src_ptr - stride); + mat[0][2] = vld1_u16(src_ptr - stride + 1); + mat[1][0] = vld1_u16(src_ptr - 1); + mat[1][1] = vld1_u16(src_ptr); + mat[1][2] = vld1_u16(src_ptr + 1); + mat[2][0] = vld1_u16(src_ptr + stride - 1); + mat[2][1] = vld1_u16(src_ptr + stride); + mat[2][2] = vld1_u16(src_ptr + stride + 1); + + // Compute Sobel gradients. + uint16x4_t gxa = vadd_u16(mat[0][0], mat[2][0]); + uint16x4_t gxb = vadd_u16(mat[0][2], mat[2][2]); + gxa = vadd_u16(gxa, vadd_u16(mat[1][0], mat[1][0])); + gxb = vadd_u16(gxb, vadd_u16(mat[1][2], mat[1][2])); + + uint16x4_t gya = vadd_u16(mat[0][0], mat[0][2]); + uint16x4_t gyb = vadd_u16(mat[2][0], mat[2][2]); + gya = vadd_u16(gya, vadd_u16(mat[0][1], mat[0][1])); + gyb = vadd_u16(gyb, vadd_u16(mat[2][1], mat[2][1])); + + uint16x4_t ga = vaba_u16(vabd_u16(gxa, gxb), gya, gyb); + ga = vrshl_u16(ga, vdup_n_s16(8 - bitdepth)); + + // Check which vector elements are under the threshold. The Laplacian is + // then unconditionnally computed and we accumulate zeros if we're not + // under the threshold. This is much faster than using an if statement. + uint16x4_t thresh_u16 = vclt_u16(ga, vget_low_u16(thresh)); + + uint16x4_t center = vshl_n_u16(mat[1][1], 2); + + uint16x4_t adj0 = vadd_u16(mat[0][1], mat[2][1]); + uint16x4_t adj1 = vadd_u16(mat[1][0], mat[1][2]); + uint16x4_t adj = vadd_u16(adj0, adj1); + adj = vadd_u16(adj, adj); + + uint16x4_t diag0 = vadd_u16(mat[0][0], mat[0][2]); + uint16x4_t diag1 = vadd_u16(mat[2][0], mat[2][2]); + uint16x4_t diag = vadd_u16(diag0, diag1); + + uint16x4_t v = vabd_u16(vadd_u16(center, diag), adj); + v = vand_u16(v, thresh_u16); + uint32x4_t v_u32 = vmovl_u16(vrshl_u16(v, vdup_n_s16(8 - bitdepth))); + + acc = vpadalq_u32(acc, v_u32); + // Add -1 for each lane where the gradient is under the threshold. + count = vaddw_s16(count, vreinterpret_s16_u16(thresh_u16)); + + w += 4; + src_ptr += 4; + } + + while (w < width - 1) { + int mat[3][3]; + mat[0][0] = *(src_ptr - stride - 1); + mat[0][1] = *(src_ptr - stride); + mat[0][2] = *(src_ptr - stride + 1); + mat[1][0] = *(src_ptr - 1); + mat[1][1] = *(src_ptr); + mat[1][2] = *(src_ptr + 1); + mat[2][0] = *(src_ptr + stride - 1); + mat[2][1] = *(src_ptr + stride); + mat[2][2] = *(src_ptr + stride + 1); + + // Compute Sobel gradients. + const int gx = (mat[0][0] - mat[0][2]) + (mat[2][0] - mat[2][2]) + + 2 * (mat[1][0] - mat[1][2]); + const int gy = (mat[0][0] - mat[2][0]) + (mat[0][2] - mat[2][2]) + + 2 * (mat[0][1] - mat[2][1]); + const int ga = ROUND_POWER_OF_TWO(abs(gx) + abs(gy), bitdepth - 8); + + // Accumulate Laplacian. + const int is_under = ga < edge_thresh; + const int v = 4 * mat[1][1] - + 2 * (mat[0][1] + mat[2][1] + mat[1][0] + mat[1][2]) + + (mat[0][0] + mat[0][2] + mat[2][0] + mat[2][2]); + final_acc += ROUND_POWER_OF_TWO(abs(v), bitdepth - 8) * is_under; + final_count += is_under; + + src_ptr++; + w++; + } + src_start += stride; + } while (++h < height - 1); + + // We counted negatively, so subtract to get the final value. + final_count -= horizontal_add_s32x4(count); + final_acc += horizontal_add_u64x2(acc); + return (final_count < 16) + ? -1.0 + : (double)final_acc / (6 * final_count) * SQRT_PI_BY_2; +} diff --git a/third_party/aom/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c b/third_party/aom/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c new file mode 100644 index 0000000000..6cf835a243 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/hybrid_fwd_txfm_neon.c @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include "aom_dsp/txfm_common.h" + +static void transpose4x4(int16x8_t in[2], int16x4_t out[4]) { + int32x4x2_t b0 = + vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1])); + int16x4x2_t c0 = vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])), + vreinterpret_s16_s32(vget_high_s32(b0.val[0]))); + int16x4x2_t c1 = vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])), + vreinterpret_s16_s32(vget_high_s32(b0.val[1]))); + out[0] = c0.val[0]; + out[1] = c0.val[1]; + out[2] = c1.val[0]; + out[3] = c1.val[1]; +} + +void av1_fwht4x4_neon(const int16_t *input, tran_low_t *output, int stride) { + // Load the 4x4 source in transposed form. + int16x4_t a1, b1, c1, d1, e; + a1 = vld1_s16(&input[0]); + b1 = vld1_s16(&input[1 * stride]); + c1 = vld1_s16(&input[2 * stride]); + d1 = vld1_s16(&input[3 * stride]); + + // WHT. + + // Row transforms. + a1 = vadd_s16(a1, b1); + d1 = vsub_s16(d1, c1); + e = vhsub_s16(a1, d1); + b1 = vsub_s16(e, b1); + c1 = vsub_s16(e, c1); + a1 = vsub_s16(a1, c1); + d1 = vadd_s16(d1, b1); + + int16x8_t x[2]; + x[0] = vcombine_s16(a1, c1); + x[1] = vcombine_s16(d1, b1); + + int16x4_t s[4]; + transpose4x4(x, s); + + a1 = s[0]; + b1 = s[1]; + c1 = s[2]; + d1 = s[3]; + + // Row transforms. + a1 = vadd_s16(a1, b1); + d1 = vsub_s16(d1, c1); + e = vhsub_s16(a1, d1); + b1 = vsub_s16(e, b1); + c1 = vsub_s16(e, c1); + a1 = vsub_s16(a1, c1); + d1 = vadd_s16(d1, b1); + + vst1q_s32(&output[0], vshll_n_s16(a1, UNIT_QUANT_SHIFT)); + vst1q_s32(&output[4], vshll_n_s16(c1, UNIT_QUANT_SHIFT)); + vst1q_s32(&output[8], vshll_n_s16(d1, UNIT_QUANT_SHIFT)); + vst1q_s32(&output[12], vshll_n_s16(b1, UNIT_QUANT_SHIFT)); +} diff --git a/third_party/aom/av1/encoder/arm/neon/ml_neon.c b/third_party/aom/av1/encoder/arm/neon/ml_neon.c new file mode 100644 index 0000000000..be6ddfd763 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/ml_neon.c @@ -0,0 +1,339 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <stdbool.h> +#include <assert.h> +#include <arm_neon.h> + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" +#include "av1/encoder/ml.h" + +static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l, + const float32x4_t *zero) { + *out_h = vmaxq_f32(*out_h, *zero); + *out_l = vmaxq_f32(*out_l, *zero); +} + +static void nn_activate4(float32x4_t *x, const float32x4_t *zero) { + *x = vmaxq_f32(*x, *zero); +} + +#define CLAMP_0(x) (x = x > 0 ? x : 0) + +static void nn_propagate_8to1(int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes, bool output_layer) { + const float32x4_t zero = vdupq_n_f32(0); + float32x4_t vadd = zero; + float total = *layer_bias; + + for (int in = 0; in < num_inputs; in += 8) { + const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]); + const float32x4_t inputs_l = vld1q_f32(&inputs[in]); + + const float32x4_t weights_h = vld1q_f32(&weights[in + 4]); + const float32x4_t weights_l = vld1q_f32(&weights[in]); + + vadd = vmlaq_f32(vadd, inputs_h, weights_h); + vadd = vmlaq_f32(vadd, inputs_l, weights_l); + } +#if AOM_ARCH_AARCH64 + total += vaddvq_f32(vadd); +#else + float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd)); + vadd_lo = vpadd_f32(vadd_lo, vadd_lo); + total += vget_lane_f32(vadd_lo, 0); +#endif + + if (!output_layer) CLAMP_0(total); + *output_nodes = total; +} + +static void nn_propagate_xto1(int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes) { + float32x4_t vadd = vdupq_n_f32(0); + + float total = *layer_bias; + int j = num_inputs; + int in = 0; + while (j > 7) { + const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]); + const float32x4_t inputs_l = vld1q_f32(&inputs[in]); + + const float32x4_t weights_h = vld1q_f32(&weights[in + 4]); + const float32x4_t weights_l = vld1q_f32(&weights[in]); + + vadd = vmlaq_f32(vadd, inputs_h, weights_h); + vadd = vmlaq_f32(vadd, inputs_l, weights_l); + in += 8; + j -= 8; + } + +#if AOM_ARCH_AARCH64 + total += vaddvq_f32(vadd); + +#else + float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd)); + vadd_lo = vpadd_f32(vadd_lo, vadd_lo); + total += vget_lane_f32(vadd_lo, 0); +#endif + for (; in < num_inputs; in++) total += weights[in] * inputs[in]; + + *output_nodes = CLAMP_0(total); +} + +static void nn_propagate_xsto1(int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes) { + float total = *layer_bias; +#if AOM_ARCH_AARCH64 + const float32x4_t v_inputs = vld1q_f32(inputs); + const float32x4_t v_weights = vld1q_f32(weights); + const float32x4_t vadd = vmulq_f32(v_inputs, v_weights); + total += vaddvq_f32(vadd); + int in = 4; +#else + int in = 0; +#endif + for (; in < num_inputs; in++) total += weights[in] * inputs[in]; + + *output_nodes = CLAMP_0(total); +} + +static void nn_propagate_4to1(int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes, bool output_layer) { + const float32x4_t zero = vdupq_n_f32(0); + float32x4_t vadd = zero; + float total = *layer_bias; + + for (int in = 0; in < num_inputs; in += 4) { + const float32x4_t v_inputs = vld1q_f32(&inputs[in]); + const float32x4_t v_weights = vld1q_f32(&weights[in]); + vadd = vmlaq_f32(vadd, v_inputs, v_weights); + } + +#if AOM_ARCH_AARCH64 + total += vaddvq_f32(vadd); +#else + float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd)); + vadd_lo = vpadd_f32(vadd_lo, vadd_lo); + total += vget_lane_f32(vadd_lo, 0); +#endif + + if (!output_layer) CLAMP_0(total); + *output_nodes = total; +} + +static void nn_propagate_4to4(int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes, bool output_layer) { + float32x4_t outputs = vld1q_f32(layer_bias); + const float32x4_t zero = vdupq_n_f32(0); + + float32x4_t mul0[2] = { zero, zero }; + float32x4_t mul1[2] = { zero, zero }; + for (int in = 0; in < num_inputs; in += 4) { + const float32x4_t v_input = vld1q_f32(&inputs[in]); + + for (int i = 0; i < 2; i++) { + const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]); + mul0[i] = vmlaq_f32(mul0[i], weight0, v_input); + const float32x4_t weight1 = + vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]); + mul1[i] = vmlaq_f32(mul1[i], weight1, v_input); + } + } + for (int i = 0; i < 2; i++) +#if AOM_ARCH_AARCH64 + mul0[i] = vpaddq_f32(mul0[i], mul1[i]); + const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]); +#else + mul0[i] = + vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])), + vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i]))); + const float32x4_t hh = + vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])), + vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1]))); +#endif + + outputs = vaddq_f32(outputs, hh); + if (!output_layer) nn_activate4(&outputs, &zero); + vst1q_f32(output_nodes, outputs); +} + +static void nn_propagate_4to8(const int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes, bool output_layer) { + float32x4_t out_h = vld1q_f32(&layer_bias[4]); + float32x4_t out_l = vld1q_f32(layer_bias); + const float32x4_t zero = vdupq_n_f32(0); + float32x4_t mul0[4] = { zero, zero, zero, zero }; + float32x4_t mul1[4] = { zero, zero, zero, zero }; + + for (int in = 0; in < num_inputs; in += 4) { + const float32x4_t v_input = vld1q_f32(&inputs[in]); + for (int i = 0; i < 4; i++) { + const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]); + const float32x4_t weight1 = + vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]); + mul0[i] = vmlaq_f32(mul0[i], v_input, weight0); + mul1[i] = vmlaq_f32(mul1[i], v_input, weight1); + } + } + for (int i = 0; i < 4; i++) +#if AOM_ARCH_AARCH64 + mul0[i] = vpaddq_f32(mul0[i], mul1[i]); + const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]); + const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]); +#else + mul0[i] = + vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])), + vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i]))); + const float32x4_t hh0 = + vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])), + vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1]))); + const float32x4_t hh1 = + vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])), + vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3]))); +#endif + + out_h = vaddq_f32(out_h, hh1); + out_l = vaddq_f32(out_l, hh0); + + if (!output_layer) nn_activate8(&out_h, &out_l, &zero); + vst1q_f32(&output_nodes[4], out_h); + vst1q_f32(output_nodes, out_l); +} + +static void nn_propagate_8to4(const int num_inputs, const float *const inputs, + const float *const weights, + const float *layer_bias, + float *const output_nodes, bool output_layer) { + float32x4_t outputs = vld1q_f32(layer_bias); + const float32x4_t zero = vdupq_n_f32(0); + float32x4_t add[4] = { zero, zero, zero, zero }; + for (int in = 0; in < num_inputs; in += 8) { + const float32x4_t inputs_l = vld1q_f32(&inputs[in]); + const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]); + + for (int i = 0; i < 4; i++) { + const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]); + const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]); + add[i] = vmlaq_f32(add[i], inputs_l, weight_l); + add[i] = vmlaq_f32(add[i], inputs_h, weight_h); + } + } +#if AOM_ARCH_AARCH64 + const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]); + const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]); + const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h); +#else + const float32x4_t hadd_h = + vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])), + vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3]))); + const float32x4_t hadd_l = + vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])), + vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1]))); + const float32x4_t haddhadd = + vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)), + vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h))); +#endif + + outputs = vaddq_f32(outputs, haddhadd); + if (!output_layer) nn_activate4(&outputs, &zero); + vst1q_f32(output_nodes, outputs); +} + +// Calculate prediction based on the given input features and neural net config. +// Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden +// layer. +void av1_nn_predict_neon(const float *input_nodes, + const NN_CONFIG *const nn_config, int reduce_prec, + float *const output) { + float buf[2][NN_MAX_NODES_PER_LAYER]; + int buf_index = 0; + int num_inputs = nn_config->num_inputs; + // Hidden layers, except the final iteration is the output layer. + for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) { + const float *layer_weights = nn_config->weights[layer]; + const float *layer_bias = nn_config->bias[layer]; + bool output_layer = (layer == nn_config->num_hidden_layers); + float *const output_nodes = output_layer ? output : buf[buf_index]; + const int num_outputs = output_layer ? nn_config->num_outputs + : nn_config->num_hidden_nodes[layer]; + + if (num_inputs % 4 == 0 && num_outputs % 8 == 0) { + for (int out = 0; out < num_outputs; out += 8) { + nn_propagate_4to8(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out], output_layer); + } + } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) { + for (int out = 0; out < num_outputs; out += 4) { + nn_propagate_8to4(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out], output_layer); + } + } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) { + for (int out = 0; out < num_outputs; out += 4) { + nn_propagate_4to4(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out], output_layer); + } + } else if (num_inputs % 8 == 0) { + for (int out = 0; out < num_outputs; out++) { + nn_propagate_8to1(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out], output_layer); + } + } else if (num_inputs % 4 == 0) { + for (int out = 0; out < num_outputs; out++) { + nn_propagate_4to1(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out], output_layer); + } + } else if (num_inputs > 8) { + for (int out = 0; out < num_outputs; out++) { + nn_propagate_xto1(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out]); + } + } else if (num_inputs >= 4) { + for (int out = 0; out < num_outputs; out++) { + nn_propagate_xsto1(num_inputs, input_nodes, + &layer_weights[out * num_inputs], &layer_bias[out], + &output_nodes[out]); + } + } else { + for (int node = 0; node < num_outputs; ++node) { + float val = layer_bias[node]; + for (int i = 0; i < num_inputs; ++i) + val += layer_weights[node * num_inputs + i] * input_nodes[i]; + // ReLU as activation function. + val = val > 0.0f ? val : 0.0f; // Could use AOMMAX(). + output_nodes[node] = val; + } + } + input_nodes = output_nodes; + num_inputs = num_outputs; + buf_index = 1 - buf_index; + } + if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs); +} diff --git a/third_party/aom/av1/encoder/arm/neon/pickrst_neon.c b/third_party/aom/av1/encoder/arm/neon/pickrst_neon.c new file mode 100644 index 0000000000..2e4761f9a4 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/pickrst_neon.c @@ -0,0 +1,1217 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" + +#include "aom_dsp/arm/sum_neon.h" +#include "av1/common/restoration.h" +#include "av1/encoder/arm/neon/pickrst_neon.h" +#include "av1/encoder/pickrst.h" + +int64_t av1_lowbd_pixel_proj_error_neon( + const uint8_t *src, int width, int height, int src_stride, + const uint8_t *dat, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) { + int64_t sse = 0; + int64x2_t sse_s64 = vdupq_n_s64(0); + + if (params->r[0] > 0 && params->r[1] > 0) { + int32x2_t xq_v = vld1_s32(xq); + int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), SGRPROJ_RST_BITS); + + do { + int j = 0; + int32x4_t sse_s32 = vdupq_n_s32(0); + + do { + const uint8x8_t d = vld1_u8(&dat[j]); + const uint8x8_t s = vld1_u8(&src[j]); + int32x4_t flt0_0 = vld1q_s32(&flt0[j]); + int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]); + int32x4_t flt1_0 = vld1q_s32(&flt1[j]); + int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]); + + int32x4_t offset = + vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)); + int32x4_t v0 = vmlaq_lane_s32(offset, flt0_0, xq_v, 0); + int32x4_t v1 = vmlaq_lane_s32(offset, flt0_1, xq_v, 0); + + v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1); + v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1); + + int16x8_t d_s16 = vreinterpretq_s16_u16(vmovl_u8(d)); + v0 = vmlsl_lane_s16(v0, vget_low_s16(d_s16), + vreinterpret_s16_s32(xq_sum_v), 0); + v1 = vmlsl_lane_s16(v1, vget_high_s16(d_s16), + vreinterpret_s16_s32(xq_sum_v), 0); + + int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s)); + int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff); + int16x4_t e_lo = vget_low_s16(e); + int16x4_t e_hi = vget_high_s16(e); + + sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo); + sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi); + + j += 8; + } while (j <= width - 8); + + for (int k = j; k < width; ++k) { + int32_t u = (dat[k] << SGRPROJ_RST_BITS); + int32_t v = (1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)) + + xq[0] * flt0[k] + xq[1] * flt1[k] - u * (xq[0] + xq[1]); + int32_t e = + (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k]; + sse += e * e; + } + + sse_s64 = vpadalq_s32(sse_s64, sse_s32); + + dat += dat_stride; + src += src_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + } while (--height != 0); + } else if (params->r[0] > 0 || params->r[1] > 0) { + int xq_active = (params->r[0] > 0) ? xq[0] : xq[1]; + int32_t *flt = (params->r[0] > 0) ? flt0 : flt1; + int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride; + int32x2_t xq_v = vdup_n_s32(xq_active); + + do { + int32x4_t sse_s32 = vdupq_n_s32(0); + int j = 0; + + do { + const uint8x8_t d = vld1_u8(&dat[j]); + const uint8x8_t s = vld1_u8(&src[j]); + int32x4_t flt_0 = vld1q_s32(&flt[j]); + int32x4_t flt_1 = vld1q_s32(&flt[j + 4]); + int16x8_t d_s16 = + vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS)); + + int32x4_t sub_0 = vsubw_s16(flt_0, vget_low_s16(d_s16)); + int32x4_t sub_1 = vsubw_s16(flt_1, vget_high_s16(d_s16)); + + int32x4_t offset = + vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)); + int32x4_t v0 = vmlaq_lane_s32(offset, sub_0, xq_v, 0); + int32x4_t v1 = vmlaq_lane_s32(offset, sub_1, xq_v, 0); + + int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS); + + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s)); + int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff); + int16x4_t e_lo = vget_low_s16(e); + int16x4_t e_hi = vget_high_s16(e); + + sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo); + sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi); + + j += 8; + } while (j <= width - 8); + + for (int k = j; k < width; ++k) { + int32_t u = dat[k] << SGRPROJ_RST_BITS; + int32_t v = xq_active * (flt[k] - u); + int32_t e = ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) + + dat[k] - src[k]; + sse += e * e; + } + + sse_s64 = vpadalq_s32(sse_s64, sse_s32); + + dat += dat_stride; + src += src_stride; + flt += flt_stride; + } while (--height != 0); + } else { + uint32x4_t sse_s32 = vdupq_n_u32(0); + + do { + int j = 0; + + do { + const uint8x16_t d = vld1q_u8(&dat[j]); + const uint8x16_t s = vld1q_u8(&src[j]); + + uint8x16_t diff = vabdq_u8(d, s); + uint8x8_t diff_lo = vget_low_u8(diff); + uint8x8_t diff_hi = vget_high_u8(diff); + + sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_lo, diff_lo)); + sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_hi, diff_hi)); + + j += 16; + } while (j <= width - 16); + + for (int k = j; k < width; ++k) { + int32_t e = dat[k] - src[k]; + sse += e * e; + } + + dat += dat_stride; + src += src_stride; + } while (--height != 0); + + sse_s64 = vreinterpretq_s64_u64(vpaddlq_u32(sse_s32)); + } + + sse += horizontal_add_s64x2(sse_s64); + return sse; +} + +// We can accumulate up to 65536 8-bit multiplication results in 32-bit. We are +// processing 2 pixels at a time, so the accumulator max can be as high as 32768 +// for the compute stats. +#define STAT_ACCUMULATOR_MAX 32768 + +static INLINE uint8x8_t tbl2(uint8x16_t a, uint8x16_t b, uint8x8_t idx) { +#if AOM_ARCH_AARCH64 + uint8x16x2_t table = { { a, b } }; + return vqtbl2_u8(table, idx); +#else + uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b), + vget_high_u8(b) } }; + return vtbl4_u8(table, idx); +#endif +} + +static INLINE uint8x16_t tbl2q(uint8x16_t a, uint8x16_t b, uint8x16_t idx) { +#if AOM_ARCH_AARCH64 + uint8x16x2_t table = { { a, b } }; + return vqtbl2q_u8(table, idx); +#else + uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b), + vget_high_u8(b) } }; + return vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)), + vtbl4_u8(table, vget_high_u8(idx))); +#endif +} + +// The M matrix is accumulated in STAT_ACCUMULATOR_MAX steps to speed-up the +// computation. This function computes the final M from the accumulated +// (src_s64) and the residual parts (src_s32). It also transposes the result as +// the output needs to be column-major. +static INLINE void acc_transpose_M(int64_t *dst, const int64_t *src_s64, + const int32_t *src_s32, const int wiener_win, + int scale) { + for (int i = 0; i < wiener_win; ++i) { + for (int j = 0; j < wiener_win; ++j) { + int tr_idx = j * wiener_win + i; + *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale; + } + } +} + +// The resulting H is a column-major matrix accumulated from the transposed +// (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single +// vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This +// function transforms back to the originally expected format (double +// transpose). The H matrix is accumulated in STAT_ACCUMULATOR_MAX steps to +// speed-up the computation. This function computes the final H from the +// accumulated (src_s64) and the residual parts (src_s32). The computed H is +// only an upper triangle matrix, this function also fills the lower triangle of +// the resulting matrix. +static void update_H(int64_t *dst, const int64_t *src_s64, + const int32_t *src_s32, const int wiener_win, int stride, + int scale) { + // For a simplified theoretical 3x3 case where `wiener_win` is 3 and + // `wiener_win2` is 9, the M matrix is 3x3: + // 0, 3, 6 + // 1, 4, 7 + // 2, 5, 8 + // + // This is viewed as a vector to compute H (9x9) by vector outer product: + // 0, 3, 6, 1, 4, 7, 2, 5, 8 + // + // Double transpose and upper triangle remapping for 3x3 -> 9x9 case: + // 0, 3, 6, 1, 4, 7, 2, 5, 8, + // 3, 30, 33, 12, 31, 34, 21, 32, 35, + // 6, 33, 60, 15, 42, 61, 24, 51, 62, + // 1, 12, 15, 10, 13, 16, 11, 14, 17, + // 4, 31, 42, 13, 40, 43, 22, 41, 44, + // 7, 34, 61, 16, 43, 70, 25, 52, 71, + // 2, 21, 24, 11, 22, 25, 20, 23, 26, + // 5, 32, 51, 14, 41, 52, 23, 50, 53, + // 8, 35, 62, 17, 44, 71, 26, 53, 80, + const int wiener_win2 = wiener_win * wiener_win; + + // Loop through the indices according to the remapping above, along the + // columns: + // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ..., + // wiener_win - 1, wiener_win - 1 + wiener_win, ... + // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8. + for (int i = 0; i < wiener_win; ++i) { + for (int j = i; j < wiener_win2; j += wiener_win) { + // These two inner loops are the same as the two outer loops, but running + // along rows instead of columns. For the 3x3 case `l` will be: + // 0, 3, 6, 1, 4, 7, 2, 5, 8. + for (int k = 0; k < wiener_win; ++k) { + for (int l = k; l < wiener_win2; l += wiener_win) { + // The nominal double transpose indexing would be: + // int idx = stride * j + l; + // However we need the upper-triangle indices, it is easy with some + // min/max operations. + int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l); + + // Resulting matrix is filled by combining the 64-bit and the residual + // 32-bit matrices together with scaling. + *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale; + } + } + } + } +} + +// Load 7x7 matrix into 3 and a half 128-bit vectors from consecutive rows, the +// last load address is offset to prevent out-of-bounds access. +static INLINE void load_and_pack_u8_8x7(uint8x16_t dst[4], const uint8_t *src, + ptrdiff_t stride) { + dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride)); + src += 2 * stride; + dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride)); + src += 2 * stride; + dst[2] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride)); + src += 2 * stride; + dst[3] = vcombine_u8(vld1_u8(src - 1), vdup_n_u8(0)); +} + +static INLINE void compute_stats_win7_neon(const uint8_t *dgd, + const uint8_t *src, int width, + int height, int dgd_stride, + int src_stride, int avg, int64_t *M, + int64_t *H, int downsample_factor) { + // Matrix names are capitalized to help readability. + DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]); + DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]); + + memset(M_s32, 0, sizeof(M_s32)); + memset(M_s64, 0, sizeof(M_s64)); + memset(H_s32, 0, sizeof(H_s32)); + memset(H_s64, 0, sizeof(H_s64)); + + // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7 + // matrices. + // clang-format off + DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7[96]) = { + 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 16, 17, + 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, + 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 21, 22, + 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, + 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, + 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, + }; + // clang-format on + + const uint8x16_t lut0 = vld1q_u8(shuffle_stats7 + 0); + const uint8x16_t lut1 = vld1q_u8(shuffle_stats7 + 16); + const uint8x16_t lut2 = vld1q_u8(shuffle_stats7 + 32); + const uint8x16_t lut3 = vld1q_u8(shuffle_stats7 + 48); + const uint8x16_t lut4 = vld1q_u8(shuffle_stats7 + 64); + const uint8x16_t lut5 = vld1q_u8(shuffle_stats7 + 80); + + int acc_cnt = STAT_ACCUMULATOR_MAX; + const int src_next = downsample_factor * src_stride - width; + const int dgd_next = downsample_factor * dgd_stride - width; + const uint8x8_t avg_u8 = vdup_n_u8(avg); + + do { + int j = width; + while (j >= 2) { + // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the + // middle 6x7 elements being shared. + uint8x16_t dgd_rows[4]; + load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride); + + const uint8_t *dgd_ptr = dgd + dgd_stride * 6; + dgd += 2; + + // Re-arrange (and widen) the combined 8x7 matrix to have the 2 whole 7x7 + // matrices (1 for each of the 2 pixels) separated into distinct + // int16x8_t[6] arrays. These arrays contain 48 elements of the 49 (7x7). + // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 49 + // consecutive elements. + int16x8_t dgd_avg0[6]; + int16x8_t dgd_avg1[6]; + uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0); + uint8x16_t dgd_shuf3 = tbl2q(dgd_rows[0], dgd_rows[1], lut3); + + dgd_avg0[0] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8)); + dgd_avg0[1] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8)); + dgd_avg1[0] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf3), avg_u8)); + dgd_avg1[1] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf3), avg_u8)); + + vst1q_s16(DGD_AVG0, dgd_avg0[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG1, dgd_avg1[0]); + vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]); + + uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1); + uint8x16_t dgd_shuf4 = tbl2q(dgd_rows[1], dgd_rows[2], lut4); + + dgd_avg0[2] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8)); + dgd_avg0[3] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8)); + dgd_avg1[2] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf4), avg_u8)); + dgd_avg1[3] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf4), avg_u8)); + + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]); + vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]); + vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]); + + uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2); + uint8x16_t dgd_shuf5 = tbl2q(dgd_rows[2], dgd_rows[3], lut5); + + dgd_avg0[4] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8)); + dgd_avg0[5] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8)); + dgd_avg1[4] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf5), avg_u8)); + dgd_avg1[5] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf5), avg_u8)); + + vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]); + vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]); + vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]); + vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]); + + // The remaining last (49th) elements of `dgd - avg`. + DGD_AVG0[48] = dgd_ptr[6] - avg; + DGD_AVG1[48] = dgd_ptr[7] - avg; + + // Accumulate into row-major variant of matrix M (cross-correlation) for 2 + // output pixels at a time. M is of size 7 * 7. It needs to be filled such + // that multiplying one element from src with each element of a row of the + // wiener window will fill one column of M. However this is not very + // convenient in terms of memory access, as it means we do contiguous + // loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int src_avg1 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1); + update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0], + dgd_avg1[0]); + update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1], + dgd_avg1[1]); + update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2], + dgd_avg1[2]); + update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3], + dgd_avg1[3]); + update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4], + dgd_avg1[4]); + update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5], + dgd_avg1[5]); + + // Last (49th) element of M_s32 can be computed as scalar more efficiently + // for 2 output pixels. + M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1; + + // Start accumulating into row-major version of matrix H + // (auto-covariance), it expects the DGD_AVG[01] matrices to also be + // row-major. H is of size 49 * 49. It is filled by multiplying every pair + // of elements of the wiener window together (vector outer product). Since + // it is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work with column-major matrices, + // so we accumulate into a row-major matrix H_s32. At the end of the + // algorithm a double transpose transformation will convert H_s32 back to + // the expected output layout. + update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1); + + // The last element of the triangle of H_s32 matrix can be computed as a + // scalar more efficiently. + H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += + DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48]; + + // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent + // overflow. + if (--acc_cnt == 0) { + acc_cnt = STAT_ACCUMULATOR_MAX; + + accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2); + + // The widening accumulation is only needed for the upper triangle part + // of the matrix. + int64_t *lh = H_s64; + int32_t *lh32 = H_s32; + for (int k = 0; k < WIENER_WIN2; ++k) { + // The widening accumulation is only run for the relevant parts + // (upper-right triangle) in a row 4-element aligned. + int k4 = k / 4 * 4; + accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4); + + // Last element of the row is computed separately. + lh[48] += lh32[48]; + lh32[48] = 0; + + lh += WIENER_WIN2_ALIGN2; + lh32 += WIENER_WIN2_ALIGN2; + } + } + + j -= 2; + } + + // Computations for odd pixel in the row. + if (width & 1) { + // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the + // middle 6x7 elements being shared. + uint8x16_t dgd_rows[4]; + load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride); + + const uint8_t *dgd_ptr = dgd + dgd_stride * 6; + ++dgd; + + // Re-arrange (and widen) the combined 8x7 matrix to have a whole 7x7 + // matrix tightly packed into a int16x8_t[6] array. This array contains + // 48 elements of the 49 (7x7). Compute `dgd - avg` for the whole buffer. + // The DGD_AVG buffer contains 49 consecutive elements. + int16x8_t dgd_avg0[6]; + uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0); + dgd_avg0[0] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8)); + dgd_avg0[1] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8)); + vst1q_s16(DGD_AVG0, dgd_avg0[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + + uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1); + dgd_avg0[2] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8)); + dgd_avg0[3] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8)); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]); + + uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2); + dgd_avg0[4] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8)); + dgd_avg0[5] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8)); + vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]); + vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]); + + // The remaining last (49th) element of `dgd - avg`. + DGD_AVG0[48] = dgd_ptr[6] - avg; + + // Accumulate into row-major order variant of matrix M (cross-correlation) + // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled + // such that multiplying one element from src with each element of a row + // of the wiener window will fill one column of M. However this is not + // very convenient in terms of memory access, as it means we do + // contiguous loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]); + update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]); + update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]); + update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]); + update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]); + update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]); + + // Last (49th) element of M_s32 can be computed as scalar more efficiently + // for 1 output pixel. + M_s32[48] += DGD_AVG0[48] * src_avg0; + + // Start accumulating into row-major order version of matrix H + // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major. + // H is of size 49 * 49. It is filled by multiplying every pair of + // elements of the wiener window together (vector outer product). Since it + // is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work column-major matrices, so we + // accumulate into a row-major matrix H_s32. At the end of the algorithm a + // double transpose transformation will convert H_s32 back to the expected + // output layout. + update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48); + + // The last element of the triangle of H_s32 matrix can be computed as + // scalar more efficiently. + H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48]; + } + + src += src_next; + dgd += dgd_next; + } while (--height != 0); + + acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, downsample_factor); + + update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, downsample_factor); +} + +// Load 5x5 matrix into 2 and a half 128-bit vectors from consecutive rows, the +// last load address is offset to prevent out-of-bounds access. +static INLINE void load_and_pack_u8_6x5(uint8x16_t dst[3], const uint8_t *src, + ptrdiff_t stride) { + dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride)); + src += 2 * stride; + dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride)); + src += 2 * stride; + dst[2] = vcombine_u8(vld1_u8(src - 3), vdup_n_u8(0)); +} + +static INLINE void compute_stats_win5_neon(const uint8_t *dgd, + const uint8_t *src, int width, + int height, int dgd_stride, + int src_stride, int avg, int64_t *M, + int64_t *H, int downsample_factor) { + // Matrix names are capitalized to help readability. + DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]); + DECLARE_ALIGNED(64, int32_t, + H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]); + DECLARE_ALIGNED(64, int64_t, + H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]); + + memset(M_s32, 0, sizeof(M_s32)); + memset(M_s64, 0, sizeof(M_s64)); + memset(H_s32, 0, sizeof(H_s32)); + memset(H_s64, 0, sizeof(H_s64)); + + // Look-up tables to create 8x3 matrix with consecutive elements from two 5x5 + // matrices. + // clang-format off + DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5[48]) = { + 0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 16, 17, 18, 19, 20, 24, + 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21, 25, + 9, 10, 11, 12, 19, 20, 21, 22, 10, 11, 12, 13, 20, 21, 22, 23, + }; + // clang-format on + + const uint8x16_t lut0 = vld1q_u8(shuffle_stats5 + 0); + const uint8x16_t lut1 = vld1q_u8(shuffle_stats5 + 16); + const uint8x16_t lut2 = vld1q_u8(shuffle_stats5 + 32); + + int acc_cnt = STAT_ACCUMULATOR_MAX; + const int src_next = downsample_factor * src_stride - width; + const int dgd_next = downsample_factor * dgd_stride - width; + const uint8x8_t avg_u8 = vdup_n_u8(avg); + + do { + int j = width; + while (j >= 2) { + // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the + // middle 4x5 elements being shared. + uint8x16_t dgd_rows[3]; + load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride); + + const uint8_t *dgd_ptr = dgd + dgd_stride * 4; + dgd += 2; + + // Re-arrange (and widen) the combined 6x5 matrix to have the 2 whole 5x5 + // matrices (1 for each of the 2 pixels) separated into distinct + // int16x8_t[3] arrays. These arrays contain 24 elements of the 25 (5x5). + // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 25 + // consecutive elements. + int16x8_t dgd_avg0[3]; + int16x8_t dgd_avg1[3]; + uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0); + uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[0], dgd_rows[1], lut1); + uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[1], dgd_rows[2], lut2); + + dgd_avg0[0] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8)); + dgd_avg0[1] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8)); + dgd_avg0[2] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8)); + dgd_avg1[0] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8)); + dgd_avg1[1] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8)); + dgd_avg1[2] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8)); + + vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + vst1q_s16(DGD_AVG1 + 0, dgd_avg1[0]); + vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]); + vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]); + + // The remaining last (25th) elements of `dgd - avg`. + DGD_AVG0[24] = dgd_ptr[4] - avg; + DGD_AVG1[24] = dgd_ptr[5] - avg; + + // Accumulate into row-major variant of matrix M (cross-correlation) for 2 + // output pixels at a time. M is of size 5 * 5. It needs to be filled such + // that multiplying one element from src with each element of a row of the + // wiener window will fill one column of M. However this is not very + // convenient in terms of memory access, as it means we do contiguous + // loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int src_avg1 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1); + update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0], + dgd_avg1[0]); + update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1], + dgd_avg1[1]); + update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2], + dgd_avg1[2]); + + // Last (25th) element of M_s32 can be computed as scalar more efficiently + // for 2 output pixels. + M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1; + + // Start accumulating into row-major version of matrix H + // (auto-covariance), it expects the DGD_AVG[01] matrices to also be + // row-major. H is of size 25 * 25. It is filled by multiplying every pair + // of elements of the wiener window together (vector outer product). Since + // it is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work with column-major matrices, + // so we accumulate into a row-major matrix H_s32. At the end of the + // algorithm a double transpose transformation will convert H_s32 back to + // the expected output layout. + update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1); + + // The last element of the triangle of H_s32 matrix can be computed as a + // scalar more efficiently. + H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] += + DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24]; + + // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent + // overflow. + if (--acc_cnt == 0) { + acc_cnt = STAT_ACCUMULATOR_MAX; + + accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2); + + // The widening accumulation is only needed for the upper triangle part + // of the matrix. + int64_t *lh = H_s64; + int32_t *lh32 = H_s32; + for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) { + // The widening accumulation is only run for the relevant parts + // (upper-right triangle) in a row 4-element aligned. + int k4 = k / 4 * 4; + accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4); + + // Last element of the row is computed separately. + lh[24] += lh32[24]; + lh32[24] = 0; + + lh += WIENER_WIN2_REDUCED_ALIGN2; + lh32 += WIENER_WIN2_REDUCED_ALIGN2; + } + } + + j -= 2; + } + + // Computations for odd pixel in the row. + if (width & 1) { + // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the + // middle 4x5 elements being shared. + uint8x16_t dgd_rows[3]; + load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride); + + const uint8_t *dgd_ptr = dgd + dgd_stride * 4; + ++dgd; + + // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5 + // matrix tightly packed into a int16x8_t[3] array. This array contains + // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer. + // The DGD_AVG buffer contains 25 consecutive elements. + int16x8_t dgd_avg0[3]; + uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0); + uint8x8_t dgd_shuf1 = tbl2(dgd_rows[1], dgd_rows[2], vget_low_u8(lut2)); + + dgd_avg0[0] = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8)); + dgd_avg0[1] = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8)); + dgd_avg0[2] = vreinterpretq_s16_u16(vsubl_u8(dgd_shuf1, avg_u8)); + + vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]); + vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]); + vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]); + + // The remaining last (25th) element of `dgd - avg`. + DGD_AVG0[24] = dgd_ptr[4] - avg; + + // Accumulate into row-major order variant of matrix M (cross-correlation) + // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled + // such that multiplying one element from src with each element of a row + // of the wiener window will fill one column of M. However this is not + // very convenient in terms of memory access, as it means we do + // contiguous loads of dgd but strided stores to M. As a result, we use an + // intermediate matrix M_s32 which is instead filled such that one row of + // the wiener window gives one row of M_s32. Once fully computed, M_s32 is + // then transposed to return M. + int src_avg0 = *src++ - avg; + int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0); + update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]); + update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]); + update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]); + + // Last (25th) element of M_s32 can be computed as scalar more efficiently + // for 1 output pixel. + M_s32[24] += DGD_AVG0[24] * src_avg0; + + // Start accumulating into row-major order version of matrix H + // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major. + // H is of size 25 * 25. It is filled by multiplying every pair of + // elements of the wiener window together (vector outer product). Since it + // is a symmetric matrix, we only compute the upper-right triangle, and + // then copy it down to the lower-left later. The upper triangle is + // covered by 4x4 tiles. The original algorithm assumes the M matrix is + // column-major and the resulting H matrix is also expected to be + // column-major. It is not efficient to work column-major matrices, so we + // accumulate into a row-major matrix H_s32. At the end of the algorithm a + // double transpose transformation will convert H_s32 back to the expected + // output layout. + update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24); + + // The last element of the triangle of H_s32 matrix can be computed as a + // scalar more efficiently. + H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] += + DGD_AVG0[24] * DGD_AVG0[24]; + } + + src += src_next; + dgd += dgd_next; + } while (--height != 0); + + acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, downsample_factor); + + update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2, + downsample_factor); +} + +static INLINE uint8_t find_average_neon(const uint8_t *src, int src_stride, + int width, int height) { + uint64_t sum = 0; + + if (width >= 16) { + int h = 0; + // We can accumulate up to 257 8-bit values in a 16-bit value, given + // that each 16-bit vector has 8 elements, that means we can process up to + // int(257*8/width) rows before we need to widen to 32-bit vector + // elements. + int h_overflow = 257 * 8 / width; + int h_limit = height > h_overflow ? h_overflow : height; + uint32x4_t avg_u32 = vdupq_n_u32(0); + do { + uint16x8_t avg_u16 = vdupq_n_u16(0); + do { + int j = width; + const uint8_t *src_ptr = src; + do { + uint8x16_t s = vld1q_u8(src_ptr); + avg_u16 = vpadalq_u8(avg_u16, s); + j -= 16; + src_ptr += 16; + } while (j >= 16); + if (j >= 8) { + uint8x8_t s = vld1_u8(src_ptr); + avg_u16 = vaddw_u8(avg_u16, s); + j -= 8; + src_ptr += 8; + } + // Scalar tail case. + while (j > 0) { + sum += src[width - j]; + j--; + } + src += src_stride; + } while (++h < h_limit); + avg_u32 = vpadalq_u16(avg_u32, avg_u16); + + h_limit += h_overflow; + h_limit = height > h_overflow ? h_overflow : height; + } while (h < height); + return (uint8_t)((horizontal_long_add_u32x4(avg_u32) + sum) / + (width * height)); + } + if (width >= 8) { + int h = 0; + // We can accumulate up to 257 8-bit values in a 16-bit value, given + // that each 16-bit vector has 4 elements, that means we can process up to + // int(257*4/width) rows before we need to widen to 32-bit vector + // elements. + int h_overflow = 257 * 4 / width; + int h_limit = height > h_overflow ? h_overflow : height; + uint32x2_t avg_u32 = vdup_n_u32(0); + do { + uint16x4_t avg_u16 = vdup_n_u16(0); + do { + int j = width; + const uint8_t *src_ptr = src; + uint8x8_t s = vld1_u8(src_ptr); + avg_u16 = vpadal_u8(avg_u16, s); + j -= 8; + src_ptr += 8; + // Scalar tail case. + while (j > 0) { + sum += src[width - j]; + j--; + } + src += src_stride; + } while (++h < h_limit); + avg_u32 = vpadal_u16(avg_u32, avg_u16); + + h_limit += h_overflow; + h_limit = height > h_overflow ? h_overflow : height; + } while (h < height); + return (uint8_t)((horizontal_long_add_u32x2(avg_u32) + sum) / + (width * height)); + } + int i = height; + do { + int j = 0; + do { + sum += src[j]; + } while (++j < width); + src += src_stride; + } while (--i != 0); + return (uint8_t)(sum / (width * height)); +} + +void av1_compute_stats_neon(int wiener_win, const uint8_t *dgd, + const uint8_t *src, int16_t *dgd_avg, + int16_t *src_avg, int h_start, int h_end, + int v_start, int v_end, int dgd_stride, + int src_stride, int64_t *M, int64_t *H, + int use_downsampled_wiener_stats) { + assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA); + assert(WIENER_STATS_DOWNSAMPLE_FACTOR == 4); + (void)dgd_avg; + (void)src_avg; + + const int wiener_win2 = wiener_win * wiener_win; + const int wiener_halfwin = wiener_win >> 1; + const int width = h_end - h_start; + const int height = v_end - v_start; + + const uint8_t *dgd_start = dgd + h_start + v_start * dgd_stride; + const uint8_t *src_start = src + h_start + v_start * src_stride; + + // The wiener window will slide along the dgd frame, centered on each pixel. + // For the top left pixel and all the pixels on the side of the frame this + // means half of the window will be outside of the frame. As such the actual + // buffer that we need to subtract the avg from will be 2 * wiener_halfwin + // wider and 2 * wiener_halfwin higher than the original dgd buffer. + const int vert_offset = v_start - wiener_halfwin; + const int horiz_offset = h_start - wiener_halfwin; + const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride; + + uint8_t avg = find_average_neon(dgd_start, dgd_stride, width, height); + + // Since the height is not necessarily a multiple of the downsample factor, + // the last line of src will be scaled according to how many rows remain. + int downsample_factor = + use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1; + + int downsampled_height = height / downsample_factor; + int downsample_remainder = height % downsample_factor; + + memset(M, 0, wiener_win2 * sizeof(*M)); + memset(H, 0, wiener_win2 * wiener_win2 * sizeof(*H)); + + // Calculate the M and H matrices for the normal and downsampled cases. + if (downsampled_height > 0) { + if (wiener_win == WIENER_WIN) { + compute_stats_win7_neon(dgd_win, src_start, width, downsampled_height, + dgd_stride, src_stride, avg, M, H, + downsample_factor); + } else { + compute_stats_win5_neon(dgd_win, src_start, width, downsampled_height, + dgd_stride, src_stride, avg, M, H, + downsample_factor); + } + } + + // Accumulate the remaining last rows in the downsampled case. + if (downsample_remainder > 0) { + int remainder_offset = height - downsample_remainder; + if (wiener_win == WIENER_WIN) { + compute_stats_win7_neon(dgd_win + remainder_offset * dgd_stride, + src_start + remainder_offset * src_stride, width, + 1, dgd_stride, src_stride, avg, M, H, + downsample_remainder); + } else { + compute_stats_win5_neon(dgd_win + remainder_offset * dgd_stride, + src_start + remainder_offset * src_stride, width, + 1, dgd_stride, src_stride, avg, M, H, + downsample_remainder); + } + } +} + +static INLINE void calc_proj_params_r0_r1_neon( + const uint8_t *src8, int width, int height, int src_stride, + const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) { + assert(width % 8 == 0); + const int size = width * height; + + int64x2_t h00_lo = vdupq_n_s64(0); + int64x2_t h00_hi = vdupq_n_s64(0); + int64x2_t h11_lo = vdupq_n_s64(0); + int64x2_t h11_hi = vdupq_n_s64(0); + int64x2_t h01_lo = vdupq_n_s64(0); + int64x2_t h01_hi = vdupq_n_s64(0); + int64x2_t c0_lo = vdupq_n_s64(0); + int64x2_t c0_hi = vdupq_n_s64(0); + int64x2_t c1_lo = vdupq_n_s64(0); + int64x2_t c1_hi = vdupq_n_s64(0); + + do { + const uint8_t *src_ptr = src8; + const uint8_t *dat_ptr = dat8; + int32_t *flt0_ptr = flt0; + int32_t *flt1_ptr = flt1; + int w = width; + + do { + uint8x8_t s = vld1_u8(src_ptr); + uint8x8_t d = vld1_u8(dat_ptr); + int32x4_t f0_lo = vld1q_s32(flt0_ptr); + int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4); + int32x4_t f1_lo = vld1q_s32(flt1_ptr); + int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4); + + int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS)); + int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS)); + + int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u)); + int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u)); + f0_lo = vsubw_s16(f0_lo, vget_low_s16(u)); + f0_hi = vsubw_s16(f0_hi, vget_high_s16(u)); + f1_lo = vsubw_s16(f1_lo, vget_low_s16(u)); + f1_hi = vsubw_s16(f1_hi, vget_high_s16(u)); + + h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo)); + h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo)); + h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi)); + h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi)); + + h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo)); + h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo)); + h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi)); + h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi)); + + h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo)); + h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo)); + h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi)); + h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi)); + + c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo)); + c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo)); + c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi)); + c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi)); + + c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo)); + c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo)); + c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi)); + c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi)); + + src_ptr += 8; + dat_ptr += 8; + flt0_ptr += 8; + flt1_ptr += 8; + w -= 8; + } while (w != 0); + + src8 += src_stride; + dat8 += dat_stride; + flt0 += flt0_stride; + flt1 += flt1_stride; + } while (--height != 0); + + H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size; + H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size; + H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size; + H[1][0] = H[0][1]; + C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size; + C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size; +} + +static INLINE void calc_proj_params_r0_neon(const uint8_t *src8, int width, + int height, int src_stride, + const uint8_t *dat8, int dat_stride, + int32_t *flt0, int flt0_stride, + int64_t H[2][2], int64_t C[2]) { + assert(width % 8 == 0); + const int size = width * height; + + int64x2_t h00_lo = vdupq_n_s64(0); + int64x2_t h00_hi = vdupq_n_s64(0); + int64x2_t c0_lo = vdupq_n_s64(0); + int64x2_t c0_hi = vdupq_n_s64(0); + + do { + const uint8_t *src_ptr = src8; + const uint8_t *dat_ptr = dat8; + int32_t *flt0_ptr = flt0; + int w = width; + + do { + uint8x8_t s = vld1_u8(src_ptr); + uint8x8_t d = vld1_u8(dat_ptr); + int32x4_t f0_lo = vld1q_s32(flt0_ptr); + int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4); + + int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS)); + int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS)); + + int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u)); + int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u)); + f0_lo = vsubw_s16(f0_lo, vget_low_s16(u)); + f0_hi = vsubw_s16(f0_hi, vget_high_s16(u)); + + h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo)); + h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo)); + h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi)); + h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi)); + + c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo)); + c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo)); + c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi)); + c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi)); + + src_ptr += 8; + dat_ptr += 8; + flt0_ptr += 8; + w -= 8; + } while (w != 0); + + src8 += src_stride; + dat8 += dat_stride; + flt0 += flt0_stride; + } while (--height != 0); + + H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size; + C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size; +} + +static INLINE void calc_proj_params_r1_neon(const uint8_t *src8, int width, + int height, int src_stride, + const uint8_t *dat8, int dat_stride, + int32_t *flt1, int flt1_stride, + int64_t H[2][2], int64_t C[2]) { + assert(width % 8 == 0); + const int size = width * height; + + int64x2_t h11_lo = vdupq_n_s64(0); + int64x2_t h11_hi = vdupq_n_s64(0); + int64x2_t c1_lo = vdupq_n_s64(0); + int64x2_t c1_hi = vdupq_n_s64(0); + + do { + const uint8_t *src_ptr = src8; + const uint8_t *dat_ptr = dat8; + int32_t *flt1_ptr = flt1; + int w = width; + + do { + uint8x8_t s = vld1_u8(src_ptr); + uint8x8_t d = vld1_u8(dat_ptr); + int32x4_t f1_lo = vld1q_s32(flt1_ptr); + int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4); + + int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS)); + int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS)); + + int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u)); + int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u)); + f1_lo = vsubw_s16(f1_lo, vget_low_s16(u)); + f1_hi = vsubw_s16(f1_hi, vget_high_s16(u)); + + h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo)); + h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo)); + h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi)); + h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi)); + + c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo)); + c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo)); + c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi)); + c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi)); + + src_ptr += 8; + dat_ptr += 8; + flt1_ptr += 8; + w -= 8; + } while (w != 0); + + src8 += src_stride; + dat8 += dat_stride; + flt1 += flt1_stride; + } while (--height != 0); + + H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size; + C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size; +} + +// The function calls 3 subfunctions for the following cases : +// 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements +// of C and H need to be computed. +// 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are +// non-zero and need to be computed. +// 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are +// non-zero and need to be computed. +void av1_calc_proj_params_neon(const uint8_t *src8, int width, int height, + int src_stride, const uint8_t *dat8, + int dat_stride, int32_t *flt0, int flt0_stride, + int32_t *flt1, int flt1_stride, int64_t H[2][2], + int64_t C[2], const sgr_params_type *params) { + if ((params->r[0] > 0) && (params->r[1] > 0)) { + calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8, + dat_stride, flt0, flt0_stride, flt1, + flt1_stride, H, C); + } else if (params->r[0] > 0) { + calc_proj_params_r0_neon(src8, width, height, src_stride, dat8, dat_stride, + flt0, flt0_stride, H, C); + } else if (params->r[1] > 0) { + calc_proj_params_r1_neon(src8, width, height, src_stride, dat8, dat_stride, + flt1, flt1_stride, H, C); + } +} diff --git a/third_party/aom/av1/encoder/arm/neon/pickrst_neon.h b/third_party/aom/av1/encoder/arm/neon/pickrst_neon.h new file mode 100644 index 0000000000..7b72dca34d --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/pickrst_neon.h @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_AV1_ENCODER_ARM_NEON_PICKRST_NEON_H_ +#define AOM_AV1_ENCODER_ARM_NEON_PICKRST_NEON_H_ + +#include <arm_neon.h> + +#include "av1/common/restoration.h" + +// Aligned sizes for Wiener filters. +#define WIENER_WIN2_ALIGN2 ALIGN_POWER_OF_TWO(WIENER_WIN2, 2) +#define WIENER_WIN2_ALIGN3 ALIGN_POWER_OF_TWO(WIENER_WIN2, 3) +#define WIENER_WIN2_REDUCED ((WIENER_WIN_REDUCED) * (WIENER_WIN_REDUCED)) +#define WIENER_WIN2_REDUCED_ALIGN2 ALIGN_POWER_OF_TWO(WIENER_WIN2_REDUCED, 2) +#define WIENER_WIN2_REDUCED_ALIGN3 ALIGN_POWER_OF_TWO(WIENER_WIN2_REDUCED, 3) + +// Compute 8 values of M (cross correlation) for a single source pixel and +// accumulate. +static INLINE void update_M_1pixel(int32_t *M_s32, int16x4_t src_avg, + int16x8_t dgd_avg) { + int32x4_t lo = vld1q_s32(M_s32 + 0); + int32x4_t hi = vld1q_s32(M_s32 + 4); + + lo = vmlal_s16(lo, vget_low_s16(dgd_avg), src_avg); + hi = vmlal_s16(hi, vget_high_s16(dgd_avg), src_avg); + + vst1q_s32(M_s32 + 0, lo); + vst1q_s32(M_s32 + 4, hi); +} + +// Compute 8 values of M (cross correlation) for two source pixels and +// accumulate. +static INLINE void update_M_2pixels(int32_t *M_s32, int16x4_t src_avg0, + int16x4_t src_avg1, int16x8_t dgd_avg0, + int16x8_t dgd_avg1) { + int32x4_t lo = vld1q_s32(M_s32 + 0); + int32x4_t hi = vld1q_s32(M_s32 + 4); + + lo = vmlal_s16(lo, vget_low_s16(dgd_avg0), src_avg0); + hi = vmlal_s16(hi, vget_high_s16(dgd_avg0), src_avg0); + lo = vmlal_s16(lo, vget_low_s16(dgd_avg1), src_avg1); + hi = vmlal_s16(hi, vget_high_s16(dgd_avg1), src_avg1); + + vst1q_s32(M_s32 + 0, lo); + vst1q_s32(M_s32 + 4, hi); +} + +static INLINE void update_H_1pixel(int32_t *H_s32, const int16_t *dgd_avg, + int width, int height) { + for (int i = 0; i < height; i += 4) { + int16x4_t di = vld1_s16(dgd_avg + i); + + for (int j = i; j < width; j += 4) { + int16x4_t dj = vld1_s16(dgd_avg + j); + int32x4_t h0 = vld1q_s32(H_s32 + 0 * width + j); + int32x4_t h1 = vld1q_s32(H_s32 + 1 * width + j); + int32x4_t h2 = vld1q_s32(H_s32 + 2 * width + j); + int32x4_t h3 = vld1q_s32(H_s32 + 3 * width + j); + + h0 = vmlal_lane_s16(h0, dj, di, 0); + h1 = vmlal_lane_s16(h1, dj, di, 1); + h2 = vmlal_lane_s16(h2, dj, di, 2); + h3 = vmlal_lane_s16(h3, dj, di, 3); + + vst1q_s32(H_s32 + 0 * width + j, h0); + vst1q_s32(H_s32 + 1 * width + j, h1); + vst1q_s32(H_s32 + 2 * width + j, h2); + vst1q_s32(H_s32 + 3 * width + j, h3); + } + H_s32 += 4 * width; + } +} + +static INLINE void update_H_5x5_2pixels(int32_t *H_s32, const int16_t *dgd_avg0, + const int16_t *dgd_avg1) { + for (int i = 0; i < 24; i += 4) { + int16x4_t di0 = vld1_s16(dgd_avg0 + i); + int16x4_t di1 = vld1_s16(dgd_avg1 + i); + + for (int j = i + 0; j < WIENER_WIN2_REDUCED_ALIGN2; j += 4) { + int16x4_t dj0 = vld1_s16(dgd_avg0 + j); + int16x4_t dj1 = vld1_s16(dgd_avg1 + j); + int32x4_t h0 = vld1q_s32(H_s32 + 0 * WIENER_WIN2_REDUCED_ALIGN2 + j); + int32x4_t h1 = vld1q_s32(H_s32 + 1 * WIENER_WIN2_REDUCED_ALIGN2 + j); + int32x4_t h2 = vld1q_s32(H_s32 + 2 * WIENER_WIN2_REDUCED_ALIGN2 + j); + int32x4_t h3 = vld1q_s32(H_s32 + 3 * WIENER_WIN2_REDUCED_ALIGN2 + j); + + h0 = vmlal_lane_s16(h0, dj0, di0, 0); + h0 = vmlal_lane_s16(h0, dj1, di1, 0); + h1 = vmlal_lane_s16(h1, dj0, di0, 1); + h1 = vmlal_lane_s16(h1, dj1, di1, 1); + h2 = vmlal_lane_s16(h2, dj0, di0, 2); + h2 = vmlal_lane_s16(h2, dj1, di1, 2); + h3 = vmlal_lane_s16(h3, dj0, di0, 3); + h3 = vmlal_lane_s16(h3, dj1, di1, 3); + + vst1q_s32(H_s32 + 0 * WIENER_WIN2_REDUCED_ALIGN2 + j, h0); + vst1q_s32(H_s32 + 1 * WIENER_WIN2_REDUCED_ALIGN2 + j, h1); + vst1q_s32(H_s32 + 2 * WIENER_WIN2_REDUCED_ALIGN2 + j, h2); + vst1q_s32(H_s32 + 3 * WIENER_WIN2_REDUCED_ALIGN2 + j, h3); + } + H_s32 += 4 * WIENER_WIN2_REDUCED_ALIGN2; + } +} + +static INLINE void update_H_7x7_2pixels(int32_t *H_s32, const int16_t *dgd_avg0, + const int16_t *dgd_avg1) { + for (int i = 0; i < 48; i += 4) { + int16x4_t di0 = vld1_s16(dgd_avg0 + i); + int16x4_t di1 = vld1_s16(dgd_avg1 + i); + + int32x4_t h0 = vld1q_s32(H_s32 + 0 * WIENER_WIN2_ALIGN2 + i); + int32x4_t h1 = vld1q_s32(H_s32 + 1 * WIENER_WIN2_ALIGN2 + i); + int32x4_t h2 = vld1q_s32(H_s32 + 2 * WIENER_WIN2_ALIGN2 + i); + int32x4_t h3 = vld1q_s32(H_s32 + 3 * WIENER_WIN2_ALIGN2 + i); + + h0 = vmlal_lane_s16(h0, di0, di0, 0); + h0 = vmlal_lane_s16(h0, di1, di1, 0); + h1 = vmlal_lane_s16(h1, di0, di0, 1); + h1 = vmlal_lane_s16(h1, di1, di1, 1); + h2 = vmlal_lane_s16(h2, di0, di0, 2); + h2 = vmlal_lane_s16(h2, di1, di1, 2); + h3 = vmlal_lane_s16(h3, di0, di0, 3); + h3 = vmlal_lane_s16(h3, di1, di1, 3); + + vst1q_s32(H_s32 + 0 * WIENER_WIN2_ALIGN2 + i, h0); + vst1q_s32(H_s32 + 1 * WIENER_WIN2_ALIGN2 + i, h1); + vst1q_s32(H_s32 + 2 * WIENER_WIN2_ALIGN2 + i, h2); + vst1q_s32(H_s32 + 3 * WIENER_WIN2_ALIGN2 + i, h3); + + for (int j = i + 4; j < WIENER_WIN2_ALIGN2; j += 4) { + int16x4_t dj0 = vld1_s16(dgd_avg0 + j); + int16x4_t dj1 = vld1_s16(dgd_avg1 + j); + h0 = vld1q_s32(H_s32 + 0 * WIENER_WIN2_ALIGN2 + j); + h1 = vld1q_s32(H_s32 + 1 * WIENER_WIN2_ALIGN2 + j); + h2 = vld1q_s32(H_s32 + 2 * WIENER_WIN2_ALIGN2 + j); + h3 = vld1q_s32(H_s32 + 3 * WIENER_WIN2_ALIGN2 + j); + + h0 = vmlal_lane_s16(h0, dj0, di0, 0); + h0 = vmlal_lane_s16(h0, dj1, di1, 0); + h1 = vmlal_lane_s16(h1, dj0, di0, 1); + h1 = vmlal_lane_s16(h1, dj1, di1, 1); + h2 = vmlal_lane_s16(h2, dj0, di0, 2); + h2 = vmlal_lane_s16(h2, dj1, di1, 2); + h3 = vmlal_lane_s16(h3, dj0, di0, 3); + h3 = vmlal_lane_s16(h3, dj1, di1, 3); + + vst1q_s32(H_s32 + 0 * WIENER_WIN2_ALIGN2 + j, h0); + vst1q_s32(H_s32 + 1 * WIENER_WIN2_ALIGN2 + j, h1); + vst1q_s32(H_s32 + 2 * WIENER_WIN2_ALIGN2 + j, h2); + vst1q_s32(H_s32 + 3 * WIENER_WIN2_ALIGN2 + j, h3); + } + H_s32 += 4 * WIENER_WIN2_ALIGN2; + } +} + +// Widen 32-bit src data and accumulate into 64-bit dst. Clear src data. +static INLINE void accumulate_and_clear(int64_t *dst, int32_t *src, + int length) { + do { + int32x4_t s32 = vld1q_s32(src); + vst1q_s32(src, vdupq_n_s32(0)); + src += 4; + + int64x2_t d_lo = vld1q_s64(dst + 0); + int64x2_t d_hi = vld1q_s64(dst + 2); + + d_lo = vaddw_s32(d_lo, vget_low_s32(s32)); + d_hi = vaddw_s32(d_hi, vget_high_s32(s32)); + + vst1q_s64(dst + 0, d_lo); + vst1q_s64(dst + 2, d_hi); + + dst += 4; + length -= 4; + } while (length > 0); +} + +#endif // AOM_AV1_ENCODER_ARM_NEON_PICKRST_NEON_H_ diff --git a/third_party/aom/av1/encoder/arm/neon/quantize_neon.c b/third_party/aom/av1/encoder/arm/neon/quantize_neon.c new file mode 100644 index 0000000000..c3b57ce206 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/quantize_neon.c @@ -0,0 +1,928 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include <assert.h> +#include <math.h> + +#include "config/aom_config.h" + +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" +#include "aom_mem/aom_mem.h" + +#include "av1/common/quant_common.h" +#include "av1/common/seg_common.h" + +#include "av1/encoder/av1_quantize.h" +#include "av1/encoder/encoder.h" +#include "av1/encoder/rd.h" + +static INLINE uint16_t get_max_eob(int16x8_t v_eobmax) { +#if AOM_ARCH_AARCH64 + return (uint16_t)vmaxvq_s16(v_eobmax); +#else + const int16x4_t v_eobmax_3210 = + vmax_s16(vget_low_s16(v_eobmax), vget_high_s16(v_eobmax)); + const int64x1_t v_eobmax_xx32 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmax_3210), 32); + const int16x4_t v_eobmax_tmp = + vmax_s16(v_eobmax_3210, vreinterpret_s16_s64(v_eobmax_xx32)); + const int64x1_t v_eobmax_xxx3 = + vshr_n_s64(vreinterpret_s64_s16(v_eobmax_tmp), 16); + const int16x4_t v_eobmax_final = + vmax_s16(v_eobmax_tmp, vreinterpret_s16_s64(v_eobmax_xxx3)); + return (uint16_t)vget_lane_s16(v_eobmax_final, 0); +#endif +} + +static INLINE int16x8_t get_max_lane_eob(const int16_t *iscan, + int16x8_t v_eobmax, + uint16x8_t v_mask) { + const int16x8_t v_iscan = vld1q_s16(&iscan[0]); + const int16x8_t v_iscan_plus1 = vaddq_s16(v_iscan, vdupq_n_s16(1)); + const int16x8_t v_nz_iscan = vbslq_s16(v_mask, v_iscan_plus1, vdupq_n_s16(0)); + return vmaxq_s16(v_eobmax, v_nz_iscan); +} + +static INLINE uint16x8_t quantize_fp_8(const tran_low_t *coeff_ptr, + tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, + int16x8_t v_quant, int16x8_t v_dequant, + int16x8_t v_round, int16x8_t v_zero) { + const int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]); + const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + const int16x8_t v_abs = vabsq_s16(v_coeff); + const int16x8_t v_tmp = vqaddq_s16(v_abs, v_round); + const int16x8_t v_tmp2 = vshrq_n_s16(vqdmulhq_s16(v_tmp, v_quant), 1); + const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero); + const int16x8_t v_qcoeff_a = veorq_s16(v_tmp2, v_coeff_sign); + const int16x8_t v_qcoeff = vsubq_s16(v_qcoeff_a, v_coeff_sign); + const int16x8_t v_dqcoeff = vmulq_s16(v_qcoeff, v_dequant); + store_s16q_to_tran_low(&qcoeff_ptr[0], v_qcoeff); + store_s16q_to_tran_low(&dqcoeff_ptr[0], v_dqcoeff); + return v_nz_mask; +} + +void av1_quantize_fp_neon(const tran_low_t *coeff_ptr, intptr_t count, + const int16_t *zbin_ptr, const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + // TODO(jingning) Decide the need of these arguments after the + // quantization process is completed. + (void)zbin_ptr; + (void)quant_shift_ptr; + (void)scan; + + // Quantization pass: All coefficients with index >= zero_flag are + // skippable. Note: zero_flag can be zero. + const int16x8_t v_zero = vdupq_n_s16(0); + int16x8_t v_quant = vld1q_s16(quant_ptr); + int16x8_t v_dequant = vld1q_s16(dequant_ptr); + int16x8_t v_round = vld1q_s16(round_ptr); + int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1); + uint16x8_t v_nz_mask; + // process dc and the first seven ac coeffs + v_nz_mask = quantize_fp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant, + v_dequant, v_round, v_zero); + v_eobmax_76543210 = get_max_lane_eob(&iscan[0], v_eobmax_76543210, v_nz_mask); + // overwrite the dc constants with ac constants + v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1); + v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1); + v_round = vdupq_lane_s16(vget_low_s16(v_round), 1); + + count -= 8; + // now process the rest of the ac coeffs + do { + coeff_ptr += 8; + qcoeff_ptr += 8; + dqcoeff_ptr += 8; + iscan += 8; + v_nz_mask = quantize_fp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant, + v_dequant, v_round, v_zero); + v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask); + count -= 8; + } while (count > 0); + *eob_ptr = get_max_eob(v_eobmax_76543210); +} + +static INLINE uint16x8_t quantize_lp_8(const int16_t *coeff_ptr, + int16_t *qcoeff_ptr, + int16_t *dqcoeff_ptr, int16x8_t v_quant, + int16x8_t v_dequant, int16x8_t v_round, + int16x8_t v_zero) { + const int16x8_t v_coeff = vld1q_s16(&coeff_ptr[0]); + const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + const int16x8_t v_abs = vabsq_s16(v_coeff); + const int16x8_t v_tmp = vqaddq_s16(v_abs, v_round); + const int16x8_t v_tmp2 = vshrq_n_s16(vqdmulhq_s16(v_tmp, v_quant), 1); + const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero); + const int16x8_t v_qcoeff_a = veorq_s16(v_tmp2, v_coeff_sign); + const int16x8_t v_qcoeff = vsubq_s16(v_qcoeff_a, v_coeff_sign); + const int16x8_t v_dqcoeff = vmulq_s16(v_qcoeff, v_dequant); + vst1q_s16(qcoeff_ptr, v_qcoeff); + vst1q_s16(dqcoeff_ptr, v_dqcoeff); + return v_nz_mask; +} + +void av1_quantize_lp_neon(const int16_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *round_ptr, const int16_t *quant_ptr, + int16_t *qcoeff_ptr, int16_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + (void)scan; + // Quantization pass: All coefficients with index >= zero_flag are + // skippable. Note: zero_flag can be zero. + const int16x8_t v_zero = vdupq_n_s16(0); + int16x8_t v_quant = vld1q_s16(quant_ptr); + int16x8_t v_dequant = vld1q_s16(dequant_ptr); + int16x8_t v_round = vld1q_s16(round_ptr); + int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1); + uint16x8_t v_nz_mask; + intptr_t count = n_coeffs; + + // process dc and the first seven ac coeffs + v_nz_mask = quantize_lp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant, + v_dequant, v_round, v_zero); + v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask); + // overwrite the dc constants with ac constants + v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1); + v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1); + v_round = vdupq_lane_s16(vget_low_s16(v_round), 1); + + count -= 8; + // now process the rest of the ac coeffs + do { + coeff_ptr += 8; + qcoeff_ptr += 8; + dqcoeff_ptr += 8; + iscan += 8; + v_nz_mask = quantize_lp_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant, + v_dequant, v_round, v_zero); + v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask); + count -= 8; + } while (count != 0); + *eob_ptr = get_max_eob(v_eobmax_76543210); +} + +static AOM_FORCE_INLINE uint16x8_t quantize_fp_logscale_8( + const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, int16x8_t v_quant, int16x8_t v_dequant, + int16x8_t v_round, int16x8_t v_zero, int log_scale) { + const int16x8_t v_log_scale_minus_1 = vdupq_n_s16(log_scale - 1); + const int16x8_t v_neg_log_scale_plus_1 = vdupq_n_s16(-(1 + log_scale)); + const int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr); + const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + const int16x8_t v_abs_coeff = vabsq_s16(v_coeff); + const uint16x8_t v_mask = + vcgeq_s16(v_abs_coeff, vshlq_s16(v_dequant, v_neg_log_scale_plus_1)); + // const int64_t tmp = vmask ? (int64_t)abs_coeff + log_scaled_round : 0 + const int16x8_t v_tmp = vandq_s16(vqaddq_s16(v_abs_coeff, v_round), + vreinterpretq_s16_u16(v_mask)); + const int16x8_t v_tmp2 = + vqdmulhq_s16(vshlq_s16(v_tmp, v_log_scale_minus_1), v_quant); + const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero); + const int16x8_t v_qcoeff = + vsubq_s16(veorq_s16(v_tmp2, v_coeff_sign), v_coeff_sign); + // Multiplying by dequant here will use all 16 bits. Cast to unsigned before + // shifting right. (vshlq_s16 will shift right if shift value is negative) + const uint16x8_t v_abs_dqcoeff = + vshlq_u16(vreinterpretq_u16_s16(vmulq_s16(v_tmp2, v_dequant)), + vdupq_n_s16(-log_scale)); + const int16x8_t v_dqcoeff = + vsubq_s16(veorq_s16(vreinterpretq_s16_u16(v_abs_dqcoeff), v_coeff_sign), + v_coeff_sign); + store_s16q_to_tran_low(qcoeff_ptr, v_qcoeff); + store_s16q_to_tran_low(dqcoeff_ptr, v_dqcoeff); + return v_nz_mask; +} + +static AOM_FORCE_INLINE uint16x8_t quantize_fp_logscale2_8( + const tran_low_t *coeff_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, int16x8_t v_quant, int16x8_t v_dequant, + int16x8_t v_round, int16x8_t v_zero) { + const int16x8_t v_coeff = load_tran_low_to_s16q(coeff_ptr); + const int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + const int16x8_t v_abs_coeff = vabsq_s16(v_coeff); + const uint16x8_t v_mask = + vcgeq_u16(vshlq_n_u16(vreinterpretq_u16_s16(v_abs_coeff), 1), + vshrq_n_u16(vreinterpretq_u16_s16(v_dequant), 2)); + // abs_coeff = vmask ? (int64_t)abs_coeff + log_scaled_round : 0 + const int16x8_t v_tmp = vandq_s16(vqaddq_s16(v_abs_coeff, v_round), + vreinterpretq_s16_u16(v_mask)); + // tmp32 = (int)((abs_coeff * quant_ptr[rc != 0]) >> (16 - log_scale)); + const int16x8_t v_tmp2 = + vorrq_s16(vshlq_n_s16(vqdmulhq_s16(v_tmp, v_quant), 1), + vreinterpretq_s16_u16(vshrq_n_u16( + vreinterpretq_u16_s16(vmulq_s16(v_tmp, v_quant)), 14))); + const uint16x8_t v_nz_mask = vcgtq_s16(v_tmp2, v_zero); + const int16x8_t v_qcoeff = + vsubq_s16(veorq_s16(v_tmp2, v_coeff_sign), v_coeff_sign); + // const tran_low_t abs_dqcoeff = (tmp32 * dequant_ptr[rc != 0]) >> log_scale; + const int16x8_t v_abs_dqcoeff = + vorrq_s16(vshlq_n_s16(vqdmulhq_s16(v_tmp2, v_dequant), 13), + vreinterpretq_s16_u16(vshrq_n_u16( + vreinterpretq_u16_s16(vmulq_s16(v_tmp2, v_dequant)), 2))); + const int16x8_t v_dqcoeff = + vsubq_s16(veorq_s16(v_abs_dqcoeff, v_coeff_sign), v_coeff_sign); + store_s16q_to_tran_low(qcoeff_ptr, v_qcoeff); + store_s16q_to_tran_low(dqcoeff_ptr, v_dqcoeff); + return v_nz_mask; +} + +static AOM_FORCE_INLINE void quantize_fp_no_qmatrix_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *round_ptr, + const int16_t *quant_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *iscan, + int log_scale) { + const int16x8_t v_zero = vdupq_n_s16(0); + int16x8_t v_quant = vld1q_s16(quant_ptr); + int16x8_t v_dequant = vld1q_s16(dequant_ptr); + const int16x8_t v_round_no_scale = vld1q_s16(round_ptr); + int16x8_t v_round = + vqrdmulhq_n_s16(v_round_no_scale, (int16_t)(1 << (15 - log_scale))); + int16x8_t v_eobmax_76543210 = vdupq_n_s16(-1); + intptr_t non_zero_count = n_coeffs; + + assert(n_coeffs > 16); + // Pre-scan pass + const int16x8_t v_dequant_scaled = + vshlq_s16(v_dequant, vdupq_n_s16(-(1 + log_scale))); + const int16x8_t v_zbin_s16 = + vdupq_lane_s16(vget_low_s16(v_dequant_scaled), 1); + intptr_t i = n_coeffs; + do { + const int16x8_t v_coeff_a = load_tran_low_to_s16q(coeff_ptr + i - 8); + const int16x8_t v_coeff_b = load_tran_low_to_s16q(coeff_ptr + i - 16); + const int16x8_t v_abs_coeff_a = vabsq_s16(v_coeff_a); + const int16x8_t v_abs_coeff_b = vabsq_s16(v_coeff_b); + const uint16x8_t v_mask_a = vcgeq_s16(v_abs_coeff_a, v_zbin_s16); + const uint16x8_t v_mask_b = vcgeq_s16(v_abs_coeff_b, v_zbin_s16); + // If the coefficient is in the base ZBIN range, then discard. + if (horizontal_long_add_u16x8(v_mask_a, v_mask_b) == 0) { + non_zero_count -= 16; + } else { + break; + } + i -= 16; + } while (i > 0); + + const intptr_t remaining_zcoeffs = n_coeffs - non_zero_count; + memset(qcoeff_ptr + non_zero_count, 0, + remaining_zcoeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr + non_zero_count, 0, + remaining_zcoeffs * sizeof(*dqcoeff_ptr)); + + // process dc and the first seven ac coeffs + uint16x8_t v_nz_mask; + if (log_scale == 2) { + v_nz_mask = quantize_fp_logscale2_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, + v_quant, v_dequant, v_round, v_zero); + } else { + v_nz_mask = + quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant, + v_dequant, v_round, v_zero, log_scale); + } + v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask); + // overwrite the dc constants with ac constants + v_quant = vdupq_lane_s16(vget_low_s16(v_quant), 1); + v_dequant = vdupq_lane_s16(vget_low_s16(v_dequant), 1); + v_round = vdupq_lane_s16(vget_low_s16(v_round), 1); + + for (intptr_t count = non_zero_count - 8; count > 0; count -= 8) { + coeff_ptr += 8; + qcoeff_ptr += 8; + dqcoeff_ptr += 8; + iscan += 8; + if (log_scale == 2) { + v_nz_mask = quantize_fp_logscale2_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, + v_quant, v_dequant, v_round, v_zero); + } else { + v_nz_mask = + quantize_fp_logscale_8(coeff_ptr, qcoeff_ptr, dqcoeff_ptr, v_quant, + v_dequant, v_round, v_zero, log_scale); + } + v_eobmax_76543210 = get_max_lane_eob(iscan, v_eobmax_76543210, v_nz_mask); + } + *eob_ptr = get_max_eob(v_eobmax_76543210); +} + +void av1_quantize_fp_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *zbin_ptr, + const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + (void)zbin_ptr; + (void)quant_shift_ptr; + (void)scan; + quantize_fp_no_qmatrix_neon(coeff_ptr, n_coeffs, round_ptr, quant_ptr, + qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr, + iscan, 1); +} + +void av1_quantize_fp_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *zbin_ptr, + const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + (void)zbin_ptr; + (void)quant_shift_ptr; + (void)scan; + quantize_fp_no_qmatrix_neon(coeff_ptr, n_coeffs, round_ptr, quant_ptr, + qcoeff_ptr, dqcoeff_ptr, dequant_ptr, eob_ptr, + iscan, 2); +} + +void aom_quantize_b_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *zbin_ptr, const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, + uint16_t *eob_ptr, const int16_t *scan, + const int16_t *iscan) { + (void)quant_shift_ptr; + (void)scan; + + const int zbins[2] = { zbin_ptr[0], zbin_ptr[1] }; + + memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr)); + + const int16x8_t zero = vdupq_n_s16(0); + int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero)); + + int16x8_t vzbins = vdupq_n_s16(zbins[1]), vround = vdupq_n_s16(round_ptr[1]); + int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t vquant = vdupq_n_s16(quant_ptr[1]); + int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]); + + int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]); + int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + int16x8_t v_abs = vabsq_s16(v_coeff); + + vzbins = vsetq_lane_s16(zbins[0], vzbins, 0); + + uint16x8_t vcond = vcgeq_s16(v_abs, vzbins); + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + vround = vsetq_lane_s16(round_ptr[0], vround, 0); + vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0); + + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + int16x8_t vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1); + + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0])); + store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask); + int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant); + + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0])); + store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask); + + vround = vsetq_lane_s16(round_ptr[1], vround, 0); + vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[0]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + vzbins = vsetq_lane_s16(zbins[1], vzbins, 0); + + for (int i = 8; i < n_coeffs; i += 8) { + v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]); + v_coeff_sign = vshrq_n_s16(v_coeff, 15); + v_abs = vabsq_s16(v_coeff); + vcond = vcgeq_s16(v_abs, vzbins); + + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + int16x8_t vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + + vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1); + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i])); + store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask); + int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i])); + store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[i]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + } + *eob_ptr = get_max_eob(v_eobmax_76543210) + 1; +} + +#define QM_MULL_SHIFT(x0, x1) \ + vreinterpretq_s16_u16(vorrq_u16( \ + vreinterpretq_u16_s16(vshlq_n_s16( \ + vqdmulhq_s16(x0, vreinterpretq_s16_u16(x1)), 15 - AOM_QM_BITS)), \ + vshrq_n_u16(vmulq_u16(vreinterpretq_u16_s16(x0), x1), AOM_QM_BITS))) + +static void aom_quantize_b_helper_16x16_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr, + const qm_val_t *iqm_ptr) { + (void)scan; + + uint16x8_t vwt, viwt; + const int zbins[2] = { zbin_ptr[0], zbin_ptr[1] }; + + memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr)); + + const int16x8_t zero = vdupq_n_s16(0); + int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero)); + + int16x8_t vzbins = vdupq_n_s16(zbins[1]), vround = vdupq_n_s16(round_ptr[1]); + int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t vquant = vdupq_n_s16(quant_ptr[1]); + int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]); + + int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]); + int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + int16x8_t v_abs = vabsq_s16(v_coeff); + vzbins = vsetq_lane_s16(zbins[0], vzbins, 0); + uint16x8_t vcond; + if (qm_ptr == NULL) { + vcond = vcgeq_s16(v_abs, vzbins); + } else { + vwt = vmovl_u8(vld1_u8(&qm_ptr[0])); + vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins); + } + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + vround = vsetq_lane_s16(round_ptr[0], vround, 0); + vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0); + + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + + int16x8_t vtmp2; + if (qm_ptr == NULL) { + vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + } else { + vtmp2 = QM_MULL_SHIFT(vtmp, vwt); + vtmp2 = vaddq_s16(vtmp2, vtmp); + } + + vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1); + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0])); + store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask); + + if (iqm_ptr != NULL) { + viwt = vmovl_u8(vld1_u8(&iqm_ptr[0])); + vdequant = QM_MULL_SHIFT(vdequant, viwt); + } + int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0])); + store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask); + + vround = vsetq_lane_s16(round_ptr[1], vround, 0); + vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[0]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + vzbins = vsetq_lane_s16(zbins[1], vzbins, 0); + + for (int i = 8; i < n_coeffs; i += 8) { + v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]); + v_coeff_sign = vshrq_n_s16(v_coeff, 15); + v_abs = vabsq_s16(v_coeff); + + if (qm_ptr == NULL) { + vcond = vcgeq_s16(v_abs, vzbins); + } else { + vwt = vmovl_u8(vld1_u8(&qm_ptr[i])); + vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins); + } + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + + int16x8_t vtmp2; + if (qm_ptr == NULL) { + vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + } else { + vtmp2 = QM_MULL_SHIFT(vtmp, vwt); + vtmp2 = vaddq_s16(vtmp2, vtmp); + } + + vtmp2 = vshrq_n_s16(vqdmulhq_s16(vtmp2, vquant_shift), 1); + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i])); + store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask); + + if (iqm_ptr != NULL) { + viwt = vmovl_u8(vld1_u8(&iqm_ptr[i])); + vdequant = QM_MULL_SHIFT(vdequant, viwt); + } + int16x8_t v_deq_abs = vmulq_s16(vtmp2, vdequant); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i])); + store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[i]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + } + *eob_ptr = get_max_eob(v_eobmax_76543210) + 1; +} + +static void aom_quantize_b_helper_32x32_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr, + const qm_val_t *iqm_ptr) { + (void)scan; + + uint16x8_t vwt, viwt; + const int log_scale = 1; + const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale), + ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) }; + + memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr)); + + const int16x8_t zero = vdupq_n_s16(0); + int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero)); + const int16x8_t v_log_scale = v_eobmax_76543210; + + int16x8_t vzbins = vdupq_n_s16(zbins[1]), + vround = vdupq_n_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale)); + int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t vquant = vdupq_n_s16(quant_ptr[1]); + int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]); + + int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]); + int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + int16x8_t v_abs = vabsq_s16(v_coeff); + vzbins = vsetq_lane_s16(zbins[0], vzbins, 0); + uint16x8_t vcond; + if (qm_ptr == NULL) { + vcond = vcgeq_s16(v_abs, vzbins); + } else { + vwt = vmovl_u8(vld1_u8(&qm_ptr[0])); + vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins); + } + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + vround = + vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[0], log_scale), vround, 0); + vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0); + + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + + int16x8_t vtmp2; + if (qm_ptr == NULL) { + vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + } else { + vtmp2 = QM_MULL_SHIFT(vtmp, vwt); + vtmp2 = vaddq_s16(vtmp2, vtmp); + } + + vtmp2 = vqdmulhq_s16(vtmp2, vquant_shift); + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0])); + store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask); + + if (iqm_ptr != NULL) { + viwt = vmovl_u8(vld1_u8(&iqm_ptr[0])); + vdequant = QM_MULL_SHIFT(vdequant, viwt); + } + int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16( + vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale)); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0])); + store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask); + + vzbins = vsetq_lane_s16(zbins[1], vzbins, 0); + vround = + vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale), vround, 0); + vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[0]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + vzbins = vsetq_lane_s16(zbins[1], vzbins, 0); + + for (int i = 8; i < n_coeffs; i += 8) { + v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]); + v_coeff_sign = vshrq_n_s16(v_coeff, 15); + v_abs = vabsq_s16(v_coeff); + + if (qm_ptr == NULL) { + vcond = vcgeq_s16(v_abs, vzbins); + } else { + vwt = vmovl_u8(vld1_u8(&qm_ptr[i])); + vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins); + } + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + + int16x8_t vtmp2; + if (qm_ptr == NULL) { + vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + } else { + vtmp2 = QM_MULL_SHIFT(vtmp, vwt); + vtmp2 = vaddq_s16(vtmp2, vtmp); + } + vtmp2 = vqdmulhq_s16(vtmp2, vquant_shift); + + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i])); + store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask); + + if (iqm_ptr != NULL) { + viwt = vmovl_u8(vld1_u8(&iqm_ptr[i])); + vdequant = QM_MULL_SHIFT(vdequant, viwt); + } + int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16( + vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale)); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i])); + store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[i]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + } + *eob_ptr = get_max_eob(v_eobmax_76543210) + 1; +} + +static void aom_quantize_b_helper_64x64_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr, + const qm_val_t *iqm_ptr) { + (void)scan; + + uint16x8_t vwt, viwt; + const int log_scale = 2; + const int16x8_t v_log_scale = + vreinterpretq_s16_s64(vdupq_n_s64(0xFFFEFFFEFFFEFFFE)); + + const int zbins[2] = { ROUND_POWER_OF_TWO(zbin_ptr[0], log_scale), + ROUND_POWER_OF_TWO(zbin_ptr[1], log_scale) }; + + memset(qcoeff_ptr, 0, n_coeffs * sizeof(*qcoeff_ptr)); + memset(dqcoeff_ptr, 0, n_coeffs * sizeof(*dqcoeff_ptr)); + + const int16x8_t zero = vdupq_n_s16(0); + int16x8_t v_eobmax_76543210 = vreinterpretq_s16_u16(vceqq_s16(zero, zero)); + int16x8_t v_ones = vnegq_s16(v_eobmax_76543210); + + int16x8_t vzbins = vdupq_n_s16(zbins[1]), + vround = vdupq_n_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale)); + int16x8_t vdequant = vdupq_n_s16(dequant_ptr[1]); + int16x8_t vquant = vdupq_n_s16(quant_ptr[1]); + int16x8_t vquant_shift = vdupq_n_s16(quant_shift_ptr[1]); + + int16x8_t v_coeff = load_tran_low_to_s16q(&coeff_ptr[0]); + int16x8_t v_coeff_sign = vshrq_n_s16(v_coeff, 15); + int16x8_t v_abs = vabsq_s16(v_coeff); + vzbins = vsetq_lane_s16(zbins[0], vzbins, 0); + uint16x8_t vcond; + if (qm_ptr == NULL) { + vcond = vcgeq_s16(v_abs, vzbins); + } else { + vwt = vmovl_u8(vld1_u8(&qm_ptr[0])); + vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins); + } + uint64_t nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + vround = + vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[0], log_scale), vround, 0); + vquant = vsetq_lane_s16(quant_ptr[0], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[0], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[0], vquant_shift, 0); + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + + int16x8_t vtmp2; + if (qm_ptr == NULL) { + vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + } else { + vtmp2 = QM_MULL_SHIFT(vtmp, vwt); + vtmp2 = vaddq_s16(vtmp2, vtmp); + } + + int16x8_t ones = + vandq_s16(vshrq_n_s16(vmulq_s16(vtmp2, vquant_shift), 14), v_ones); + vtmp2 = + vaddq_s16(vshlq_s16(vqdmulhq_s16(vtmp2, vquant_shift), v_ones), ones); + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[0])); + store_s16q_to_tran_low(&qcoeff_ptr[0], coeff_nz_mask); + + if (iqm_ptr != NULL) { + viwt = vmovl_u8(vld1_u8(&iqm_ptr[0])); + vdequant = QM_MULL_SHIFT(vdequant, viwt); + } + int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16( + vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale)); + v_deq_abs = + vorrq_s16(vshlq_n_s16(vqdmulhq_s16(vtmp2, vdequant), 13), v_deq_abs); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[0])); + store_s16q_to_tran_low(&dqcoeff_ptr[0], coeff_nz_mask); + + vround = + vsetq_lane_s16(ROUND_POWER_OF_TWO(round_ptr[1], log_scale), vround, 0); + vquant = vsetq_lane_s16(quant_ptr[1], vquant, 0); + vdequant = vsetq_lane_s16(dequant_ptr[1], vdequant, 0); + vquant_shift = vsetq_lane_s16(quant_shift_ptr[1], vquant_shift, 0); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[0]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + vzbins = vsetq_lane_s16(zbins[1], vzbins, 0); + + for (int i = 8; i < n_coeffs; i += 8) { + v_coeff = load_tran_low_to_s16q(&coeff_ptr[i]); + v_coeff_sign = vshrq_n_s16(v_coeff, 15); + v_abs = vabsq_s16(v_coeff); + + if (qm_ptr == NULL) { + vcond = vcgeq_s16(v_abs, vzbins); + } else { + vwt = vmovl_u8(vld1_u8(&qm_ptr[i])); + vcond = vcgeq_s16(QM_MULL_SHIFT(v_abs, vwt), vzbins); + } + nz_check = vget_lane_u64(vreinterpret_u64_u8(vmovn_u16(vcond)), 0); + if (nz_check) { + int16x8_t vtmp = vqaddq_s16(v_abs, vround); + + int16x8_t vtmp2; + if (qm_ptr == NULL) { + vtmp2 = vsraq_n_s16(vtmp, vqdmulhq_s16(vtmp, vquant), 1); + } else { + vtmp2 = QM_MULL_SHIFT(vtmp, vwt); + vtmp2 = vaddq_s16(vtmp2, vtmp); + } + + int16x8_t ones = + vandq_s16(vshrq_n_s16(vmulq_s16(vtmp2, vquant_shift), 14), v_ones); + vtmp2 = + vaddq_s16(vshlq_s16(vqdmulhq_s16(vtmp2, vquant_shift), v_ones), ones); + int16x8_t vdest = vsubq_s16(veorq_s16(vtmp2, v_coeff_sign), v_coeff_sign); + int16x8_t coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&qcoeff_ptr[i])); + store_s16q_to_tran_low(&qcoeff_ptr[i], coeff_nz_mask); + + if (iqm_ptr != NULL) { + viwt = vmovl_u8(vld1_u8(&iqm_ptr[i])); + vdequant = QM_MULL_SHIFT(vdequant, viwt); + } + int16x8_t v_deq_abs = vreinterpretq_s16_u16(vshlq_u16( + vreinterpretq_u16_s16(vmulq_s16(vtmp2, vdequant)), v_log_scale)); + v_deq_abs = + vorrq_s16(vshlq_n_s16(vqdmulhq_s16(vtmp2, vdequant), 13), v_deq_abs); + vdest = vsubq_s16(veorq_s16(v_deq_abs, v_coeff_sign), v_coeff_sign); + coeff_nz_mask = + vbslq_s16(vcond, vdest, load_tran_low_to_s16q(&dqcoeff_ptr[i])); + store_s16q_to_tran_low(&dqcoeff_ptr[i], coeff_nz_mask); + + uint16x8_t vtmp_mask = vcgtq_s16(vtmp2, zero); + const uint16x8_t v_nz_mask = vandq_u16(vtmp_mask, vcond); + int16x8_t v_iscan = vld1q_s16(&iscan[i]); + vcond = vandq_u16(v_nz_mask, vcgtq_s16(v_iscan, v_eobmax_76543210)); + v_eobmax_76543210 = vbslq_s16(vcond, v_iscan, v_eobmax_76543210); + } + } + *eob_ptr = get_max_eob(v_eobmax_76543210) + 1; +} + +void aom_quantize_b_helper_neon( + const tran_low_t *coeff_ptr, intptr_t n_coeffs, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, const qm_val_t *qm_ptr, + const qm_val_t *iqm_ptr, const int log_scale) { + switch (log_scale) { // log_scale for AV1 encoder can be only 0, 1, 2 + case 0: + aom_quantize_b_helper_16x16_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, + quant_ptr, quant_shift_ptr, qcoeff_ptr, + dqcoeff_ptr, dequant_ptr, eob_ptr, scan, + iscan, qm_ptr, iqm_ptr); + break; + case 1: + aom_quantize_b_helper_32x32_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, + quant_ptr, quant_shift_ptr, qcoeff_ptr, + dqcoeff_ptr, dequant_ptr, eob_ptr, scan, + iscan, qm_ptr, iqm_ptr); + break; + case 2: + aom_quantize_b_helper_64x64_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, + quant_ptr, quant_shift_ptr, qcoeff_ptr, + dqcoeff_ptr, dequant_ptr, eob_ptr, scan, + iscan, qm_ptr, iqm_ptr); + break; + } +} + +void aom_quantize_b_32x32_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *zbin_ptr, + const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + aom_quantize_b_helper_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, + quant_ptr, quant_shift_ptr, qcoeff_ptr, + dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, + NULL, NULL, 1); +} + +void aom_quantize_b_64x64_neon(const tran_low_t *coeff_ptr, intptr_t n_coeffs, + const int16_t *zbin_ptr, + const int16_t *round_ptr, + const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, + tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, + const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan) { + aom_quantize_b_helper_neon(coeff_ptr, n_coeffs, zbin_ptr, round_ptr, + quant_ptr, quant_shift_ptr, qcoeff_ptr, + dqcoeff_ptr, dequant_ptr, eob_ptr, scan, iscan, + NULL, NULL, 2); +} diff --git a/third_party/aom/av1/encoder/arm/neon/rdopt_neon.c b/third_party/aom/av1/encoder/arm/neon/rdopt_neon.c new file mode 100644 index 0000000000..7d3bd4c606 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/rdopt_neon.c @@ -0,0 +1,459 @@ +/* + * Copyright (c) 2020, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <assert.h> + +#include <arm_neon.h> + +#include "av1/encoder/rdopt.h" +#include "config/aom_config.h" +#include "config/av1_rtcd.h" + +// Process horizontal and vertical correlations in a 4x4 block of pixels. +// We actually use the 4x4 pixels to calculate correlations corresponding to +// the top-left 3x3 pixels, so this function must be called with 1x1 overlap, +// moving the window along/down by 3 pixels at a time. +INLINE static void horver_correlation_4x4(const int16_t *diff, int stride, + int32x4_t *xy_sum_32, + int32x4_t *xz_sum_32, + int32x4_t *x_sum_32, + int32x4_t *x2_sum_32) { + // Pixels in this 4x4 [ a b c d ] + // are referred to as: [ e f g h ] + // [ i j k l ] + // [ m n o p ] + + const int16x4_t pixelsa_2_lo = vld1_s16(diff + (0 * stride)); + const int16x4_t pixelsa_2_sli = + vreinterpret_s16_s64(vshl_n_s64(vreinterpret_s64_s16(pixelsa_2_lo), 16)); + const int16x4_t pixelsb_2_lo = vld1_s16(diff + (1 * stride)); + const int16x4_t pixelsb_2_sli = + vreinterpret_s16_s64(vshl_n_s64(vreinterpret_s64_s16(pixelsb_2_lo), 16)); + const int16x4_t pixelsa_1_lo = vld1_s16(diff + (2 * stride)); + const int16x4_t pixelsa_1_sli = + vreinterpret_s16_s64(vshl_n_s64(vreinterpret_s64_s16(pixelsa_1_lo), 16)); + const int16x4_t pixelsb_1_lo = vld1_s16(diff + (3 * stride)); + const int16x4_t pixelsb_1_sli = + vreinterpret_s16_s64(vshl_n_s64(vreinterpret_s64_s16(pixelsb_1_lo), 16)); + + const int16x8_t slli_a = vcombine_s16(pixelsa_1_sli, pixelsa_2_sli); + + *xy_sum_32 = vmlal_s16(*xy_sum_32, pixelsa_1_lo, pixelsa_1_sli); + *xy_sum_32 = vmlal_s16(*xy_sum_32, pixelsa_2_lo, pixelsa_2_sli); + *xy_sum_32 = vmlal_s16(*xy_sum_32, pixelsb_2_lo, pixelsb_2_sli); + + *xz_sum_32 = vmlal_s16(*xz_sum_32, pixelsa_1_sli, pixelsb_1_sli); + *xz_sum_32 = vmlal_s16(*xz_sum_32, pixelsa_2_sli, pixelsb_2_sli); + *xz_sum_32 = vmlal_s16(*xz_sum_32, pixelsa_1_sli, pixelsb_2_sli); + + // Now calculate the straight sums, x_sum += a+b+c+e+f+g+i+j+k + // (sum up every element in slli_a and swap_b) + *x_sum_32 = vpadalq_s16(*x_sum_32, slli_a); + *x_sum_32 = vaddw_s16(*x_sum_32, pixelsb_2_sli); + + // Also sum their squares + *x2_sum_32 = vmlal_s16(*x2_sum_32, pixelsa_1_sli, pixelsa_1_sli); + *x2_sum_32 = vmlal_s16(*x2_sum_32, pixelsa_2_sli, pixelsa_2_sli); + *x2_sum_32 = vmlal_s16(*x2_sum_32, pixelsb_2_sli, pixelsb_2_sli); +} + +void av1_get_horver_correlation_full_neon(const int16_t *diff, int stride, + int width, int height, float *hcorr, + float *vcorr) { + // The following notation is used: + // x - current pixel + // y - right neighbour pixel + // z - below neighbour pixel + // w - down-right neighbour pixel + int64_t xy_sum = 0, xz_sum = 0; + int64_t x_sum = 0, x2_sum = 0; + int32x4_t zero = vdupq_n_s32(0); + int64x2_t v_x_sum = vreinterpretq_s64_s32(zero); + int64x2_t v_xy_sum = vreinterpretq_s64_s32(zero); + int64x2_t v_xz_sum = vreinterpretq_s64_s32(zero); + int64x2_t v_x2_sum = vreinterpretq_s64_s32(zero); + // Process horizontal and vertical correlations through the body in 4x4 + // blocks. This excludes the final row and column and possibly one extra + // column depending how 3 divides into width and height + + for (int i = 0; i <= height - 4; i += 3) { + int32x4_t xy_sum_32 = zero; + int32x4_t xz_sum_32 = zero; + int32x4_t x_sum_32 = zero; + int32x4_t x2_sum_32 = zero; + for (int j = 0; j <= width - 4; j += 3) { + horver_correlation_4x4(&diff[i * stride + j], stride, &xy_sum_32, + &xz_sum_32, &x_sum_32, &x2_sum_32); + } + v_xy_sum = vpadalq_s32(v_xy_sum, xy_sum_32); + v_xz_sum = vpadalq_s32(v_xz_sum, xz_sum_32); + v_x_sum = vpadalq_s32(v_x_sum, x_sum_32); + v_x2_sum = vpadalq_s32(v_x2_sum, x2_sum_32); + } +#if AOM_ARCH_AARCH64 + xy_sum = vaddvq_s64(v_xy_sum); + xz_sum = vaddvq_s64(v_xz_sum); + x2_sum = vaddvq_s64(v_x2_sum); + x_sum = vaddvq_s64(v_x_sum); +#else + xy_sum = vget_lane_s64( + vadd_s64(vget_low_s64(v_xy_sum), vget_high_s64(v_xy_sum)), 0); + xz_sum = vget_lane_s64( + vadd_s64(vget_low_s64(v_xz_sum), vget_high_s64(v_xz_sum)), 0); + x2_sum = vget_lane_s64( + vadd_s64(vget_low_s64(v_x2_sum), vget_high_s64(v_x2_sum)), 0); + x_sum = + vget_lane_s64(vadd_s64(vget_low_s64(v_x_sum), vget_high_s64(v_x_sum)), 0); +#endif + // x_sum now covers every pixel except the final 1-2 rows and 1-2 cols + int64_t x_finalrow = 0, x_finalcol = 0, x2_finalrow = 0, x2_finalcol = 0; + + // Do we have 2 rows remaining or just the one? Note that width and height + // are powers of 2, so each modulo 3 must be 1 or 2. + if (height % 3 == 1) { // Just horiz corrs on the final row + const int16_t x0 = diff[(height - 1) * stride]; + x_sum += x0; + x_finalrow += x0; + x2_sum += x0 * x0; + x2_finalrow += x0 * x0; + if (width >= 8) { + int32x4_t v_y_sum = zero; + int32x4_t v_y2_sum = zero; + int32x4_t v_xy_sum_a = zero; + int k = width - 1; + int j = 0; + while ((k - 8) > 0) { + const int16x8_t v_x = vld1q_s16(&diff[(height - 1) * stride + j]); + const int16x8_t v_y = vld1q_s16(&diff[(height - 1) * stride + j + 1]); + const int16x4_t v_x_lo = vget_low_s16(v_x); + const int16x4_t v_x_hi = vget_high_s16(v_x); + const int16x4_t v_y_lo = vget_low_s16(v_y); + const int16x4_t v_y_hi = vget_high_s16(v_y); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_lo, v_y_lo); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_hi, v_y_hi); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_lo, v_y_lo); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_hi, v_y_hi); + v_y_sum = vpadalq_s16(v_y_sum, v_y); + k -= 8; + j += 8; + } + + const int16x8_t v_l = vld1q_s16(&diff[(height - 1) * stride] + j); + const int16x8_t v_x = + vextq_s16(vextq_s16(vreinterpretq_s16_s32(zero), v_l, 7), + vreinterpretq_s16_s32(zero), 1); + const int16x8_t v_y = vextq_s16(v_l, vreinterpretq_s16_s32(zero), 1); + const int16x4_t v_x_lo = vget_low_s16(v_x); + const int16x4_t v_x_hi = vget_high_s16(v_x); + const int16x4_t v_y_lo = vget_low_s16(v_y); + const int16x4_t v_y_hi = vget_high_s16(v_y); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_lo, v_y_lo); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_hi, v_y_hi); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_lo, v_y_lo); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_hi, v_y_hi); + const int32x4_t v_y_sum_a = vpadalq_s16(v_y_sum, v_y); + const int64x2_t v_xy_sum2 = vpaddlq_s32(v_xy_sum_a); +#if AOM_ARCH_AARCH64 + const int64x2_t v_y2_sum_a = vpaddlq_s32(v_y2_sum); + xy_sum += vaddvq_s64(v_xy_sum2); + const int32_t y = vaddvq_s32(v_y_sum_a); + const int64_t y2 = vaddvq_s64(v_y2_sum_a); +#else + xy_sum += vget_lane_s64( + vadd_s64(vget_low_s64(v_xy_sum2), vget_high_s64(v_xy_sum2)), 0); + const int64x2_t v_y_a = vpaddlq_s32(v_y_sum_a); + const int64_t y = + vget_lane_s64(vadd_s64(vget_low_s64(v_y_a), vget_high_s64(v_y_a)), 0); + const int64x2_t v_y2_sum_b = vpaddlq_s32(v_y2_sum); + int64_t y2 = vget_lane_s64( + vadd_s64(vget_low_s64(v_y2_sum_b), vget_high_s64(v_y2_sum_b)), 0); +#endif + x_sum += y; + x2_sum += y2; + x_finalrow += y; + x2_finalrow += y2; + } else { + for (int j = 0; j < width - 1; ++j) { + const int16_t x = diff[(height - 1) * stride + j]; + const int16_t y = diff[(height - 1) * stride + j + 1]; + xy_sum += x * y; + x_sum += y; + x2_sum += y * y; + x_finalrow += y; + x2_finalrow += y * y; + } + } + } else { // Two rows remaining to do + const int16_t x0 = diff[(height - 2) * stride]; + const int16_t z0 = diff[(height - 1) * stride]; + x_sum += x0 + z0; + x2_sum += x0 * x0 + z0 * z0; + x_finalrow += z0; + x2_finalrow += z0 * z0; + if (width >= 8) { + int32x4_t v_y2_sum = zero; + int32x4_t v_w2_sum = zero; + int32x4_t v_xy_sum_a = zero; + int32x4_t v_xz_sum_a = zero; + int32x4_t v_x_sum_a = zero; + int32x4_t v_w_sum = zero; + int k = width - 1; + int j = 0; + while ((k - 8) > 0) { + const int16x8_t v_x = vld1q_s16(&diff[(height - 2) * stride + j]); + const int16x8_t v_y = vld1q_s16(&diff[(height - 2) * stride + j + 1]); + const int16x8_t v_z = vld1q_s16(&diff[(height - 1) * stride + j]); + const int16x8_t v_w = vld1q_s16(&diff[(height - 1) * stride + j + 1]); + + const int16x4_t v_x_lo = vget_low_s16(v_x); + const int16x4_t v_y_lo = vget_low_s16(v_y); + const int16x4_t v_z_lo = vget_low_s16(v_z); + const int16x4_t v_w_lo = vget_low_s16(v_w); + const int16x4_t v_x_hi = vget_high_s16(v_x); + const int16x4_t v_y_hi = vget_high_s16(v_y); + const int16x4_t v_z_hi = vget_high_s16(v_z); + const int16x4_t v_w_hi = vget_high_s16(v_w); + + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_lo, v_y_lo); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_hi, v_y_hi); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_z_lo, v_w_lo); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_z_hi, v_w_hi); + + v_xz_sum_a = vmlal_s16(v_xz_sum_a, v_x_lo, v_z_lo); + v_xz_sum_a = vmlal_s16(v_xz_sum_a, v_x_hi, v_z_hi); + + v_w2_sum = vmlal_s16(v_w2_sum, v_w_lo, v_w_lo); + v_w2_sum = vmlal_s16(v_w2_sum, v_w_hi, v_w_hi); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_lo, v_y_lo); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_hi, v_y_hi); + + v_w_sum = vpadalq_s16(v_w_sum, v_w); + v_x_sum_a = vpadalq_s16(v_x_sum_a, v_y); + v_x_sum_a = vpadalq_s16(v_x_sum_a, v_w); + + k -= 8; + j += 8; + } + const int16x8_t v_l = vld1q_s16(&diff[(height - 2) * stride] + j); + const int16x8_t v_x = + vextq_s16(vextq_s16(vreinterpretq_s16_s32(zero), v_l, 7), + vreinterpretq_s16_s32(zero), 1); + const int16x8_t v_y = vextq_s16(v_l, vreinterpretq_s16_s32(zero), 1); + const int16x8_t v_l_2 = vld1q_s16(&diff[(height - 1) * stride] + j); + const int16x8_t v_z = + vextq_s16(vextq_s16(vreinterpretq_s16_s32(zero), v_l_2, 7), + vreinterpretq_s16_s32(zero), 1); + const int16x8_t v_w = vextq_s16(v_l_2, vreinterpretq_s16_s32(zero), 1); + + const int16x4_t v_x_lo = vget_low_s16(v_x); + const int16x4_t v_y_lo = vget_low_s16(v_y); + const int16x4_t v_z_lo = vget_low_s16(v_z); + const int16x4_t v_w_lo = vget_low_s16(v_w); + const int16x4_t v_x_hi = vget_high_s16(v_x); + const int16x4_t v_y_hi = vget_high_s16(v_y); + const int16x4_t v_z_hi = vget_high_s16(v_z); + const int16x4_t v_w_hi = vget_high_s16(v_w); + + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_lo, v_y_lo); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_x_hi, v_y_hi); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_z_lo, v_w_lo); + v_xy_sum_a = vmlal_s16(v_xy_sum_a, v_z_hi, v_w_hi); + + v_xz_sum_a = vmlal_s16(v_xz_sum_a, v_x_lo, v_z_lo); + v_xz_sum_a = vmlal_s16(v_xz_sum_a, v_x_hi, v_z_hi); + + v_w2_sum = vmlal_s16(v_w2_sum, v_w_lo, v_w_lo); + v_w2_sum = vmlal_s16(v_w2_sum, v_w_hi, v_w_hi); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_lo, v_y_lo); + v_y2_sum = vmlal_s16(v_y2_sum, v_y_hi, v_y_hi); + + v_w_sum = vpadalq_s16(v_w_sum, v_w); + v_x_sum_a = vpadalq_s16(v_x_sum_a, v_y); + v_x_sum_a = vpadalq_s16(v_x_sum_a, v_w); + +#if AOM_ARCH_AARCH64 + xy_sum += vaddvq_s64(vpaddlq_s32(v_xy_sum_a)); + xz_sum += vaddvq_s64(vpaddlq_s32(v_xz_sum_a)); + x_sum += vaddvq_s32(v_x_sum_a); + x_finalrow += vaddvq_s32(v_w_sum); + int64_t y2 = vaddvq_s64(vpaddlq_s32(v_y2_sum)); + int64_t w2 = vaddvq_s64(vpaddlq_s32(v_w2_sum)); +#else + const int64x2_t v_xy_sum2 = vpaddlq_s32(v_xy_sum_a); + xy_sum += vget_lane_s64( + vadd_s64(vget_low_s64(v_xy_sum2), vget_high_s64(v_xy_sum2)), 0); + const int64x2_t v_xz_sum2 = vpaddlq_s32(v_xz_sum_a); + xz_sum += vget_lane_s64( + vadd_s64(vget_low_s64(v_xz_sum2), vget_high_s64(v_xz_sum2)), 0); + const int64x2_t v_x_sum2 = vpaddlq_s32(v_x_sum_a); + x_sum += vget_lane_s64( + vadd_s64(vget_low_s64(v_x_sum2), vget_high_s64(v_x_sum2)), 0); + const int64x2_t v_w_sum_a = vpaddlq_s32(v_w_sum); + x_finalrow += vget_lane_s64( + vadd_s64(vget_low_s64(v_w_sum_a), vget_high_s64(v_w_sum_a)), 0); + const int64x2_t v_y2_sum_a = vpaddlq_s32(v_y2_sum); + int64_t y2 = vget_lane_s64( + vadd_s64(vget_low_s64(v_y2_sum_a), vget_high_s64(v_y2_sum_a)), 0); + const int64x2_t v_w2_sum_a = vpaddlq_s32(v_w2_sum); + int64_t w2 = vget_lane_s64( + vadd_s64(vget_low_s64(v_w2_sum_a), vget_high_s64(v_w2_sum_a)), 0); +#endif + x2_sum += y2 + w2; + x2_finalrow += w2; + } else { + for (int j = 0; j < width - 1; ++j) { + const int16_t x = diff[(height - 2) * stride + j]; + const int16_t y = diff[(height - 2) * stride + j + 1]; + const int16_t z = diff[(height - 1) * stride + j]; + const int16_t w = diff[(height - 1) * stride + j + 1]; + + // Horizontal and vertical correlations for the penultimate row: + xy_sum += x * y; + xz_sum += x * z; + + // Now just horizontal correlations for the final row: + xy_sum += z * w; + + x_sum += y + w; + x2_sum += y * y + w * w; + x_finalrow += w; + x2_finalrow += w * w; + } + } + } + + // Do we have 2 columns remaining or just the one? + if (width % 3 == 1) { // Just vert corrs on the final col + const int16_t x0 = diff[width - 1]; + x_sum += x0; + x_finalcol += x0; + x2_sum += x0 * x0; + x2_finalcol += x0 * x0; + for (int i = 0; i < height - 1; ++i) { + const int16_t x = diff[i * stride + width - 1]; + const int16_t z = diff[(i + 1) * stride + width - 1]; + xz_sum += x * z; + x_finalcol += z; + x2_finalcol += z * z; + // So the bottom-right elements don't get counted twice: + if (i < height - (height % 3 == 1 ? 2 : 3)) { + x_sum += z; + x2_sum += z * z; + } + } + } else { // Two cols remaining + const int16_t x0 = diff[width - 2]; + const int16_t y0 = diff[width - 1]; + x_sum += x0 + y0; + x2_sum += x0 * x0 + y0 * y0; + x_finalcol += y0; + x2_finalcol += y0 * y0; + for (int i = 0; i < height - 1; ++i) { + const int16_t x = diff[i * stride + width - 2]; + const int16_t y = diff[i * stride + width - 1]; + const int16_t z = diff[(i + 1) * stride + width - 2]; + const int16_t w = diff[(i + 1) * stride + width - 1]; + + // Horizontal and vertical correlations for the penultimate col: + // Skip these on the last iteration of this loop if we also had two + // rows remaining, otherwise the final horizontal and vertical correlation + // get erroneously processed twice + if (i < height - 2 || height % 3 == 1) { + xy_sum += x * y; + xz_sum += x * z; + } + + x_finalcol += w; + x2_finalcol += w * w; + // So the bottom-right elements don't get counted twice: + if (i < height - (height % 3 == 1 ? 2 : 3)) { + x_sum += z + w; + x2_sum += z * z + w * w; + } + + // Now just vertical correlations for the final column: + xz_sum += y * w; + } + } + + // Calculate the simple sums and squared-sums + int64_t x_firstrow = 0, x_firstcol = 0; + int64_t x2_firstrow = 0, x2_firstcol = 0; + + if (width >= 8) { + int32x4_t v_x_firstrow = zero; + int32x4_t v_x2_firstrow = zero; + for (int j = 0; j < width; j += 8) { + const int16x8_t v_diff = vld1q_s16(diff + j); + const int16x4_t v_diff_lo = vget_low_s16(v_diff); + const int16x4_t v_diff_hi = vget_high_s16(v_diff); + v_x_firstrow = vpadalq_s16(v_x_firstrow, v_diff); + v_x2_firstrow = vmlal_s16(v_x2_firstrow, v_diff_lo, v_diff_lo); + v_x2_firstrow = vmlal_s16(v_x2_firstrow, v_diff_hi, v_diff_hi); + } +#if AOM_ARCH_AARCH64 + x_firstrow += vaddvq_s32(v_x_firstrow); + x2_firstrow += vaddvq_s32(v_x2_firstrow); +#else + const int64x2_t v_x_firstrow_64 = vpaddlq_s32(v_x_firstrow); + x_firstrow += vget_lane_s64( + vadd_s64(vget_low_s64(v_x_firstrow_64), vget_high_s64(v_x_firstrow_64)), + 0); + const int64x2_t v_x2_firstrow_64 = vpaddlq_s32(v_x2_firstrow); + x2_firstrow += vget_lane_s64(vadd_s64(vget_low_s64(v_x2_firstrow_64), + vget_high_s64(v_x2_firstrow_64)), + 0); +#endif + } else { + for (int j = 0; j < width; ++j) { + x_firstrow += diff[j]; + x2_firstrow += diff[j] * diff[j]; + } + } + for (int i = 0; i < height; ++i) { + x_firstcol += diff[i * stride]; + x2_firstcol += diff[i * stride] * diff[i * stride]; + } + + int64_t xhor_sum = x_sum - x_finalcol; + int64_t xver_sum = x_sum - x_finalrow; + int64_t y_sum = x_sum - x_firstcol; + int64_t z_sum = x_sum - x_firstrow; + int64_t x2hor_sum = x2_sum - x2_finalcol; + int64_t x2ver_sum = x2_sum - x2_finalrow; + int64_t y2_sum = x2_sum - x2_firstcol; + int64_t z2_sum = x2_sum - x2_firstrow; + + const float num_hor = (float)(height * (width - 1)); + const float num_ver = (float)((height - 1) * width); + + const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor; + const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver; + + const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor; + const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver; + + const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor; + const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver; + + if (xhor_var_n > 0 && y_var_n > 0) { + *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n); + *hcorr = *hcorr < 0 ? 0 : *hcorr; + } else { + *hcorr = 1.0; + } + if (xver_var_n > 0 && z_var_n > 0) { + *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n); + *vcorr = *vcorr < 0 ? 0 : *vcorr; + } else { + *vcorr = 1.0; + } +} diff --git a/third_party/aom/av1/encoder/arm/neon/reconinter_enc_neon.c b/third_party/aom/av1/encoder/arm/neon/reconinter_enc_neon.c new file mode 100644 index 0000000000..3d17723224 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/reconinter_enc_neon.c @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" + +#include "aom_dsp/arm/mem_neon.h" + +#include "av1/encoder/reconinter_enc.h" + +void aom_upsampled_pred_neon(MACROBLOCKD *xd, const AV1_COMMON *const cm, + int mi_row, int mi_col, const MV *const mv, + uint8_t *comp_pred, int width, int height, + int subpel_x_q3, int subpel_y_q3, + const uint8_t *ref, int ref_stride, + int subpel_search) { + // expect xd == NULL only in tests + if (xd != NULL) { + const MB_MODE_INFO *mi = xd->mi[0]; + const int ref_num = 0; + const int is_intrabc = is_intrabc_block(mi); + const struct scale_factors *const sf = + is_intrabc ? &cm->sf_identity : xd->block_ref_scale_factors[ref_num]; + const int is_scaled = av1_is_scaled(sf); + + if (is_scaled) { + int plane = 0; + const int mi_x = mi_col * MI_SIZE; + const int mi_y = mi_row * MI_SIZE; + const struct macroblockd_plane *const pd = &xd->plane[plane]; + const struct buf_2d *const dst_buf = &pd->dst; + const struct buf_2d *const pre_buf = + is_intrabc ? dst_buf : &pd->pre[ref_num]; + + InterPredParams inter_pred_params; + inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd); + const int_interpfilters filters = + av1_broadcast_interp_filter(EIGHTTAP_REGULAR); + av1_init_inter_params( + &inter_pred_params, width, height, mi_y >> pd->subsampling_y, + mi_x >> pd->subsampling_x, pd->subsampling_x, pd->subsampling_y, + xd->bd, is_cur_buf_hbd(xd), is_intrabc, sf, pre_buf, filters); + av1_enc_build_one_inter_predictor(comp_pred, width, mv, + &inter_pred_params); + return; + } + } + + const InterpFilterParams *filter_params = av1_get_filter(subpel_search); + + if (!subpel_x_q3 && !subpel_y_q3) { + if (width > 8) { + assert(width % 16 == 0); + int i = height; + do { + int j = 0; + do { + uint8x16_t r = vld1q_u8(ref + j); + vst1q_u8(comp_pred + j, r); + j += 16; + } while (j < width); + ref += ref_stride; + comp_pred += width; + } while (--i != 0); + } else if (width == 8) { + int i = height; + do { + uint8x8_t r = vld1_u8(ref); + vst1_u8(comp_pred, r); + ref += ref_stride; + comp_pred += width; + } while (--i != 0); + } else { + assert(width == 4); + int i = height / 2; + do { + uint8x8_t r = load_unaligned_u8(ref, ref_stride); + vst1_u8(comp_pred, r); + ref += 2 * ref_stride; + comp_pred += 2 * width; + } while (--i != 0); + } + } else if (!subpel_y_q3) { + const int16_t *const filter_x = + av1_get_interp_filter_subpel_kernel(filter_params, subpel_x_q3 << 1); + aom_convolve8_horiz(ref, ref_stride, comp_pred, width, filter_x, 16, NULL, + -1, width, height); + } else if (!subpel_x_q3) { + const int16_t *const filter_y = + av1_get_interp_filter_subpel_kernel(filter_params, subpel_y_q3 << 1); + aom_convolve8_vert(ref, ref_stride, comp_pred, width, NULL, -1, filter_y, + 16, width, height); + } else { + DECLARE_ALIGNED(16, uint8_t, + im_block[((MAX_SB_SIZE * 2 + 16) + 16) * MAX_SB_SIZE]); + + const int16_t *const filter_x = + av1_get_interp_filter_subpel_kernel(filter_params, subpel_x_q3 << 1); + const int16_t *const filter_y = + av1_get_interp_filter_subpel_kernel(filter_params, subpel_y_q3 << 1); + + const int im_stride = MAX_SB_SIZE; + const int im_height = (((height - 1) * 8 + subpel_y_q3) >> 3) + SUBPEL_TAPS; + + const int ref_vert_offset = ref_stride * ((SUBPEL_TAPS >> 1) - 1); + const int im_vert_offset = im_stride * ((filter_params->taps >> 1) - 1); + + assert(im_height <= (MAX_SB_SIZE * 2 + 16) + 16); + aom_convolve8_horiz(ref - ref_vert_offset, ref_stride, im_block, + MAX_SB_SIZE, filter_x, 16, NULL, -1, width, im_height); + aom_convolve8_vert(im_block + im_vert_offset, MAX_SB_SIZE, comp_pred, width, + NULL, -1, filter_y, 16, width, height); + } +} + +void aom_comp_avg_upsampled_pred_neon(MACROBLOCKD *xd, + const AV1_COMMON *const cm, int mi_row, + int mi_col, const MV *const mv, + uint8_t *comp_pred, const uint8_t *pred, + int width, int height, int subpel_x_q3, + int subpel_y_q3, const uint8_t *ref, + int ref_stride, int subpel_search) { + aom_upsampled_pred_neon(xd, cm, mi_row, mi_col, mv, comp_pred, width, height, + subpel_x_q3, subpel_y_q3, ref, ref_stride, + subpel_search); + + aom_comp_avg_pred_neon(comp_pred, pred, width, height, comp_pred, width); +} + +void aom_dist_wtd_comp_avg_upsampled_pred_neon( + MACROBLOCKD *xd, const AV1_COMMON *const cm, int mi_row, int mi_col, + const MV *const mv, uint8_t *comp_pred, const uint8_t *pred, int width, + int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref, + int ref_stride, const DIST_WTD_COMP_PARAMS *jcp_param, int subpel_search) { + aom_upsampled_pred_neon(xd, cm, mi_row, mi_col, mv, comp_pred, width, height, + subpel_x_q3, subpel_y_q3, ref, ref_stride, + subpel_search); + + aom_dist_wtd_comp_avg_pred_neon(comp_pred, pred, width, height, comp_pred, + width, jcp_param); +} + +#if CONFIG_AV1_HIGHBITDEPTH +void aom_highbd_upsampled_pred_neon(MACROBLOCKD *xd, + const struct AV1Common *const cm, + int mi_row, int mi_col, const MV *const mv, + uint8_t *comp_pred8, int width, int height, + int subpel_x_q3, int subpel_y_q3, + const uint8_t *ref8, int ref_stride, int bd, + int subpel_search) { + // expect xd == NULL only in tests + if (xd != NULL) { + const MB_MODE_INFO *mi = xd->mi[0]; + const int ref_num = 0; + const int is_intrabc = is_intrabc_block(mi); + const struct scale_factors *const sf = + is_intrabc ? &cm->sf_identity : xd->block_ref_scale_factors[ref_num]; + const int is_scaled = av1_is_scaled(sf); + + if (is_scaled) { + int plane = 0; + const int mi_x = mi_col * MI_SIZE; + const int mi_y = mi_row * MI_SIZE; + const struct macroblockd_plane *const pd = &xd->plane[plane]; + const struct buf_2d *const dst_buf = &pd->dst; + const struct buf_2d *const pre_buf = + is_intrabc ? dst_buf : &pd->pre[ref_num]; + + InterPredParams inter_pred_params; + inter_pred_params.conv_params = get_conv_params(0, plane, xd->bd); + const int_interpfilters filters = + av1_broadcast_interp_filter(EIGHTTAP_REGULAR); + av1_init_inter_params( + &inter_pred_params, width, height, mi_y >> pd->subsampling_y, + mi_x >> pd->subsampling_x, pd->subsampling_x, pd->subsampling_y, + xd->bd, is_cur_buf_hbd(xd), is_intrabc, sf, pre_buf, filters); + av1_enc_build_one_inter_predictor(comp_pred8, width, mv, + &inter_pred_params); + return; + } + } + + const InterpFilterParams *filter = av1_get_filter(subpel_search); + + if (!subpel_x_q3 && !subpel_y_q3) { + const uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); + uint16_t *comp_pred = CONVERT_TO_SHORTPTR(comp_pred8); + if (width > 4) { + assert(width % 8 == 0); + int i = height; + do { + int j = 0; + do { + uint16x8_t r = vld1q_u16(ref + j); + vst1q_u16(comp_pred + j, r); + j += 8; + } while (j < width); + ref += ref_stride; + comp_pred += width; + } while (--i != 0); + } else if (width == 4) { + int i = height; + do { + uint16x4_t r = vld1_u16(ref); + vst1_u16(comp_pred, r); + ref += ref_stride; + comp_pred += width; + } while (--i != 0); + } else { + assert(width == 2); + int i = height / 2; + do { + uint16x4_t r = load_u16_2x2(ref, ref_stride); + store_u16x2_strided_x2(comp_pred, width, r); + ref += 2 * ref_stride; + comp_pred += 2 * width; + } while (--i != 0); + } + } else if (!subpel_y_q3) { + const int16_t *const kernel = + av1_get_interp_filter_subpel_kernel(filter, subpel_x_q3 << 1); + aom_highbd_convolve8_horiz_neon(ref8, ref_stride, comp_pred8, width, kernel, + 16, NULL, -1, width, height, bd); + } else if (!subpel_x_q3) { + const int16_t *const kernel = + av1_get_interp_filter_subpel_kernel(filter, subpel_y_q3 << 1); + aom_highbd_convolve8_vert_neon(ref8, ref_stride, comp_pred8, width, NULL, + -1, kernel, 16, width, height, bd); + } else { + DECLARE_ALIGNED(16, uint16_t, + temp[((MAX_SB_SIZE + 16) + 16) * MAX_SB_SIZE]); + const int16_t *const kernel_x = + av1_get_interp_filter_subpel_kernel(filter, subpel_x_q3 << 1); + const int16_t *const kernel_y = + av1_get_interp_filter_subpel_kernel(filter, subpel_y_q3 << 1); + const int intermediate_height = + (((height - 1) * 8 + subpel_y_q3) >> 3) + filter->taps; + assert(intermediate_height <= (MAX_SB_SIZE * 2 + 16) + 16); + aom_highbd_convolve8_horiz_neon( + ref8 - ref_stride * ((filter->taps >> 1) - 1), ref_stride, + CONVERT_TO_BYTEPTR(temp), MAX_SB_SIZE, kernel_x, 16, NULL, -1, width, + intermediate_height, bd); + aom_highbd_convolve8_vert_neon( + CONVERT_TO_BYTEPTR(temp + MAX_SB_SIZE * ((filter->taps >> 1) - 1)), + MAX_SB_SIZE, comp_pred8, width, NULL, -1, kernel_y, 16, width, height, + bd); + } +} + +void aom_highbd_comp_avg_upsampled_pred_neon( + MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col, + const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width, + int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8, + int ref_stride, int bd, int subpel_search) { + aom_highbd_upsampled_pred_neon(xd, cm, mi_row, mi_col, mv, comp_pred8, width, + height, subpel_x_q3, subpel_y_q3, ref8, + ref_stride, bd, subpel_search); + + aom_highbd_comp_avg_pred_neon(comp_pred8, pred8, width, height, comp_pred8, + width); +} + +void aom_highbd_dist_wtd_comp_avg_upsampled_pred_neon( + MACROBLOCKD *xd, const struct AV1Common *const cm, int mi_row, int mi_col, + const MV *const mv, uint8_t *comp_pred8, const uint8_t *pred8, int width, + int height, int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8, + int ref_stride, int bd, const DIST_WTD_COMP_PARAMS *jcp_param, + int subpel_search) { + aom_highbd_upsampled_pred_neon(xd, cm, mi_row, mi_col, mv, comp_pred8, width, + height, subpel_x_q3, subpel_y_q3, ref8, + ref_stride, bd, subpel_search); + + aom_highbd_dist_wtd_comp_avg_pred_neon(comp_pred8, pred8, width, height, + comp_pred8, width, jcp_param); +} + +#endif // CONFIG_AV1_HIGHBITDEPTH diff --git a/third_party/aom/av1/encoder/arm/neon/shift_neon.h b/third_party/aom/av1/encoder/arm/neon/shift_neon.h new file mode 100644 index 0000000000..d73aef2f25 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/shift_neon.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_AV1_ENCODER_ARM_NEON_SHIFT_NEON_H_ +#define AOM_AV1_ENCODER_ARM_NEON_SHIFT_NEON_H_ + +#include <arm_neon.h> + +#include "aom/aom_integer.h" // For AOM_INLINE. + +#define SHIFT_LOOP_HELPER(name, type, intrinsic, arg) \ + static AOM_INLINE void name(const type *in, type *out, int size) { \ + int i = 0; \ + do { \ + out[i] = intrinsic(in[i], arg); \ + } while (++i < size); \ + } + +SHIFT_LOOP_HELPER(shift_left_2_s16_x4, int16x4_t, vshl_n_s16, 2) +SHIFT_LOOP_HELPER(shift_left_2_s16_x8, int16x8_t, vshlq_n_s16, 2) +SHIFT_LOOP_HELPER(shift_left_2_s32_x4, int32x4_t, vshlq_n_s32, 2) +SHIFT_LOOP_HELPER(shift_right_2_round_s16_x8, int16x8_t, vrshrq_n_s16, 2) +SHIFT_LOOP_HELPER(shift_right_2_round_s32_x4, int32x4_t, vrshrq_n_s32, 2) +SHIFT_LOOP_HELPER(shift_right_4_round_s16_x8, int16x8_t, vrshrq_n_s16, 4) +SHIFT_LOOP_HELPER(shift_right_4_round_s32_x4, int32x4_t, vrshrq_n_s32, 4) + +// Addition instructions have slightly better performance compared to shift +// instructions on some micro-architectures, so use these for shifts by one. + +SHIFT_LOOP_HELPER(shift_left_1_s16_x4, int16x4_t, vadd_s16, in[i]) +SHIFT_LOOP_HELPER(shift_left_1_s16_x8, int16x8_t, vaddq_s16, in[i]) +SHIFT_LOOP_HELPER(shift_right_1_round_s16_x4, int16x4_t, vrhadd_s16, + vdup_n_s16(0)) +SHIFT_LOOP_HELPER(shift_right_1_round_s16_x8, int16x8_t, vrhaddq_s16, + vdupq_n_s16(0)) +SHIFT_LOOP_HELPER(shift_right_1_round_s32_x4, int32x4_t, vrhaddq_s32, + vdupq_n_s32(0)) + +#undef SHIFT_LOOP_HELPER + +#endif // AOM_AV1_ENCODER_ARM_NEON_SHIFT_NEON_H_ diff --git a/third_party/aom/av1/encoder/arm/neon/temporal_filter_neon.c b/third_party/aom/av1/encoder/arm/neon/temporal_filter_neon.c new file mode 100644 index 0000000000..986f143864 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/temporal_filter_neon.c @@ -0,0 +1,548 @@ +/* + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" +#include "av1/encoder/encoder.h" +#include "av1/encoder/temporal_filter.h" +#include "aom_dsp/mathutils.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +// For the squared error buffer, add padding for 4 samples. +#define SSE_STRIDE (BW + 4) + +// When using vld1q_u16_x4 compilers may insert an alignment hint of 256 bits. +DECLARE_ALIGNED(32, static const uint16_t, kSlidingWindowMask[]) = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0x0000, + 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0x0000, + 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, + 0x0000, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF +}; + +static INLINE void get_squared_error( + const uint8_t *frame1, const uint32_t stride1, const uint8_t *frame2, + const uint32_t stride2, const uint32_t block_width, + const uint32_t block_height, uint16_t *frame_sse, + const unsigned int dst_stride) { + uint16_t *dst = frame_sse; + + uint32_t i = 0; + do { + uint32_t j = 0; + do { + uint8x16_t s = vld1q_u8(frame1 + i * stride1 + j); + uint8x16_t r = vld1q_u8(frame2 + i * stride2 + j); + + uint8x16_t abs_diff = vabdq_u8(s, r); + uint16x8_t sse_lo = + vmull_u8(vget_low_u8(abs_diff), vget_low_u8(abs_diff)); + uint16x8_t sse_hi = + vmull_u8(vget_high_u8(abs_diff), vget_high_u8(abs_diff)); + + vst1q_u16(dst + j + 2, sse_lo); + vst1q_u16(dst + j + 10, sse_hi); + + j += 16; + } while (j < block_width); + + dst += dst_stride; + } while (++i < block_height); +} + +static INLINE uint16x8_t load_and_pad(const uint16_t *src, const uint32_t col, + const uint32_t block_width) { + uint16x8_t s = vld1q_u16(src); + + if (col == 0) { + const uint16_t lane2 = vgetq_lane_u16(s, 2); + s = vsetq_lane_u16(lane2, s, 0); + s = vsetq_lane_u16(lane2, s, 1); + } else if (col >= block_width - 4) { + const uint16_t lane5 = vgetq_lane_u16(s, 5); + s = vsetq_lane_u16(lane5, s, 6); + s = vsetq_lane_u16(lane5, s, 7); + } + return s; +} + +static void apply_temporal_filter( + const uint8_t *frame, const unsigned int stride, const uint32_t block_width, + const uint32_t block_height, const int *subblock_mses, + unsigned int *accumulator, uint16_t *count, const uint16_t *frame_sse, + const uint32_t *luma_sse_sum, const double inv_num_ref_pixels, + const double decay_factor, const double inv_factor, + const double weight_factor, const double *d_factor, int tf_wgt_calc_lvl) { + assert(((block_width == 16) || (block_width == 32)) && + ((block_height == 16) || (block_height == 32))); + + uint32_t acc_5x5_neon[BH][BW]; + const uint16x8x4_t vmask = vld1q_u16_x4(kSlidingWindowMask); + + // Traverse 4 columns at a time - first and last two columns need padding. + for (uint32_t col = 0; col < block_width; col += 4) { + uint16x8_t vsrc[5]; + const uint16_t *src = frame_sse + col; + + // Load and pad (for first and last two columns) 3 rows from the top. + for (int i = 2; i < 5; i++) { + vsrc[i] = load_and_pad(src, col, block_width); + src += SSE_STRIDE; + } + + // Pad the top 2 rows. + vsrc[0] = vsrc[2]; + vsrc[1] = vsrc[2]; + + for (unsigned int row = 0; row < block_height; row++) { + for (int i = 0; i < 4; i++) { + uint32x4_t vsum = vdupq_n_u32(0); + for (int j = 0; j < 5; j++) { + vsum = vpadalq_u16(vsum, vandq_u16(vsrc[j], vmask.val[i])); + } + acc_5x5_neon[row][col + i] = horizontal_add_u32x4(vsum); + } + + // Push all rows in the sliding window up one. + for (int i = 0; i < 4; i++) { + vsrc[i] = vsrc[i + 1]; + } + + if (row <= block_height - 4) { + // Load next row into the bottom of the sliding window. + vsrc[4] = load_and_pad(src, col, block_width); + src += SSE_STRIDE; + } else { + // Pad the bottom 2 rows. + vsrc[4] = vsrc[3]; + } + } + } + + // Perform filtering. + if (tf_wgt_calc_lvl == 0) { + for (unsigned int i = 0, k = 0; i < block_height; i++) { + for (unsigned int j = 0; j < block_width; j++, k++) { + const int pixel_value = frame[i * stride + j]; + const uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j]; + + const double window_error = diff_sse * inv_num_ref_pixels; + const int subblock_idx = + (i >= block_height / 2) * 2 + (j >= block_width / 2); + const double block_error = (double)subblock_mses[subblock_idx]; + const double combined_error = + weight_factor * window_error + block_error * inv_factor; + // Compute filter weight. + double scaled_error = + combined_error * d_factor[subblock_idx] * decay_factor; + scaled_error = AOMMIN(scaled_error, 7); + const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE); + accumulator[k] += weight * pixel_value; + count[k] += weight; + } + } + } else { + for (unsigned int i = 0, k = 0; i < block_height; i++) { + for (unsigned int j = 0; j < block_width; j++, k++) { + const int pixel_value = frame[i * stride + j]; + const uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j]; + + const double window_error = diff_sse * inv_num_ref_pixels; + const int subblock_idx = + (i >= block_height / 2) * 2 + (j >= block_width / 2); + const double block_error = (double)subblock_mses[subblock_idx]; + const double combined_error = + weight_factor * window_error + block_error * inv_factor; + // Compute filter weight. + double scaled_error = + combined_error * d_factor[subblock_idx] * decay_factor; + scaled_error = AOMMIN(scaled_error, 7); + const float fweight = + approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE; + const int weight = iroundpf(fweight); + accumulator[k] += weight * pixel_value; + count[k] += weight; + } + } + } +} + +void av1_apply_temporal_filter_neon( + const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd, + const BLOCK_SIZE block_size, const int mb_row, const int mb_col, + const int num_planes, const double *noise_levels, const MV *subblock_mvs, + const int *subblock_mses, const int q_factor, const int filter_strength, + int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum, + uint16_t *count) { + const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH; + assert(block_size == BLOCK_32X32 && "Only support 32x32 block with Neon!"); + assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with Neon!"); + assert(!is_high_bitdepth && "Only support low bit-depth with Neon!"); + assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE); + (void)is_high_bitdepth; + + // Block information. + const int mb_height = block_size_high[block_size]; + const int mb_width = block_size_wide[block_size]; + // Frame information. + const int frame_height = frame_to_filter->y_crop_height; + const int frame_width = frame_to_filter->y_crop_width; + const int min_frame_size = AOMMIN(frame_height, frame_width); + // Variables to simplify combined error calculation. + const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) * + TF_SEARCH_ERROR_NORM_WEIGHT); + const double weight_factor = + (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor; + // Adjust filtering based on q. + // Larger q -> stronger filtering -> larger weight. + // Smaller q -> weaker filtering -> smaller weight. + double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2); + q_decay = CLIP(q_decay, 1e-5, 1); + if (q_factor >= TF_QINDEX_CUTOFF) { + // Max q_factor is 255, therefore the upper bound of q_decay is 8. + // We do not need a clip here. + q_decay = 0.5 * pow((double)q_factor / 64, 2); + } + // Smaller strength -> smaller filtering weight. + double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2); + s_decay = CLIP(s_decay, 1e-5, 1); + double d_factor[4] = { 0 }; + uint16_t frame_sse[SSE_STRIDE * BH] = { 0 }; + uint32_t luma_sse_sum[BW * BH] = { 0 }; + + for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) { + // Larger motion vector -> smaller filtering weight. + const MV mv = subblock_mvs[subblock_idx]; + const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2)); + double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD; + distance_threshold = AOMMAX(distance_threshold, 1); + d_factor[subblock_idx] = distance / distance_threshold; + d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1); + } + + // Handle planes in sequence. + int plane_offset = 0; + for (int plane = 0; plane < num_planes; ++plane) { + const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y; + const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x; + const uint32_t frame_stride = + frame_to_filter->strides[plane == AOM_PLANE_Y ? 0 : 1]; + const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w; + + const uint8_t *ref = frame_to_filter->buffers[plane] + frame_offset; + const int ss_x_shift = + mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x; + const int ss_y_shift = + mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y; + const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH + + ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0); + const double inv_num_ref_pixels = 1.0 / num_ref_pixels; + // Larger noise -> larger filtering weight. + const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0); + // Decay factors for non-local mean approach. + const double decay_factor = 1 / (n_decay * q_decay * s_decay); + + // Filter U-plane and V-plane using Y-plane. This is because motion + // search is only done on Y-plane, so the information from Y-plane + // will be more accurate. The luma sse sum is reused in both chroma + // planes. + if (plane == AOM_PLANE_U) { + for (unsigned int i = 0; i < plane_h; i++) { + for (unsigned int j = 0; j < plane_w; j++) { + for (int ii = 0; ii < (1 << ss_y_shift); ++ii) { + for (int jj = 0; jj < (1 << ss_x_shift); ++jj) { + const int yy = (i << ss_y_shift) + ii; // Y-coord on Y-plane. + const int xx = (j << ss_x_shift) + jj; // X-coord on Y-plane. + luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2]; + } + } + } + } + } + + get_squared_error(ref, frame_stride, pred + plane_offset, plane_w, plane_w, + plane_h, frame_sse, SSE_STRIDE); + + apply_temporal_filter(pred + plane_offset, plane_w, plane_w, plane_h, + subblock_mses, accum + plane_offset, + count + plane_offset, frame_sse, luma_sse_sum, + inv_num_ref_pixels, decay_factor, inv_factor, + weight_factor, d_factor, tf_wgt_calc_lvl); + + plane_offset += plane_h * plane_w; + } +} + +double av1_estimate_noise_from_single_plane_neon(const uint8_t *src, int height, + int width, int stride, + int edge_thresh) { + uint16x8_t thresh = vdupq_n_u16(edge_thresh); + uint32x4_t acc = vdupq_n_u32(0); + // Count is in theory positive as it counts the number of times we're under + // the threshold, but it will be counted negatively in order to make best use + // of the vclt instruction, which sets every bit of a lane to 1 when the + // condition is true. + int32x4_t count = vdupq_n_s32(0); + int final_count = 0; + int64_t final_acc = 0; + const uint8_t *src_start = src + stride + 1; + int h = 1; + + do { + int w = 1; + const uint8_t *src_ptr = src_start; + + while (w <= (width - 1) - 16) { + uint8x16_t mat[3][3]; + mat[0][0] = vld1q_u8(src_ptr - stride - 1); + mat[0][1] = vld1q_u8(src_ptr - stride); + mat[0][2] = vld1q_u8(src_ptr - stride + 1); + mat[1][0] = vld1q_u8(src_ptr - 1); + mat[1][1] = vld1q_u8(src_ptr); + mat[1][2] = vld1q_u8(src_ptr + 1); + mat[2][0] = vld1q_u8(src_ptr + stride - 1); + mat[2][1] = vld1q_u8(src_ptr + stride); + mat[2][2] = vld1q_u8(src_ptr + stride + 1); + + // Compute Sobel gradients. + uint16x8_t gxa_lo = + vaddl_u8(vget_low_u8(mat[0][0]), vget_low_u8(mat[2][0])); + uint16x8_t gxa_hi = + vaddl_u8(vget_high_u8(mat[0][0]), vget_high_u8(mat[2][0])); + uint16x8_t gxb_lo = + vaddl_u8(vget_low_u8(mat[0][2]), vget_low_u8(mat[2][2])); + uint16x8_t gxb_hi = + vaddl_u8(vget_high_u8(mat[0][2]), vget_high_u8(mat[2][2])); + gxa_lo = vaddq_u16( + gxa_lo, vaddl_u8(vget_low_u8(mat[1][0]), vget_low_u8(mat[1][0]))); + gxa_hi = vaddq_u16( + gxa_hi, vaddl_u8(vget_high_u8(mat[1][0]), vget_high_u8(mat[1][0]))); + gxb_lo = vaddq_u16( + gxb_lo, vaddl_u8(vget_low_u8(mat[1][2]), vget_low_u8(mat[1][2]))); + gxb_hi = vaddq_u16( + gxb_hi, vaddl_u8(vget_high_u8(mat[1][2]), vget_high_u8(mat[1][2]))); + + uint16x8_t gya_lo = + vaddl_u8(vget_low_u8(mat[0][0]), vget_low_u8(mat[0][2])); + uint16x8_t gya_hi = + vaddl_u8(vget_high_u8(mat[0][0]), vget_high_u8(mat[0][2])); + uint16x8_t gyb_lo = + vaddl_u8(vget_low_u8(mat[2][0]), vget_low_u8(mat[2][2])); + uint16x8_t gyb_hi = + vaddl_u8(vget_high_u8(mat[2][0]), vget_high_u8(mat[2][2])); + gya_lo = vaddq_u16( + gya_lo, vaddl_u8(vget_low_u8(mat[0][1]), vget_low_u8(mat[0][1]))); + gya_hi = vaddq_u16( + gya_hi, vaddl_u8(vget_high_u8(mat[0][1]), vget_high_u8(mat[0][1]))); + gyb_lo = vaddq_u16( + gyb_lo, vaddl_u8(vget_low_u8(mat[2][1]), vget_low_u8(mat[2][1]))); + gyb_hi = vaddq_u16( + gyb_hi, vaddl_u8(vget_high_u8(mat[2][1]), vget_high_u8(mat[2][1]))); + + uint16x8_t ga_lo = vabaq_u16(vabdq_u16(gxa_lo, gxb_lo), gya_lo, gyb_lo); + uint16x8_t ga_hi = vabaq_u16(vabdq_u16(gxa_hi, gxb_hi), gya_hi, gyb_hi); + + // Check which vector elements are under the threshold. The Laplacian is + // then unconditionally computed and we accumulate zeros if we're not + // under the threshold. This is much faster than using an if statement. + uint16x8_t thresh_u16_lo = vcltq_u16(ga_lo, thresh); + uint16x8_t thresh_u16_hi = vcltq_u16(ga_hi, thresh); + + uint16x8_t center_lo = vshll_n_u8(vget_low_u8(mat[1][1]), 2); + uint16x8_t center_hi = vshll_n_u8(vget_high_u8(mat[1][1]), 2); + + uint16x8_t adj0_lo = + vaddl_u8(vget_low_u8(mat[0][1]), vget_low_u8(mat[2][1])); + uint16x8_t adj0_hi = + vaddl_u8(vget_high_u8(mat[0][1]), vget_high_u8(mat[2][1])); + uint16x8_t adj1_lo = + vaddl_u8(vget_low_u8(mat[1][0]), vget_low_u8(mat[1][2])); + uint16x8_t adj1_hi = + vaddl_u8(vget_high_u8(mat[1][0]), vget_high_u8(mat[1][2])); + uint16x8_t adj_lo = vaddq_u16(adj0_lo, adj1_lo); + adj_lo = vaddq_u16(adj_lo, adj_lo); + uint16x8_t adj_hi = vaddq_u16(adj0_hi, adj1_hi); + adj_hi = vaddq_u16(adj_hi, adj_hi); + + uint16x8_t diag0_lo = + vaddl_u8(vget_low_u8(mat[0][0]), vget_low_u8(mat[0][2])); + uint16x8_t diag0_hi = + vaddl_u8(vget_high_u8(mat[0][0]), vget_high_u8(mat[0][2])); + uint16x8_t diag1_lo = + vaddl_u8(vget_low_u8(mat[2][0]), vget_low_u8(mat[2][2])); + uint16x8_t diag1_hi = + vaddl_u8(vget_high_u8(mat[2][0]), vget_high_u8(mat[2][2])); + uint16x8_t diag_lo = vaddq_u16(diag0_lo, diag1_lo); + uint16x8_t diag_hi = vaddq_u16(diag0_hi, diag1_hi); + + uint16x8_t v_lo = vaddq_u16(center_lo, diag_lo); + v_lo = vabdq_u16(v_lo, adj_lo); + uint16x8_t v_hi = vaddq_u16(center_hi, diag_hi); + v_hi = vabdq_u16(v_hi, adj_hi); + + acc = vpadalq_u16(acc, vandq_u16(v_lo, thresh_u16_lo)); + acc = vpadalq_u16(acc, vandq_u16(v_hi, thresh_u16_hi)); + + // Add -1 for each lane where the gradient is under the threshold. + count = vpadalq_s16(count, vreinterpretq_s16_u16(thresh_u16_lo)); + count = vpadalq_s16(count, vreinterpretq_s16_u16(thresh_u16_hi)); + + w += 16; + src_ptr += 16; + } + + if (w <= (width - 1) - 8) { + uint8x8_t mat[3][3]; + mat[0][0] = vld1_u8(src_ptr - stride - 1); + mat[0][1] = vld1_u8(src_ptr - stride); + mat[0][2] = vld1_u8(src_ptr - stride + 1); + mat[1][0] = vld1_u8(src_ptr - 1); + mat[1][1] = vld1_u8(src_ptr); + mat[1][2] = vld1_u8(src_ptr + 1); + mat[2][0] = vld1_u8(src_ptr + stride - 1); + mat[2][1] = vld1_u8(src_ptr + stride); + mat[2][2] = vld1_u8(src_ptr + stride + 1); + + // Compute Sobel gradients. + uint16x8_t gxa = vaddl_u8(mat[0][0], mat[2][0]); + uint16x8_t gxb = vaddl_u8(mat[0][2], mat[2][2]); + gxa = vaddq_u16(gxa, vaddl_u8(mat[1][0], mat[1][0])); + gxb = vaddq_u16(gxb, vaddl_u8(mat[1][2], mat[1][2])); + + uint16x8_t gya = vaddl_u8(mat[0][0], mat[0][2]); + uint16x8_t gyb = vaddl_u8(mat[2][0], mat[2][2]); + gya = vaddq_u16(gya, vaddl_u8(mat[0][1], mat[0][1])); + gyb = vaddq_u16(gyb, vaddl_u8(mat[2][1], mat[2][1])); + + uint16x8_t ga = vabaq_u16(vabdq_u16(gxa, gxb), gya, gyb); + + // Check which vector elements are under the threshold. The Laplacian is + // then unconditionally computed and we accumulate zeros if we're not + // under the threshold. This is much faster than using an if statement. + uint16x8_t thresh_u16 = vcltq_u16(ga, thresh); + + uint16x8_t center = vshll_n_u8(mat[1][1], 2); + + uint16x8_t adj0 = vaddl_u8(mat[0][1], mat[2][1]); + uint16x8_t adj1 = vaddl_u8(mat[1][0], mat[1][2]); + uint16x8_t adj = vaddq_u16(adj0, adj1); + adj = vaddq_u16(adj, adj); + + uint16x8_t diag0 = vaddl_u8(mat[0][0], mat[0][2]); + uint16x8_t diag1 = vaddl_u8(mat[2][0], mat[2][2]); + uint16x8_t diag = vaddq_u16(diag0, diag1); + + uint16x8_t v = vaddq_u16(center, diag); + v = vabdq_u16(v, adj); + + acc = vpadalq_u16(acc, vandq_u16(v, thresh_u16)); + // Add -1 for each lane where the gradient is under the threshold. + count = vpadalq_s16(count, vreinterpretq_s16_u16(thresh_u16)); + + w += 8; + src_ptr += 8; + } + + if (w <= (width - 1) - 4) { + uint16x8_t mask = vcombine_u16(vdup_n_u16(65535), vdup_n_u16(0)); + uint8x8_t mat[3][3]; + mat[0][0] = load_u8_4x1(src_ptr - stride - 1); + mat[0][1] = load_u8_4x1(src_ptr - stride); + mat[0][2] = load_u8_4x1(src_ptr - stride + 1); + mat[1][0] = load_u8_4x1(src_ptr - 1); + mat[1][1] = load_u8_4x1(src_ptr); + mat[1][2] = load_u8_4x1(src_ptr + 1); + mat[2][0] = load_u8_4x1(src_ptr + stride - 1); + mat[2][1] = load_u8_4x1(src_ptr + stride); + mat[2][2] = load_u8_4x1(src_ptr + stride + 1); + + // Compute Sobel gradients. + uint16x8_t gxa = vaddl_u8(mat[0][0], mat[2][0]); + uint16x8_t gxb = vaddl_u8(mat[0][2], mat[2][2]); + gxa = vaddq_u16(gxa, vaddl_u8(mat[1][0], mat[1][0])); + gxb = vaddq_u16(gxb, vaddl_u8(mat[1][2], mat[1][2])); + + uint16x8_t gya = vaddl_u8(mat[0][0], mat[0][2]); + uint16x8_t gyb = vaddl_u8(mat[2][0], mat[2][2]); + gya = vaddq_u16(gya, vaddl_u8(mat[0][1], mat[0][1])); + gyb = vaddq_u16(gyb, vaddl_u8(mat[2][1], mat[2][1])); + + uint16x8_t ga = vabaq_u16(vabdq_u16(gxa, gxb), gya, gyb); + + // Check which vector elements are under the threshold. The Laplacian is + // then unconditionally computed and we accumulate zeros if we're not + // under the threshold. This is much faster than using an if statement. + uint16x8_t thresh_u16 = vandq_u16(vcltq_u16(ga, thresh), mask); + + uint16x8_t center = vshll_n_u8(mat[1][1], 2); + + uint16x8_t adj0 = vaddl_u8(mat[0][1], mat[2][1]); + uint16x8_t adj1 = vaddl_u8(mat[1][0], mat[1][2]); + uint16x8_t adj = vaddq_u16(adj0, adj1); + adj = vaddq_u16(adj, adj); + + uint16x8_t diag0 = vaddl_u8(mat[0][0], mat[0][2]); + uint16x8_t diag1 = vaddl_u8(mat[2][0], mat[2][2]); + uint16x8_t diag = vaddq_u16(diag0, diag1); + + uint16x8_t v = vaddq_u16(center, diag); + v = vabdq_u16(v, adj); + + acc = vpadalq_u16(acc, vandq_u16(v, thresh_u16)); + // Add -1 for each lane where the gradient is under the threshold. + count = vpadalq_s16(count, vreinterpretq_s16_u16(thresh_u16)); + + w += 4; + src_ptr += 4; + } + + while (w < width - 1) { + int mat[3][3]; + mat[0][0] = *(src_ptr - stride - 1); + mat[0][1] = *(src_ptr - stride); + mat[0][2] = *(src_ptr - stride + 1); + mat[1][0] = *(src_ptr - 1); + mat[1][1] = *(src_ptr); + mat[1][2] = *(src_ptr + 1); + mat[2][0] = *(src_ptr + stride - 1); + mat[2][1] = *(src_ptr + stride); + mat[2][2] = *(src_ptr + stride + 1); + + // Compute Sobel gradients. + const int gx = (mat[0][0] - mat[0][2]) + (mat[2][0] - mat[2][2]) + + 2 * (mat[1][0] - mat[1][2]); + const int gy = (mat[0][0] - mat[2][0]) + (mat[0][2] - mat[2][2]) + + 2 * (mat[0][1] - mat[2][1]); + const int ga = abs(gx) + abs(gy); + + // Accumulate Laplacian. + const int is_under = ga < edge_thresh; + const int v = 4 * mat[1][1] - + 2 * (mat[0][1] + mat[2][1] + mat[1][0] + mat[1][2]) + + (mat[0][0] + mat[0][2] + mat[2][0] + mat[2][2]); + final_acc += abs(v) * is_under; + final_count += is_under; + + src_ptr++; + w++; + } + src_start += stride; + } while (++h < height - 1); + + // We counted negatively, so subtract to get the final value. + final_count -= horizontal_add_s32x4(count); + final_acc += horizontal_long_add_u32x4(acc); + return (final_count < 16) + ? -1.0 + : (double)final_acc / (6 * final_count) * SQRT_PI_BY_2; +} diff --git a/third_party/aom/av1/encoder/arm/neon/temporal_filter_neon_dotprod.c b/third_party/aom/av1/encoder/arm/neon/temporal_filter_neon_dotprod.c new file mode 100644 index 0000000000..5a52e701a2 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/temporal_filter_neon_dotprod.c @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> + +#include "config/aom_config.h" +#include "config/av1_rtcd.h" +#include "av1/encoder/encoder.h" +#include "av1/encoder/temporal_filter.h" +#include "aom_dsp/mathutils.h" +#include "aom_dsp/arm/mem_neon.h" +#include "aom_dsp/arm/sum_neon.h" + +// For the squared error buffer, add padding for 4 samples. +#define SSE_STRIDE (BW + 4) + +// clang-format off + +DECLARE_ALIGNED(16, static const uint8_t, kSlidingWindowMask[]) = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, + 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, + 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, + 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF +}; + +// clang-format on + +static INLINE void get_abs_diff(const uint8_t *frame1, const uint32_t stride1, + const uint8_t *frame2, const uint32_t stride2, + const uint32_t block_width, + const uint32_t block_height, + uint8_t *frame_abs_diff, + const unsigned int dst_stride) { + uint8_t *dst = frame_abs_diff; + + uint32_t i = 0; + do { + uint32_t j = 0; + do { + uint8x16_t s = vld1q_u8(frame1 + i * stride1 + j); + uint8x16_t r = vld1q_u8(frame2 + i * stride2 + j); + uint8x16_t abs_diff = vabdq_u8(s, r); + vst1q_u8(dst + j + 2, abs_diff); + j += 16; + } while (j < block_width); + + dst += dst_stride; + } while (++i < block_height); +} + +static INLINE uint8x16_t load_and_pad(const uint8_t *src, const uint32_t col, + const uint32_t block_width) { + uint8x8_t s = vld1_u8(src); + + if (col == 0) { + const uint8_t lane2 = vget_lane_u8(s, 2); + s = vset_lane_u8(lane2, s, 0); + s = vset_lane_u8(lane2, s, 1); + } else if (col >= block_width - 4) { + const uint8_t lane5 = vget_lane_u8(s, 5); + s = vset_lane_u8(lane5, s, 6); + s = vset_lane_u8(lane5, s, 7); + } + return vcombine_u8(s, s); +} + +static void apply_temporal_filter( + const uint8_t *frame, const unsigned int stride, const uint32_t block_width, + const uint32_t block_height, const int *subblock_mses, + unsigned int *accumulator, uint16_t *count, const uint8_t *frame_abs_diff, + const uint32_t *luma_sse_sum, const double inv_num_ref_pixels, + const double decay_factor, const double inv_factor, + const double weight_factor, const double *d_factor, int tf_wgt_calc_lvl) { + assert(((block_width == 16) || (block_width == 32)) && + ((block_height == 16) || (block_height == 32))); + + uint32_t acc_5x5_neon[BH][BW]; + const uint8x16x2_t vmask = vld1q_u8_x2(kSlidingWindowMask); + + // Traverse 4 columns at a time - first and last two columns need padding. + for (uint32_t col = 0; col < block_width; col += 4) { + uint8x16_t vsrc[5][2]; + const uint8_t *src = frame_abs_diff + col; + + // Load, pad (for first and last two columns) and mask 3 rows from the top. + for (int i = 2; i < 5; i++) { + const uint8x16_t s = load_and_pad(src, col, block_width); + vsrc[i][0] = vandq_u8(s, vmask.val[0]); + vsrc[i][1] = vandq_u8(s, vmask.val[1]); + src += SSE_STRIDE; + } + + // Pad the top 2 rows. + vsrc[0][0] = vsrc[2][0]; + vsrc[0][1] = vsrc[2][1]; + vsrc[1][0] = vsrc[2][0]; + vsrc[1][1] = vsrc[2][1]; + + for (unsigned int row = 0; row < block_height; row++) { + uint32x4_t sum_01 = vdupq_n_u32(0); + uint32x4_t sum_23 = vdupq_n_u32(0); + + sum_01 = vdotq_u32(sum_01, vsrc[0][0], vsrc[0][0]); + sum_01 = vdotq_u32(sum_01, vsrc[1][0], vsrc[1][0]); + sum_01 = vdotq_u32(sum_01, vsrc[2][0], vsrc[2][0]); + sum_01 = vdotq_u32(sum_01, vsrc[3][0], vsrc[3][0]); + sum_01 = vdotq_u32(sum_01, vsrc[4][0], vsrc[4][0]); + + sum_23 = vdotq_u32(sum_23, vsrc[0][1], vsrc[0][1]); + sum_23 = vdotq_u32(sum_23, vsrc[1][1], vsrc[1][1]); + sum_23 = vdotq_u32(sum_23, vsrc[2][1], vsrc[2][1]); + sum_23 = vdotq_u32(sum_23, vsrc[3][1], vsrc[3][1]); + sum_23 = vdotq_u32(sum_23, vsrc[4][1], vsrc[4][1]); + + vst1q_u32(&acc_5x5_neon[row][col], vpaddq_u32(sum_01, sum_23)); + + // Push all rows in the sliding window up one. + for (int i = 0; i < 4; i++) { + vsrc[i][0] = vsrc[i + 1][0]; + vsrc[i][1] = vsrc[i + 1][1]; + } + + if (row <= block_height - 4) { + // Load next row into the bottom of the sliding window. + uint8x16_t s = load_and_pad(src, col, block_width); + vsrc[4][0] = vandq_u8(s, vmask.val[0]); + vsrc[4][1] = vandq_u8(s, vmask.val[1]); + src += SSE_STRIDE; + } else { + // Pad the bottom 2 rows. + vsrc[4][0] = vsrc[3][0]; + vsrc[4][1] = vsrc[3][1]; + } + } + } + + // Perform filtering. + if (tf_wgt_calc_lvl == 0) { + for (unsigned int i = 0, k = 0; i < block_height; i++) { + for (unsigned int j = 0; j < block_width; j++, k++) { + const int pixel_value = frame[i * stride + j]; + const uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j]; + + const double window_error = diff_sse * inv_num_ref_pixels; + const int subblock_idx = + (i >= block_height / 2) * 2 + (j >= block_width / 2); + const double block_error = (double)subblock_mses[subblock_idx]; + const double combined_error = + weight_factor * window_error + block_error * inv_factor; + // Compute filter weight. + double scaled_error = + combined_error * d_factor[subblock_idx] * decay_factor; + scaled_error = AOMMIN(scaled_error, 7); + const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE); + accumulator[k] += weight * pixel_value; + count[k] += weight; + } + } + } else { + for (unsigned int i = 0, k = 0; i < block_height; i++) { + for (unsigned int j = 0; j < block_width; j++, k++) { + const int pixel_value = frame[i * stride + j]; + const uint32_t diff_sse = acc_5x5_neon[i][j] + luma_sse_sum[i * BW + j]; + + const double window_error = diff_sse * inv_num_ref_pixels; + const int subblock_idx = + (i >= block_height / 2) * 2 + (j >= block_width / 2); + const double block_error = (double)subblock_mses[subblock_idx]; + const double combined_error = + weight_factor * window_error + block_error * inv_factor; + // Compute filter weight. + double scaled_error = + combined_error * d_factor[subblock_idx] * decay_factor; + scaled_error = AOMMIN(scaled_error, 7); + const float fweight = + approx_exp((float)-scaled_error) * TF_WEIGHT_SCALE; + const int weight = iroundpf(fweight); + accumulator[k] += weight * pixel_value; + count[k] += weight; + } + } + } +} + +void av1_apply_temporal_filter_neon_dotprod( + const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd, + const BLOCK_SIZE block_size, const int mb_row, const int mb_col, + const int num_planes, const double *noise_levels, const MV *subblock_mvs, + const int *subblock_mses, const int q_factor, const int filter_strength, + int tf_wgt_calc_lvl, const uint8_t *pred, uint32_t *accum, + uint16_t *count) { + const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH; + assert(block_size == BLOCK_32X32 && "Only support 32x32 block with Neon!"); + assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with Neon!"); + assert(!is_high_bitdepth && "Only support low bit-depth with Neon!"); + assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE); + (void)is_high_bitdepth; + + // Block information. + const int mb_height = block_size_high[block_size]; + const int mb_width = block_size_wide[block_size]; + // Frame information. + const int frame_height = frame_to_filter->y_crop_height; + const int frame_width = frame_to_filter->y_crop_width; + const int min_frame_size = AOMMIN(frame_height, frame_width); + // Variables to simplify combined error calculation. + const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) * + TF_SEARCH_ERROR_NORM_WEIGHT); + const double weight_factor = + (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor; + // Adjust filtering based on q. + // Larger q -> stronger filtering -> larger weight. + // Smaller q -> weaker filtering -> smaller weight. + double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2); + q_decay = CLIP(q_decay, 1e-5, 1); + if (q_factor >= TF_QINDEX_CUTOFF) { + // Max q_factor is 255, therefore the upper bound of q_decay is 8. + // We do not need a clip here. + q_decay = 0.5 * pow((double)q_factor / 64, 2); + } + // Smaller strength -> smaller filtering weight. + double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2); + s_decay = CLIP(s_decay, 1e-5, 1); + double d_factor[4] = { 0 }; + uint8_t frame_abs_diff[SSE_STRIDE * BH] = { 0 }; + uint32_t luma_sse_sum[BW * BH] = { 0 }; + + for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) { + // Larger motion vector -> smaller filtering weight. + const MV mv = subblock_mvs[subblock_idx]; + const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2)); + double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD; + distance_threshold = AOMMAX(distance_threshold, 1); + d_factor[subblock_idx] = distance / distance_threshold; + d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1); + } + + // Handle planes in sequence. + int plane_offset = 0; + for (int plane = 0; plane < num_planes; ++plane) { + const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y; + const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x; + const uint32_t frame_stride = + frame_to_filter->strides[plane == AOM_PLANE_Y ? 0 : 1]; + const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w; + + const uint8_t *ref = frame_to_filter->buffers[plane] + frame_offset; + const int ss_x_shift = + mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x; + const int ss_y_shift = + mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y; + const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH + + ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0); + const double inv_num_ref_pixels = 1.0 / num_ref_pixels; + // Larger noise -> larger filtering weight. + const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0); + // Decay factors for non-local mean approach. + const double decay_factor = 1 / (n_decay * q_decay * s_decay); + + // Filter U-plane and V-plane using Y-plane. This is because motion + // search is only done on Y-plane, so the information from Y-plane + // will be more accurate. The luma sse sum is reused in both chroma + // planes. + if (plane == AOM_PLANE_U) { + for (unsigned int i = 0; i < plane_h; i++) { + for (unsigned int j = 0; j < plane_w; j++) { + for (int ii = 0; ii < (1 << ss_y_shift); ++ii) { + for (int jj = 0; jj < (1 << ss_x_shift); ++jj) { + const int yy = (i << ss_y_shift) + ii; // Y-coord on Y-plane. + const int xx = (j << ss_x_shift) + jj; // X-coord on Y-plane. + luma_sse_sum[i * BW + j] += + (frame_abs_diff[yy * SSE_STRIDE + xx + 2] * + frame_abs_diff[yy * SSE_STRIDE + xx + 2]); + } + } + } + } + } + + get_abs_diff(ref, frame_stride, pred + plane_offset, plane_w, plane_w, + plane_h, frame_abs_diff, SSE_STRIDE); + + apply_temporal_filter(pred + plane_offset, plane_w, plane_w, plane_h, + subblock_mses, accum + plane_offset, + count + plane_offset, frame_abs_diff, luma_sse_sum, + inv_num_ref_pixels, decay_factor, inv_factor, + weight_factor, d_factor, tf_wgt_calc_lvl); + + plane_offset += plane_h * plane_w; + } +} diff --git a/third_party/aom/av1/encoder/arm/neon/txfm_neon.h b/third_party/aom/av1/encoder/arm/neon/txfm_neon.h new file mode 100644 index 0000000000..635364f46a --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/txfm_neon.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#ifndef AOM_AV1_ENCODER_ARM_NEON_TXFM_NEON_H_ +#define AOM_AV1_ENCODER_ARM_NEON_TXFM_NEON_H_ + +#include "aom/aom_integer.h" // For AOM_INLINE. + +static AOM_INLINE void ud_adjust_input_and_stride(int ud_flip, + const int16_t **input, + int *stride, int out_size) { + if (ud_flip) { + *input = *input + (out_size - 1) * *stride; + *stride = -*stride; + } +} + +#endif // AOM_AV1_ENCODER_ARM_NEON_TXFM_NEON_H_ diff --git a/third_party/aom/av1/encoder/arm/neon/wedge_utils_neon.c b/third_party/aom/av1/encoder/arm/neon/wedge_utils_neon.c new file mode 100644 index 0000000000..1b35269b33 --- /dev/null +++ b/third_party/aom/av1/encoder/arm/neon/wedge_utils_neon.c @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <arm_neon.h> +#include <assert.h> + +#include "aom_dsp/arm/sum_neon.h" +#include "av1/common/reconinter.h" + +#define MAX_MASK_VALUE (1 << WEDGE_WEIGHT_BITS) + +/** + * See av1_wedge_sse_from_residuals_c for details of the parameters and + * computation. + */ +uint64_t av1_wedge_sse_from_residuals_neon(const int16_t *r1, const int16_t *d, + const uint8_t *m, int N) { + assert(N % 64 == 0); + + uint64x2_t v_csse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) }; + + int i = 0; + do { + int32x4_t sum[4]; + int32x4_t sse[2]; + int16x4_t sum_s16[4]; + + const int16x8_t r1_l = vld1q_s16(r1 + i); + const int16x8_t r1_h = vld1q_s16(r1 + i + 8); + const int16x8_t d_l = vld1q_s16(d + i); + const int16x8_t d_h = vld1q_s16(d + i + 8); + // The following three lines are a bit inelegant compared to using a pair + // of vmovl_u8()... but it forces the compiler to emit a ZIP1, ZIP2 pair - + // which can be executed in parallel with the subsequent SSHL instructions. + // (SSHL can only be executed on half of the Neon pipes in modern Arm + // cores, whereas ZIP1/2 can be executed on all of them.) + const uint8x16x2_t m_u16 = vzipq_u8(vld1q_u8(m + i), vdupq_n_u8(0)); + const int16x8_t m_l = vreinterpretq_s16_u8(m_u16.val[0]); + const int16x8_t m_h = vreinterpretq_s16_u8(m_u16.val[1]); + + sum[0] = vshll_n_s16(vget_low_s16(r1_l), WEDGE_WEIGHT_BITS); + sum[1] = vshll_n_s16(vget_high_s16(r1_l), WEDGE_WEIGHT_BITS); + sum[2] = vshll_n_s16(vget_low_s16(r1_h), WEDGE_WEIGHT_BITS); + sum[3] = vshll_n_s16(vget_high_s16(r1_h), WEDGE_WEIGHT_BITS); + + sum[0] = vmlal_s16(sum[0], vget_low_s16(m_l), vget_low_s16(d_l)); + sum[1] = vmlal_s16(sum[1], vget_high_s16(m_l), vget_high_s16(d_l)); + sum[2] = vmlal_s16(sum[2], vget_low_s16(m_h), vget_low_s16(d_h)); + sum[3] = vmlal_s16(sum[3], vget_high_s16(m_h), vget_high_s16(d_h)); + + sum_s16[0] = vqmovn_s32(sum[0]); + sum_s16[1] = vqmovn_s32(sum[1]); + sum_s16[2] = vqmovn_s32(sum[2]); + sum_s16[3] = vqmovn_s32(sum[3]); + + sse[0] = vmull_s16(sum_s16[0], sum_s16[0]); + sse[1] = vmull_s16(sum_s16[2], sum_s16[2]); + sse[0] = vmlal_s16(sse[0], sum_s16[1], sum_s16[1]); + sse[1] = vmlal_s16(sse[1], sum_s16[3], sum_s16[3]); + + v_csse[0] = vpadalq_u32(v_csse[0], vreinterpretq_u32_s32(sse[0])); + v_csse[1] = vpadalq_u32(v_csse[1], vreinterpretq_u32_s32(sse[1])); + + i += 16; + } while (i < N); + + uint64_t csse = horizontal_add_u64x2(vaddq_u64(v_csse[0], v_csse[1])); + return ROUND_POWER_OF_TWO(csse, 2 * WEDGE_WEIGHT_BITS); +} + +int8_t av1_wedge_sign_from_residuals_neon(const int16_t *ds, const uint8_t *m, + int N, int64_t limit) { + int32x4_t acc[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), + vdupq_n_s32(0) }; + + do { + int16x8_t ds_l = vld1q_s16(ds); + int16x8_t ds_h = vld1q_s16(ds + 8); + + int8x16_t m_s8 = vreinterpretq_s8_u8(vld1q_u8(m)); + int16x8_t m_l = vmovl_s8(vget_low_s8(m_s8)); + int16x8_t m_h = vmovl_s8(vget_high_s8(m_s8)); + + acc[0] = vmlal_s16(acc[0], vget_low_s16(ds_l), vget_low_s16(m_l)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(ds_l), vget_high_s16(m_l)); + acc[2] = vmlal_s16(acc[2], vget_low_s16(ds_h), vget_low_s16(m_h)); + acc[3] = vmlal_s16(acc[3], vget_high_s16(ds_h), vget_high_s16(m_h)); + + ds += 16; + m += 16; + N -= 16; + } while (N != 0); + + int64x2_t sum = vpaddlq_s32(acc[0]); + sum = vpadalq_s32(sum, acc[1]); + sum = vpadalq_s32(sum, acc[2]); + sum = vpadalq_s32(sum, acc[3]); + + return (horizontal_add_s64x2(sum) > limit); +} + +void av1_wedge_compute_delta_squares_neon(int16_t *d_ptr, const int16_t *a_ptr, + const int16_t *b_ptr, int N) { + do { + int16x8_t a = vld1q_s16(a_ptr); + int16x8_t b = vld1q_s16(b_ptr); + + int32x4_t sq_lo = vmull_s16(vget_low_s16(a), vget_low_s16(a)); + int32x4_t sq_hi = vmull_s16(vget_high_s16(a), vget_high_s16(a)); + + sq_lo = vmlsl_s16(sq_lo, vget_low_s16(b), vget_low_s16(b)); + sq_hi = vmlsl_s16(sq_hi, vget_high_s16(b), vget_high_s16(b)); + + int16x8_t res = vcombine_s16(vqmovn_s32(sq_lo), vqmovn_s32(sq_hi)); + + vst1q_s16(d_ptr, res); + + d_ptr += 8; + a_ptr += 8; + b_ptr += 8; + N -= 8; + } while (N != 0); +} |