diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc')
-rw-r--r-- | third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc | 900 |
1 files changed, 900 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc new file mode 100644 index 0000000000..af30ff1b9f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc @@ -0,0 +1,900 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/matched_filter.h" + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_HAS_NEON) +#include <arm_neon.h> +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <algorithm> +#include <cstddef> +#include <initializer_list> +#include <iterator> +#include <numeric> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/experiments/field_trial_parser.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" + +namespace { + +// Subsample rate used for computing the accumulated error. +// The implementation of some core functions depends on this constant being +// equal to 4. +constexpr int kAccumulatedErrorSubSampleRate = 4; + +void UpdateAccumulatedError( + const rtc::ArrayView<const float> instantaneous_accumulated_error, + const rtc::ArrayView<float> accumulated_error, + float one_over_error_sum_anchor, + float smooth_constant_increases) { + for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) { + float error_norm = + instantaneous_accumulated_error[k] * one_over_error_sum_anchor; + if (error_norm < accumulated_error[k]) { + accumulated_error[k] = error_norm; + } else { + accumulated_error[k] += + smooth_constant_increases * (error_norm - accumulated_error[k]); + } + } +} + +size_t ComputePreEchoLag( + const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration, + const rtc::ArrayView<const float> accumulated_error, + size_t lag, + size_t alignment_shift_winner) { + RTC_DCHECK_GE(lag, alignment_shift_winner); + size_t pre_echo_lag_estimate = lag - alignment_shift_winner; + size_t maximum_pre_echo_lag = + std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate, + accumulated_error.size()); + switch (pre_echo_configuration.mode) { + case 0: + // Mode 0: Pre echo lag is defined as the first coefficient with an error + // lower than a threshold with a certain decrease slope. + for (size_t k = 1; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < + pre_echo_configuration.threshold * accumulated_error[k - 1] && + accumulated_error[k] < pre_echo_configuration.threshold) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + break; + case 1: + // Mode 1: Pre echo lag is defined as the first coefficient with an error + // lower than a certain threshold. + for (size_t k = 0; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < pre_echo_configuration.threshold) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + break; + case 2: + case 3: + // Mode 2,3: Pre echo lag is defined as the closest coefficient to the lag + // with an error lower than a certain threshold. + for (int k = static_cast<int>(maximum_pre_echo_lag) - 1; k >= 0; --k) { + if (accumulated_error[k] > pre_echo_configuration.threshold) { + break; + } + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + } + break; + default: + RTC_DCHECK_NOTREACHED(); + break; + } + return pre_echo_lag_estimate + alignment_shift_winner; +} + +webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() { + float threshold = 0.5f; + int mode = 0; + const std::string pre_echo_configuration_field_trial = + webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration"); + webrtc::FieldTrialParameter<double> threshold_field_trial_parameter( + /*key=*/"threshold", /*default_value=*/threshold); + webrtc::FieldTrialParameter<int> mode_field_trial_parameter( + /*key=*/"mode", /*default_value=*/mode); + webrtc::ParseFieldTrial( + {&threshold_field_trial_parameter, &mode_field_trial_parameter}, + pre_echo_configuration_field_trial); + float threshold_read = + static_cast<float>(threshold_field_trial_parameter.Get()); + int mode_read = mode_field_trial_parameter.Get(); + if (threshold_read < 1.0f && threshold_read > 0.0f) { + threshold = threshold_read; + } else { + RTC_LOG(LS_ERROR) + << "AEC3: Pre echo configuration: wrong input, threshold = " + << threshold_read << "."; + } + if (mode_read >= 0 && mode_read <= 3) { + mode = mode_read; + } else { + RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, mode = " + << mode_read << "."; + } + RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration: threshold = " << threshold + << ", mode = " << mode << "."; + return {.threshold = threshold, .mode = mode}; +} + +} // namespace + +namespace webrtc { +namespace aec3 { + +#if defined(WEBRTC_HAS_NEON) + +inline float SumAllElements(float32x4_t elements) { + float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements)); + sum = vpadd_f32(sum, sum); + return vget_lane_f32(sum, 0); +} + +void MatchedFilterCoreWithAccumulatedError_NEON( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* accumulated_error_p = &accumulated_error[0]; + // Initialize values for the accumulation. + float32x4_t x2_sum_128 = vdupq_n_f32(0); + float x2_sum = 0.f; + float s = 0; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; + --k, h_p += 4, x_p += 4, accumulated_error_p++) { + // Load the data into 128 bit vectors. + const float32x4_t x_k = vld1q_f32(x_p); + const float32x4_t h_k = vld1q_f32(h_p); + // Compute and accumulate x * x. + x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); + // Compute x * h + float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k); + s += SumAllElements(hk_xk_128); + const float e = s - y[i]; + accumulated_error_p[0] += e * e; + } + // Combine the accumulated vector and scalar values. + x2_sum += SumAllElements(x2_sum_128); + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const float32x4_t alpha_128 = vmovq_n_f32(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + float32x4_t h_k = vld1q_f32(h_p); + const float32x4_t x_k = vld1q_f32(x_p); + // Compute h = h + alpha * x. + h_k = vmlaq_f32(h_k, alpha_128, x_k); + // Store the result. + vst1q_f32(h_p, h_k); + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +void MatchedFilterCore_NEON(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + + if (compute_accumulated_error) { + return MatchedFilterCoreWithAccumulatedError_NEON( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + + // Initialize values for the accumulation. + float32x4_t s_128 = vdupq_n_f32(0); + float32x4_t x2_sum_128 = vdupq_n_f32(0); + float x2_sum = 0.f; + float s = 0; + + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + const float32x4_t x_k = vld1q_f32(x_p); + const float32x4_t h_k = vld1q_f32(h_p); + // Compute and accumulate x * x and h * x. + x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); + s_128 = vmlaq_f32(s_128, h_k, x_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + + x_p = &x[0]; + } + + // Combine the accumulated vector and scalar values. + s += SumAllElements(s_128); + x2_sum += SumAllElements(x2_sum_128); + + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const float32x4_t alpha_128 = vmovq_n_f32(alpha); + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + float32x4_t h_k = vld1q_f32(h_p); + const float32x4_t x_k = vld1q_f32(x_p); + // Compute h = h + alpha * x. + h_k = vmlaq_f32(h_k, alpha_128, x_k); + + // Store the result. + vst1q_f32(h_p, h_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + + x_p = &x[0]; + } + + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void MatchedFilterCore_AccumulatedError_SSE2( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 8); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* a_p = &accumulated_error[0]; + __m128 s_inst_128; + __m128 s_inst_128_4; + __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); + __m128 e_128; + float* const s_p = reinterpret_cast<float*>(&s_inst_128); + float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4); + float* const e_p = reinterpret_cast<float*>(&e_128); + float x2_sum = 0.0f; + float s_acum = 0; + // Perform 128 bit vector operations. + const int limit_by_8 = h_size >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) { + // Load the data into 128 bit vectors. + const __m128 x_k = _mm_loadu_ps(x_p); + const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); + const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); + // Compute and accumulate x * x and h * x. + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); + s_inst_128 = _mm_mul_ps(h_k, x_k); + s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4); + s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3]; + e_p[0] = s_acum - y[i]; + s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3]; + e_p[1] = s_acum - y[i]; + a_p[0] += e_p[0] * e_p[0]; + a_p[1] += e_p[1] * e_p[1]; + } + // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + // Compute the matched filter error. + float e = y[i] - s_acum; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m128 alpha_128 = _mm_set1_ps(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k = _mm_loadu_ps(x_p); + // Compute h = h + alpha * x. + const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); + h_k = _mm_add_ps(h_k, alpha_x); + // Store the result. + _mm_storeu_ps(h_p, h_k); + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +void MatchedFilterCore_SSE2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + if (compute_accumulated_error) { + return MatchedFilterCore_AccumulatedError_SSE2( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + // Initialize values for the accumulation. + __m128 s_128 = _mm_set1_ps(0); + __m128 s_128_4 = _mm_set1_ps(0); + __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); + float x2_sum = 0.f; + float s = 0; + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 128 bit vectors. + const __m128 x_k = _mm_loadu_ps(x_p); + const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); + const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); + // Compute and accumulate x * x and h * x. + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); + const __m128 hx = _mm_mul_ps(h_k, x_k); + const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4); + s_128 = _mm_add_ps(s_128, hx); + s_128_4 = _mm_add_ps(s_128_4, hx_4); + } + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + x_p = &x[0]; + } + // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + s_128 = _mm_add_ps(s_128, s_128_4); + v = reinterpret_cast<float*>(&s_128); + s += v[0] + v[1] + v[2] + v[3]; + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m128 alpha_128 = _mm_set1_ps(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k = _mm_loadu_ps(x_p); + + // Compute h = h + alpha * x. + const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); + h_k = _mm_add_ps(h_k, alpha_x); + // Store the result. + _mm_storeu_ps(h_p, h_k); + } + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + x_p = &x[0]; + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} +#endif + +void MatchedFilterCore(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error) { + if (compute_accumulated_error) { + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + } + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + float x2_sum = 0.f; + float s = 0; + size_t x_index = x_start_index; + if (compute_accumulated_error) { + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + if ((k + 1 & 0b11) == 0) { + int idx = k >> 2; + accumulated_error[idx] += (y[i] - s) * (y[i] - s); + } + } + } else { + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + } + + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + size_t x_index = x_start_index; + for (size_t k = 0; k < h.size(); ++k) { + h[k] += alpha * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; + } +} + +size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) { + if (h.size() < 2) { + return 0; + } + float max_element1 = h[0] * h[0]; + float max_element2 = h[1] * h[1]; + size_t lag_estimate1 = 0; + size_t lag_estimate2 = 1; + const size_t last_index = h.size() - 1; + // Keeping track of even & odd max elements separately typically allows the + // compiler to produce more efficient code. + for (size_t k = 2; k < last_index; k += 2) { + float element1 = h[k] * h[k]; + float element2 = h[k + 1] * h[k + 1]; + if (element1 > max_element1) { + max_element1 = element1; + lag_estimate1 = k; + } + if (element2 > max_element2) { + max_element2 = element2; + lag_estimate2 = k + 1; + } + } + if (max_element2 > max_element1) { + max_element1 = max_element2; + lag_estimate1 = lag_estimate2; + } + // In case of odd h size, we have not yet checked the last element. + float last_element = h[last_index] * h[last_index]; + if (last_element > max_element1) { + return last_index; + } + return lag_estimate1; +} + +} // namespace aec3 + +MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, + Aec3Optimization optimization, + size_t sub_block_size, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks, + float excitation_limit, + float smoothing_fast, + float smoothing_slow, + float matching_filter_threshold, + bool detect_pre_echo) + : data_dumper_(data_dumper), + optimization_(optimization), + sub_block_size_(sub_block_size), + filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_), + filters_( + num_matched_filters, + std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)), + filters_offsets_(num_matched_filters, 0), + excitation_limit_(excitation_limit), + smoothing_fast_(smoothing_fast), + smoothing_slow_(smoothing_slow), + matching_filter_threshold_(matching_filter_threshold), + detect_pre_echo_(detect_pre_echo), + pre_echo_config_(FetchPreEchoConfiguration()) { + RTC_DCHECK(data_dumper); + RTC_DCHECK_LT(0, window_size_sub_blocks); + RTC_DCHECK((kBlockSize % sub_block_size) == 0); + RTC_DCHECK((sub_block_size % 4) == 0); + static_assert(kAccumulatedErrorSubSampleRate == 4); + if (detect_pre_echo_) { + accumulated_error_ = std::vector<std::vector<float>>( + num_matched_filters, + std::vector<float>(window_size_sub_blocks * sub_block_size_ / + kAccumulatedErrorSubSampleRate, + 1.0f)); + + instantaneous_accumulated_error_ = + std::vector<float>(window_size_sub_blocks * sub_block_size_ / + kAccumulatedErrorSubSampleRate, + 0.0f); + scratch_memory_ = + std::vector<float>(window_size_sub_blocks * sub_block_size_); + } +} + +MatchedFilter::~MatchedFilter() = default; + +void MatchedFilter::Reset(bool full_reset) { + for (auto& f : filters_) { + std::fill(f.begin(), f.end(), 0.f); + } + + winner_lag_ = absl::nullopt; + reported_lag_estimate_ = absl::nullopt; + if (pre_echo_config_.mode != 3 || full_reset) { + for (auto& e : accumulated_error_) { + std::fill(e.begin(), e.end(), 1.0f); + } + number_pre_echo_updates_ = 0; + } +} + +void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, + rtc::ArrayView<const float> capture, + bool use_slow_smoothing) { + RTC_DCHECK_EQ(sub_block_size_, capture.size()); + auto& y = capture; + + const float smoothing = + use_slow_smoothing ? smoothing_slow_ : smoothing_fast_; + + const float x2_sum_threshold = + filters_[0].size() * excitation_limit_ * excitation_limit_; + + // Compute anchor for the matched filter error. + float error_sum_anchor = 0.0f; + for (size_t k = 0; k < y.size(); ++k) { + error_sum_anchor += y[k] * y[k]; + } + + // Apply all matched filters. + float winner_error_sum = error_sum_anchor; + winner_lag_ = absl::nullopt; + reported_lag_estimate_ = absl::nullopt; + size_t alignment_shift = 0; + absl::optional<size_t> previous_lag_estimate; + const int num_filters = static_cast<int>(filters_.size()); + int winner_index = -1; + for (int n = 0; n < num_filters; ++n) { + float error_sum = 0.f; + bool filters_updated = false; + const bool compute_pre_echo = + detect_pre_echo_ && n == last_detected_best_lag_filter_; + + size_t x_start_index = + (render_buffer.read + alignment_shift + sub_block_size_ - 1) % + render_buffer.buffer.size(); + + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::MatchedFilterCore_SSE2( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); + break; + case Aec3Optimization::kAvx2: + aec3::MatchedFilterCore_AVX2( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::MatchedFilterCore_NEON( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); + break; +#endif + default: + aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing, + render_buffer.buffer, y, filters_[n], + &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_); + } + + // Estimate the lag in the matched filter as the distance to the portion in + // the filter that contributes the most to the matched filter output. This + // is detected as the peak of the matched filter. + const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); + const bool reliable = + lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && + error_sum < matching_filter_threshold_ * error_sum_anchor; + + // Find the best estimate + const size_t lag = lag_estimate + alignment_shift; + if (filters_updated && reliable && error_sum < winner_error_sum) { + winner_error_sum = error_sum; + winner_index = n; + // In case that 2 matched filters return the same winner candidate + // (overlap region), the one with the smaller index is chosen in order + // to search for pre-echoes. + if (previous_lag_estimate && previous_lag_estimate == lag) { + winner_lag_ = previous_lag_estimate; + winner_index = n - 1; + } else { + winner_lag_ = lag; + } + } + previous_lag_estimate = lag; + alignment_shift += filter_intra_lag_shift_; + } + + if (winner_index != -1) { + RTC_DCHECK(winner_lag_.has_value()); + reported_lag_estimate_ = + LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value()); + if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) { + const float energy_threshold = + pre_echo_config_.mode == 3 ? 1.0f : 30.0f * 30.0f * y.size(); + + if (error_sum_anchor > energy_threshold) { + const float smooth_constant_increases = + pre_echo_config_.mode != 3 ? 0.01f : 0.015f; + + UpdateAccumulatedError( + instantaneous_accumulated_error_, accumulated_error_[winner_index], + 1.0f / error_sum_anchor, smooth_constant_increases); + number_pre_echo_updates_++; + } + if (pre_echo_config_.mode != 3 || number_pre_echo_updates_ >= 50) { + reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag( + pre_echo_config_, accumulated_error_[winner_index], + winner_lag_.value(), + winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/); + } else { + reported_lag_estimate_->pre_echo_lag = winner_lag_.value(); + } + } + last_detected_best_lag_filter_ = winner_index; + } + if (ApmDataDumper::IsAvailable()) { + Dump(); + data_dumper_->DumpRaw("error_sum_anchor", error_sum_anchor / y.size()); + data_dumper_->DumpRaw("number_pre_echo_updates", number_pre_echo_updates_); + data_dumper_->DumpRaw("filter_smoothing", smoothing); + } +} + +void MatchedFilter::LogFilterProperties(int sample_rate_hz, + size_t shift, + size_t downsampling_factor) const { + size_t alignment_shift = 0; + constexpr int kFsBy1000 = 16; + for (size_t k = 0; k < filters_.size(); ++k) { + int start = static_cast<int>(alignment_shift * downsampling_factor); + int end = static_cast<int>((alignment_shift + filters_[k].size()) * + downsampling_factor); + RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: " + << (start - static_cast<int>(shift)) / kFsBy1000 + << " ms, end: " + << (end - static_cast<int>(shift)) / kFsBy1000 + << " ms."; + alignment_shift += filter_intra_lag_shift_; + } +} + +void MatchedFilter::Dump() { + for (size_t n = 0; n < filters_.size(); ++n) { + const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); + std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h"; + data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]); + std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n); + data_dumper_->DumpRaw(dumper_lag.c_str(), + lag_estimate + n * filter_intra_lag_shift_); + if (detect_pre_echo_) { + std::string dumper_error = + "aec3_correlator_error_" + std::to_string(n) + "_h"; + data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]); + + size_t pre_echo_lag = + ComputePreEchoLag(pre_echo_config_, accumulated_error_[n], + lag_estimate + n * filter_intra_lag_shift_, + n * filter_intra_lag_shift_); + std::string dumper_pre_lag = + "aec3_correlator_pre_echo_lag_" + std::to_string(n); + data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag); + if (static_cast<int>(n) == last_detected_best_lag_filter_) { + data_dumper_->DumpRaw("aec3_pre_echo_delay_winner_inst", pre_echo_lag); + } + } + } +} + +} // namespace webrtc |