summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc')
-rw-r--r--third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc900
1 files changed, 900 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc
new file mode 100644
index 0000000000..af30ff1b9f
--- /dev/null
+++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc
@@ -0,0 +1,900 @@
+/*
+ * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "modules/audio_processing/aec3/matched_filter.h"
+
+// Defines WEBRTC_ARCH_X86_FAMILY, used below.
+#include "rtc_base/system/arch.h"
+
+#if defined(WEBRTC_HAS_NEON)
+#include <arm_neon.h>
+#endif
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+#include <emmintrin.h>
+#endif
+#include <algorithm>
+#include <cstddef>
+#include <initializer_list>
+#include <iterator>
+#include <numeric>
+
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
+#include "modules/audio_processing/logging/apm_data_dumper.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/experiments/field_trial_parser.h"
+#include "rtc_base/logging.h"
+#include "system_wrappers/include/field_trial.h"
+
+namespace {
+
+// Subsample rate used for computing the accumulated error.
+// The implementation of some core functions depends on this constant being
+// equal to 4.
+constexpr int kAccumulatedErrorSubSampleRate = 4;
+
+void UpdateAccumulatedError(
+ const rtc::ArrayView<const float> instantaneous_accumulated_error,
+ const rtc::ArrayView<float> accumulated_error,
+ float one_over_error_sum_anchor,
+ float smooth_constant_increases) {
+ for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
+ float error_norm =
+ instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
+ if (error_norm < accumulated_error[k]) {
+ accumulated_error[k] = error_norm;
+ } else {
+ accumulated_error[k] +=
+ smooth_constant_increases * (error_norm - accumulated_error[k]);
+ }
+ }
+}
+
+size_t ComputePreEchoLag(
+ const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration,
+ const rtc::ArrayView<const float> accumulated_error,
+ size_t lag,
+ size_t alignment_shift_winner) {
+ RTC_DCHECK_GE(lag, alignment_shift_winner);
+ size_t pre_echo_lag_estimate = lag - alignment_shift_winner;
+ size_t maximum_pre_echo_lag =
+ std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate,
+ accumulated_error.size());
+ switch (pre_echo_configuration.mode) {
+ case 0:
+ // Mode 0: Pre echo lag is defined as the first coefficient with an error
+ // lower than a threshold with a certain decrease slope.
+ for (size_t k = 1; k < maximum_pre_echo_lag; ++k) {
+ if (accumulated_error[k] <
+ pre_echo_configuration.threshold * accumulated_error[k - 1] &&
+ accumulated_error[k] < pre_echo_configuration.threshold) {
+ pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
+ break;
+ }
+ }
+ break;
+ case 1:
+ // Mode 1: Pre echo lag is defined as the first coefficient with an error
+ // lower than a certain threshold.
+ for (size_t k = 0; k < maximum_pre_echo_lag; ++k) {
+ if (accumulated_error[k] < pre_echo_configuration.threshold) {
+ pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
+ break;
+ }
+ }
+ break;
+ case 2:
+ case 3:
+ // Mode 2,3: Pre echo lag is defined as the closest coefficient to the lag
+ // with an error lower than a certain threshold.
+ for (int k = static_cast<int>(maximum_pre_echo_lag) - 1; k >= 0; --k) {
+ if (accumulated_error[k] > pre_echo_configuration.threshold) {
+ break;
+ }
+ pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
+ }
+ break;
+ default:
+ RTC_DCHECK_NOTREACHED();
+ break;
+ }
+ return pre_echo_lag_estimate + alignment_shift_winner;
+}
+
+webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() {
+ float threshold = 0.5f;
+ int mode = 0;
+ const std::string pre_echo_configuration_field_trial =
+ webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration");
+ webrtc::FieldTrialParameter<double> threshold_field_trial_parameter(
+ /*key=*/"threshold", /*default_value=*/threshold);
+ webrtc::FieldTrialParameter<int> mode_field_trial_parameter(
+ /*key=*/"mode", /*default_value=*/mode);
+ webrtc::ParseFieldTrial(
+ {&threshold_field_trial_parameter, &mode_field_trial_parameter},
+ pre_echo_configuration_field_trial);
+ float threshold_read =
+ static_cast<float>(threshold_field_trial_parameter.Get());
+ int mode_read = mode_field_trial_parameter.Get();
+ if (threshold_read < 1.0f && threshold_read > 0.0f) {
+ threshold = threshold_read;
+ } else {
+ RTC_LOG(LS_ERROR)
+ << "AEC3: Pre echo configuration: wrong input, threshold = "
+ << threshold_read << ".";
+ }
+ if (mode_read >= 0 && mode_read <= 3) {
+ mode = mode_read;
+ } else {
+ RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, mode = "
+ << mode_read << ".";
+ }
+ RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration: threshold = " << threshold
+ << ", mode = " << mode << ".";
+ return {.threshold = threshold, .mode = mode};
+}
+
+} // namespace
+
+namespace webrtc {
+namespace aec3 {
+
+#if defined(WEBRTC_HAS_NEON)
+
+inline float SumAllElements(float32x4_t elements) {
+ float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
+ sum = vpadd_f32(sum, sum);
+ return vget_lane_f32(sum, 0);
+}
+
+void MatchedFilterCoreWithAccumulatedError_NEON(
+ size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 4);
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ RTC_DCHECK_GT(x_size, x_start_index);
+ // Compute loop chunk sizes until, and after, the wraparound of the circular
+ // buffer for x.
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+ if (chunk1 != h_size) {
+ const int chunk2 = h_size - chunk1;
+ std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
+ std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
+ }
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ const float* h_p = &h[0];
+ float* accumulated_error_p = &accumulated_error[0];
+ // Initialize values for the accumulation.
+ float32x4_t x2_sum_128 = vdupq_n_f32(0);
+ float x2_sum = 0.f;
+ float s = 0;
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = h_size >> 2;
+ for (int k = limit_by_4; k > 0;
+ --k, h_p += 4, x_p += 4, accumulated_error_p++) {
+ // Load the data into 128 bit vectors.
+ const float32x4_t x_k = vld1q_f32(x_p);
+ const float32x4_t h_k = vld1q_f32(h_p);
+ // Compute and accumulate x * x.
+ x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
+ // Compute x * h
+ float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
+ s += SumAllElements(hk_xk_128);
+ const float e = s - y[i];
+ accumulated_error_p[0] += e * e;
+ }
+ // Combine the accumulated vector and scalar values.
+ x2_sum += SumAllElements(x2_sum_128);
+ // Compute the matched filter error.
+ float e = y[i] - s;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const float32x4_t alpha_128 = vmovq_n_f32(alpha);
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = h_size >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ float32x4_t h_k = vld1q_f32(h_p);
+ const float32x4_t x_k = vld1q_f32(x_p);
+ // Compute h = h + alpha * x.
+ h_k = vmlaq_f32(h_k, alpha_128, x_k);
+ // Store the result.
+ vst1q_f32(h_p, h_k);
+ }
+ *filters_updated = true;
+ }
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+
+void MatchedFilterCore_NEON(size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 4);
+
+ if (compute_accumulated_error) {
+ return MatchedFilterCoreWithAccumulatedError_NEON(
+ x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
+ error_sum, accumulated_error, scratch_memory);
+ }
+
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+
+ RTC_DCHECK_GT(x_size, x_start_index);
+ const float* x_p = &x[x_start_index];
+ const float* h_p = &h[0];
+
+ // Initialize values for the accumulation.
+ float32x4_t s_128 = vdupq_n_f32(0);
+ float32x4_t x2_sum_128 = vdupq_n_f32(0);
+ float x2_sum = 0.f;
+ float s = 0;
+
+ // Compute loop chunk sizes until, and after, the wraparound of the circular
+ // buffer for x.
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+
+ // Perform the loop in two chunks.
+ const int chunk2 = h_size - chunk1;
+ for (int limit : {chunk1, chunk2}) {
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = limit >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ const float32x4_t x_k = vld1q_f32(x_p);
+ const float32x4_t h_k = vld1q_f32(h_p);
+ // Compute and accumulate x * x and h * x.
+ x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
+ s_128 = vmlaq_f32(s_128, h_k, x_k);
+ }
+
+ // Perform non-vector operations for any remaining items.
+ for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
+ const float x_k = *x_p;
+ x2_sum += x_k * x_k;
+ s += *h_p * x_k;
+ }
+
+ x_p = &x[0];
+ }
+
+ // Combine the accumulated vector and scalar values.
+ s += SumAllElements(s_128);
+ x2_sum += SumAllElements(x2_sum_128);
+
+ // Compute the matched filter error.
+ float e = y[i] - s;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const float32x4_t alpha_128 = vmovq_n_f32(alpha);
+
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ x_p = &x[x_start_index];
+
+ // Perform the loop in two chunks.
+ for (int limit : {chunk1, chunk2}) {
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = limit >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ float32x4_t h_k = vld1q_f32(h_p);
+ const float32x4_t x_k = vld1q_f32(x_p);
+ // Compute h = h + alpha * x.
+ h_k = vmlaq_f32(h_k, alpha_128, x_k);
+
+ // Store the result.
+ vst1q_f32(h_p, h_k);
+ }
+
+ // Perform non-vector operations for any remaining items.
+ for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
+ *h_p += alpha * *x_p;
+ }
+
+ x_p = &x[0];
+ }
+
+ *filters_updated = true;
+ }
+
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+
+#endif
+
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+
+void MatchedFilterCore_AccumulatedError_SSE2(
+ size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 8);
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ RTC_DCHECK_GT(x_size, x_start_index);
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+ if (chunk1 != h_size) {
+ const int chunk2 = h_size - chunk1;
+ std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
+ std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
+ }
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ const float* h_p = &h[0];
+ float* a_p = &accumulated_error[0];
+ __m128 s_inst_128;
+ __m128 s_inst_128_4;
+ __m128 x2_sum_128 = _mm_set1_ps(0);
+ __m128 x2_sum_128_4 = _mm_set1_ps(0);
+ __m128 e_128;
+ float* const s_p = reinterpret_cast<float*>(&s_inst_128);
+ float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4);
+ float* const e_p = reinterpret_cast<float*>(&e_128);
+ float x2_sum = 0.0f;
+ float s_acum = 0;
+ // Perform 128 bit vector operations.
+ const int limit_by_8 = h_size >> 3;
+ for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) {
+ // Load the data into 128 bit vectors.
+ const __m128 x_k = _mm_loadu_ps(x_p);
+ const __m128 h_k = _mm_loadu_ps(h_p);
+ const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
+ const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
+ const __m128 xx = _mm_mul_ps(x_k, x_k);
+ const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
+ // Compute and accumulate x * x and h * x.
+ x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
+ x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
+ s_inst_128 = _mm_mul_ps(h_k, x_k);
+ s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4);
+ s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3];
+ e_p[0] = s_acum - y[i];
+ s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3];
+ e_p[1] = s_acum - y[i];
+ a_p[0] += e_p[0] * e_p[0];
+ a_p[1] += e_p[1] * e_p[1];
+ }
+ // Combine the accumulated vector and scalar values.
+ x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
+ float* v = reinterpret_cast<float*>(&x2_sum_128);
+ x2_sum += v[0] + v[1] + v[2] + v[3];
+ // Compute the matched filter error.
+ float e = y[i] - s_acum;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const __m128 alpha_128 = _mm_set1_ps(alpha);
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ const float* x_p =
+ chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = h_size >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ __m128 h_k = _mm_loadu_ps(h_p);
+ const __m128 x_k = _mm_loadu_ps(x_p);
+ // Compute h = h + alpha * x.
+ const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
+ h_k = _mm_add_ps(h_k, alpha_x);
+ // Store the result.
+ _mm_storeu_ps(h_p, h_k);
+ }
+ *filters_updated = true;
+ }
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+
+void MatchedFilterCore_SSE2(size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error,
+ rtc::ArrayView<float> scratch_memory) {
+ if (compute_accumulated_error) {
+ return MatchedFilterCore_AccumulatedError_SSE2(
+ x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
+ error_sum, accumulated_error, scratch_memory);
+ }
+ const int h_size = static_cast<int>(h.size());
+ const int x_size = static_cast<int>(x.size());
+ RTC_DCHECK_EQ(0, h_size % 4);
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ RTC_DCHECK_GT(x_size, x_start_index);
+ const float* x_p = &x[x_start_index];
+ const float* h_p = &h[0];
+ // Initialize values for the accumulation.
+ __m128 s_128 = _mm_set1_ps(0);
+ __m128 s_128_4 = _mm_set1_ps(0);
+ __m128 x2_sum_128 = _mm_set1_ps(0);
+ __m128 x2_sum_128_4 = _mm_set1_ps(0);
+ float x2_sum = 0.f;
+ float s = 0;
+ // Compute loop chunk sizes until, and after, the wraparound of the circular
+ // buffer for x.
+ const int chunk1 =
+ std::min(h_size, static_cast<int>(x_size - x_start_index));
+ // Perform the loop in two chunks.
+ const int chunk2 = h_size - chunk1;
+ for (int limit : {chunk1, chunk2}) {
+ // Perform 128 bit vector operations.
+ const int limit_by_8 = limit >> 3;
+ for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
+ // Load the data into 128 bit vectors.
+ const __m128 x_k = _mm_loadu_ps(x_p);
+ const __m128 h_k = _mm_loadu_ps(h_p);
+ const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
+ const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
+ const __m128 xx = _mm_mul_ps(x_k, x_k);
+ const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
+ // Compute and accumulate x * x and h * x.
+ x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
+ x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
+ const __m128 hx = _mm_mul_ps(h_k, x_k);
+ const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4);
+ s_128 = _mm_add_ps(s_128, hx);
+ s_128_4 = _mm_add_ps(s_128_4, hx_4);
+ }
+ // Perform non-vector operations for any remaining items.
+ for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
+ const float x_k = *x_p;
+ x2_sum += x_k * x_k;
+ s += *h_p * x_k;
+ }
+ x_p = &x[0];
+ }
+ // Combine the accumulated vector and scalar values.
+ x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
+ float* v = reinterpret_cast<float*>(&x2_sum_128);
+ x2_sum += v[0] + v[1] + v[2] + v[3];
+ s_128 = _mm_add_ps(s_128, s_128_4);
+ v = reinterpret_cast<float*>(&s_128);
+ s += v[0] + v[1] + v[2] + v[3];
+ // Compute the matched filter error.
+ float e = y[i] - s;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+ const __m128 alpha_128 = _mm_set1_ps(alpha);
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ float* h_p = &h[0];
+ x_p = &x[x_start_index];
+ // Perform the loop in two chunks.
+ for (int limit : {chunk1, chunk2}) {
+ // Perform 128 bit vector operations.
+ const int limit_by_4 = limit >> 2;
+ for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
+ // Load the data into 128 bit vectors.
+ __m128 h_k = _mm_loadu_ps(h_p);
+ const __m128 x_k = _mm_loadu_ps(x_p);
+
+ // Compute h = h + alpha * x.
+ const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
+ h_k = _mm_add_ps(h_k, alpha_x);
+ // Store the result.
+ _mm_storeu_ps(h_p, h_k);
+ }
+ // Perform non-vector operations for any remaining items.
+ for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
+ *h_p += alpha * *x_p;
+ }
+ x_p = &x[0];
+ }
+ *filters_updated = true;
+ }
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
+ }
+}
+#endif
+
+void MatchedFilterCore(size_t x_start_index,
+ float x2_sum_threshold,
+ float smoothing,
+ rtc::ArrayView<const float> x,
+ rtc::ArrayView<const float> y,
+ rtc::ArrayView<float> h,
+ bool* filters_updated,
+ float* error_sum,
+ bool compute_accumulated_error,
+ rtc::ArrayView<float> accumulated_error) {
+ if (compute_accumulated_error) {
+ std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
+ }
+
+ // Process for all samples in the sub-block.
+ for (size_t i = 0; i < y.size(); ++i) {
+ // Apply the matched filter as filter * x, and compute x * x.
+ float x2_sum = 0.f;
+ float s = 0;
+ size_t x_index = x_start_index;
+ if (compute_accumulated_error) {
+ for (size_t k = 0; k < h.size(); ++k) {
+ x2_sum += x[x_index] * x[x_index];
+ s += h[k] * x[x_index];
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+ if ((k + 1 & 0b11) == 0) {
+ int idx = k >> 2;
+ accumulated_error[idx] += (y[i] - s) * (y[i] - s);
+ }
+ }
+ } else {
+ for (size_t k = 0; k < h.size(); ++k) {
+ x2_sum += x[x_index] * x[x_index];
+ s += h[k] * x[x_index];
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+ }
+ }
+
+ // Compute the matched filter error.
+ float e = y[i] - s;
+ const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
+ (*error_sum) += e * e;
+
+ // Update the matched filter estimate in an NLMS manner.
+ if (x2_sum > x2_sum_threshold && !saturation) {
+ RTC_DCHECK_LT(0.f, x2_sum);
+ const float alpha = smoothing * e / x2_sum;
+
+ // filter = filter + smoothing * (y - filter * x) * x / x * x.
+ size_t x_index = x_start_index;
+ for (size_t k = 0; k < h.size(); ++k) {
+ h[k] += alpha * x[x_index];
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
+ }
+ *filters_updated = true;
+ }
+
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
+ }
+}
+
+size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) {
+ if (h.size() < 2) {
+ return 0;
+ }
+ float max_element1 = h[0] * h[0];
+ float max_element2 = h[1] * h[1];
+ size_t lag_estimate1 = 0;
+ size_t lag_estimate2 = 1;
+ const size_t last_index = h.size() - 1;
+ // Keeping track of even & odd max elements separately typically allows the
+ // compiler to produce more efficient code.
+ for (size_t k = 2; k < last_index; k += 2) {
+ float element1 = h[k] * h[k];
+ float element2 = h[k + 1] * h[k + 1];
+ if (element1 > max_element1) {
+ max_element1 = element1;
+ lag_estimate1 = k;
+ }
+ if (element2 > max_element2) {
+ max_element2 = element2;
+ lag_estimate2 = k + 1;
+ }
+ }
+ if (max_element2 > max_element1) {
+ max_element1 = max_element2;
+ lag_estimate1 = lag_estimate2;
+ }
+ // In case of odd h size, we have not yet checked the last element.
+ float last_element = h[last_index] * h[last_index];
+ if (last_element > max_element1) {
+ return last_index;
+ }
+ return lag_estimate1;
+}
+
+} // namespace aec3
+
+MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
+ Aec3Optimization optimization,
+ size_t sub_block_size,
+ size_t window_size_sub_blocks,
+ int num_matched_filters,
+ size_t alignment_shift_sub_blocks,
+ float excitation_limit,
+ float smoothing_fast,
+ float smoothing_slow,
+ float matching_filter_threshold,
+ bool detect_pre_echo)
+ : data_dumper_(data_dumper),
+ optimization_(optimization),
+ sub_block_size_(sub_block_size),
+ filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_),
+ filters_(
+ num_matched_filters,
+ std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
+ filters_offsets_(num_matched_filters, 0),
+ excitation_limit_(excitation_limit),
+ smoothing_fast_(smoothing_fast),
+ smoothing_slow_(smoothing_slow),
+ matching_filter_threshold_(matching_filter_threshold),
+ detect_pre_echo_(detect_pre_echo),
+ pre_echo_config_(FetchPreEchoConfiguration()) {
+ RTC_DCHECK(data_dumper);
+ RTC_DCHECK_LT(0, window_size_sub_blocks);
+ RTC_DCHECK((kBlockSize % sub_block_size) == 0);
+ RTC_DCHECK((sub_block_size % 4) == 0);
+ static_assert(kAccumulatedErrorSubSampleRate == 4);
+ if (detect_pre_echo_) {
+ accumulated_error_ = std::vector<std::vector<float>>(
+ num_matched_filters,
+ std::vector<float>(window_size_sub_blocks * sub_block_size_ /
+ kAccumulatedErrorSubSampleRate,
+ 1.0f));
+
+ instantaneous_accumulated_error_ =
+ std::vector<float>(window_size_sub_blocks * sub_block_size_ /
+ kAccumulatedErrorSubSampleRate,
+ 0.0f);
+ scratch_memory_ =
+ std::vector<float>(window_size_sub_blocks * sub_block_size_);
+ }
+}
+
+MatchedFilter::~MatchedFilter() = default;
+
+void MatchedFilter::Reset(bool full_reset) {
+ for (auto& f : filters_) {
+ std::fill(f.begin(), f.end(), 0.f);
+ }
+
+ winner_lag_ = absl::nullopt;
+ reported_lag_estimate_ = absl::nullopt;
+ if (pre_echo_config_.mode != 3 || full_reset) {
+ for (auto& e : accumulated_error_) {
+ std::fill(e.begin(), e.end(), 1.0f);
+ }
+ number_pre_echo_updates_ = 0;
+ }
+}
+
+void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
+ rtc::ArrayView<const float> capture,
+ bool use_slow_smoothing) {
+ RTC_DCHECK_EQ(sub_block_size_, capture.size());
+ auto& y = capture;
+
+ const float smoothing =
+ use_slow_smoothing ? smoothing_slow_ : smoothing_fast_;
+
+ const float x2_sum_threshold =
+ filters_[0].size() * excitation_limit_ * excitation_limit_;
+
+ // Compute anchor for the matched filter error.
+ float error_sum_anchor = 0.0f;
+ for (size_t k = 0; k < y.size(); ++k) {
+ error_sum_anchor += y[k] * y[k];
+ }
+
+ // Apply all matched filters.
+ float winner_error_sum = error_sum_anchor;
+ winner_lag_ = absl::nullopt;
+ reported_lag_estimate_ = absl::nullopt;
+ size_t alignment_shift = 0;
+ absl::optional<size_t> previous_lag_estimate;
+ const int num_filters = static_cast<int>(filters_.size());
+ int winner_index = -1;
+ for (int n = 0; n < num_filters; ++n) {
+ float error_sum = 0.f;
+ bool filters_updated = false;
+ const bool compute_pre_echo =
+ detect_pre_echo_ && n == last_detected_best_lag_filter_;
+
+ size_t x_start_index =
+ (render_buffer.read + alignment_shift + sub_block_size_ - 1) %
+ render_buffer.buffer.size();
+
+ switch (optimization_) {
+#if defined(WEBRTC_ARCH_X86_FAMILY)
+ case Aec3Optimization::kSse2:
+ aec3::MatchedFilterCore_SSE2(
+ x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
+ filters_[n], &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_, scratch_memory_);
+ break;
+ case Aec3Optimization::kAvx2:
+ aec3::MatchedFilterCore_AVX2(
+ x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
+ filters_[n], &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_, scratch_memory_);
+ break;
+#endif
+#if defined(WEBRTC_HAS_NEON)
+ case Aec3Optimization::kNeon:
+ aec3::MatchedFilterCore_NEON(
+ x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
+ filters_[n], &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_, scratch_memory_);
+ break;
+#endif
+ default:
+ aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
+ render_buffer.buffer, y, filters_[n],
+ &filters_updated, &error_sum, compute_pre_echo,
+ instantaneous_accumulated_error_);
+ }
+
+ // Estimate the lag in the matched filter as the distance to the portion in
+ // the filter that contributes the most to the matched filter output. This
+ // is detected as the peak of the matched filter.
+ const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
+ const bool reliable =
+ lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
+ error_sum < matching_filter_threshold_ * error_sum_anchor;
+
+ // Find the best estimate
+ const size_t lag = lag_estimate + alignment_shift;
+ if (filters_updated && reliable && error_sum < winner_error_sum) {
+ winner_error_sum = error_sum;
+ winner_index = n;
+ // In case that 2 matched filters return the same winner candidate
+ // (overlap region), the one with the smaller index is chosen in order
+ // to search for pre-echoes.
+ if (previous_lag_estimate && previous_lag_estimate == lag) {
+ winner_lag_ = previous_lag_estimate;
+ winner_index = n - 1;
+ } else {
+ winner_lag_ = lag;
+ }
+ }
+ previous_lag_estimate = lag;
+ alignment_shift += filter_intra_lag_shift_;
+ }
+
+ if (winner_index != -1) {
+ RTC_DCHECK(winner_lag_.has_value());
+ reported_lag_estimate_ =
+ LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
+ if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
+ const float energy_threshold =
+ pre_echo_config_.mode == 3 ? 1.0f : 30.0f * 30.0f * y.size();
+
+ if (error_sum_anchor > energy_threshold) {
+ const float smooth_constant_increases =
+ pre_echo_config_.mode != 3 ? 0.01f : 0.015f;
+
+ UpdateAccumulatedError(
+ instantaneous_accumulated_error_, accumulated_error_[winner_index],
+ 1.0f / error_sum_anchor, smooth_constant_increases);
+ number_pre_echo_updates_++;
+ }
+ if (pre_echo_config_.mode != 3 || number_pre_echo_updates_ >= 50) {
+ reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
+ pre_echo_config_, accumulated_error_[winner_index],
+ winner_lag_.value(),
+ winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
+ } else {
+ reported_lag_estimate_->pre_echo_lag = winner_lag_.value();
+ }
+ }
+ last_detected_best_lag_filter_ = winner_index;
+ }
+ if (ApmDataDumper::IsAvailable()) {
+ Dump();
+ data_dumper_->DumpRaw("error_sum_anchor", error_sum_anchor / y.size());
+ data_dumper_->DumpRaw("number_pre_echo_updates", number_pre_echo_updates_);
+ data_dumper_->DumpRaw("filter_smoothing", smoothing);
+ }
+}
+
+void MatchedFilter::LogFilterProperties(int sample_rate_hz,
+ size_t shift,
+ size_t downsampling_factor) const {
+ size_t alignment_shift = 0;
+ constexpr int kFsBy1000 = 16;
+ for (size_t k = 0; k < filters_.size(); ++k) {
+ int start = static_cast<int>(alignment_shift * downsampling_factor);
+ int end = static_cast<int>((alignment_shift + filters_[k].size()) *
+ downsampling_factor);
+ RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: "
+ << (start - static_cast<int>(shift)) / kFsBy1000
+ << " ms, end: "
+ << (end - static_cast<int>(shift)) / kFsBy1000
+ << " ms.";
+ alignment_shift += filter_intra_lag_shift_;
+ }
+}
+
+void MatchedFilter::Dump() {
+ for (size_t n = 0; n < filters_.size(); ++n) {
+ const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
+ std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h";
+ data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]);
+ std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n);
+ data_dumper_->DumpRaw(dumper_lag.c_str(),
+ lag_estimate + n * filter_intra_lag_shift_);
+ if (detect_pre_echo_) {
+ std::string dumper_error =
+ "aec3_correlator_error_" + std::to_string(n) + "_h";
+ data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]);
+
+ size_t pre_echo_lag =
+ ComputePreEchoLag(pre_echo_config_, accumulated_error_[n],
+ lag_estimate + n * filter_intra_lag_shift_,
+ n * filter_intra_lag_shift_);
+ std::string dumper_pre_lag =
+ "aec3_correlator_pre_echo_lag_" + std::to_string(n);
+ data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
+ if (static_cast<int>(n) == last_detected_best_lag_filter_) {
+ data_dumper_->DumpRaw("aec3_pre_echo_delay_winner_inst", pre_echo_lag);
+ }
+ }
+ }
+}
+
+} // namespace webrtc