/* * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/aec3/matched_filter.h" // Defines WEBRTC_ARCH_X86_FAMILY, used below. #include "rtc_base/system/arch.h" #if defined(WEBRTC_HAS_NEON) #include #endif #if defined(WEBRTC_ARCH_X86_FAMILY) #include #endif #include #include #include #include #include #include "absl/types/optional.h" #include "api/array_view.h" #include "modules/audio_processing/aec3/downsampled_render_buffer.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" #include "rtc_base/experiments/field_trial_parser.h" #include "rtc_base/logging.h" #include "system_wrappers/include/field_trial.h" namespace { // Subsample rate used for computing the accumulated error. // The implementation of some core functions depends on this constant being // equal to 4. constexpr int kAccumulatedErrorSubSampleRate = 4; void UpdateAccumulatedError( const rtc::ArrayView instantaneous_accumulated_error, const rtc::ArrayView accumulated_error, float one_over_error_sum_anchor, float smooth_constant_increases) { for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) { float error_norm = instantaneous_accumulated_error[k] * one_over_error_sum_anchor; if (error_norm < accumulated_error[k]) { accumulated_error[k] = error_norm; } else { accumulated_error[k] += smooth_constant_increases * (error_norm - accumulated_error[k]); } } } size_t ComputePreEchoLag( const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration, const rtc::ArrayView accumulated_error, size_t lag, size_t alignment_shift_winner) { RTC_DCHECK_GE(lag, alignment_shift_winner); size_t pre_echo_lag_estimate = lag - alignment_shift_winner; size_t maximum_pre_echo_lag = std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate, accumulated_error.size()); switch (pre_echo_configuration.mode) { case 0: // Mode 0: Pre echo lag is defined as the first coefficient with an error // lower than a threshold with a certain decrease slope. for (size_t k = 1; k < maximum_pre_echo_lag; ++k) { if (accumulated_error[k] < pre_echo_configuration.threshold * accumulated_error[k - 1] && accumulated_error[k] < pre_echo_configuration.threshold) { pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; break; } } break; case 1: // Mode 1: Pre echo lag is defined as the first coefficient with an error // lower than a certain threshold. for (size_t k = 0; k < maximum_pre_echo_lag; ++k) { if (accumulated_error[k] < pre_echo_configuration.threshold) { pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; break; } } break; case 2: case 3: // Mode 2,3: Pre echo lag is defined as the closest coefficient to the lag // with an error lower than a certain threshold. for (int k = static_cast(maximum_pre_echo_lag) - 1; k >= 0; --k) { if (accumulated_error[k] > pre_echo_configuration.threshold) { break; } pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; } break; default: RTC_DCHECK_NOTREACHED(); break; } return pre_echo_lag_estimate + alignment_shift_winner; } webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() { constexpr float kDefaultThreshold = 0.5f; constexpr int kDefaultMode = 3; float threshold = kDefaultThreshold; int mode = kDefaultMode; const std::string pre_echo_configuration_field_trial = webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration"); webrtc::FieldTrialParameter threshold_field_trial_parameter( /*key=*/"threshold", /*default_value=*/kDefaultThreshold); webrtc::FieldTrialParameter mode_field_trial_parameter( /*key=*/"mode", /*default_value=*/kDefaultMode); webrtc::ParseFieldTrial( {&threshold_field_trial_parameter, &mode_field_trial_parameter}, pre_echo_configuration_field_trial); float threshold_read = static_cast(threshold_field_trial_parameter.Get()); int mode_read = mode_field_trial_parameter.Get(); if (threshold_read < 1.0f && threshold_read > 0.0f) { threshold = threshold_read; } else { RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, threshold = " << threshold_read << "."; } if (mode_read >= 0 && mode_read <= 3) { mode = mode_read; } else { RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, mode = " << mode_read << "."; } RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration: threshold = " << threshold << ", mode = " << mode << "."; return {.threshold = threshold, .mode = mode}; } } // namespace namespace webrtc { namespace aec3 { #if defined(WEBRTC_HAS_NEON) inline float SumAllElements(float32x4_t elements) { float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements)); sum = vpadd_f32(sum, sum); return vget_lane_f32(sum, 0); } void MatchedFilterCoreWithAccumulatedError_NEON( size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory) { const int h_size = static_cast(h.size()); const int x_size = static_cast(x.size()); RTC_DCHECK_EQ(0, h_size % 4); std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. RTC_DCHECK_GT(x_size, x_start_index); // Compute loop chunk sizes until, and after, the wraparound of the circular // buffer for x. const int chunk1 = std::min(h_size, static_cast(x_size - x_start_index)); if (chunk1 != h_size) { const int chunk2 = h_size - chunk1; std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); } const float* x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; const float* h_p = &h[0]; float* accumulated_error_p = &accumulated_error[0]; // Initialize values for the accumulation. float32x4_t x2_sum_128 = vdupq_n_f32(0); float x2_sum = 0.f; float s = 0; // Perform 128 bit vector operations. const int limit_by_4 = h_size >> 2; for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4, accumulated_error_p++) { // Load the data into 128 bit vectors. const float32x4_t x_k = vld1q_f32(x_p); const float32x4_t h_k = vld1q_f32(h_p); // Compute and accumulate x * x. x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); // Compute x * h float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k); s += SumAllElements(hk_xk_128); const float e = s - y[i]; accumulated_error_p[0] += e * e; } // Combine the accumulated vector and scalar values. x2_sum += SumAllElements(x2_sum_128); // Compute the matched filter error. float e = y[i] - s; const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; (*error_sum) += e * e; // Update the matched filter estimate in an NLMS manner. if (x2_sum > x2_sum_threshold && !saturation) { RTC_DCHECK_LT(0.f, x2_sum); const float alpha = smoothing * e / x2_sum; const float32x4_t alpha_128 = vmovq_n_f32(alpha); // filter = filter + smoothing * (y - filter * x) * x / x * x. float* h_p = &h[0]; x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; // Perform 128 bit vector operations. const int limit_by_4 = h_size >> 2; for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { // Load the data into 128 bit vectors. float32x4_t h_k = vld1q_f32(h_p); const float32x4_t x_k = vld1q_f32(x_p); // Compute h = h + alpha * x. h_k = vmlaq_f32(h_k, alpha_128, x_k); // Store the result. vst1q_f32(h_p, h_k); } *filters_updated = true; } x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; } } void MatchedFilterCore_NEON(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulated_error, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory) { const int h_size = static_cast(h.size()); const int x_size = static_cast(x.size()); RTC_DCHECK_EQ(0, h_size % 4); if (compute_accumulated_error) { return MatchedFilterCoreWithAccumulatedError_NEON( x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, error_sum, accumulated_error, scratch_memory); } // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. RTC_DCHECK_GT(x_size, x_start_index); const float* x_p = &x[x_start_index]; const float* h_p = &h[0]; // Initialize values for the accumulation. float32x4_t s_128 = vdupq_n_f32(0); float32x4_t x2_sum_128 = vdupq_n_f32(0); float x2_sum = 0.f; float s = 0; // Compute loop chunk sizes until, and after, the wraparound of the circular // buffer for x. const int chunk1 = std::min(h_size, static_cast(x_size - x_start_index)); // Perform the loop in two chunks. const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. const int limit_by_4 = limit >> 2; for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { // Load the data into 128 bit vectors. const float32x4_t x_k = vld1q_f32(x_p); const float32x4_t h_k = vld1q_f32(h_p); // Compute and accumulate x * x and h * x. x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); s_128 = vmlaq_f32(s_128, h_k, x_k); } // Perform non-vector operations for any remaining items. for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; } x_p = &x[0]; } // Combine the accumulated vector and scalar values. s += SumAllElements(s_128); x2_sum += SumAllElements(x2_sum_128); // Compute the matched filter error. float e = y[i] - s; const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; (*error_sum) += e * e; // Update the matched filter estimate in an NLMS manner. if (x2_sum > x2_sum_threshold && !saturation) { RTC_DCHECK_LT(0.f, x2_sum); const float alpha = smoothing * e / x2_sum; const float32x4_t alpha_128 = vmovq_n_f32(alpha); // filter = filter + smoothing * (y - filter * x) * x / x * x. float* h_p = &h[0]; x_p = &x[x_start_index]; // Perform the loop in two chunks. for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. const int limit_by_4 = limit >> 2; for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { // Load the data into 128 bit vectors. float32x4_t h_k = vld1q_f32(h_p); const float32x4_t x_k = vld1q_f32(x_p); // Compute h = h + alpha * x. h_k = vmlaq_f32(h_k, alpha_128, x_k); // Store the result. vst1q_f32(h_p, h_k); } // Perform non-vector operations for any remaining items. for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { *h_p += alpha * *x_p; } x_p = &x[0]; } *filters_updated = true; } x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; } } #endif #if defined(WEBRTC_ARCH_X86_FAMILY) void MatchedFilterCore_AccumulatedError_SSE2( size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory) { const int h_size = static_cast(h.size()); const int x_size = static_cast(x.size()); RTC_DCHECK_EQ(0, h_size % 8); std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. RTC_DCHECK_GT(x_size, x_start_index); const int chunk1 = std::min(h_size, static_cast(x_size - x_start_index)); if (chunk1 != h_size) { const int chunk2 = h_size - chunk1; std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); } const float* x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; const float* h_p = &h[0]; float* a_p = &accumulated_error[0]; __m128 s_inst_128; __m128 s_inst_128_4; __m128 x2_sum_128 = _mm_set1_ps(0); __m128 x2_sum_128_4 = _mm_set1_ps(0); __m128 e_128; float* const s_p = reinterpret_cast(&s_inst_128); float* const s_4_p = reinterpret_cast(&s_inst_128_4); float* const e_p = reinterpret_cast(&e_128); float x2_sum = 0.0f; float s_acum = 0; // Perform 128 bit vector operations. const int limit_by_8 = h_size >> 3; for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) { // Load the data into 128 bit vectors. const __m128 x_k = _mm_loadu_ps(x_p); const __m128 h_k = _mm_loadu_ps(h_p); const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); const __m128 xx = _mm_mul_ps(x_k, x_k); const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); // Compute and accumulate x * x and h * x. x2_sum_128 = _mm_add_ps(x2_sum_128, xx); x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); s_inst_128 = _mm_mul_ps(h_k, x_k); s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4); s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3]; e_p[0] = s_acum - y[i]; s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3]; e_p[1] = s_acum - y[i]; a_p[0] += e_p[0] * e_p[0]; a_p[1] += e_p[1] * e_p[1]; } // Combine the accumulated vector and scalar values. x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); float* v = reinterpret_cast(&x2_sum_128); x2_sum += v[0] + v[1] + v[2] + v[3]; // Compute the matched filter error. float e = y[i] - s_acum; const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; (*error_sum) += e * e; // Update the matched filter estimate in an NLMS manner. if (x2_sum > x2_sum_threshold && !saturation) { RTC_DCHECK_LT(0.f, x2_sum); const float alpha = smoothing * e / x2_sum; const __m128 alpha_128 = _mm_set1_ps(alpha); // filter = filter + smoothing * (y - filter * x) * x / x * x. float* h_p = &h[0]; const float* x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; // Perform 128 bit vector operations. const int limit_by_4 = h_size >> 2; for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { // Load the data into 128 bit vectors. __m128 h_k = _mm_loadu_ps(h_p); const __m128 x_k = _mm_loadu_ps(x_p); // Compute h = h + alpha * x. const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); h_k = _mm_add_ps(h_k, alpha_x); // Store the result. _mm_storeu_ps(h_p, h_k); } *filters_updated = true; } x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; } } void MatchedFilterCore_SSE2(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulated_error, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory) { if (compute_accumulated_error) { return MatchedFilterCore_AccumulatedError_SSE2( x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, error_sum, accumulated_error, scratch_memory); } const int h_size = static_cast(h.size()); const int x_size = static_cast(x.size()); RTC_DCHECK_EQ(0, h_size % 4); // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. RTC_DCHECK_GT(x_size, x_start_index); const float* x_p = &x[x_start_index]; const float* h_p = &h[0]; // Initialize values for the accumulation. __m128 s_128 = _mm_set1_ps(0); __m128 s_128_4 = _mm_set1_ps(0); __m128 x2_sum_128 = _mm_set1_ps(0); __m128 x2_sum_128_4 = _mm_set1_ps(0); float x2_sum = 0.f; float s = 0; // Compute loop chunk sizes until, and after, the wraparound of the circular // buffer for x. const int chunk1 = std::min(h_size, static_cast(x_size - x_start_index)); // Perform the loop in two chunks. const int chunk2 = h_size - chunk1; for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. const int limit_by_8 = limit >> 3; for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { // Load the data into 128 bit vectors. const __m128 x_k = _mm_loadu_ps(x_p); const __m128 h_k = _mm_loadu_ps(h_p); const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); const __m128 xx = _mm_mul_ps(x_k, x_k); const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); // Compute and accumulate x * x and h * x. x2_sum_128 = _mm_add_ps(x2_sum_128, xx); x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); const __m128 hx = _mm_mul_ps(h_k, x_k); const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4); s_128 = _mm_add_ps(s_128, hx); s_128_4 = _mm_add_ps(s_128_4, hx_4); } // Perform non-vector operations for any remaining items. for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { const float x_k = *x_p; x2_sum += x_k * x_k; s += *h_p * x_k; } x_p = &x[0]; } // Combine the accumulated vector and scalar values. x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); float* v = reinterpret_cast(&x2_sum_128); x2_sum += v[0] + v[1] + v[2] + v[3]; s_128 = _mm_add_ps(s_128, s_128_4); v = reinterpret_cast(&s_128); s += v[0] + v[1] + v[2] + v[3]; // Compute the matched filter error. float e = y[i] - s; const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; (*error_sum) += e * e; // Update the matched filter estimate in an NLMS manner. if (x2_sum > x2_sum_threshold && !saturation) { RTC_DCHECK_LT(0.f, x2_sum); const float alpha = smoothing * e / x2_sum; const __m128 alpha_128 = _mm_set1_ps(alpha); // filter = filter + smoothing * (y - filter * x) * x / x * x. float* h_p = &h[0]; x_p = &x[x_start_index]; // Perform the loop in two chunks. for (int limit : {chunk1, chunk2}) { // Perform 128 bit vector operations. const int limit_by_4 = limit >> 2; for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { // Load the data into 128 bit vectors. __m128 h_k = _mm_loadu_ps(h_p); const __m128 x_k = _mm_loadu_ps(x_p); // Compute h = h + alpha * x. const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); h_k = _mm_add_ps(h_k, alpha_x); // Store the result. _mm_storeu_ps(h_p, h_k); } // Perform non-vector operations for any remaining items. for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { *h_p += alpha * *x_p; } x_p = &x[0]; } *filters_updated = true; } x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; } } #endif void MatchedFilterCore(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulated_error, rtc::ArrayView accumulated_error) { if (compute_accumulated_error) { std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); } // Process for all samples in the sub-block. for (size_t i = 0; i < y.size(); ++i) { // Apply the matched filter as filter * x, and compute x * x. float x2_sum = 0.f; float s = 0; size_t x_index = x_start_index; if (compute_accumulated_error) { for (size_t k = 0; k < h.size(); ++k) { x2_sum += x[x_index] * x[x_index]; s += h[k] * x[x_index]; x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; if ((k + 1 & 0b11) == 0) { int idx = k >> 2; accumulated_error[idx] += (y[i] - s) * (y[i] - s); } } } else { for (size_t k = 0; k < h.size(); ++k) { x2_sum += x[x_index] * x[x_index]; s += h[k] * x[x_index]; x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; } } // Compute the matched filter error. float e = y[i] - s; const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; (*error_sum) += e * e; // Update the matched filter estimate in an NLMS manner. if (x2_sum > x2_sum_threshold && !saturation) { RTC_DCHECK_LT(0.f, x2_sum); const float alpha = smoothing * e / x2_sum; // filter = filter + smoothing * (y - filter * x) * x / x * x. size_t x_index = x_start_index; for (size_t k = 0; k < h.size(); ++k) { h[k] += alpha * x[x_index]; x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; } *filters_updated = true; } x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; } } size_t MaxSquarePeakIndex(rtc::ArrayView h) { if (h.size() < 2) { return 0; } float max_element1 = h[0] * h[0]; float max_element2 = h[1] * h[1]; size_t lag_estimate1 = 0; size_t lag_estimate2 = 1; const size_t last_index = h.size() - 1; // Keeping track of even & odd max elements separately typically allows the // compiler to produce more efficient code. for (size_t k = 2; k < last_index; k += 2) { float element1 = h[k] * h[k]; float element2 = h[k + 1] * h[k + 1]; if (element1 > max_element1) { max_element1 = element1; lag_estimate1 = k; } if (element2 > max_element2) { max_element2 = element2; lag_estimate2 = k + 1; } } if (max_element2 > max_element1) { max_element1 = max_element2; lag_estimate1 = lag_estimate2; } // In case of odd h size, we have not yet checked the last element. float last_element = h[last_index] * h[last_index]; if (last_element > max_element1) { return last_index; } return lag_estimate1; } } // namespace aec3 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, Aec3Optimization optimization, size_t sub_block_size, size_t window_size_sub_blocks, int num_matched_filters, size_t alignment_shift_sub_blocks, float excitation_limit, float smoothing_fast, float smoothing_slow, float matching_filter_threshold, bool detect_pre_echo) : data_dumper_(data_dumper), optimization_(optimization), sub_block_size_(sub_block_size), filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_), filters_( num_matched_filters, std::vector(window_size_sub_blocks * sub_block_size_, 0.f)), filters_offsets_(num_matched_filters, 0), excitation_limit_(excitation_limit), smoothing_fast_(smoothing_fast), smoothing_slow_(smoothing_slow), matching_filter_threshold_(matching_filter_threshold), detect_pre_echo_(detect_pre_echo), pre_echo_config_(FetchPreEchoConfiguration()) { RTC_DCHECK(data_dumper); RTC_DCHECK_LT(0, window_size_sub_blocks); RTC_DCHECK((kBlockSize % sub_block_size) == 0); RTC_DCHECK((sub_block_size % 4) == 0); static_assert(kAccumulatedErrorSubSampleRate == 4); if (detect_pre_echo_) { accumulated_error_ = std::vector>( num_matched_filters, std::vector(window_size_sub_blocks * sub_block_size_ / kAccumulatedErrorSubSampleRate, 1.0f)); instantaneous_accumulated_error_ = std::vector(window_size_sub_blocks * sub_block_size_ / kAccumulatedErrorSubSampleRate, 0.0f); scratch_memory_ = std::vector(window_size_sub_blocks * sub_block_size_); } } MatchedFilter::~MatchedFilter() = default; void MatchedFilter::Reset(bool full_reset) { for (auto& f : filters_) { std::fill(f.begin(), f.end(), 0.f); } winner_lag_ = absl::nullopt; reported_lag_estimate_ = absl::nullopt; if (pre_echo_config_.mode != 3 || full_reset) { for (auto& e : accumulated_error_) { std::fill(e.begin(), e.end(), 1.0f); } number_pre_echo_updates_ = 0; } } void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, rtc::ArrayView capture, bool use_slow_smoothing) { RTC_DCHECK_EQ(sub_block_size_, capture.size()); auto& y = capture; const float smoothing = use_slow_smoothing ? smoothing_slow_ : smoothing_fast_; const float x2_sum_threshold = filters_[0].size() * excitation_limit_ * excitation_limit_; // Compute anchor for the matched filter error. float error_sum_anchor = 0.0f; for (size_t k = 0; k < y.size(); ++k) { error_sum_anchor += y[k] * y[k]; } // Apply all matched filters. float winner_error_sum = error_sum_anchor; winner_lag_ = absl::nullopt; reported_lag_estimate_ = absl::nullopt; size_t alignment_shift = 0; absl::optional previous_lag_estimate; const int num_filters = static_cast(filters_.size()); int winner_index = -1; for (int n = 0; n < num_filters; ++n) { float error_sum = 0.f; bool filters_updated = false; const bool compute_pre_echo = detect_pre_echo_ && n == last_detected_best_lag_filter_; size_t x_start_index = (render_buffer.read + alignment_shift + sub_block_size_ - 1) % render_buffer.buffer.size(); switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Aec3Optimization::kSse2: aec3::MatchedFilterCore_SSE2( x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum, compute_pre_echo, instantaneous_accumulated_error_, scratch_memory_); break; case Aec3Optimization::kAvx2: aec3::MatchedFilterCore_AVX2( x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum, compute_pre_echo, instantaneous_accumulated_error_, scratch_memory_); break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: aec3::MatchedFilterCore_NEON( x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum, compute_pre_echo, instantaneous_accumulated_error_, scratch_memory_); break; #endif default: aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, filters_[n], &filters_updated, &error_sum, compute_pre_echo, instantaneous_accumulated_error_); } // Estimate the lag in the matched filter as the distance to the portion in // the filter that contributes the most to the matched filter output. This // is detected as the peak of the matched filter. const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); const bool reliable = lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && error_sum < matching_filter_threshold_ * error_sum_anchor; // Find the best estimate const size_t lag = lag_estimate + alignment_shift; if (filters_updated && reliable && error_sum < winner_error_sum) { winner_error_sum = error_sum; winner_index = n; // In case that 2 matched filters return the same winner candidate // (overlap region), the one with the smaller index is chosen in order // to search for pre-echoes. if (previous_lag_estimate && previous_lag_estimate == lag) { winner_lag_ = previous_lag_estimate; winner_index = n - 1; } else { winner_lag_ = lag; } } previous_lag_estimate = lag; alignment_shift += filter_intra_lag_shift_; } if (winner_index != -1) { RTC_DCHECK(winner_lag_.has_value()); reported_lag_estimate_ = LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value()); if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) { const float energy_threshold = pre_echo_config_.mode == 3 ? 1.0f : 30.0f * 30.0f * y.size(); if (error_sum_anchor > energy_threshold) { const float smooth_constant_increases = pre_echo_config_.mode != 3 ? 0.01f : 0.015f; UpdateAccumulatedError( instantaneous_accumulated_error_, accumulated_error_[winner_index], 1.0f / error_sum_anchor, smooth_constant_increases); number_pre_echo_updates_++; } if (pre_echo_config_.mode != 3 || number_pre_echo_updates_ >= 50) { reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag( pre_echo_config_, accumulated_error_[winner_index], winner_lag_.value(), winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/); } else { reported_lag_estimate_->pre_echo_lag = winner_lag_.value(); } } last_detected_best_lag_filter_ = winner_index; } if (ApmDataDumper::IsAvailable()) { Dump(); data_dumper_->DumpRaw("error_sum_anchor", error_sum_anchor / y.size()); data_dumper_->DumpRaw("number_pre_echo_updates", number_pre_echo_updates_); data_dumper_->DumpRaw("filter_smoothing", smoothing); } } void MatchedFilter::LogFilterProperties(int sample_rate_hz, size_t shift, size_t downsampling_factor) const { size_t alignment_shift = 0; constexpr int kFsBy1000 = 16; for (size_t k = 0; k < filters_.size(); ++k) { int start = static_cast(alignment_shift * downsampling_factor); int end = static_cast((alignment_shift + filters_[k].size()) * downsampling_factor); RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: " << (start - static_cast(shift)) / kFsBy1000 << " ms, end: " << (end - static_cast(shift)) / kFsBy1000 << " ms."; alignment_shift += filter_intra_lag_shift_; } } void MatchedFilter::Dump() { for (size_t n = 0; n < filters_.size(); ++n) { const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h"; data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]); std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n); data_dumper_->DumpRaw(dumper_lag.c_str(), lag_estimate + n * filter_intra_lag_shift_); if (detect_pre_echo_) { std::string dumper_error = "aec3_correlator_error_" + std::to_string(n) + "_h"; data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]); size_t pre_echo_lag = ComputePreEchoLag(pre_echo_config_, accumulated_error_[n], lag_estimate + n * filter_intra_lag_shift_, n * filter_intra_lag_shift_); std::string dumper_pre_lag = "aec3_correlator_pre_echo_lag_" + std::to_string(n); data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag); if (static_cast(n) == last_detected_best_lag_filter_) { data_dumper_->DumpRaw("aec3_pre_echo_delay_winner_inst", pre_echo_lag); } } } } } // namespace webrtc