diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/libwebrtc/modules/audio_processing/aec3 | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/aec3')
185 files changed, 30157 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/BUILD.gn b/third_party/libwebrtc/modules/audio_processing/aec3/BUILD.gn new file mode 100644 index 0000000000..c29b893b7d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/BUILD.gn @@ -0,0 +1,384 @@ +# Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("aec3") { + visibility = [ "*" ] + configs += [ "..:apm_debug_dump" ] + sources = [ + "adaptive_fir_filter.cc", + "adaptive_fir_filter_erl.cc", + "aec3_common.cc", + "aec3_fft.cc", + "aec_state.cc", + "aec_state.h", + "alignment_mixer.cc", + "alignment_mixer.h", + "api_call_jitter_metrics.cc", + "api_call_jitter_metrics.h", + "block.h", + "block_buffer.cc", + "block_delay_buffer.cc", + "block_delay_buffer.h", + "block_framer.cc", + "block_framer.h", + "block_processor.cc", + "block_processor.h", + "block_processor_metrics.cc", + "block_processor_metrics.h", + "clockdrift_detector.cc", + "clockdrift_detector.h", + "coarse_filter_update_gain.cc", + "coarse_filter_update_gain.h", + "comfort_noise_generator.cc", + "comfort_noise_generator.h", + "config_selector.cc", + "config_selector.h", + "decimator.cc", + "decimator.h", + "delay_estimate.h", + "dominant_nearend_detector.cc", + "dominant_nearend_detector.h", + "downsampled_render_buffer.cc", + "downsampled_render_buffer.h", + "echo_audibility.cc", + "echo_audibility.h", + "echo_canceller3.cc", + "echo_canceller3.h", + "echo_path_delay_estimator.cc", + "echo_path_delay_estimator.h", + "echo_path_variability.cc", + "echo_path_variability.h", + "echo_remover.cc", + "echo_remover.h", + "echo_remover_metrics.cc", + "echo_remover_metrics.h", + "erl_estimator.cc", + "erl_estimator.h", + "erle_estimator.cc", + "erle_estimator.h", + "fft_buffer.cc", + "filter_analyzer.cc", + "filter_analyzer.h", + "frame_blocker.cc", + "frame_blocker.h", + "fullband_erle_estimator.cc", + "fullband_erle_estimator.h", + "matched_filter.cc", + "matched_filter_lag_aggregator.cc", + "matched_filter_lag_aggregator.h", + "moving_average.cc", + "moving_average.h", + "multi_channel_content_detector.cc", + "multi_channel_content_detector.h", + "nearend_detector.h", + "refined_filter_update_gain.cc", + "refined_filter_update_gain.h", + "render_buffer.cc", + "render_delay_buffer.cc", + "render_delay_buffer.h", + "render_delay_controller.cc", + "render_delay_controller.h", + "render_delay_controller_metrics.cc", + "render_delay_controller_metrics.h", + "render_signal_analyzer.cc", + "render_signal_analyzer.h", + "residual_echo_estimator.cc", + "residual_echo_estimator.h", + "reverb_decay_estimator.cc", + "reverb_decay_estimator.h", + "reverb_frequency_response.cc", + "reverb_frequency_response.h", + "reverb_model.cc", + "reverb_model.h", + "reverb_model_estimator.cc", + "reverb_model_estimator.h", + "signal_dependent_erle_estimator.cc", + "signal_dependent_erle_estimator.h", + "spectrum_buffer.cc", + "stationarity_estimator.cc", + "stationarity_estimator.h", + "subband_erle_estimator.cc", + "subband_erle_estimator.h", + "subband_nearend_detector.cc", + "subband_nearend_detector.h", + "subtractor.cc", + "subtractor.h", + "subtractor_output.cc", + "subtractor_output.h", + "subtractor_output_analyzer.cc", + "subtractor_output_analyzer.h", + "suppression_filter.cc", + "suppression_filter.h", + "suppression_gain.cc", + "suppression_gain.h", + "transparent_mode.cc", + "transparent_mode.h", + ] + + defines = [] + if (rtc_build_with_neon && target_cpu != "arm64") { + suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ] + cflags = [ "-mfpu=neon" ] + } + + deps = [ + ":adaptive_fir_filter", + ":adaptive_fir_filter_erl", + ":aec3_common", + ":aec3_fft", + ":fft_data", + ":matched_filter", + ":render_buffer", + ":vector_math", + "..:apm_logging", + "..:audio_buffer", + "..:high_pass_filter", + "../../../api:array_view", + "../../../api/audio:aec3_config", + "../../../api/audio:echo_control", + "../../../common_audio:common_audio_c", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:macromagic", + "../../../rtc_base:race_checker", + "../../../rtc_base:safe_minmax", + "../../../rtc_base:swap_queue", + "../../../rtc_base/experiments:field_trial_parser", + "../../../rtc_base/system:arch", + "../../../system_wrappers", + "../../../system_wrappers:field_trial", + "../../../system_wrappers:metrics", + "../utility:cascaded_biquad_filter", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + + if (target_cpu == "x86" || target_cpu == "x64") { + deps += [ ":aec3_avx2" ] + } +} + +rtc_source_set("aec3_common") { + sources = [ "aec3_common.h" ] +} + +rtc_source_set("aec3_fft") { + sources = [ "aec3_fft.h" ] + deps = [ + ":aec3_common", + ":fft_data", + "../../../api:array_view", + "../../../common_audio/third_party/ooura:fft_size_128", + "../../../rtc_base:checks", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("render_buffer") { + sources = [ + "block.h", + "block_buffer.h", + "fft_buffer.h", + "render_buffer.h", + "spectrum_buffer.h", + ] + deps = [ + ":aec3_common", + ":fft_data", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("adaptive_fir_filter") { + sources = [ "adaptive_fir_filter.h" ] + deps = [ + ":aec3_common", + ":aec3_fft", + ":fft_data", + ":render_buffer", + "..:apm_logging", + "../../../api:array_view", + "../../../rtc_base/system:arch", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_source_set("adaptive_fir_filter_erl") { + sources = [ "adaptive_fir_filter_erl.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("matched_filter") { + sources = [ "matched_filter.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base:gtest_prod", + "../../../rtc_base/system:arch", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_source_set("vector_math") { + sources = [ "vector_math.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base/system:arch", + ] +} + +rtc_source_set("fft_data") { + sources = [ "fft_data.h" ] + deps = [ + ":aec3_common", + "../../../api:array_view", + "../../../rtc_base/system:arch", + ] +} + +if (target_cpu == "x86" || target_cpu == "x64") { + rtc_library("aec3_avx2") { + configs += [ "..:apm_debug_dump" ] + sources = [ + "adaptive_fir_filter_avx2.cc", + "adaptive_fir_filter_erl_avx2.cc", + "fft_data_avx2.cc", + "matched_filter_avx2.cc", + "vector_math_avx2.cc", + ] + + cflags = [ + "-mavx", + "-mavx2", + "-mfma", + ] + + deps = [ + ":adaptive_fir_filter", + ":adaptive_fir_filter_erl", + ":fft_data", + ":matched_filter", + ":vector_math", + "../../../api:array_view", + "../../../rtc_base:checks", + ] + } +} + +if (rtc_include_tests) { + rtc_library("aec3_unittests") { + testonly = true + + configs += [ "..:apm_debug_dump" ] + sources = [ + "mock/mock_block_processor.cc", + "mock/mock_block_processor.h", + "mock/mock_echo_remover.cc", + "mock/mock_echo_remover.h", + "mock/mock_render_delay_buffer.cc", + "mock/mock_render_delay_buffer.h", + "mock/mock_render_delay_controller.cc", + "mock/mock_render_delay_controller.h", + ] + + deps = [ + ":adaptive_fir_filter", + ":adaptive_fir_filter_erl", + ":aec3", + ":aec3_common", + ":aec3_fft", + ":fft_data", + ":matched_filter", + ":render_buffer", + ":vector_math", + "..:apm_logging", + "..:audio_buffer", + "..:audio_processing", + "..:high_pass_filter", + "../../../api:array_view", + "../../../api/audio:aec3_config", + "../../../rtc_base:checks", + "../../../rtc_base:macromagic", + "../../../rtc_base:random", + "../../../rtc_base:safe_minmax", + "../../../rtc_base:stringutils", + "../../../rtc_base/system:arch", + "../../../system_wrappers", + "../../../system_wrappers:metrics", + "../../../test:field_trial", + "../../../test:test_support", + "../utility:cascaded_biquad_filter", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + + defines = [] + + if (rtc_enable_protobuf) { + sources += [ + "adaptive_fir_filter_erl_unittest.cc", + "adaptive_fir_filter_unittest.cc", + "aec3_fft_unittest.cc", + "aec_state_unittest.cc", + "alignment_mixer_unittest.cc", + "api_call_jitter_metrics_unittest.cc", + "block_delay_buffer_unittest.cc", + "block_framer_unittest.cc", + "block_processor_metrics_unittest.cc", + "block_processor_unittest.cc", + "clockdrift_detector_unittest.cc", + "coarse_filter_update_gain_unittest.cc", + "comfort_noise_generator_unittest.cc", + "config_selector_unittest.cc", + "decimator_unittest.cc", + "echo_canceller3_unittest.cc", + "echo_path_delay_estimator_unittest.cc", + "echo_path_variability_unittest.cc", + "echo_remover_metrics_unittest.cc", + "echo_remover_unittest.cc", + "erl_estimator_unittest.cc", + "erle_estimator_unittest.cc", + "fft_data_unittest.cc", + "filter_analyzer_unittest.cc", + "frame_blocker_unittest.cc", + "matched_filter_lag_aggregator_unittest.cc", + "matched_filter_unittest.cc", + "moving_average_unittest.cc", + "multi_channel_content_detector_unittest.cc", + "refined_filter_update_gain_unittest.cc", + "render_buffer_unittest.cc", + "render_delay_buffer_unittest.cc", + "render_delay_controller_metrics_unittest.cc", + "render_delay_controller_unittest.cc", + "render_signal_analyzer_unittest.cc", + "residual_echo_estimator_unittest.cc", + "reverb_model_estimator_unittest.cc", + "signal_dependent_erle_estimator_unittest.cc", + "subtractor_unittest.cc", + "suppression_filter_unittest.cc", + "suppression_gain_unittest.cc", + "vector_math_unittest.cc", + ] + } + + if (!build_with_chromium) { + deps += [ "..:audio_processing_unittests" ] + } + } +} diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc new file mode 100644 index 0000000000..917aa951ee --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -0,0 +1,744 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_HAS_NEON) +#include <arm_neon.h> +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <math.h> + +#include <algorithm> +#include <functional> + +#include "modules/audio_processing/aec3/fft_data.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the frequency response of the filter. +void ComputeFrequencyResponse( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) { + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) { + float tmp = + H[p][ch].re[j] * H[p][ch].re[j] + H[p][ch].im[j] * H[p][ch].im[j]; + (*H2)[p][j] = std::max((*H2)[p][j], tmp); + } + } + } +} + +#if defined(WEBRTC_HAS_NEON) +// Computes and stores the frequency response of the filter. +void ComputeFrequencyResponse_Neon( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) { + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + auto& H2_p = (*H2)[p]; + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + for (size_t j = 0; j < kFftLengthBy2; j += 4) { + const float32x4_t re = vld1q_f32(&H_p_ch.re[j]); + const float32x4_t im = vld1q_f32(&H_p_ch.im[j]); + float32x4_t H2_new = vmulq_f32(re, re); + H2_new = vmlaq_f32(H2_new, im, im); + float32x4_t H2_p_j = vld1q_f32(&H2_p[j]); + H2_p_j = vmaxq_f32(H2_p_j, H2_new); + vst1q_f32(&H2_p[j], H2_p_j); + } + float H2_new = H_p_ch.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] + + H_p_ch.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + H2_p[kFftLengthBy2] = std::max(H2_p[kFftLengthBy2], H2_new); + } + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Computes and stores the frequency response of the filter. +void ComputeFrequencyResponse_Sse2( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) { + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + // constexpr __mmmask8 kMaxMask = static_cast<__mmmask8>(256u); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + auto& H2_p = (*H2)[p]; + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + for (size_t j = 0; j < kFftLengthBy2; j += 4) { + const __m128 re = _mm_loadu_ps(&H_p_ch.re[j]); + const __m128 re2 = _mm_mul_ps(re, re); + const __m128 im = _mm_loadu_ps(&H_p_ch.im[j]); + const __m128 im2 = _mm_mul_ps(im, im); + const __m128 H2_new = _mm_add_ps(re2, im2); + __m128 H2_k_j = _mm_loadu_ps(&H2_p[j]); + H2_k_j = _mm_max_ps(H2_k_j, H2_new); + _mm_storeu_ps(&H2_p[j], H2_k_j); + } + float H2_new = H_p_ch.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] + + H_p_ch.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + H2_p[kFftLengthBy2] = std::max(H2_p[kFftLengthBy2], H2_new); + } + } +} +#endif + +// Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)). +void AdaptPartitions(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H) { + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + size_t index = render_buffer.Position(); + const size_t num_render_channels = render_buffer_data[index].size(); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& X_p_ch = render_buffer_data[index][ch]; + FftData& H_p_ch = (*H)[p][ch]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + H_p_ch.re[k] += X_p_ch.re[k] * G.re[k] + X_p_ch.im[k] * G.im[k]; + H_p_ch.im[k] += X_p_ch.re[k] * G.im[k] - X_p_ch.im[k] * G.re[k]; + } + } + index = index < (render_buffer_data.size() - 1) ? index + 1 : 0; + } +} + +#if defined(WEBRTC_HAS_NEON) +// Adapts the filter partitions. (Neon variant) +void AdaptPartitions_Neon(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H) { + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; + + size_t X_partition = render_buffer.Position(); + size_t limit = lim1; + size_t p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const float32x4_t G_re = vld1q_f32(&G.re[k]); + const float32x4_t G_im = vld1q_f32(&G.im[k]); + const float32x4_t X_re = vld1q_f32(&X.re[k]); + const float32x4_t X_im = vld1q_f32(&X.im[k]); + const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]); + const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]); + const float32x4_t a = vmulq_f32(X_re, G_re); + const float32x4_t e = vmlaq_f32(a, X_im, G_im); + const float32x4_t c = vmulq_f32(X_re, G_im); + const float32x4_t f = vmlsq_f32(c, X_im, G_re); + const float32x4_t g = vaddq_f32(H_re, e); + const float32x4_t h = vaddq_f32(H_im, f); + vst1q_f32(&H_p_ch.re[k], g); + vst1q_f32(&H_p_ch.im[k], h); + } + } + } + + X_partition = 0; + limit = lim2; + } while (p < lim2); + + X_partition = render_buffer.Position(); + limit = lim1; + p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + } + X_partition = 0; + limit = lim2; + } while (p < lim2); +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Adapts the filter partitions. (SSE2 variant) +void AdaptPartitions_Sse2(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H) { + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; + + size_t X_partition = render_buffer.Position(); + size_t limit = lim1; + size_t p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const __m128 G_re = _mm_loadu_ps(&G.re[k]); + const __m128 G_im = _mm_loadu_ps(&G.im[k]); + const __m128 X_re = _mm_loadu_ps(&X.re[k]); + const __m128 X_im = _mm_loadu_ps(&X.im[k]); + const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]); + const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]); + const __m128 a = _mm_mul_ps(X_re, G_re); + const __m128 b = _mm_mul_ps(X_im, G_im); + const __m128 c = _mm_mul_ps(X_re, G_im); + const __m128 d = _mm_mul_ps(X_im, G_re); + const __m128 e = _mm_add_ps(a, b); + const __m128 f = _mm_sub_ps(c, d); + const __m128 g = _mm_add_ps(H_re, e); + const __m128 h = _mm_add_ps(H_im, f); + _mm_storeu_ps(&H_p_ch.re[k], g); + _mm_storeu_ps(&H_p_ch.im[k], h); + } + } + } + X_partition = 0; + limit = lim2; + } while (p < lim2); + + X_partition = render_buffer.Position(); + limit = lim1; + p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + } + + X_partition = 0; + limit = lim2; + } while (p < lim2); +} +#endif + +// Produces the filter output. +void ApplyFilter(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S) { + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + size_t index = render_buffer.Position(); + const size_t num_render_channels = render_buffer_data[index].size(); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(num_render_channels, H[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& X_p_ch = render_buffer_data[index][ch]; + const FftData& H_p_ch = H[p][ch]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + S->re[k] += X_p_ch.re[k] * H_p_ch.re[k] - X_p_ch.im[k] * H_p_ch.im[k]; + S->im[k] += X_p_ch.re[k] * H_p_ch.im[k] + X_p_ch.im[k] * H_p_ch.re[k]; + } + } + index = index < (render_buffer_data.size() - 1) ? index + 1 : 0; + } +} + +#if defined(WEBRTC_HAS_NEON) +// Produces the filter output (Neon variant). +void ApplyFilter_Neon(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S) { + // const RenderBuffer& render_buffer, + // rtc::ArrayView<const FftData> H, + // FftData* S) { + RTC_DCHECK_GE(H.size(), H.size() - 1); + S->Clear(); + + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; + + size_t X_partition = render_buffer.Position(); + size_t p = 0; + size_t limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const float32x4_t X_re = vld1q_f32(&X.re[k]); + const float32x4_t X_im = vld1q_f32(&X.im[k]); + const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]); + const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]); + const float32x4_t S_re = vld1q_f32(&S->re[k]); + const float32x4_t S_im = vld1q_f32(&S->im[k]); + const float32x4_t a = vmulq_f32(X_re, H_re); + const float32x4_t e = vmlsq_f32(a, X_im, H_im); + const float32x4_t c = vmulq_f32(X_re, H_im); + const float32x4_t f = vmlaq_f32(c, X_im, H_re); + const float32x4_t g = vaddq_f32(S_re, e); + const float32x4_t h = vaddq_f32(S_im, f); + vst1q_f32(&S->re[k], g); + vst1q_f32(&S->im[k], h); + } + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); + + X_partition = render_buffer.Position(); + p = 0; + limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] - + X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] + + X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2]; + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Produces the filter output (SSE2 variant). +void ApplyFilter_Sse2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S) { + // const RenderBuffer& render_buffer, + // rtc::ArrayView<const FftData> H, + // FftData* S) { + RTC_DCHECK_GE(H.size(), H.size() - 1); + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; + + size_t X_partition = render_buffer.Position(); + size_t p = 0; + size_t limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const __m128 X_re = _mm_loadu_ps(&X.re[k]); + const __m128 X_im = _mm_loadu_ps(&X.im[k]); + const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]); + const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]); + const __m128 S_re = _mm_loadu_ps(&S->re[k]); + const __m128 S_im = _mm_loadu_ps(&S->im[k]); + const __m128 a = _mm_mul_ps(X_re, H_re); + const __m128 b = _mm_mul_ps(X_im, H_im); + const __m128 c = _mm_mul_ps(X_re, H_im); + const __m128 d = _mm_mul_ps(X_im, H_re); + const __m128 e = _mm_sub_ps(a, b); + const __m128 f = _mm_add_ps(c, d); + const __m128 g = _mm_add_ps(S_re, e); + const __m128 h = _mm_add_ps(S_im, f); + _mm_storeu_ps(&S->re[k], g); + _mm_storeu_ps(&S->im[k], h); + } + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); + + X_partition = render_buffer.Position(); + p = 0; + limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] - + X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] + + X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2]; + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); +} +#endif + +} // namespace aec3 + +namespace { + +// Ensures that the newly added filter partitions after a size increase are set +// to zero. +void ZeroFilter(size_t old_size, + size_t new_size, + std::vector<std::vector<FftData>>* H) { + RTC_DCHECK_GE(H->size(), old_size); + RTC_DCHECK_GE(H->size(), new_size); + + for (size_t p = old_size; p < new_size; ++p) { + RTC_DCHECK_EQ((*H)[p].size(), (*H)[0].size()); + for (size_t ch = 0; ch < (*H)[0].size(); ++ch) { + (*H)[p][ch].Clear(); + } + } +} + +} // namespace + +AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions, + size_t initial_size_partitions, + size_t size_change_duration_blocks, + size_t num_render_channels, + Aec3Optimization optimization, + ApmDataDumper* data_dumper) + : data_dumper_(data_dumper), + fft_(), + optimization_(optimization), + num_render_channels_(num_render_channels), + max_size_partitions_(max_size_partitions), + size_change_duration_blocks_( + static_cast<int>(size_change_duration_blocks)), + current_size_partitions_(initial_size_partitions), + target_size_partitions_(initial_size_partitions), + old_target_size_partitions_(initial_size_partitions), + H_(max_size_partitions_, std::vector<FftData>(num_render_channels_)) { + RTC_DCHECK(data_dumper_); + RTC_DCHECK_GE(max_size_partitions, initial_size_partitions); + + RTC_DCHECK_LT(0, size_change_duration_blocks_); + one_by_size_change_duration_blocks_ = 1.f / size_change_duration_blocks_; + + ZeroFilter(0, max_size_partitions_, &H_); + + SetSizePartitions(current_size_partitions_, true); +} + +AdaptiveFirFilter::~AdaptiveFirFilter() = default; + +void AdaptiveFirFilter::HandleEchoPathChange() { + // TODO(peah): Check the value and purpose of the code below. + ZeroFilter(current_size_partitions_, max_size_partitions_, &H_); +} + +void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) { + RTC_DCHECK_EQ(max_size_partitions_, H_.capacity()); + RTC_DCHECK_LE(size, max_size_partitions_); + + target_size_partitions_ = std::min(max_size_partitions_, size); + if (immediate_effect) { + size_t old_size_partitions_ = current_size_partitions_; + current_size_partitions_ = old_target_size_partitions_ = + target_size_partitions_; + ZeroFilter(old_size_partitions_, current_size_partitions_, &H_); + + partition_to_constrain_ = + std::min(partition_to_constrain_, current_size_partitions_ - 1); + size_change_counter_ = 0; + } else { + size_change_counter_ = size_change_duration_blocks_; + } +} + +void AdaptiveFirFilter::UpdateSize() { + RTC_DCHECK_GE(size_change_duration_blocks_, size_change_counter_); + size_t old_size_partitions_ = current_size_partitions_; + if (size_change_counter_ > 0) { + --size_change_counter_; + + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + float change_factor = + size_change_counter_ * one_by_size_change_duration_blocks_; + + current_size_partitions_ = average(old_target_size_partitions_, + target_size_partitions_, change_factor); + + partition_to_constrain_ = + std::min(partition_to_constrain_, current_size_partitions_ - 1); + } else { + current_size_partitions_ = old_target_size_partitions_ = + target_size_partitions_; + } + ZeroFilter(old_size_partitions_, current_size_partitions_, &H_); + RTC_DCHECK_LE(0, size_change_counter_); +} + +void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, + FftData* S) const { + RTC_DCHECK(S); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S); + break; + case Aec3Optimization::kAvx2: + aec3::ApplyFilter_Avx2(render_buffer, current_size_partitions_, H_, S); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::ApplyFilter_Neon(render_buffer, current_size_partitions_, H_, S); + break; +#endif + default: + aec3::ApplyFilter(render_buffer, current_size_partitions_, H_, S); + } +} + +void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, + const FftData& G) { + // Adapt the filter and update the filter size. + AdaptAndUpdateSize(render_buffer, G); + + // Constrain the filter partitions in a cyclic manner. + Constrain(); +} + +void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, + const FftData& G, + std::vector<float>* impulse_response) { + // Adapt the filter and update the filter size. + AdaptAndUpdateSize(render_buffer, G); + + // Constrain the filter partitions in a cyclic manner. + ConstrainAndUpdateImpulseResponse(impulse_response); +} + +void AdaptiveFirFilter::ComputeFrequencyResponse( + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const { + RTC_DCHECK_GE(max_size_partitions_, H2->capacity()); + + H2->resize(current_size_partitions_); + + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2); + break; + case Aec3Optimization::kAvx2: + aec3::ComputeFrequencyResponse_Avx2(current_size_partitions_, H_, H2); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::ComputeFrequencyResponse_Neon(current_size_partitions_, H_, H2); + break; +#endif + default: + aec3::ComputeFrequencyResponse(current_size_partitions_, H_, H2); + } +} + +void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer, + const FftData& G) { + // Update the filter size if needed. + UpdateSize(); + + // Adapt the filter. + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_, + &H_); + break; + case Aec3Optimization::kAvx2: + aec3::AdaptPartitions_Avx2(render_buffer, G, current_size_partitions_, + &H_); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::AdaptPartitions_Neon(render_buffer, G, current_size_partitions_, + &H_); + break; +#endif + default: + aec3::AdaptPartitions(render_buffer, G, current_size_partitions_, &H_); + } +} + +// Constrains the partition of the frequency domain filter to be limited in +// time via setting the relevant time-domain coefficients to zero and updates +// the corresponding values in an externally stored impulse response estimate. +void AdaptiveFirFilter::ConstrainAndUpdateImpulseResponse( + std::vector<float>* impulse_response) { + RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), + impulse_response->capacity()); + impulse_response->resize(GetTimeDomainLength(current_size_partitions_)); + std::array<float, kFftLength> h; + impulse_response->resize(GetTimeDomainLength(current_size_partitions_)); + std::fill( + impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2, + impulse_response->begin() + (partition_to_constrain_ + 1) * kFftLengthBy2, + 0.f); + + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + fft_.Ifft(H_[partition_to_constrain_][ch], &h); + + static constexpr float kScale = 1.0f / kFftLengthBy2; + std::for_each(h.begin(), h.begin() + kFftLengthBy2, + [](float& a) { a *= kScale; }); + std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); + + if (ch == 0) { + std::copy( + h.begin(), h.begin() + kFftLengthBy2, + impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2); + } else { + for (size_t k = 0, j = partition_to_constrain_ * kFftLengthBy2; + k < kFftLengthBy2; ++k, ++j) { + if (fabsf((*impulse_response)[j]) < fabsf(h[k])) { + (*impulse_response)[j] = h[k]; + } + } + } + + fft_.Fft(&h, &H_[partition_to_constrain_][ch]); + } + + partition_to_constrain_ = + partition_to_constrain_ < (current_size_partitions_ - 1) + ? partition_to_constrain_ + 1 + : 0; +} + +// Constrains the a partiton of the frequency domain filter to be limited in +// time via setting the relevant time-domain coefficients to zero. +void AdaptiveFirFilter::Constrain() { + std::array<float, kFftLength> h; + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + fft_.Ifft(H_[partition_to_constrain_][ch], &h); + + static constexpr float kScale = 1.0f / kFftLengthBy2; + std::for_each(h.begin(), h.begin() + kFftLengthBy2, + [](float& a) { a *= kScale; }); + std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); + + fft_.Fft(&h, &H_[partition_to_constrain_][ch]); + } + + partition_to_constrain_ = + partition_to_constrain_ < (current_size_partitions_ - 1) + ? partition_to_constrain_ + 1 + : 0; +} + +void AdaptiveFirFilter::ScaleFilter(float factor) { + for (auto& H_p : H_) { + for (auto& H_p_ch : H_p) { + for (auto& re : H_p_ch.re) { + re *= factor; + } + for (auto& im : H_p_ch.im) { + im *= factor; + } + } + } +} + +// Set the filter coefficients. +void AdaptiveFirFilter::SetFilter(size_t num_partitions, + const std::vector<std::vector<FftData>>& H) { + const size_t min_num_partitions = + std::min(current_size_partitions_, num_partitions); + for (size_t p = 0; p < min_num_partitions; ++p) { + RTC_DCHECK_EQ(H_[p].size(), H[p].size()); + RTC_DCHECK_EQ(num_render_channels_, H_[p].size()); + + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + std::copy(H[p][ch].re.begin(), H[p][ch].re.end(), H_[p][ch].re.begin()); + std::copy(H[p][ch].im.begin(), H[p][ch].im.end(), H_[p][ch].im.begin()); + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.h b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.h new file mode 100644 index 0000000000..34c06f4367 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_H_ + +#include <stddef.h> + +#include <array> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { +namespace aec3 { +// Computes and stores the frequency response of the filter. +void ComputeFrequencyResponse( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2); +#if defined(WEBRTC_HAS_NEON) +void ComputeFrequencyResponse_Neon( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2); +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +void ComputeFrequencyResponse_Sse2( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2); + +void ComputeFrequencyResponse_Avx2( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2); +#endif + +// Adapts the filter partitions. +void AdaptPartitions(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H); +#if defined(WEBRTC_HAS_NEON) +void AdaptPartitions_Neon(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H); +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +void AdaptPartitions_Sse2(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H); + +void AdaptPartitions_Avx2(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H); +#endif + +// Produces the filter output. +void ApplyFilter(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S); +#if defined(WEBRTC_HAS_NEON) +void ApplyFilter_Neon(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S); +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +void ApplyFilter_Sse2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S); + +void ApplyFilter_Avx2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S); +#endif + +} // namespace aec3 + +// Provides a frequency domain adaptive filter functionality. +class AdaptiveFirFilter { + public: + AdaptiveFirFilter(size_t max_size_partitions, + size_t initial_size_partitions, + size_t size_change_duration_blocks, + size_t num_render_channels, + Aec3Optimization optimization, + ApmDataDumper* data_dumper); + + ~AdaptiveFirFilter(); + + AdaptiveFirFilter(const AdaptiveFirFilter&) = delete; + AdaptiveFirFilter& operator=(const AdaptiveFirFilter&) = delete; + + // Produces the output of the filter. + void Filter(const RenderBuffer& render_buffer, FftData* S) const; + + // Adapts the filter and updates an externally stored impulse response + // estimate. + void Adapt(const RenderBuffer& render_buffer, + const FftData& G, + std::vector<float>* impulse_response); + + // Adapts the filter. + void Adapt(const RenderBuffer& render_buffer, const FftData& G); + + // Receives reports that known echo path changes have occured and adjusts + // the filter adaptation accordingly. + void HandleEchoPathChange(); + + // Returns the filter size. + size_t SizePartitions() const { return current_size_partitions_; } + + // Sets the filter size. + void SetSizePartitions(size_t size, bool immediate_effect); + + // Computes the frequency responses for the filter partitions. + void ComputeFrequencyResponse( + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const; + + // Returns the maximum number of partitions for the filter. + size_t max_filter_size_partitions() const { return max_size_partitions_; } + + void DumpFilter(absl::string_view name_frequency_domain) { + for (size_t p = 0; p < max_size_partitions_; ++p) { + data_dumper_->DumpRaw(name_frequency_domain, H_[p][0].re); + data_dumper_->DumpRaw(name_frequency_domain, H_[p][0].im); + } + } + + // Scale the filter impulse response and spectrum by a factor. + void ScaleFilter(float factor); + + // Set the filter coefficients. + void SetFilter(size_t num_partitions, + const std::vector<std::vector<FftData>>& H); + + // Gets the filter coefficients. + const std::vector<std::vector<FftData>>& GetFilter() const { return H_; } + + private: + // Adapts the filter and updates the filter size. + void AdaptAndUpdateSize(const RenderBuffer& render_buffer, const FftData& G); + + // Constrain the filter partitions in a cyclic manner. + void Constrain(); + // Constrains the filter in a cyclic manner and updates the corresponding + // values in the supplied impulse response. + void ConstrainAndUpdateImpulseResponse(std::vector<float>* impulse_response); + + // Gradually Updates the current filter size towards the target size. + void UpdateSize(); + + ApmDataDumper* const data_dumper_; + const Aec3Fft fft_; + const Aec3Optimization optimization_; + const size_t num_render_channels_; + const size_t max_size_partitions_; + const int size_change_duration_blocks_; + float one_by_size_change_duration_blocks_; + size_t current_size_partitions_; + size_t target_size_partitions_; + size_t old_target_size_partitions_; + int size_change_counter_ = 0; + std::vector<std::vector<FftData>> H_; + size_t partition_to_constrain_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc new file mode 100644 index 0000000000..44d4514275 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" + +#include "common_audio/intrin.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the frequency response of the filter. +void ComputeFrequencyResponse_Avx2( + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) { + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + auto& H2_p = (*H2)[p]; + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + for (size_t j = 0; j < kFftLengthBy2; j += 8) { + __m256 re = _mm256_loadu_ps(&H_p_ch.re[j]); + __m256 re2 = _mm256_mul_ps(re, re); + __m256 im = _mm256_loadu_ps(&H_p_ch.im[j]); + re2 = _mm256_fmadd_ps(im, im, re2); + __m256 H2_k_j = _mm256_loadu_ps(&H2_p[j]); + H2_k_j = _mm256_max_ps(H2_k_j, re2); + _mm256_storeu_ps(&H2_p[j], H2_k_j); + } + float H2_new = H_p_ch.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] + + H_p_ch.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + H2_p[kFftLengthBy2] = std::max(H2_p[kFftLengthBy2], H2_new); + } + } +} + +// Adapts the filter partitions. +void AdaptPartitions_Avx2(const RenderBuffer& render_buffer, + const FftData& G, + size_t num_partitions, + std::vector<std::vector<FftData>>* H) { + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8; + + size_t X_partition = render_buffer.Position(); + size_t limit = lim1; + size_t p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) { + const __m256 G_re = _mm256_loadu_ps(&G.re[k]); + const __m256 G_im = _mm256_loadu_ps(&G.im[k]); + const __m256 X_re = _mm256_loadu_ps(&X.re[k]); + const __m256 X_im = _mm256_loadu_ps(&X.im[k]); + const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]); + const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]); + const __m256 a = _mm256_mul_ps(X_re, G_re); + const __m256 b = _mm256_mul_ps(X_im, G_im); + const __m256 c = _mm256_mul_ps(X_re, G_im); + const __m256 d = _mm256_mul_ps(X_im, G_re); + const __m256 e = _mm256_add_ps(a, b); + const __m256 f = _mm256_sub_ps(c, d); + const __m256 g = _mm256_add_ps(H_re, e); + const __m256 h = _mm256_add_ps(H_im, f); + _mm256_storeu_ps(&H_p_ch.re[k], g); + _mm256_storeu_ps(&H_p_ch.im[k], h); + } + } + } + X_partition = 0; + limit = lim2; + } while (p < lim2); + + X_partition = render_buffer.Position(); + limit = lim1; + p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + } + + X_partition = 0; + limit = lim2; + } while (p < lim2); +} + +// Produces the filter output (AVX2 variant). +void ApplyFilter_Avx2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector<std::vector<FftData>>& H, + FftData* S) { + RTC_DCHECK_GE(H.size(), H.size() - 1); + S->re.fill(0.f); + S->im.fill(0.f); + + rtc::ArrayView<const std::vector<FftData>> render_buffer_data = + render_buffer.GetFftBuffer(); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumEightBinBands = kFftLengthBy2 / 8; + + size_t X_partition = render_buffer.Position(); + size_t p = 0; + size_t limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumEightBinBands; ++n, k += 8) { + const __m256 X_re = _mm256_loadu_ps(&X.re[k]); + const __m256 X_im = _mm256_loadu_ps(&X.im[k]); + const __m256 H_re = _mm256_loadu_ps(&H_p_ch.re[k]); + const __m256 H_im = _mm256_loadu_ps(&H_p_ch.im[k]); + const __m256 S_re = _mm256_loadu_ps(&S->re[k]); + const __m256 S_im = _mm256_loadu_ps(&S->im[k]); + const __m256 a = _mm256_mul_ps(X_re, H_re); + const __m256 b = _mm256_mul_ps(X_im, H_im); + const __m256 c = _mm256_mul_ps(X_re, H_im); + const __m256 d = _mm256_mul_ps(X_im, H_re); + const __m256 e = _mm256_sub_ps(a, b); + const __m256 f = _mm256_add_ps(c, d); + const __m256 g = _mm256_add_ps(S_re, e); + const __m256 h = _mm256_add_ps(S_im, f); + _mm256_storeu_ps(&S->re[k], g); + _mm256_storeu_ps(&S->im[k], h); + } + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); + + X_partition = render_buffer.Position(); + p = 0; + limit = lim1; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] - + X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] + + X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2]; + } + } + limit = lim2; + X_partition = 0; + } while (p < lim2); +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc new file mode 100644 index 0000000000..45b8813979 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" + +#include <algorithm> +#include <functional> + +#if defined(WEBRTC_HAS_NEON) +#include <arm_neon.h> +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer(const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + std::transform(H2_j.begin(), H2_j.end(), erl.begin(), erl.begin(), + std::plus<float>()); + } +} + +#if defined(WEBRTC_HAS_NEON) +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer_NEON( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]); + float32x4_t erl_k = vld1q_f32(&erl[k]); + erl_k = vaddq_f32(erl_k, H2_j_k); + vst1q_f32(&erl[k], erl_k); + } + erl[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer_SSE2( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]); + __m128 erl_k = _mm_loadu_ps(&erl[k]); + erl_k = _mm_add_ps(erl_k, H2_j_k); + _mm_storeu_ps(&erl[k], erl_k); + } + erl[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} +#endif + +} // namespace aec3 + +void ComputeErl(const Aec3Optimization& optimization, + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, erl.size()); + // Update the frequency response and echo return loss for the filter. + switch (optimization) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::ErlComputer_SSE2(H2, erl); + break; + case Aec3Optimization::kAvx2: + aec3::ErlComputer_AVX2(H2, erl); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::ErlComputer_NEON(H2, erl); + break; +#endif + default: + aec3::ErlComputer(H2, erl); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.h b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.h new file mode 100644 index 0000000000..4ac13b1bc3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_ + +#include <stddef.h> + +#include <array> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { +namespace aec3 { + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer(const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl); +#if defined(WEBRTC_HAS_NEON) +void ErlComputer_NEON( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl); +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +void ErlComputer_SSE2( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl); + +void ErlComputer_AVX2( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl); +#endif + +} // namespace aec3 + +// Computes the echo return loss based on a frequency response. +void ComputeErl(const Aec3Optimization& optimization, + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc new file mode 100644 index 0000000000..5fe7514db1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" + +#include <immintrin.h> + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer_AVX2( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2, + rtc::ArrayView<float> erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 8) { + const __m256 H2_j_k = _mm256_loadu_ps(&H2_j[k]); + __m256 erl_k = _mm256_loadu_ps(&erl[k]); + erl_k = _mm256_add_ps(erl_k, H2_j_k); + _mm256_storeu_ps(&erl[k], erl_k); + } + erl[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_gn/moz.build new file mode 100644 index 0000000000..60ecc93ab9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_gn/moz.build @@ -0,0 +1,205 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("adaptive_fir_filter_erl_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc new file mode 100644 index 0000000000..d2af70a9f2 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" + +#include <array> +#include <vector> + +#include "rtc_base/system/arch.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif + +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { +namespace aec3 { + +#if defined(WEBRTC_HAS_NEON) +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlNeonOptimization) { + const size_t kNumPartitions = 12; + std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions); + std::array<float, kFftLengthBy2Plus1> erl; + std::array<float, kFftLengthBy2Plus1> erl_NEON; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + ErlComputer(H2, erl); + ErlComputer_NEON(H2, erl_NEON); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_NEON[j]); + } +} + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) { + bool use_sse2 = (GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + const size_t kNumPartitions = 12; + std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions); + std::array<float, kFftLengthBy2Plus1> erl; + std::array<float, kFftLengthBy2Plus1> erl_SSE2; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + ErlComputer(H2, erl); + ErlComputer_SSE2(H2, erl_SSE2); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]); + } + } +} + +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlAvx2Optimization) { + bool use_avx2 = (GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + const size_t kNumPartitions = 12; + std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions); + std::array<float, kFftLengthBy2Plus1> erl; + std::array<float, kFftLengthBy2Plus1> erl_AVX2; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + ErlComputer(H2, erl); + ErlComputer_AVX2(H2, erl_AVX2); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_AVX2[j]); + } + } +} + +#endif + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_gn/moz.build new file mode 100644 index 0000000000..fd78a43560 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_gn/moz.build @@ -0,0 +1,216 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "rt" + ] + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + + OS_LIBS += [ + "crypt32", + "iphlpapi", + "secur32", + "winmm" + ] + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("adaptive_fir_filter_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc new file mode 100644 index 0000000000..a13764c109 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -0,0 +1,594 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include <math.h> + +#include <algorithm> +#include <numeric> +#include <string> + +#include "rtc_base/system/arch.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/coarse_filter_update_gain.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "modules/audio_processing/utility/cascaded_biquad_filter.h" +#include "rtc_base/arraysize.h" +#include "rtc_base/numerics/safe_minmax.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { +namespace aec3 { +namespace { + +std::string ProduceDebugText(size_t num_render_channels, size_t delay) { + rtc::StringBuilder ss; + ss << "delay: " << delay << ", "; + ss << "num_render_channels:" << num_render_channels; + return ss.Release(); +} + +} // namespace + +class AdaptiveFirFilterOneTwoFourEightRenderChannels + : public ::testing::Test, + public ::testing::WithParamInterface<size_t> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + AdaptiveFirFilterOneTwoFourEightRenderChannels, + ::testing::Values(1, 2, 4, 8)); + +#if defined(WEBRTC_HAS_NEON) +// Verifies that the optimized methods for filter adaptation are similar to +// their reference counterparts. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + FilterAdaptationNeonOptimizations) { + const size_t num_render_channels = GetParam(); + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + num_render_channels)); + Random random_generator(42U); + Block x(kNumBands, num_render_channels); + FftData S_C; + FftData S_Neon; + FftData G; + Aec3Fft fft; + std::vector<std::vector<FftData>> H_C( + num_partitions, std::vector<FftData>(num_render_channels)); + std::vector<std::vector<FftData>> H_Neon( + num_partitions, std::vector<FftData>(num_render_channels)); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + H_C[p][ch].Clear(); + H_Neon[p][ch].Clear(); + } + } + + for (int k = 0; k < 30; ++k) { + for (int band = 0; band < x.NumBands(); ++band) { + for (int ch = 0; ch < x.NumChannels(); ++ch) { + RandomizeSampleVector(&random_generator, x.View(band, ch)); + } + } + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + } + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + for (size_t j = 0; j < G.re.size(); ++j) { + G.re[j] = j / 10001.f; + } + for (size_t j = 1; j < G.im.size() - 1; ++j) { + G.im[j] = j / 20001.f; + } + G.im[0] = 0.f; + G.im[G.im.size() - 1] = 0.f; + + AdaptPartitions_Neon(*render_buffer, G, num_partitions, &H_Neon); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + AdaptPartitions_Neon(*render_buffer, G, num_partitions, &H_Neon); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Neon[p][ch].re[j]); + EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Neon[p][ch].im[j]); + } + } + } + + ApplyFilter_Neon(*render_buffer, num_partitions, H_Neon, &S_Neon); + ApplyFilter(*render_buffer, num_partitions, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_NEAR(S_C.re[j], S_Neon.re[j], fabs(S_C.re[j] * 0.00001f)); + EXPECT_NEAR(S_C.im[j], S_Neon.im[j], fabs(S_C.re[j] * 0.00001f)); + } + } +} + +// Verifies that the optimized method for frequency response computation is +// bitexact to the reference counterpart. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + ComputeFrequencyResponseNeonOptimization) { + const size_t num_render_channels = GetParam(); + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::vector<std::vector<FftData>> H( + num_partitions, std::vector<FftData>(num_render_channels)); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Neon(num_partitions); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < H[p][ch].re.size(); ++k) { + H[p][ch].re[k] = k + p / 3.f + ch; + H[p][ch].im[k] = p + k / 7.f - ch; + } + } + } + + ComputeFrequencyResponse(num_partitions, H, &H2); + ComputeFrequencyResponse_Neon(num_partitions, H, &H2_Neon); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t k = 0; k < H2[p].size(); ++k) { + EXPECT_FLOAT_EQ(H2[p][k], H2_Neon[p][k]); + } + } + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods for filter adaptation are bitexact to +// their reference counterparts. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + FilterAdaptationSse2Optimizations) { + const size_t num_render_channels = GetParam(); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + bool use_sse2 = (GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + num_render_channels)); + Random random_generator(42U); + Block x(kNumBands, num_render_channels); + FftData S_C; + FftData S_Sse2; + FftData G; + Aec3Fft fft; + std::vector<std::vector<FftData>> H_C( + num_partitions, std::vector<FftData>(num_render_channels)); + std::vector<std::vector<FftData>> H_Sse2( + num_partitions, std::vector<FftData>(num_render_channels)); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + H_C[p][ch].Clear(); + H_Sse2[p][ch].Clear(); + } + } + + for (size_t k = 0; k < 500; ++k) { + for (int band = 0; band < x.NumBands(); ++band) { + for (int ch = 0; ch < x.NumChannels(); ++ch) { + RandomizeSampleVector(&random_generator, x.View(band, ch)); + } + } + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + ApplyFilter_Sse2(*render_buffer, num_partitions, H_Sse2, &S_Sse2); + ApplyFilter(*render_buffer, num_partitions, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_FLOAT_EQ(S_C.re[j], S_Sse2.re[j]); + EXPECT_FLOAT_EQ(S_C.im[j], S_Sse2.im[j]); + } + + std::for_each(G.re.begin(), G.re.end(), + [&](float& a) { a = random_generator.Rand<float>(); }); + std::for_each(G.im.begin(), G.im.end(), + [&](float& a) { a = random_generator.Rand<float>(); }); + + AdaptPartitions_Sse2(*render_buffer, G, num_partitions, &H_Sse2); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Sse2[p][ch].re[j]); + EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Sse2[p][ch].im[j]); + } + } + } + } + } + } +} + +// Verifies that the optimized methods for filter adaptation are bitexact to +// their reference counterparts. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + FilterAdaptationAvx2Optimizations) { + const size_t num_render_channels = GetParam(); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + bool use_avx2 = (GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + num_render_channels)); + Random random_generator(42U); + Block x(kNumBands, num_render_channels); + FftData S_C; + FftData S_Avx2; + FftData G; + Aec3Fft fft; + std::vector<std::vector<FftData>> H_C( + num_partitions, std::vector<FftData>(num_render_channels)); + std::vector<std::vector<FftData>> H_Avx2( + num_partitions, std::vector<FftData>(num_render_channels)); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + H_C[p][ch].Clear(); + H_Avx2[p][ch].Clear(); + } + } + + for (size_t k = 0; k < 500; ++k) { + for (int band = 0; band < x.NumBands(); ++band) { + for (int ch = 0; ch < x.NumChannels(); ++ch) { + RandomizeSampleVector(&random_generator, x.View(band, ch)); + } + } + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + ApplyFilter_Avx2(*render_buffer, num_partitions, H_Avx2, &S_Avx2); + ApplyFilter(*render_buffer, num_partitions, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_FLOAT_EQ(S_C.re[j], S_Avx2.re[j]); + EXPECT_FLOAT_EQ(S_C.im[j], S_Avx2.im[j]); + } + + std::for_each(G.re.begin(), G.re.end(), + [&](float& a) { a = random_generator.Rand<float>(); }); + std::for_each(G.im.begin(), G.im.end(), + [&](float& a) { a = random_generator.Rand<float>(); }); + + AdaptPartitions_Avx2(*render_buffer, G, num_partitions, &H_Avx2); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Avx2[p][ch].re[j]); + EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Avx2[p][ch].im[j]); + } + } + } + } + } + } +} + +// Verifies that the optimized method for frequency response computation is +// bitexact to the reference counterpart. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + ComputeFrequencyResponseSse2Optimization) { + const size_t num_render_channels = GetParam(); + bool use_sse2 = (GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::vector<std::vector<FftData>> H( + num_partitions, std::vector<FftData>(num_render_channels)); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Sse2( + num_partitions); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < H[p][ch].re.size(); ++k) { + H[p][ch].re[k] = k + p / 3.f + ch; + H[p][ch].im[k] = p + k / 7.f - ch; + } + } + } + + ComputeFrequencyResponse(num_partitions, H, &H2); + ComputeFrequencyResponse_Sse2(num_partitions, H, &H2_Sse2); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t k = 0; k < H2[p].size(); ++k) { + EXPECT_FLOAT_EQ(H2[p][k], H2_Sse2[p][k]); + } + } + } + } +} + +// Verifies that the optimized method for frequency response computation is +// bitexact to the reference counterpart. +TEST_P(AdaptiveFirFilterOneTwoFourEightRenderChannels, + ComputeFrequencyResponseAvx2Optimization) { + const size_t num_render_channels = GetParam(); + bool use_avx2 = (GetCPUInfo(kAVX2) != 0); + if (use_avx2) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + std::vector<std::vector<FftData>> H( + num_partitions, std::vector<FftData>(num_render_channels)); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2(num_partitions); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2_Avx2( + num_partitions); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < H[p][ch].re.size(); ++k) { + H[p][ch].re[k] = k + p / 3.f + ch; + H[p][ch].im[k] = p + k / 7.f - ch; + } + } + } + + ComputeFrequencyResponse(num_partitions, H, &H2); + ComputeFrequencyResponse_Avx2(num_partitions, H, &H2_Avx2); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t k = 0; k < H2[p].size(); ++k) { + EXPECT_FLOAT_EQ(H2[p][k], H2_Avx2[p][k]); + } + } + } + } +} + +#endif + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies that the check for non-null data dumper works. +TEST(AdaptiveFirFilterDeathTest, NullDataDumper) { + EXPECT_DEATH(AdaptiveFirFilter(9, 9, 250, 1, DetectOptimization(), nullptr), + ""); +} + +// Verifies that the check for non-null filter output works. +TEST(AdaptiveFirFilterDeathTest, NullFilterOutput) { + ApmDataDumper data_dumper(42); + AdaptiveFirFilter filter(9, 9, 250, 1, DetectOptimization(), &data_dumper); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), 48000, 1)); + EXPECT_DEATH(filter.Filter(*render_delay_buffer->GetRenderBuffer(), nullptr), + ""); +} + +#endif + +// Verifies that the filter statistics can be accessed when filter statistics +// are turned on. +TEST(AdaptiveFirFilterTest, FilterStatisticsAccess) { + ApmDataDumper data_dumper(42); + Aec3Optimization optimization = DetectOptimization(); + AdaptiveFirFilter filter(9, 9, 250, 1, optimization, &data_dumper); + std::vector<std::array<float, kFftLengthBy2Plus1>> H2( + filter.max_filter_size_partitions(), + std::array<float, kFftLengthBy2Plus1>()); + for (auto& H2_k : H2) { + H2_k.fill(0.f); + } + + std::array<float, kFftLengthBy2Plus1> erl; + ComputeErl(optimization, H2, erl); + filter.ComputeFrequencyResponse(&H2); +} + +// Verifies that the filter size if correctly repported. +TEST(AdaptiveFirFilterTest, FilterSize) { + ApmDataDumper data_dumper(42); + for (size_t filter_size = 1; filter_size < 5; ++filter_size) { + AdaptiveFirFilter filter(filter_size, filter_size, 250, 1, + DetectOptimization(), &data_dumper); + EXPECT_EQ(filter_size, filter.SizePartitions()); + } +} + +class AdaptiveFirFilterMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + AdaptiveFirFilterMultiChannel, + ::testing::Combine(::testing::Values(1, 4), + ::testing::Values(1, 8))); + +// Verifies that the filter is being able to properly filter a signal and to +// adapt its coefficients. +TEST_P(AdaptiveFirFilterMultiChannel, FilterAndAdapt) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + constexpr size_t kNumBlocksToProcessPerRenderChannel = 1000; + + ApmDataDumper data_dumper(42); + EchoCanceller3Config config; + + if (num_render_channels == 33) { + config.filter.refined = {13, 0.00005f, 0.0005f, 0.0001f, 2.f, 20075344.f}; + config.filter.coarse = {13, 0.1f, 20075344.f}; + config.filter.refined_initial = {12, 0.005f, 0.5f, 0.001f, 2.f, 20075344.f}; + config.filter.coarse_initial = {12, 0.7f, 20075344.f}; + } + + AdaptiveFirFilter filter( + config.filter.refined.length_blocks, config.filter.refined.length_blocks, + config.filter.config_change_duration_blocks, num_render_channels, + DetectOptimization(), &data_dumper); + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2( + num_capture_channels, std::vector<std::array<float, kFftLengthBy2Plus1>>( + filter.max_filter_size_partitions(), + std::array<float, kFftLengthBy2Plus1>())); + std::vector<std::vector<float>> h( + num_capture_channels, + std::vector<float>( + GetTimeDomainLength(filter.max_filter_size_partitions()), 0.f)); + Aec3Fft fft; + config.delay.default_delay = 1; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + CoarseFilterUpdateGain gain(config.filter.coarse, + config.filter.config_change_duration_blocks); + Random random_generator(42U); + Block x(kNumBands, num_render_channels); + std::vector<float> n(kBlockSize, 0.f); + std::vector<float> y(kBlockSize, 0.f); + AecState aec_state(EchoCanceller3Config{}, num_capture_channels); + RenderSignalAnalyzer render_signal_analyzer(config); + absl::optional<DelayEstimate> delay_estimate; + std::vector<float> e(kBlockSize, 0.f); + std::array<float, kFftLength> s_scratch; + std::vector<SubtractorOutput> output(num_capture_channels); + FftData S; + FftData G; + FftData E; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( + num_capture_channels); + std::array<float, kFftLengthBy2Plus1> E2_coarse; + // [B,A] = butter(2,100/8000,'high') + constexpr CascadedBiQuadFilter::BiQuadCoefficients + kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f}, + {-1.94448f, 0.94598f}}; + for (auto& Y2_ch : Y2) { + Y2_ch.fill(0.f); + } + for (auto& E2_refined_ch : E2_refined) { + E2_refined_ch.fill(0.f); + } + E2_coarse.fill(0.f); + for (auto& subtractor_output : output) { + subtractor_output.Reset(); + } + + constexpr float kScale = 1.0f / kFftLengthBy2; + + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + std::vector<DelayBuffer<float>> delay_buffer( + num_render_channels, DelayBuffer<float>(delay_samples)); + std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter( + num_render_channels); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch] = std::make_unique<CascadedBiQuadFilter>( + kHighPassFilterCoefficients, 1); + } + CascadedBiQuadFilter y_hp_filter(kHighPassFilterCoefficients, 1); + + SCOPED_TRACE(ProduceDebugText(num_render_channels, delay_samples)); + const size_t num_blocks_to_process = + kNumBlocksToProcessPerRenderChannel * num_render_channels; + for (size_t j = 0; j < num_blocks_to_process; ++j) { + std::fill(y.begin(), y.end(), 0.f); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + RandomizeSampleVector(&random_generator, x.View(/*band=*/0, ch)); + std::array<float, kBlockSize> y_channel; + delay_buffer[ch].Delay(x.View(/*band=*/0, ch), y_channel); + for (size_t k = 0; k < y.size(); ++k) { + y[k] += y_channel[k] / num_render_channels; + } + } + + RandomizeSampleVector(&random_generator, n); + const float noise_scaling = 1.f / 100.f / num_render_channels; + for (size_t k = 0; k < y.size(); ++k) { + y[k] += n[k] * noise_scaling; + } + + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch]->Process(x.View(/*band=*/0, ch)); + } + y_hp_filter.Process(y); + + render_delay_buffer->Insert(x); + if (j == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + render_signal_analyzer.Update(*render_buffer, + aec_state.MinDirectPathFilterDelay()); + + filter.Filter(*render_buffer, &S); + fft.Ifft(S, &s_scratch); + std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2, + e.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e.begin(), e.end(), + [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); + fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E); + for (auto& o : output) { + for (size_t k = 0; k < kBlockSize; ++k) { + o.s_refined[k] = kScale * s_scratch[k + kFftLengthBy2]; + } + } + + std::array<float, kFftLengthBy2Plus1> render_power; + render_buffer->SpectralSum(filter.SizePartitions(), &render_power); + gain.Compute(render_power, render_signal_analyzer, E, + filter.SizePartitions(), false, &G); + filter.Adapt(*render_buffer, G, &h[0]); + aec_state.HandleEchoPathChange(EchoPathVariability( + false, EchoPathVariability::DelayAdjustment::kNone, false)); + + filter.ComputeFrequencyResponse(&H2[0]); + aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_refined, Y2, + output); + } + // Verify that the filter is able to perform well. + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_avx2_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_avx2_gn/moz.build new file mode 100644 index 0000000000..6f67bd6fad --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_avx2_gn/moz.build @@ -0,0 +1,190 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +CXXFLAGS += [ + "-mavx2", + "-mfma" +] + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_APM_DEBUG_DUMP"] = "0" +DEFINES["WEBRTC_ENABLE_AVX2"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +UNIFIED_SOURCES += [ + "/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_avx2.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl_avx2.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_avx2.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_avx2.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_avx2.cc" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_GNU_SOURCE"] = True + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "rt" + ] + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + + OS_LIBS += [ + "crypt32", + "iphlpapi", + "secur32", + "winmm" + ] + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + CXXFLAGS += [ + "-msse2" + ] + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + CXXFLAGS += [ + "-msse2" + ] + +Library("aec3_avx2_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.cc b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.cc new file mode 100644 index 0000000000..3ba10d5baf --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/aec3_common.h" + +#include <stdint.h> + +#include "rtc_base/checks.h" +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" + +namespace webrtc { + +Aec3Optimization DetectOptimization() { +#if defined(WEBRTC_ARCH_X86_FAMILY) + if (GetCPUInfo(kAVX2) != 0) { + return Aec3Optimization::kAvx2; + } else if (GetCPUInfo(kSSE2) != 0) { + return Aec3Optimization::kSse2; + } +#endif + +#if defined(WEBRTC_HAS_NEON) + return Aec3Optimization::kNeon; +#else + return Aec3Optimization::kNone; +#endif +} + +float FastApproxLog2f(const float in) { + RTC_DCHECK_GT(in, .0f); + // Read and interpret float as uint32_t and then cast to float. + // This is done to extract the exponent (bits 30 - 23). + // "Right shift" of the exponent is then performed by multiplying + // with the constant (1/2^23). Finally, we subtract a constant to + // remove the bias (https://en.wikipedia.org/wiki/Exponent_bias). + union { + float dummy; + uint32_t a; + } x = {in}; + float out = x.a; + out *= 1.1920929e-7f; // 1/2^23 + out -= 126.942695f; // Remove bias. + return out; +} + +float Log2TodB(const float in_log2) { + return 3.0102999566398121 * in_log2; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.h b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.h new file mode 100644 index 0000000000..32b564f14b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_AEC3_COMMON_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_AEC3_COMMON_H_ + +#include <stddef.h> + +namespace webrtc { + +#ifdef _MSC_VER /* visual c++ */ +#define ALIGN16_BEG __declspec(align(16)) +#define ALIGN16_END +#else /* gcc or icc */ +#define ALIGN16_BEG +#define ALIGN16_END __attribute__((aligned(16))) +#endif + +enum class Aec3Optimization { kNone, kSse2, kAvx2, kNeon }; + +constexpr int kNumBlocksPerSecond = 250; + +constexpr int kMetricsReportingIntervalBlocks = 10 * kNumBlocksPerSecond; +constexpr int kMetricsComputationBlocks = 3; +constexpr int kMetricsCollectionBlocks = + kMetricsReportingIntervalBlocks - kMetricsComputationBlocks; + +constexpr size_t kFftLengthBy2 = 64; +constexpr size_t kFftLengthBy2Plus1 = kFftLengthBy2 + 1; +constexpr size_t kFftLengthBy2Minus1 = kFftLengthBy2 - 1; +constexpr size_t kFftLength = 2 * kFftLengthBy2; +constexpr size_t kFftLengthBy2Log2 = 6; + +constexpr int kRenderTransferQueueSizeFrames = 100; + +constexpr size_t kMaxNumBands = 3; +constexpr size_t kFrameSize = 160; +constexpr size_t kSubFrameLength = kFrameSize / 2; + +constexpr size_t kBlockSize = kFftLengthBy2; +constexpr size_t kBlockSizeLog2 = kFftLengthBy2Log2; + +constexpr size_t kExtendedBlockSize = 2 * kFftLengthBy2; +constexpr size_t kMatchedFilterWindowSizeSubBlocks = 32; +constexpr size_t kMatchedFilterAlignmentShiftSizeSubBlocks = + kMatchedFilterWindowSizeSubBlocks * 3 / 4; + +// TODO(peah): Integrate this with how it is done inside audio_processing_impl. +constexpr size_t NumBandsForRate(int sample_rate_hz) { + return static_cast<size_t>(sample_rate_hz / 16000); +} + +constexpr bool ValidFullBandRate(int sample_rate_hz) { + return sample_rate_hz == 16000 || sample_rate_hz == 32000 || + sample_rate_hz == 48000; +} + +constexpr int GetTimeDomainLength(int filter_length_blocks) { + return filter_length_blocks * kFftLengthBy2; +} + +constexpr size_t GetDownSampledBufferSize(size_t down_sampling_factor, + size_t num_matched_filters) { + return kBlockSize / down_sampling_factor * + (kMatchedFilterAlignmentShiftSizeSubBlocks * num_matched_filters + + kMatchedFilterWindowSizeSubBlocks + 1); +} + +constexpr size_t GetRenderDelayBufferSize(size_t down_sampling_factor, + size_t num_matched_filters, + size_t filter_length_blocks) { + return GetDownSampledBufferSize(down_sampling_factor, num_matched_filters) / + (kBlockSize / down_sampling_factor) + + filter_length_blocks + 1; +} + +// Detects what kind of optimizations to use for the code. +Aec3Optimization DetectOptimization(); + +// Computes the log2 of the input in a fast an approximate manner. +float FastApproxLog2f(float in); + +// Returns dB from a power quantity expressed in log2. +float Log2TodB(float in_log2); + +static_assert(1 << kBlockSizeLog2 == kBlockSize, + "Proper number of shifts for blocksize"); + +static_assert(1 << kFftLengthBy2Log2 == kFftLengthBy2, + "Proper number of shifts for the fft length"); + +static_assert(1 == NumBandsForRate(16000), "Number of bands for 16 kHz"); +static_assert(2 == NumBandsForRate(32000), "Number of bands for 32 kHz"); +static_assert(3 == NumBandsForRate(48000), "Number of bands for 48 kHz"); + +static_assert(ValidFullBandRate(16000), + "Test that 16 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(32000), + "Test that 32 kHz is a valid sample rate"); +static_assert(ValidFullBandRate(48000), + "Test that 48 kHz is a valid sample rate"); +static_assert(!ValidFullBandRate(8001), + "Test that 8001 Hz is not a valid sample rate"); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_AEC3_COMMON_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common_gn/moz.build new file mode 100644 index 0000000000..b0952a7d0c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common_gn/moz.build @@ -0,0 +1,201 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("aec3_common_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.cc b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.cc new file mode 100644 index 0000000000..9cc8016f0b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.cc @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/aec3_fft.h" + +#include <algorithm> +#include <functional> +#include <iterator> + +#include "rtc_base/checks.h" +#include "system_wrappers/include/cpu_features_wrapper.h" + +namespace webrtc { + +namespace { + +const float kHanning64[kFftLengthBy2] = { + 0.f, 0.00248461f, 0.00991376f, 0.0222136f, 0.03926189f, + 0.06088921f, 0.08688061f, 0.11697778f, 0.15088159f, 0.1882551f, + 0.22872687f, 0.27189467f, 0.31732949f, 0.36457977f, 0.41317591f, + 0.46263495f, 0.51246535f, 0.56217185f, 0.61126047f, 0.65924333f, + 0.70564355f, 0.75f, 0.79187184f, 0.83084292f, 0.86652594f, + 0.89856625f, 0.92664544f, 0.95048443f, 0.96984631f, 0.98453864f, + 0.99441541f, 0.99937846f, 0.99937846f, 0.99441541f, 0.98453864f, + 0.96984631f, 0.95048443f, 0.92664544f, 0.89856625f, 0.86652594f, + 0.83084292f, 0.79187184f, 0.75f, 0.70564355f, 0.65924333f, + 0.61126047f, 0.56217185f, 0.51246535f, 0.46263495f, 0.41317591f, + 0.36457977f, 0.31732949f, 0.27189467f, 0.22872687f, 0.1882551f, + 0.15088159f, 0.11697778f, 0.08688061f, 0.06088921f, 0.03926189f, + 0.0222136f, 0.00991376f, 0.00248461f, 0.f}; + +// Hanning window from Matlab command win = sqrt(hanning(128)). +const float kSqrtHanning128[kFftLength] = { + 0.00000000000000f, 0.02454122852291f, 0.04906767432742f, 0.07356456359967f, + 0.09801714032956f, 0.12241067519922f, 0.14673047445536f, 0.17096188876030f, + 0.19509032201613f, 0.21910124015687f, 0.24298017990326f, 0.26671275747490f, + 0.29028467725446f, 0.31368174039889f, 0.33688985339222f, 0.35989503653499f, + 0.38268343236509f, 0.40524131400499f, 0.42755509343028f, 0.44961132965461f, + 0.47139673682600f, 0.49289819222978f, 0.51410274419322f, 0.53499761988710f, + 0.55557023301960f, 0.57580819141785f, 0.59569930449243f, 0.61523159058063f, + 0.63439328416365f, 0.65317284295378f, 0.67155895484702f, 0.68954054473707f, + 0.70710678118655f, 0.72424708295147f, 0.74095112535496f, 0.75720884650648f, + 0.77301045336274f, 0.78834642762661f, 0.80320753148064f, 0.81758481315158f, + 0.83146961230255f, 0.84485356524971f, 0.85772861000027f, 0.87008699110871f, + 0.88192126434835f, 0.89322430119552f, 0.90398929312344f, 0.91420975570353f, + 0.92387953251129f, 0.93299279883474f, 0.94154406518302f, 0.94952818059304f, + 0.95694033573221f, 0.96377606579544f, 0.97003125319454f, 0.97570213003853f, + 0.98078528040323f, 0.98527764238894f, 0.98917650996478f, 0.99247953459871f, + 0.99518472667220f, 0.99729045667869f, 0.99879545620517f, 0.99969881869620f, + 1.00000000000000f, 0.99969881869620f, 0.99879545620517f, 0.99729045667869f, + 0.99518472667220f, 0.99247953459871f, 0.98917650996478f, 0.98527764238894f, + 0.98078528040323f, 0.97570213003853f, 0.97003125319454f, 0.96377606579544f, + 0.95694033573221f, 0.94952818059304f, 0.94154406518302f, 0.93299279883474f, + 0.92387953251129f, 0.91420975570353f, 0.90398929312344f, 0.89322430119552f, + 0.88192126434835f, 0.87008699110871f, 0.85772861000027f, 0.84485356524971f, + 0.83146961230255f, 0.81758481315158f, 0.80320753148064f, 0.78834642762661f, + 0.77301045336274f, 0.75720884650648f, 0.74095112535496f, 0.72424708295147f, + 0.70710678118655f, 0.68954054473707f, 0.67155895484702f, 0.65317284295378f, + 0.63439328416365f, 0.61523159058063f, 0.59569930449243f, 0.57580819141785f, + 0.55557023301960f, 0.53499761988710f, 0.51410274419322f, 0.49289819222978f, + 0.47139673682600f, 0.44961132965461f, 0.42755509343028f, 0.40524131400499f, + 0.38268343236509f, 0.35989503653499f, 0.33688985339222f, 0.31368174039889f, + 0.29028467725446f, 0.26671275747490f, 0.24298017990326f, 0.21910124015687f, + 0.19509032201613f, 0.17096188876030f, 0.14673047445536f, 0.12241067519922f, + 0.09801714032956f, 0.07356456359967f, 0.04906767432742f, 0.02454122852291f}; + +bool IsSse2Available() { +#if defined(WEBRTC_ARCH_X86_FAMILY) + return GetCPUInfo(kSSE2) != 0; +#else + return false; +#endif +} + +} // namespace + +Aec3Fft::Aec3Fft() : ooura_fft_(IsSse2Available()) {} + +// TODO(peah): Change x to be std::array once the rest of the code allows this. +void Aec3Fft::ZeroPaddedFft(rtc::ArrayView<const float> x, + Window window, + FftData* X) const { + RTC_DCHECK(X); + RTC_DCHECK_EQ(kFftLengthBy2, x.size()); + std::array<float, kFftLength> fft; + std::fill(fft.begin(), fft.begin() + kFftLengthBy2, 0.f); + switch (window) { + case Window::kRectangular: + std::copy(x.begin(), x.end(), fft.begin() + kFftLengthBy2); + break; + case Window::kHanning: + std::transform(x.begin(), x.end(), std::begin(kHanning64), + fft.begin() + kFftLengthBy2, + [](float a, float b) { return a * b; }); + break; + case Window::kSqrtHanning: + RTC_DCHECK_NOTREACHED(); + break; + default: + RTC_DCHECK_NOTREACHED(); + } + + Fft(&fft, X); +} + +void Aec3Fft::PaddedFft(rtc::ArrayView<const float> x, + rtc::ArrayView<const float> x_old, + Window window, + FftData* X) const { + RTC_DCHECK(X); + RTC_DCHECK_EQ(kFftLengthBy2, x.size()); + RTC_DCHECK_EQ(kFftLengthBy2, x_old.size()); + std::array<float, kFftLength> fft; + + switch (window) { + case Window::kRectangular: + std::copy(x_old.begin(), x_old.end(), fft.begin()); + std::copy(x.begin(), x.end(), fft.begin() + x_old.size()); + break; + case Window::kHanning: + RTC_DCHECK_NOTREACHED(); + break; + case Window::kSqrtHanning: + std::transform(x_old.begin(), x_old.end(), std::begin(kSqrtHanning128), + fft.begin(), std::multiplies<float>()); + std::transform(x.begin(), x.end(), + std::begin(kSqrtHanning128) + x_old.size(), + fft.begin() + x_old.size(), std::multiplies<float>()); + break; + default: + RTC_DCHECK_NOTREACHED(); + } + + Fft(&fft, X); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.h b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.h new file mode 100644 index 0000000000..c68de53963 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_AEC3_FFT_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_AEC3_FFT_H_ + +#include <array> + +#include "api/array_view.h" +#include "common_audio/third_party/ooura/fft_size_128/ooura_fft.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Wrapper class that provides 128 point real valued FFT functionality with the +// FftData type. +class Aec3Fft { + public: + enum class Window { kRectangular, kHanning, kSqrtHanning }; + + Aec3Fft(); + + Aec3Fft(const Aec3Fft&) = delete; + Aec3Fft& operator=(const Aec3Fft&) = delete; + + // Computes the FFT. Note that both the input and output are modified. + void Fft(std::array<float, kFftLength>* x, FftData* X) const { + RTC_DCHECK(x); + RTC_DCHECK(X); + ooura_fft_.Fft(x->data()); + X->CopyFromPackedArray(*x); + } + // Computes the inverse Fft. + void Ifft(const FftData& X, std::array<float, kFftLength>* x) const { + RTC_DCHECK(x); + X.CopyToPackedArray(x); + ooura_fft_.InverseFft(x->data()); + } + + // Windows the input using a Hanning window, and then adds padding of + // kFftLengthBy2 initial zeros before computing the Fft. + void ZeroPaddedFft(rtc::ArrayView<const float> x, + Window window, + FftData* X) const; + + // Concatenates the kFftLengthBy2 values long x and x_old before computing the + // Fft. After that, x is copied to x_old. + void PaddedFft(rtc::ArrayView<const float> x, + rtc::ArrayView<const float> x_old, + FftData* X) const { + PaddedFft(x, x_old, Window::kRectangular, X); + } + + // Padded Fft using a time-domain window. + void PaddedFft(rtc::ArrayView<const float> x, + rtc::ArrayView<const float> x_old, + Window window, + FftData* X) const; + + private: + const OouraFft ooura_fft_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_AEC3_FFT_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft_gn/moz.build new file mode 100644 index 0000000000..97bbc43539 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft_gn/moz.build @@ -0,0 +1,216 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "rt" + ] + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + + OS_LIBS += [ + "crypt32", + "iphlpapi", + "secur32", + "winmm" + ] + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("aec3_fft_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc new file mode 100644 index 0000000000..e60ef5b713 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft_unittest.cc @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/aec3_fft.h" + +#include <algorithm> + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null input in Fft works. +TEST(Aec3FftDeathTest, NullFftInput) { + Aec3Fft fft; + FftData X; + EXPECT_DEATH(fft.Fft(nullptr, &X), ""); +} + +// Verifies that the check for non-null input in Fft works. +TEST(Aec3FftDeathTest, NullFftOutput) { + Aec3Fft fft; + std::array<float, kFftLength> x; + EXPECT_DEATH(fft.Fft(&x, nullptr), ""); +} + +// Verifies that the check for non-null output in Ifft works. +TEST(Aec3FftDeathTest, NullIfftOutput) { + Aec3Fft fft; + FftData X; + EXPECT_DEATH(fft.Ifft(X, nullptr), ""); +} + +// Verifies that the check for non-null output in ZeroPaddedFft works. +TEST(Aec3FftDeathTest, NullZeroPaddedFftOutput) { + Aec3Fft fft; + std::array<float, kFftLengthBy2> x; + EXPECT_DEATH(fft.ZeroPaddedFft(x, Aec3Fft::Window::kRectangular, nullptr), + ""); +} + +// Verifies that the check for input length in ZeroPaddedFft works. +TEST(Aec3FftDeathTest, ZeroPaddedFftWrongInputLength) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLengthBy2 - 1> x; + EXPECT_DEATH(fft.ZeroPaddedFft(x, Aec3Fft::Window::kRectangular, &X), ""); +} + +// Verifies that the check for non-null output in PaddedFft works. +TEST(Aec3FftDeathTest, NullPaddedFftOutput) { + Aec3Fft fft; + std::array<float, kFftLengthBy2> x; + std::array<float, kFftLengthBy2> x_old; + EXPECT_DEATH(fft.PaddedFft(x, x_old, nullptr), ""); +} + +// Verifies that the check for input length in PaddedFft works. +TEST(Aec3FftDeathTest, PaddedFftWrongInputLength) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLengthBy2 - 1> x; + std::array<float, kFftLengthBy2> x_old; + EXPECT_DEATH(fft.PaddedFft(x, x_old, &X), ""); +} + +// Verifies that the check for length in the old value in PaddedFft works. +TEST(Aec3FftDeathTest, PaddedFftWrongOldValuesLength) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLengthBy2> x; + std::array<float, kFftLengthBy2 - 1> x_old; + EXPECT_DEATH(fft.PaddedFft(x, x_old, &X), ""); +} + +#endif + +// Verifies that Fft works as intended. +TEST(Aec3Fft, Fft) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLength> x; + x.fill(0.f); + fft.Fft(&x, &X); + EXPECT_THAT(X.re, ::testing::Each(0.f)); + EXPECT_THAT(X.im, ::testing::Each(0.f)); + + x.fill(0.f); + x[0] = 1.f; + fft.Fft(&x, &X); + EXPECT_THAT(X.re, ::testing::Each(1.f)); + EXPECT_THAT(X.im, ::testing::Each(0.f)); + + x.fill(1.f); + fft.Fft(&x, &X); + EXPECT_EQ(128.f, X.re[0]); + std::for_each(X.re.begin() + 1, X.re.end(), + [](float a) { EXPECT_EQ(0.f, a); }); + EXPECT_THAT(X.im, ::testing::Each(0.f)); +} + +// Verifies that InverseFft works as intended. +TEST(Aec3Fft, Ifft) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLength> x; + + X.re.fill(0.f); + X.im.fill(0.f); + fft.Ifft(X, &x); + EXPECT_THAT(x, ::testing::Each(0.f)); + + X.re.fill(1.f); + X.im.fill(0.f); + fft.Ifft(X, &x); + EXPECT_EQ(64.f, x[0]); + std::for_each(x.begin() + 1, x.end(), [](float a) { EXPECT_EQ(0.f, a); }); + + X.re.fill(0.f); + X.re[0] = 128; + X.im.fill(0.f); + fft.Ifft(X, &x); + EXPECT_THAT(x, ::testing::Each(64.f)); +} + +// Verifies that InverseFft and Fft work as intended. +TEST(Aec3Fft, FftAndIfft) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLength> x; + std::array<float, kFftLength> x_ref; + + int v = 0; + for (int k = 0; k < 20; ++k) { + for (size_t j = 0; j < x.size(); ++j) { + x[j] = v++; + x_ref[j] = x[j] * 64.f; + } + fft.Fft(&x, &X); + fft.Ifft(X, &x); + for (size_t j = 0; j < x.size(); ++j) { + EXPECT_NEAR(x_ref[j], x[j], 0.001f); + } + } +} + +// Verifies that ZeroPaddedFft work as intended. +TEST(Aec3Fft, ZeroPaddedFft) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLengthBy2> x_in; + std::array<float, kFftLength> x_ref; + std::array<float, kFftLength> x_out; + + int v = 0; + x_ref.fill(0.f); + for (int k = 0; k < 20; ++k) { + for (size_t j = 0; j < x_in.size(); ++j) { + x_in[j] = v++; + x_ref[j + kFftLengthBy2] = x_in[j] * 64.f; + } + fft.ZeroPaddedFft(x_in, Aec3Fft::Window::kRectangular, &X); + fft.Ifft(X, &x_out); + for (size_t j = 0; j < x_out.size(); ++j) { + EXPECT_NEAR(x_ref[j], x_out[j], 0.1f); + } + } +} + +// Verifies that ZeroPaddedFft work as intended. +TEST(Aec3Fft, PaddedFft) { + Aec3Fft fft; + FftData X; + std::array<float, kFftLengthBy2> x_in; + std::array<float, kFftLength> x_out; + std::array<float, kFftLengthBy2> x_old; + std::array<float, kFftLengthBy2> x_old_ref; + std::array<float, kFftLength> x_ref; + + int v = 0; + x_old.fill(0.f); + for (int k = 0; k < 20; ++k) { + for (size_t j = 0; j < x_in.size(); ++j) { + x_in[j] = v++; + } + + std::copy(x_old.begin(), x_old.end(), x_ref.begin()); + std::copy(x_in.begin(), x_in.end(), x_ref.begin() + kFftLengthBy2); + std::copy(x_in.begin(), x_in.end(), x_old_ref.begin()); + std::for_each(x_ref.begin(), x_ref.end(), [](float& a) { a *= 64.f; }); + + fft.PaddedFft(x_in, x_old, &X); + std::copy(x_in.begin(), x_in.end(), x_old.begin()); + fft.Ifft(X, &x_out); + + for (size_t j = 0; j < x_out.size(); ++j) { + EXPECT_NEAR(x_ref[j], x_out[j], 0.1f); + } + + EXPECT_EQ(x_old_ref, x_old); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec3_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_gn/moz.build new file mode 100644 index 0000000000..6646d41ff3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec3_gn/moz.build @@ -0,0 +1,289 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_APM_DEBUG_DUMP"] = "0" +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +UNIFIED_SOURCES += [ + "/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/aec3_common.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/aec3_fft.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/decimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.cc", + "/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.cc" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "rt" + ] + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + + OS_LIBS += [ + "crypt32", + "iphlpapi", + "secur32", + "winmm" + ] + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + CXXFLAGS += [ + "-mfpu=neon" + ] + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + CXXFLAGS += [ + "-msse2" + ] + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + CXXFLAGS += [ + "-msse2" + ] + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("aec3_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.cc b/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.cc new file mode 100644 index 0000000000..81fd91fab9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.cc @@ -0,0 +1,481 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/aec_state.h" + +#include <math.h> + +#include <algorithm> +#include <numeric> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { +namespace { + +bool DeactivateInitialStateResetAtEchoPathChange() { + return field_trial::IsEnabled( + "WebRTC-Aec3DeactivateInitialStateResetKillSwitch"); +} + +bool FullResetAtEchoPathChange() { + return !field_trial::IsEnabled("WebRTC-Aec3AecStateFullResetKillSwitch"); +} + +bool SubtractorAnalyzerResetAtEchoPathChange() { + return !field_trial::IsEnabled( + "WebRTC-Aec3AecStateSubtractorAnalyzerResetKillSwitch"); +} + +void ComputeAvgRenderReverb( + const SpectrumBuffer& spectrum_buffer, + int delay_blocks, + float reverb_decay, + ReverbModel* reverb_model, + rtc::ArrayView<float, kFftLengthBy2Plus1> reverb_power_spectrum) { + RTC_DCHECK(reverb_model); + const size_t num_render_channels = spectrum_buffer.buffer[0].size(); + int idx_at_delay = + spectrum_buffer.OffsetIndex(spectrum_buffer.read, delay_blocks); + int idx_past = spectrum_buffer.IncIndex(idx_at_delay); + + std::array<float, kFftLengthBy2Plus1> X2_data; + rtc::ArrayView<const float> X2; + if (num_render_channels > 1) { + auto average_channels = + [](size_t num_render_channels, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + spectrum_band_0, + rtc::ArrayView<float, kFftLengthBy2Plus1> render_power) { + std::fill(render_power.begin(), render_power.end(), 0.f); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + render_power[k] += spectrum_band_0[ch][k]; + } + } + const float normalizer = 1.f / num_render_channels; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + render_power[k] *= normalizer; + } + }; + average_channels(num_render_channels, spectrum_buffer.buffer[idx_past], + X2_data); + reverb_model->UpdateReverbNoFreqShaping( + X2_data, /*power_spectrum_scaling=*/1.0f, reverb_decay); + + average_channels(num_render_channels, spectrum_buffer.buffer[idx_at_delay], + X2_data); + X2 = X2_data; + } else { + reverb_model->UpdateReverbNoFreqShaping( + spectrum_buffer.buffer[idx_past][/*channel=*/0], + /*power_spectrum_scaling=*/1.0f, reverb_decay); + + X2 = spectrum_buffer.buffer[idx_at_delay][/*channel=*/0]; + } + + rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb_power = + reverb_model->reverb(); + for (size_t k = 0; k < X2.size(); ++k) { + reverb_power_spectrum[k] = X2[k] + reverb_power[k]; + } +} + +} // namespace + +std::atomic<int> AecState::instance_count_(0); + +void AecState::GetResidualEchoScaling( + rtc::ArrayView<float> residual_scaling) const { + bool filter_has_had_time_to_converge; + if (config_.filter.conservative_initial_phase) { + filter_has_had_time_to_converge = + strong_not_saturated_render_blocks_ >= 1.5f * kNumBlocksPerSecond; + } else { + filter_has_had_time_to_converge = + strong_not_saturated_render_blocks_ >= 0.8f * kNumBlocksPerSecond; + } + echo_audibility_.GetResidualEchoScaling(filter_has_had_time_to_converge, + residual_scaling); +} + +AecState::AecState(const EchoCanceller3Config& config, + size_t num_capture_channels) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + config_(config), + num_capture_channels_(num_capture_channels), + deactivate_initial_state_reset_at_echo_path_change_( + DeactivateInitialStateResetAtEchoPathChange()), + full_reset_at_echo_path_change_(FullResetAtEchoPathChange()), + subtractor_analyzer_reset_at_echo_path_change_( + SubtractorAnalyzerResetAtEchoPathChange()), + initial_state_(config_), + delay_state_(config_, num_capture_channels_), + transparent_state_(TransparentMode::Create(config_)), + filter_quality_state_(config_, num_capture_channels_), + erl_estimator_(2 * kNumBlocksPerSecond), + erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels_), + filter_analyzer_(config_, num_capture_channels_), + echo_audibility_( + config_.echo_audibility.use_stationarity_properties_at_init), + reverb_model_estimator_(config_, num_capture_channels_), + subtractor_output_analyzer_(num_capture_channels_) {} + +AecState::~AecState() = default; + +void AecState::HandleEchoPathChange( + const EchoPathVariability& echo_path_variability) { + const auto full_reset = [&]() { + filter_analyzer_.Reset(); + capture_signal_saturation_ = false; + strong_not_saturated_render_blocks_ = 0; + blocks_with_active_render_ = 0; + if (!deactivate_initial_state_reset_at_echo_path_change_) { + initial_state_.Reset(); + } + if (transparent_state_) { + transparent_state_->Reset(); + } + erle_estimator_.Reset(true); + erl_estimator_.Reset(); + filter_quality_state_.Reset(); + }; + + // TODO(peah): Refine the reset scheme according to the type of gain and + // delay adjustment. + + if (full_reset_at_echo_path_change_ && + echo_path_variability.delay_change != + EchoPathVariability::DelayAdjustment::kNone) { + full_reset(); + } else if (echo_path_variability.gain_change) { + erle_estimator_.Reset(false); + } + if (subtractor_analyzer_reset_at_echo_path_change_) { + subtractor_output_analyzer_.HandleEchoPathChange(); + } +} + +void AecState::Update( + const absl::optional<DelayEstimate>& external_delay, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + adaptive_filter_frequency_responses, + rtc::ArrayView<const std::vector<float>> adaptive_filter_impulse_responses, + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2_refined, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const SubtractorOutput> subtractor_output) { + RTC_DCHECK_EQ(num_capture_channels_, Y2.size()); + RTC_DCHECK_EQ(num_capture_channels_, subtractor_output.size()); + RTC_DCHECK_EQ(num_capture_channels_, + adaptive_filter_frequency_responses.size()); + RTC_DCHECK_EQ(num_capture_channels_, + adaptive_filter_impulse_responses.size()); + + // Analyze the filter outputs and filters. + bool any_filter_converged; + bool any_coarse_filter_converged; + bool all_filters_diverged; + subtractor_output_analyzer_.Update(subtractor_output, &any_filter_converged, + &any_coarse_filter_converged, + &all_filters_diverged); + + bool any_filter_consistent; + float max_echo_path_gain; + filter_analyzer_.Update(adaptive_filter_impulse_responses, render_buffer, + &any_filter_consistent, &max_echo_path_gain); + + // Estimate the direct path delay of the filter. + if (config_.filter.use_linear_filter) { + delay_state_.Update(filter_analyzer_.FilterDelaysBlocks(), external_delay, + strong_not_saturated_render_blocks_); + } + + const Block& aligned_render_block = + render_buffer.GetBlock(-delay_state_.MinDirectPathFilterDelay()); + + // Update render counters. + bool active_render = false; + for (int ch = 0; ch < aligned_render_block.NumChannels(); ++ch) { + const float render_energy = + std::inner_product(aligned_render_block.begin(/*block=*/0, ch), + aligned_render_block.end(/*block=*/0, ch), + aligned_render_block.begin(/*block=*/0, ch), 0.f); + if (render_energy > (config_.render_levels.active_render_limit * + config_.render_levels.active_render_limit) * + kFftLengthBy2) { + active_render = true; + break; + } + } + blocks_with_active_render_ += active_render ? 1 : 0; + strong_not_saturated_render_blocks_ += + active_render && !SaturatedCapture() ? 1 : 0; + + std::array<float, kFftLengthBy2Plus1> avg_render_spectrum_with_reverb; + + ComputeAvgRenderReverb(render_buffer.GetSpectrumBuffer(), + delay_state_.MinDirectPathFilterDelay(), + ReverbDecay(/*mild=*/false), &avg_render_reverb_, + avg_render_spectrum_with_reverb); + + if (config_.echo_audibility.use_stationarity_properties) { + // Update the echo audibility evaluator. + echo_audibility_.Update(render_buffer, avg_render_reverb_.reverb(), + delay_state_.MinDirectPathFilterDelay(), + delay_state_.ExternalDelayReported()); + } + + // Update the ERL and ERLE measures. + if (initial_state_.TransitionTriggered()) { + erle_estimator_.Reset(false); + } + + erle_estimator_.Update(render_buffer, adaptive_filter_frequency_responses, + avg_render_spectrum_with_reverb, Y2, E2_refined, + subtractor_output_analyzer_.ConvergedFilters()); + + erl_estimator_.Update( + subtractor_output_analyzer_.ConvergedFilters(), + render_buffer.Spectrum(delay_state_.MinDirectPathFilterDelay()), Y2); + + // Detect and flag echo saturation. + if (config_.ep_strength.echo_can_saturate) { + saturation_detector_.Update(aligned_render_block, SaturatedCapture(), + UsableLinearEstimate(), subtractor_output, + max_echo_path_gain); + } else { + RTC_DCHECK(!saturation_detector_.SaturatedEcho()); + } + + // Update the decision on whether to use the initial state parameter set. + initial_state_.Update(active_render, SaturatedCapture()); + + // Detect whether the transparent mode should be activated. + if (transparent_state_) { + transparent_state_->Update( + delay_state_.MinDirectPathFilterDelay(), any_filter_consistent, + any_filter_converged, any_coarse_filter_converged, all_filters_diverged, + active_render, SaturatedCapture()); + } + + // Analyze the quality of the filter. + filter_quality_state_.Update(active_render, TransparentModeActive(), + SaturatedCapture(), external_delay, + any_filter_converged); + + // Update the reverb estimate. + const bool stationary_block = + config_.echo_audibility.use_stationarity_properties && + echo_audibility_.IsBlockStationary(); + + reverb_model_estimator_.Update( + filter_analyzer_.GetAdjustedFilters(), + adaptive_filter_frequency_responses, + erle_estimator_.GetInstLinearQualityEstimates(), + delay_state_.DirectPathFilterDelays(), + filter_quality_state_.UsableLinearFilterOutputs(), stationary_block); + + erle_estimator_.Dump(data_dumper_); + reverb_model_estimator_.Dump(data_dumper_.get()); + data_dumper_->DumpRaw("aec3_active_render", active_render); + data_dumper_->DumpRaw("aec3_erl", Erl()); + data_dumper_->DumpRaw("aec3_erl_time_domain", ErlTimeDomain()); + data_dumper_->DumpRaw("aec3_erle", Erle(/*onset_compensated=*/false)[0]); + data_dumper_->DumpRaw("aec3_erle_onset_compensated", + Erle(/*onset_compensated=*/true)[0]); + data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate()); + data_dumper_->DumpRaw("aec3_transparent_mode", TransparentModeActive()); + data_dumper_->DumpRaw("aec3_filter_delay", + filter_analyzer_.MinFilterDelayBlocks()); + + data_dumper_->DumpRaw("aec3_any_filter_consistent", any_filter_consistent); + data_dumper_->DumpRaw("aec3_initial_state", + initial_state_.InitialStateActive()); + data_dumper_->DumpRaw("aec3_capture_saturation", SaturatedCapture()); + data_dumper_->DumpRaw("aec3_echo_saturation", SaturatedEcho()); + data_dumper_->DumpRaw("aec3_any_filter_converged", any_filter_converged); + data_dumper_->DumpRaw("aec3_any_coarse_filter_converged", + any_coarse_filter_converged); + data_dumper_->DumpRaw("aec3_all_filters_diverged", all_filters_diverged); + + data_dumper_->DumpRaw("aec3_external_delay_avaliable", + external_delay ? 1 : 0); + data_dumper_->DumpRaw("aec3_filter_tail_freq_resp_est", + GetReverbFrequencyResponse()); + data_dumper_->DumpRaw("aec3_subtractor_y2", subtractor_output[0].y2); + data_dumper_->DumpRaw("aec3_subtractor_e2_coarse", + subtractor_output[0].e2_coarse); + data_dumper_->DumpRaw("aec3_subtractor_e2_refined", + subtractor_output[0].e2_refined); +} + +AecState::InitialState::InitialState(const EchoCanceller3Config& config) + : conservative_initial_phase_(config.filter.conservative_initial_phase), + initial_state_seconds_(config.filter.initial_state_seconds) { + Reset(); +} +void AecState::InitialState::InitialState::Reset() { + initial_state_ = true; + strong_not_saturated_render_blocks_ = 0; +} +void AecState::InitialState::InitialState::Update(bool active_render, + bool saturated_capture) { + strong_not_saturated_render_blocks_ += + active_render && !saturated_capture ? 1 : 0; + + // Flag whether the initial state is still active. + bool prev_initial_state = initial_state_; + if (conservative_initial_phase_) { + initial_state_ = + strong_not_saturated_render_blocks_ < 5 * kNumBlocksPerSecond; + } else { + initial_state_ = strong_not_saturated_render_blocks_ < + initial_state_seconds_ * kNumBlocksPerSecond; + } + + // Flag whether the transition from the initial state has started. + transition_triggered_ = !initial_state_ && prev_initial_state; +} + +AecState::FilterDelay::FilterDelay(const EchoCanceller3Config& config, + size_t num_capture_channels) + : delay_headroom_blocks_(config.delay.delay_headroom_samples / kBlockSize), + filter_delays_blocks_(num_capture_channels, delay_headroom_blocks_), + min_filter_delay_(delay_headroom_blocks_) {} + +void AecState::FilterDelay::Update( + rtc::ArrayView<const int> analyzer_filter_delay_estimates_blocks, + const absl::optional<DelayEstimate>& external_delay, + size_t blocks_with_proper_filter_adaptation) { + // Update the delay based on the external delay. + if (external_delay && + (!external_delay_ || external_delay_->delay != external_delay->delay)) { + external_delay_ = external_delay; + external_delay_reported_ = true; + } + + // Override the estimated delay if it is not certain that the filter has had + // time to converge. + const bool delay_estimator_may_not_have_converged = + blocks_with_proper_filter_adaptation < 2 * kNumBlocksPerSecond; + if (delay_estimator_may_not_have_converged && external_delay_) { + const int delay_guess = delay_headroom_blocks_; + std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(), + delay_guess); + } else { + RTC_DCHECK_EQ(filter_delays_blocks_.size(), + analyzer_filter_delay_estimates_blocks.size()); + std::copy(analyzer_filter_delay_estimates_blocks.begin(), + analyzer_filter_delay_estimates_blocks.end(), + filter_delays_blocks_.begin()); + } + + min_filter_delay_ = *std::min_element(filter_delays_blocks_.begin(), + filter_delays_blocks_.end()); +} + +AecState::FilteringQualityAnalyzer::FilteringQualityAnalyzer( + const EchoCanceller3Config& config, + size_t num_capture_channels) + : use_linear_filter_(config.filter.use_linear_filter), + usable_linear_filter_estimates_(num_capture_channels, false) {} + +void AecState::FilteringQualityAnalyzer::Reset() { + std::fill(usable_linear_filter_estimates_.begin(), + usable_linear_filter_estimates_.end(), false); + overall_usable_linear_estimates_ = false; + filter_update_blocks_since_reset_ = 0; +} + +void AecState::FilteringQualityAnalyzer::Update( + bool active_render, + bool transparent_mode, + bool saturated_capture, + const absl::optional<DelayEstimate>& external_delay, + bool any_filter_converged) { + // Update blocks counter. + const bool filter_update = active_render && !saturated_capture; + filter_update_blocks_since_reset_ += filter_update ? 1 : 0; + filter_update_blocks_since_start_ += filter_update ? 1 : 0; + + // Store convergence flag when observed. + convergence_seen_ = convergence_seen_ || any_filter_converged; + + // Verify requirements for achieving a decent filter. The requirements for + // filter adaptation at call startup are more restrictive than after an + // in-call reset. + const bool sufficient_data_to_converge_at_startup = + filter_update_blocks_since_start_ > kNumBlocksPerSecond * 0.4f; + const bool sufficient_data_to_converge_at_reset = + sufficient_data_to_converge_at_startup && + filter_update_blocks_since_reset_ > kNumBlocksPerSecond * 0.2f; + + // The linear filter can only be used if it has had time to converge. + overall_usable_linear_estimates_ = sufficient_data_to_converge_at_startup && + sufficient_data_to_converge_at_reset; + + // The linear filter can only be used if an external delay or convergence have + // been identified + overall_usable_linear_estimates_ = + overall_usable_linear_estimates_ && (external_delay || convergence_seen_); + + // If transparent mode is on, deactivate usign the linear filter. + overall_usable_linear_estimates_ = + overall_usable_linear_estimates_ && !transparent_mode; + + if (use_linear_filter_) { + std::fill(usable_linear_filter_estimates_.begin(), + usable_linear_filter_estimates_.end(), + overall_usable_linear_estimates_); + } +} + +void AecState::SaturationDetector::Update( + const Block& x, + bool saturated_capture, + bool usable_linear_estimate, + rtc::ArrayView<const SubtractorOutput> subtractor_output, + float echo_path_gain) { + saturated_echo_ = false; + if (!saturated_capture) { + return; + } + + if (usable_linear_estimate) { + constexpr float kSaturationThreshold = 20000.f; + for (size_t ch = 0; ch < subtractor_output.size(); ++ch) { + saturated_echo_ = + saturated_echo_ || + (subtractor_output[ch].s_refined_max_abs > kSaturationThreshold || + subtractor_output[ch].s_coarse_max_abs > kSaturationThreshold); + } + } else { + float max_sample = 0.f; + for (int ch = 0; ch < x.NumChannels(); ++ch) { + rtc::ArrayView<const float, kBlockSize> x_ch = x.View(/*band=*/0, ch); + for (float sample : x_ch) { + max_sample = std::max(max_sample, fabsf(sample)); + } + } + + const float kMargin = 10.f; + float peak_echo_amplitude = max_sample * echo_path_gain * kMargin; + saturated_echo_ = saturated_echo_ || peak_echo_amplitude > 32000; + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.h b/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.h new file mode 100644 index 0000000000..a39325c8b8 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec_state.h @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_AEC_STATE_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_AEC_STATE_H_ + +#include <stddef.h> + +#include <array> +#include <atomic> +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/echo_audibility.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/erl_estimator.h" +#include "modules/audio_processing/aec3/erle_estimator.h" +#include "modules/audio_processing/aec3/filter_analyzer.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/reverb_model_estimator.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/aec3/subtractor_output_analyzer.h" +#include "modules/audio_processing/aec3/transparent_mode.h" + +namespace webrtc { + +class ApmDataDumper; + +// Handles the state and the conditions for the echo removal functionality. +class AecState { + public: + AecState(const EchoCanceller3Config& config, size_t num_capture_channels); + ~AecState(); + + // Returns whether the echo subtractor can be used to determine the residual + // echo. + bool UsableLinearEstimate() const { + return filter_quality_state_.LinearFilterUsable() && + config_.filter.use_linear_filter; + } + + // Returns whether the echo subtractor output should be used as output. + bool UseLinearFilterOutput() const { + return filter_quality_state_.LinearFilterUsable() && + config_.filter.use_linear_filter; + } + + // Returns whether the render signal is currently active. + bool ActiveRender() const { return blocks_with_active_render_ > 200; } + + // Returns the appropriate scaling of the residual echo to match the + // audibility. + void GetResidualEchoScaling(rtc::ArrayView<float> residual_scaling) const; + + // Returns whether the stationary properties of the signals are used in the + // aec. + bool UseStationarityProperties() const { + return config_.echo_audibility.use_stationarity_properties; + } + + // Returns the ERLE. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle( + bool onset_compensated) const { + return erle_estimator_.Erle(onset_compensated); + } + + // Returns the non-capped ERLE. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleUnbounded() + const { + return erle_estimator_.ErleUnbounded(); + } + + // Returns the fullband ERLE estimate in log2 units. + float FullBandErleLog2() const { return erle_estimator_.FullbandErleLog2(); } + + // Returns the ERL. + const std::array<float, kFftLengthBy2Plus1>& Erl() const { + return erl_estimator_.Erl(); + } + + // Returns the time-domain ERL. + float ErlTimeDomain() const { return erl_estimator_.ErlTimeDomain(); } + + // Returns the delay estimate based on the linear filter. + int MinDirectPathFilterDelay() const { + return delay_state_.MinDirectPathFilterDelay(); + } + + // Returns whether the capture signal is saturated. + bool SaturatedCapture() const { return capture_signal_saturation_; } + + // Returns whether the echo signal is saturated. + bool SaturatedEcho() const { return saturation_detector_.SaturatedEcho(); } + + // Updates the capture signal saturation. + void UpdateCaptureSaturation(bool capture_signal_saturation) { + capture_signal_saturation_ = capture_signal_saturation; + } + + // Returns whether the transparent mode is active + bool TransparentModeActive() const { + return transparent_state_ && transparent_state_->Active(); + } + + // Takes appropriate action at an echo path change. + void HandleEchoPathChange(const EchoPathVariability& echo_path_variability); + + // Returns the decay factor for the echo reverberation. The parameter `mild` + // indicates which exponential decay to return. The default one or a milder + // one that can be used during nearend regions. + float ReverbDecay(bool mild) const { + return reverb_model_estimator_.ReverbDecay(mild); + } + + // Return the frequency response of the reverberant echo. + rtc::ArrayView<const float> GetReverbFrequencyResponse() const { + return reverb_model_estimator_.GetReverbFrequencyResponse(); + } + + // Returns whether the transition for going out of the initial stated has + // been triggered. + bool TransitionTriggered() const { + return initial_state_.TransitionTriggered(); + } + + // Updates the aec state. + // TODO(bugs.webrtc.org/10913): Compute multi-channel ERL. + void Update( + const absl::optional<DelayEstimate>& external_delay, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + adaptive_filter_frequency_responses, + rtc::ArrayView<const std::vector<float>> + adaptive_filter_impulse_responses, + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2_refined, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const SubtractorOutput> subtractor_output); + + // Returns filter length in blocks. + int FilterLengthBlocks() const { + // All filters have the same length, so arbitrarily return channel 0 length. + return filter_analyzer_.FilterLengthBlocks(); + } + + private: + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const EchoCanceller3Config config_; + const size_t num_capture_channels_; + const bool deactivate_initial_state_reset_at_echo_path_change_; + const bool full_reset_at_echo_path_change_; + const bool subtractor_analyzer_reset_at_echo_path_change_; + + // Class for controlling the transition from the intial state, which in turn + // controls when the filter parameters for the initial state should be used. + class InitialState { + public: + explicit InitialState(const EchoCanceller3Config& config); + // Resets the state to again begin in the initial state. + void Reset(); + + // Updates the state based on new data. + void Update(bool active_render, bool saturated_capture); + + // Returns whether the initial state is active or not. + bool InitialStateActive() const { return initial_state_; } + + // Returns that the transition from the initial state has was started. + bool TransitionTriggered() const { return transition_triggered_; } + + private: + const bool conservative_initial_phase_; + const float initial_state_seconds_; + bool transition_triggered_ = false; + bool initial_state_ = true; + size_t strong_not_saturated_render_blocks_ = 0; + } initial_state_; + + // Class for choosing the direct-path delay relative to the beginning of the + // filter, as well as any other data related to the delay used within + // AecState. + class FilterDelay { + public: + FilterDelay(const EchoCanceller3Config& config, + size_t num_capture_channels); + + // Returns whether an external delay has been reported to the AecState (from + // the delay estimator). + bool ExternalDelayReported() const { return external_delay_reported_; } + + // Returns the delay in blocks relative to the beginning of the filter that + // corresponds to the direct path of the echo. + rtc::ArrayView<const int> DirectPathFilterDelays() const { + return filter_delays_blocks_; + } + + // Returns the minimum delay among the direct path delays relative to the + // beginning of the filter + int MinDirectPathFilterDelay() const { return min_filter_delay_; } + + // Updates the delay estimates based on new data. + void Update( + rtc::ArrayView<const int> analyzer_filter_delay_estimates_blocks, + const absl::optional<DelayEstimate>& external_delay, + size_t blocks_with_proper_filter_adaptation); + + private: + const int delay_headroom_blocks_; + bool external_delay_reported_ = false; + std::vector<int> filter_delays_blocks_; + int min_filter_delay_; + absl::optional<DelayEstimate> external_delay_; + } delay_state_; + + // Classifier for toggling transparent mode when there is no echo. + std::unique_ptr<TransparentMode> transparent_state_; + + // Class for analyzing how well the linear filter is, and can be expected to, + // perform on the current signals. The purpose of this is for using to + // select the echo suppression functionality as well as the input to the echo + // suppressor. + class FilteringQualityAnalyzer { + public: + FilteringQualityAnalyzer(const EchoCanceller3Config& config, + size_t num_capture_channels); + + // Returns whether the linear filter can be used for the echo + // canceller output. + bool LinearFilterUsable() const { return overall_usable_linear_estimates_; } + + // Returns whether an individual filter output can be used for the echo + // canceller output. + const std::vector<bool>& UsableLinearFilterOutputs() const { + return usable_linear_filter_estimates_; + } + + // Resets the state of the analyzer. + void Reset(); + + // Updates the analysis based on new data. + void Update(bool active_render, + bool transparent_mode, + bool saturated_capture, + const absl::optional<DelayEstimate>& external_delay, + bool any_filter_converged); + + private: + const bool use_linear_filter_; + bool overall_usable_linear_estimates_ = false; + size_t filter_update_blocks_since_reset_ = 0; + size_t filter_update_blocks_since_start_ = 0; + bool convergence_seen_ = false; + std::vector<bool> usable_linear_filter_estimates_; + } filter_quality_state_; + + // Class for detecting whether the echo is to be considered to be + // saturated. + class SaturationDetector { + public: + // Returns whether the echo is to be considered saturated. + bool SaturatedEcho() const { return saturated_echo_; } + + // Updates the detection decision based on new data. + void Update(const Block& x, + bool saturated_capture, + bool usable_linear_estimate, + rtc::ArrayView<const SubtractorOutput> subtractor_output, + float echo_path_gain); + + private: + bool saturated_echo_ = false; + } saturation_detector_; + + ErlEstimator erl_estimator_; + ErleEstimator erle_estimator_; + size_t strong_not_saturated_render_blocks_ = 0; + size_t blocks_with_active_render_ = 0; + bool capture_signal_saturation_ = false; + FilterAnalyzer filter_analyzer_; + EchoAudibility echo_audibility_; + ReverbModelEstimator reverb_model_estimator_; + ReverbModel avg_render_reverb_; + SubtractorOutputAnalyzer subtractor_output_analyzer_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_AEC_STATE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/aec_state_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/aec_state_unittest.cc new file mode 100644 index 0000000000..6662c8fb1a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/aec_state_unittest.cc @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/aec_state.h" + +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +void RunNormalUsageTest(size_t num_render_channels, + size_t num_capture_channels) { + // TODO(bugs.webrtc.org/10913): Test with different content in different + // channels. + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + ApmDataDumper data_dumper(42); + EchoCanceller3Config config; + AecState state(config, num_capture_channels); + absl::optional<DelayEstimate> delay_estimate = + DelayEstimate(DelayEstimate::Quality::kRefined, 10); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( + num_capture_channels); + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); + Block x(kNumBands, num_render_channels); + EchoPathVariability echo_path_variability( + false, EchoPathVariability::DelayAdjustment::kNone, false); + std::vector<std::array<float, kBlockSize>> y(num_capture_channels); + std::vector<SubtractorOutput> subtractor_output(num_capture_channels); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].Reset(); + subtractor_output[ch].s_refined.fill(100.f); + subtractor_output[ch].e_refined.fill(100.f); + y[ch].fill(1000.f); + E2_refined[ch].fill(0.f); + Y2[ch].fill(0.f); + } + Aec3Fft fft; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + converged_filter_frequency_response( + num_capture_channels, + std::vector<std::array<float, kFftLengthBy2Plus1>>(10)); + for (auto& v_ch : converged_filter_frequency_response) { + for (auto& v : v_ch) { + v.fill(0.01f); + } + } + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + diverged_filter_frequency_response = converged_filter_frequency_response; + converged_filter_frequency_response[0][2].fill(100.f); + converged_filter_frequency_response[0][2][0] = 1.f; + std::vector<std::vector<float>> impulse_response( + num_capture_channels, + std::vector<float>( + GetTimeDomainLength(config.filter.refined.length_blocks), 0.f)); + + // Verify that linear AEC usability is true when the filter is converged + for (size_t band = 0; band < kNumBands; ++band) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + std::fill(x.begin(band, ch), x.end(band, ch), 101.f); + } + } + for (int k = 0; k < 3000; ++k) { + render_delay_buffer->Insert(x); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + } + EXPECT_TRUE(state.UsableLinearEstimate()); + + // Verify that linear AEC usability becomes false after an echo path + // change is reported + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.HandleEchoPathChange(EchoPathVariability( + false, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false)); + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + EXPECT_FALSE(state.UsableLinearEstimate()); + + // Verify that the active render detection works as intended. + for (size_t ch = 0; ch < num_render_channels; ++ch) { + std::fill(x.begin(0, ch), x.end(0, ch), 101.f); + } + render_delay_buffer->Insert(x); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.HandleEchoPathChange(EchoPathVariability( + true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false)); + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + EXPECT_FALSE(state.ActiveRender()); + + for (int k = 0; k < 1000; ++k) { + render_delay_buffer->Insert(x); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + } + EXPECT_TRUE(state.ActiveRender()); + + // Verify that the ERL is properly estimated + for (int band = 0; band < x.NumBands(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + std::fill(x.begin(band, channel), x.end(band, channel), 0.0f); + } + } + + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x.View(/*band=*/0, ch)[0] = 5000.f; + } + for (size_t k = 0; + k < render_delay_buffer->GetRenderBuffer()->GetFftBuffer().size(); ++k) { + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + } + + for (auto& Y2_ch : Y2) { + Y2_ch.fill(10.f * 10000.f * 10000.f); + } + for (size_t k = 0; k < 1000; ++k) { + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + } + + ASSERT_TRUE(state.UsableLinearEstimate()); + const std::array<float, kFftLengthBy2Plus1>& erl = state.Erl(); + EXPECT_EQ(erl[0], erl[1]); + for (size_t k = 1; k < erl.size() - 1; ++k) { + EXPECT_NEAR(k % 2 == 0 ? 10.f : 1000.f, erl[k], 0.1); + } + EXPECT_EQ(erl[erl.size() - 2], erl[erl.size() - 1]); + + // Verify that the ERLE is properly estimated + for (auto& E2_refined_ch : E2_refined) { + E2_refined_ch.fill(1.f * 10000.f * 10000.f); + } + for (auto& Y2_ch : Y2) { + Y2_ch.fill(10.f * E2_refined[0][0]); + } + for (size_t k = 0; k < 1000; ++k) { + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + } + ASSERT_TRUE(state.UsableLinearEstimate()); + { + // Note that the render spectrum is built so it does not have energy in + // the odd bands but just in the even bands. + const auto& erle = state.Erle(/*onset_compensated=*/true)[0]; + EXPECT_EQ(erle[0], erle[1]); + constexpr size_t kLowFrequencyLimit = 32; + for (size_t k = 2; k < kLowFrequencyLimit; k = k + 2) { + EXPECT_NEAR(4.f, erle[k], 0.1); + } + for (size_t k = kLowFrequencyLimit; k < erle.size() - 1; k = k + 2) { + EXPECT_NEAR(1.5f, erle[k], 0.1); + } + EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]); + } + for (auto& E2_refined_ch : E2_refined) { + E2_refined_ch.fill(1.f * 10000.f * 10000.f); + } + for (auto& Y2_ch : Y2) { + Y2_ch.fill(5.f * E2_refined[0][0]); + } + for (size_t k = 0; k < 1000; ++k) { + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + subtractor_output[ch].ComputeMetrics(y[ch]); + } + state.Update(delay_estimate, converged_filter_frequency_response, + impulse_response, *render_delay_buffer->GetRenderBuffer(), + E2_refined, Y2, subtractor_output); + } + + ASSERT_TRUE(state.UsableLinearEstimate()); + { + const auto& erle = state.Erle(/*onset_compensated=*/true)[0]; + EXPECT_EQ(erle[0], erle[1]); + constexpr size_t kLowFrequencyLimit = 32; + for (size_t k = 1; k < kLowFrequencyLimit; ++k) { + EXPECT_NEAR(k % 2 == 0 ? 4.f : 1.f, erle[k], 0.1); + } + for (size_t k = kLowFrequencyLimit; k < erle.size() - 1; ++k) { + EXPECT_NEAR(k % 2 == 0 ? 1.5f : 1.f, erle[k], 0.1); + } + EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]); + } +} + +} // namespace + +class AecStateMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + AecStateMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 8), + ::testing::Values(1, 2, 8))); + +// Verify the general functionality of AecState +TEST_P(AecStateMultiChannel, NormalUsage) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + RunNormalUsageTest(num_render_channels, num_capture_channels); +} + +// Verifies the delay for a converged filter is correctly identified. +TEST(AecState, ConvergedFilterDelay) { + constexpr int kFilterLengthBlocks = 10; + constexpr size_t kNumCaptureChannels = 1; + EchoCanceller3Config config; + AecState state(config, kNumCaptureChannels); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, 48000, 1)); + absl::optional<DelayEstimate> delay_estimate; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( + kNumCaptureChannels); + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels); + std::array<float, kBlockSize> x; + EchoPathVariability echo_path_variability( + false, EchoPathVariability::DelayAdjustment::kNone, false); + std::vector<SubtractorOutput> subtractor_output(kNumCaptureChannels); + for (auto& output : subtractor_output) { + output.Reset(); + output.s_refined.fill(100.f); + } + std::array<float, kBlockSize> y; + x.fill(0.f); + y.fill(0.f); + + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + frequency_response(kNumCaptureChannels, + std::vector<std::array<float, kFftLengthBy2Plus1>>( + kFilterLengthBlocks)); + for (auto& v_ch : frequency_response) { + for (auto& v : v_ch) { + v.fill(0.01f); + } + } + + std::vector<std::vector<float>> impulse_response( + kNumCaptureChannels, + std::vector<float>( + GetTimeDomainLength(config.filter.refined.length_blocks), 0.f)); + + // Verify that the filter delay for a converged filter is properly + // identified. + for (int k = 0; k < kFilterLengthBlocks; ++k) { + for (auto& ir : impulse_response) { + std::fill(ir.begin(), ir.end(), 0.f); + ir[k * kBlockSize + 1] = 1.f; + } + + state.HandleEchoPathChange(echo_path_variability); + subtractor_output[0].ComputeMetrics(y); + state.Update(delay_estimate, frequency_response, impulse_response, + *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2, + subtractor_output); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.cc new file mode 100644 index 0000000000..7f076dea8e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.cc @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/alignment_mixer.h" + +#include <algorithm> + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +AlignmentMixer::MixingVariant ChooseMixingVariant(bool downmix, + bool adaptive_selection, + int num_channels) { + RTC_DCHECK(!(adaptive_selection && downmix)); + RTC_DCHECK_LT(0, num_channels); + + if (num_channels == 1) { + return AlignmentMixer::MixingVariant::kFixed; + } + if (downmix) { + return AlignmentMixer::MixingVariant::kDownmix; + } + if (adaptive_selection) { + return AlignmentMixer::MixingVariant::kAdaptive; + } + return AlignmentMixer::MixingVariant::kFixed; +} + +} // namespace + +AlignmentMixer::AlignmentMixer( + size_t num_channels, + const EchoCanceller3Config::Delay::AlignmentMixing& config) + : AlignmentMixer(num_channels, + config.downmix, + config.adaptive_selection, + config.activity_power_threshold, + config.prefer_first_two_channels) {} + +AlignmentMixer::AlignmentMixer(size_t num_channels, + bool downmix, + bool adaptive_selection, + float activity_power_threshold, + bool prefer_first_two_channels) + : num_channels_(num_channels), + one_by_num_channels_(1.f / num_channels_), + excitation_energy_threshold_(kBlockSize * activity_power_threshold), + prefer_first_two_channels_(prefer_first_two_channels), + selection_variant_( + ChooseMixingVariant(downmix, adaptive_selection, num_channels_)) { + if (selection_variant_ == MixingVariant::kAdaptive) { + std::fill(strong_block_counters_.begin(), strong_block_counters_.end(), 0); + cumulative_energies_.resize(num_channels_); + std::fill(cumulative_energies_.begin(), cumulative_energies_.end(), 0.f); + } +} + +void AlignmentMixer::ProduceOutput(const Block& x, + rtc::ArrayView<float, kBlockSize> y) { + RTC_DCHECK_EQ(x.NumChannels(), num_channels_); + + if (selection_variant_ == MixingVariant::kDownmix) { + Downmix(x, y); + return; + } + + int ch = selection_variant_ == MixingVariant::kFixed ? 0 : SelectChannel(x); + + RTC_DCHECK_GT(x.NumChannels(), ch); + std::copy(x.begin(/*band=*/0, ch), x.end(/*band=*/0, ch), y.begin()); +} + +void AlignmentMixer::Downmix(const Block& x, + rtc::ArrayView<float, kBlockSize> y) const { + RTC_DCHECK_EQ(x.NumChannels(), num_channels_); + RTC_DCHECK_GE(num_channels_, 2); + std::memcpy(&y[0], x.View(/*band=*/0, /*channel=*/0).data(), + kBlockSize * sizeof(y[0])); + for (size_t ch = 1; ch < num_channels_; ++ch) { + const auto x_ch = x.View(/*band=*/0, ch); + for (size_t i = 0; i < kBlockSize; ++i) { + y[i] += x_ch[i]; + } + } + + for (size_t i = 0; i < kBlockSize; ++i) { + y[i] *= one_by_num_channels_; + } +} + +int AlignmentMixer::SelectChannel(const Block& x) { + RTC_DCHECK_EQ(x.NumChannels(), num_channels_); + RTC_DCHECK_GE(num_channels_, 2); + RTC_DCHECK_EQ(cumulative_energies_.size(), num_channels_); + + constexpr size_t kBlocksToChooseLeftOrRight = + static_cast<size_t>(0.5f * kNumBlocksPerSecond); + const bool good_signal_in_left_or_right = + prefer_first_two_channels_ && + (strong_block_counters_[0] > kBlocksToChooseLeftOrRight || + strong_block_counters_[1] > kBlocksToChooseLeftOrRight); + + const int num_ch_to_analyze = + good_signal_in_left_or_right ? 2 : num_channels_; + + constexpr int kNumBlocksBeforeEnergySmoothing = 60 * kNumBlocksPerSecond; + ++block_counter_; + + for (int ch = 0; ch < num_ch_to_analyze; ++ch) { + float x2_sum = 0.f; + rtc::ArrayView<const float, kBlockSize> x_ch = x.View(/*band=*/0, ch); + for (size_t i = 0; i < kBlockSize; ++i) { + x2_sum += x_ch[i] * x_ch[i]; + } + + if (ch < 2 && x2_sum > excitation_energy_threshold_) { + ++strong_block_counters_[ch]; + } + + if (block_counter_ <= kNumBlocksBeforeEnergySmoothing) { + cumulative_energies_[ch] += x2_sum; + } else { + constexpr float kSmoothing = 1.f / (10 * kNumBlocksPerSecond); + cumulative_energies_[ch] += + kSmoothing * (x2_sum - cumulative_energies_[ch]); + } + } + + // Normalize the energies to allow the energy computations to from now be + // based on smoothing. + if (block_counter_ == kNumBlocksBeforeEnergySmoothing) { + constexpr float kOneByNumBlocksBeforeEnergySmoothing = + 1.f / kNumBlocksBeforeEnergySmoothing; + for (int ch = 0; ch < num_ch_to_analyze; ++ch) { + cumulative_energies_[ch] *= kOneByNumBlocksBeforeEnergySmoothing; + } + } + + int strongest_ch = 0; + for (int ch = 0; ch < num_ch_to_analyze; ++ch) { + if (cumulative_energies_[ch] > cumulative_energies_[strongest_ch]) { + strongest_ch = ch; + } + } + + if ((good_signal_in_left_or_right && selected_channel_ > 1) || + cumulative_energies_[strongest_ch] > + 2.f * cumulative_energies_[selected_channel_]) { + selected_channel_ = strongest_ch; + } + + return selected_channel_; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.h b/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.h new file mode 100644 index 0000000000..b3ed04755c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ALIGNMENT_MIXER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ALIGNMENT_MIXER_H_ + +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block.h" + +namespace webrtc { + +// Performs channel conversion to mono for the purpose of providing a decent +// mono input for the delay estimation. This is achieved by analyzing all +// incoming channels and produce one single channel output. +class AlignmentMixer { + public: + AlignmentMixer(size_t num_channels, + const EchoCanceller3Config::Delay::AlignmentMixing& config); + + AlignmentMixer(size_t num_channels, + bool downmix, + bool adaptive_selection, + float excitation_limit, + bool prefer_first_two_channels); + + void ProduceOutput(const Block& x, rtc::ArrayView<float, kBlockSize> y); + + enum class MixingVariant { kDownmix, kAdaptive, kFixed }; + + private: + const size_t num_channels_; + const float one_by_num_channels_; + const float excitation_energy_threshold_; + const bool prefer_first_two_channels_; + const MixingVariant selection_variant_; + std::array<size_t, 2> strong_block_counters_; + std::vector<float> cumulative_energies_; + int selected_channel_ = 0; + size_t block_counter_ = 0; + + void Downmix(const Block& x, rtc::ArrayView<float, kBlockSize> y) const; + int SelectChannel(const Block& x); +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ALIGNMENT_MIXER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer_unittest.cc new file mode 100644 index 0000000000..eaf6dcb235 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/alignment_mixer_unittest.cc @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/alignment_mixer.h" + +#include <string> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gmock.h" +#include "test/gtest.h" + +using ::testing::AllOf; +using ::testing::Each; + +namespace webrtc { +namespace { +std::string ProduceDebugText(bool initial_silence, + bool huge_activity_threshold, + bool prefer_first_two_channels, + int num_channels, + int strongest_ch) { + rtc::StringBuilder ss; + ss << ", Initial silence: " << initial_silence; + ss << ", Huge activity threshold: " << huge_activity_threshold; + ss << ", Prefer first two channels: " << prefer_first_two_channels; + ss << ", Number of channels: " << num_channels; + ss << ", Strongest channel: " << strongest_ch; + return ss.Release(); +} + +} // namespace + +TEST(AlignmentMixer, GeneralAdaptiveMode) { + constexpr int kChannelOffset = 100; + constexpr int kMaxChannelsToTest = 8; + constexpr float kStrongestSignalScaling = + kMaxChannelsToTest * kChannelOffset * 100; + + for (bool initial_silence : {false, true}) { + for (bool huge_activity_threshold : {false, true}) { + for (bool prefer_first_two_channels : {false, true}) { + for (int num_channels = 2; num_channels < 8; ++num_channels) { + for (int strongest_ch = 0; strongest_ch < num_channels; + ++strongest_ch) { + SCOPED_TRACE(ProduceDebugText( + initial_silence, huge_activity_threshold, + prefer_first_two_channels, num_channels, strongest_ch)); + const float excitation_limit = + huge_activity_threshold ? 1000000000.f : 0.001f; + AlignmentMixer am(num_channels, /*downmix*/ false, + /*adaptive_selection*/ true, excitation_limit, + prefer_first_two_channels); + + Block x( + /*num_bands=*/1, num_channels); + if (initial_silence) { + std::array<float, kBlockSize> y; + for (int frame = 0; frame < 10 * kNumBlocksPerSecond; ++frame) { + am.ProduceOutput(x, y); + } + } + + for (int frame = 0; frame < 2 * kNumBlocksPerSecond; ++frame) { + const auto channel_value = [&](int frame_index, + int channel_index) { + return static_cast<float>(frame_index + + channel_index * kChannelOffset); + }; + + for (int ch = 0; ch < num_channels; ++ch) { + float scaling = + ch == strongest_ch ? kStrongestSignalScaling : 1.f; + auto x_ch = x.View(/*band=*/0, ch); + std::fill(x_ch.begin(), x_ch.end(), + channel_value(frame, ch) * scaling); + } + + std::array<float, kBlockSize> y; + y.fill(-1.f); + am.ProduceOutput(x, y); + + if (frame > 1 * kNumBlocksPerSecond) { + if (!prefer_first_two_channels || huge_activity_threshold) { + EXPECT_THAT(y, + AllOf(Each(x.View(/*band=*/0, strongest_ch)[0]))); + } else { + bool left_or_right_chosen; + for (int ch = 0; ch < 2; ++ch) { + left_or_right_chosen = true; + const auto x_ch = x.View(/*band=*/0, ch); + for (size_t k = 0; k < kBlockSize; ++k) { + if (y[k] != x_ch[k]) { + left_or_right_chosen = false; + break; + } + } + if (left_or_right_chosen) { + break; + } + } + EXPECT_TRUE(left_or_right_chosen); + } + } + } + } + } + } + } + } +} + +TEST(AlignmentMixer, DownmixMode) { + for (int num_channels = 1; num_channels < 8; ++num_channels) { + AlignmentMixer am(num_channels, /*downmix*/ true, + /*adaptive_selection*/ false, /*excitation_limit*/ 1.f, + /*prefer_first_two_channels*/ false); + + Block x(/*num_bands=*/1, num_channels); + const auto channel_value = [](int frame_index, int channel_index) { + return static_cast<float>(frame_index + channel_index); + }; + for (int frame = 0; frame < 10; ++frame) { + for (int ch = 0; ch < num_channels; ++ch) { + auto x_ch = x.View(/*band=*/0, ch); + std::fill(x_ch.begin(), x_ch.end(), channel_value(frame, ch)); + } + + std::array<float, kBlockSize> y; + y.fill(-1.f); + am.ProduceOutput(x, y); + + float expected_mixed_value = 0.f; + for (int ch = 0; ch < num_channels; ++ch) { + expected_mixed_value += channel_value(frame, ch); + } + expected_mixed_value *= 1.f / num_channels; + + EXPECT_THAT(y, AllOf(Each(expected_mixed_value))); + } + } +} + +TEST(AlignmentMixer, FixedMode) { + for (int num_channels = 1; num_channels < 8; ++num_channels) { + AlignmentMixer am(num_channels, /*downmix*/ false, + /*adaptive_selection*/ false, /*excitation_limit*/ 1.f, + /*prefer_first_two_channels*/ false); + + Block x(/*num_band=*/1, num_channels); + const auto channel_value = [](int frame_index, int channel_index) { + return static_cast<float>(frame_index + channel_index); + }; + for (int frame = 0; frame < 10; ++frame) { + for (int ch = 0; ch < num_channels; ++ch) { + auto x_ch = x.View(/*band=*/0, ch); + std::fill(x_ch.begin(), x_ch.end(), channel_value(frame, ch)); + } + + std::array<float, kBlockSize> y; + y.fill(-1.f); + am.ProduceOutput(x, y); + EXPECT_THAT(y, AllOf(Each(x.View(/*band=*/0, /*channel=*/0)[0]))); + } + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +TEST(AlignmentMixerDeathTest, ZeroNumChannels) { + EXPECT_DEATH( + AlignmentMixer(/*num_channels*/ 0, /*downmix*/ false, + /*adaptive_selection*/ false, /*excitation_limit*/ 1.f, + /*prefer_first_two_channels*/ false); + , ""); +} + +TEST(AlignmentMixerDeathTest, IncorrectVariant) { + EXPECT_DEATH( + AlignmentMixer(/*num_channels*/ 1, /*downmix*/ true, + /*adaptive_selection*/ true, /*excitation_limit*/ 1.f, + /*prefer_first_two_channels*/ false); + , ""); +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.cc b/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.cc new file mode 100644 index 0000000000..45f56a5dce --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/api_call_jitter_metrics.h" + +#include <algorithm> +#include <limits> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "system_wrappers/include/metrics.h" + +namespace webrtc { +namespace { + +bool TimeToReportMetrics(int frames_since_last_report) { + constexpr int kNumFramesPerSecond = 100; + constexpr int kReportingIntervalFrames = 10 * kNumFramesPerSecond; + return frames_since_last_report == kReportingIntervalFrames; +} + +} // namespace + +ApiCallJitterMetrics::Jitter::Jitter() + : max_(0), min_(std::numeric_limits<int>::max()) {} + +void ApiCallJitterMetrics::Jitter::Update(int num_api_calls_in_a_row) { + min_ = std::min(min_, num_api_calls_in_a_row); + max_ = std::max(max_, num_api_calls_in_a_row); +} + +void ApiCallJitterMetrics::Jitter::Reset() { + min_ = std::numeric_limits<int>::max(); + max_ = 0; +} + +void ApiCallJitterMetrics::Reset() { + render_jitter_.Reset(); + capture_jitter_.Reset(); + num_api_calls_in_a_row_ = 0; + frames_since_last_report_ = 0; + last_call_was_render_ = false; + proper_call_observed_ = false; +} + +void ApiCallJitterMetrics::ReportRenderCall() { + if (!last_call_was_render_) { + // If the previous call was a capture and a proper call has been observed + // (containing both render and capture data), storing the last number of + // capture calls into the metrics. + if (proper_call_observed_) { + capture_jitter_.Update(num_api_calls_in_a_row_); + } + + // Reset the call counter to start counting render calls. + num_api_calls_in_a_row_ = 0; + } + ++num_api_calls_in_a_row_; + last_call_was_render_ = true; +} + +void ApiCallJitterMetrics::ReportCaptureCall() { + if (last_call_was_render_) { + // If the previous call was a render and a proper call has been observed + // (containing both render and capture data), storing the last number of + // render calls into the metrics. + if (proper_call_observed_) { + render_jitter_.Update(num_api_calls_in_a_row_); + } + // Reset the call counter to start counting capture calls. + num_api_calls_in_a_row_ = 0; + + // If this statement is reached, at least one render and one capture call + // have been observed. + proper_call_observed_ = true; + } + ++num_api_calls_in_a_row_; + last_call_was_render_ = false; + + // Only report and update jitter metrics for when a proper call, containing + // both render and capture data, has been observed. + if (proper_call_observed_ && + TimeToReportMetrics(++frames_since_last_report_)) { + // Report jitter, where the base basic unit is frames. + constexpr int kMaxJitterToReport = 50; + + // Report max and min jitter for render and capture, in units of 20 ms. + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.MaxRenderJitter", + std::min(kMaxJitterToReport, render_jitter().max()), 1, + kMaxJitterToReport, kMaxJitterToReport); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.MinRenderJitter", + std::min(kMaxJitterToReport, render_jitter().min()), 1, + kMaxJitterToReport, kMaxJitterToReport); + + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.MaxCaptureJitter", + std::min(kMaxJitterToReport, capture_jitter().max()), 1, + kMaxJitterToReport, kMaxJitterToReport); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.MinCaptureJitter", + std::min(kMaxJitterToReport, capture_jitter().min()), 1, + kMaxJitterToReport, kMaxJitterToReport); + + frames_since_last_report_ = 0; + Reset(); + } +} + +bool ApiCallJitterMetrics::WillReportMetricsAtNextCapture() const { + return TimeToReportMetrics(frames_since_last_report_ + 1); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.h b/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.h new file mode 100644 index 0000000000..dd1fa82e93 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_API_CALL_JITTER_METRICS_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_API_CALL_JITTER_METRICS_H_ + +namespace webrtc { + +// Stores data for reporting metrics on the API call jitter. +class ApiCallJitterMetrics { + public: + class Jitter { + public: + Jitter(); + void Update(int num_api_calls_in_a_row); + void Reset(); + + int min() const { return min_; } + int max() const { return max_; } + + private: + int max_; + int min_; + }; + + ApiCallJitterMetrics() { Reset(); } + + // Update metrics for render API call. + void ReportRenderCall(); + + // Update and periodically report metrics for capture API call. + void ReportCaptureCall(); + + // Methods used only for testing. + const Jitter& render_jitter() const { return render_jitter_; } + const Jitter& capture_jitter() const { return capture_jitter_; } + bool WillReportMetricsAtNextCapture() const; + + private: + void Reset(); + + Jitter render_jitter_; + Jitter capture_jitter_; + + int num_api_calls_in_a_row_ = 0; + int frames_since_last_report_ = 0; + bool last_call_was_render_ = false; + bool proper_call_observed_ = false; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_API_CALL_JITTER_METRICS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics_unittest.cc new file mode 100644 index 0000000000..b902487152 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/api_call_jitter_metrics_unittest.cc @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/api_call_jitter_metrics.h" + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "test/gtest.h" + +namespace webrtc { + +// Verify constant jitter. +TEST(ApiCallJitterMetrics, ConstantJitter) { + for (int jitter = 1; jitter < 20; ++jitter) { + ApiCallJitterMetrics metrics; + for (size_t k = 0; k < 30 * kNumBlocksPerSecond; ++k) { + for (int j = 0; j < jitter; ++j) { + metrics.ReportRenderCall(); + } + + for (int j = 0; j < jitter; ++j) { + metrics.ReportCaptureCall(); + + if (metrics.WillReportMetricsAtNextCapture()) { + EXPECT_EQ(jitter, metrics.render_jitter().min()); + EXPECT_EQ(jitter, metrics.render_jitter().max()); + EXPECT_EQ(jitter, metrics.capture_jitter().min()); + EXPECT_EQ(jitter, metrics.capture_jitter().max()); + } + } + } + } +} + +// Verify peaky jitter for the render. +TEST(ApiCallJitterMetrics, JitterPeakRender) { + constexpr int kMinJitter = 2; + constexpr int kJitterPeak = 10; + constexpr int kPeakInterval = 100; + + ApiCallJitterMetrics metrics; + int render_surplus = 0; + + for (size_t k = 0; k < 30 * kNumBlocksPerSecond; ++k) { + const int num_render_calls = + k % kPeakInterval == 0 ? kJitterPeak : kMinJitter; + for (int j = 0; j < num_render_calls; ++j) { + metrics.ReportRenderCall(); + ++render_surplus; + } + + ASSERT_LE(kMinJitter, render_surplus); + const int num_capture_calls = + render_surplus == kMinJitter ? kMinJitter : kMinJitter + 1; + for (int j = 0; j < num_capture_calls; ++j) { + metrics.ReportCaptureCall(); + + if (metrics.WillReportMetricsAtNextCapture()) { + EXPECT_EQ(kMinJitter, metrics.render_jitter().min()); + EXPECT_EQ(kJitterPeak, metrics.render_jitter().max()); + EXPECT_EQ(kMinJitter, metrics.capture_jitter().min()); + EXPECT_EQ(kMinJitter + 1, metrics.capture_jitter().max()); + } + --render_surplus; + } + } +} + +// Verify peaky jitter for the capture. +TEST(ApiCallJitterMetrics, JitterPeakCapture) { + constexpr int kMinJitter = 2; + constexpr int kJitterPeak = 10; + constexpr int kPeakInterval = 100; + + ApiCallJitterMetrics metrics; + int capture_surplus = kMinJitter; + + for (size_t k = 0; k < 30 * kNumBlocksPerSecond; ++k) { + ASSERT_LE(kMinJitter, capture_surplus); + const int num_render_calls = + capture_surplus == kMinJitter ? kMinJitter : kMinJitter + 1; + for (int j = 0; j < num_render_calls; ++j) { + metrics.ReportRenderCall(); + --capture_surplus; + } + + const int num_capture_calls = + k % kPeakInterval == 0 ? kJitterPeak : kMinJitter; + for (int j = 0; j < num_capture_calls; ++j) { + metrics.ReportCaptureCall(); + + if (metrics.WillReportMetricsAtNextCapture()) { + EXPECT_EQ(kMinJitter, metrics.render_jitter().min()); + EXPECT_EQ(kMinJitter + 1, metrics.render_jitter().max()); + EXPECT_EQ(kMinJitter, metrics.capture_jitter().min()); + EXPECT_EQ(kJitterPeak, metrics.capture_jitter().max()); + } + ++capture_surplus; + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block.h b/third_party/libwebrtc/modules/audio_processing/aec3/block.h new file mode 100644 index 0000000000..c1fc70722d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_BLOCK_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_BLOCK_H_ + +#include <array> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Contains one or more channels of 4 milliseconds of audio data. +// The audio is split in one or more frequency bands, each with a sampling +// rate of 16 kHz. +class Block { + public: + Block(int num_bands, int num_channels, float default_value = 0.0f) + : num_bands_(num_bands), + num_channels_(num_channels), + data_(num_bands * num_channels * kBlockSize, default_value) {} + + // Returns the number of bands. + int NumBands() const { return num_bands_; } + + // Returns the number of channels. + int NumChannels() const { return num_channels_; } + + // Modifies the number of channels and sets all samples to zero. + void SetNumChannels(int num_channels) { + num_channels_ = num_channels; + data_.resize(num_bands_ * num_channels_ * kBlockSize); + std::fill(data_.begin(), data_.end(), 0.0f); + } + + // Iterators for accessing the data. + auto begin(int band, int channel) { + return data_.begin() + GetIndex(band, channel); + } + + auto begin(int band, int channel) const { + return data_.begin() + GetIndex(band, channel); + } + + auto end(int band, int channel) { return begin(band, channel) + kBlockSize; } + + auto end(int band, int channel) const { + return begin(band, channel) + kBlockSize; + } + + // Access data via ArrayView. + rtc::ArrayView<float, kBlockSize> View(int band, int channel) { + return rtc::ArrayView<float, kBlockSize>(&data_[GetIndex(band, channel)], + kBlockSize); + } + + rtc::ArrayView<const float, kBlockSize> View(int band, int channel) const { + return rtc::ArrayView<const float, kBlockSize>( + &data_[GetIndex(band, channel)], kBlockSize); + } + + // Lets two Blocks swap audio data. + void Swap(Block& b) { + std::swap(num_bands_, b.num_bands_); + std::swap(num_channels_, b.num_channels_); + data_.swap(b.data_); + } + + private: + // Returns the index of the first sample of the requested |band| and + // |channel|. + int GetIndex(int band, int channel) const { + return (band * num_channels_ + channel) * kBlockSize; + } + + int num_bands_; + int num_channels_; + std::vector<float> data_; +}; + +} // namespace webrtc +#endif // MODULES_AUDIO_PROCESSING_AEC3_BLOCK_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.cc new file mode 100644 index 0000000000..289c3f0d10 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.cc @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_buffer.h" + +#include <algorithm> + +namespace webrtc { + +BlockBuffer::BlockBuffer(size_t size, size_t num_bands, size_t num_channels) + : size(static_cast<int>(size)), + buffer(size, Block(num_bands, num_channels)) {} + +BlockBuffer::~BlockBuffer() = default; + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.h new file mode 100644 index 0000000000..3489d51646 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_buffer.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_BLOCK_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_BLOCK_BUFFER_H_ + +#include <stddef.h> + +#include <vector> + +#include "modules/audio_processing/aec3/block.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Struct for bundling a circular buffer of two dimensional vector objects +// together with the read and write indices. +struct BlockBuffer { + BlockBuffer(size_t size, size_t num_bands, size_t num_channels); + ~BlockBuffer(); + + int IncIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index < size - 1 ? index + 1 : 0; + } + + int DecIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index > 0 ? index - 1 : size - 1; + } + + int OffsetIndex(int index, int offset) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + RTC_DCHECK_GE(size, offset); + return (size + index + offset) % size; + } + + void UpdateWriteIndex(int offset) { write = OffsetIndex(write, offset); } + void IncWriteIndex() { write = IncIndex(write); } + void DecWriteIndex() { write = DecIndex(write); } + void UpdateReadIndex(int offset) { read = OffsetIndex(read, offset); } + void IncReadIndex() { read = IncIndex(read); } + void DecReadIndex() { read = DecIndex(read); } + + const int size; + std::vector<Block> buffer; + int write = 0; + int read = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_BLOCK_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.cc new file mode 100644 index 0000000000..059bbafcdb --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/block_delay_buffer.h" + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +BlockDelayBuffer::BlockDelayBuffer(size_t num_channels, + size_t num_bands, + size_t frame_length, + size_t delay_samples) + : frame_length_(frame_length), + delay_(delay_samples), + buf_(num_channels, + std::vector<std::vector<float>>(num_bands, + std::vector<float>(delay_, 0.f))) {} + +BlockDelayBuffer::~BlockDelayBuffer() = default; + +void BlockDelayBuffer::DelaySignal(AudioBuffer* frame) { + RTC_DCHECK_EQ(buf_.size(), frame->num_channels()); + if (delay_ == 0) { + return; + } + + const size_t num_bands = buf_[0].size(); + const size_t num_channels = buf_.size(); + + const size_t i_start = last_insert_; + size_t i = 0; + for (size_t ch = 0; ch < num_channels; ++ch) { + RTC_DCHECK_EQ(buf_[ch].size(), frame->num_bands()); + RTC_DCHECK_EQ(buf_[ch].size(), num_bands); + rtc::ArrayView<float* const> frame_ch(frame->split_bands(ch), num_bands); + const size_t delay = delay_; + + for (size_t band = 0; band < num_bands; ++band) { + RTC_DCHECK_EQ(delay_, buf_[ch][band].size()); + i = i_start; + + // Offloading these pointers and class variables to local variables allows + // the compiler to optimize the below loop when compiling with + // '-fno-strict-aliasing'. + float* buf_ch_band = buf_[ch][band].data(); + float* frame_ch_band = frame_ch[band]; + + for (size_t k = 0, frame_length = frame_length_; k < frame_length; ++k) { + const float tmp = buf_ch_band[i]; + buf_ch_band[i] = frame_ch_band[k]; + frame_ch_band[k] = tmp; + + i = i < delay - 1 ? i + 1 : 0; + } + } + } + + last_insert_ = i; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.h new file mode 100644 index 0000000000..711a790bfe --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_BLOCK_DELAY_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_BLOCK_DELAY_BUFFER_H_ + +#include <stddef.h> + +#include <vector> + +#include "modules/audio_processing/audio_buffer.h" + +namespace webrtc { + +// Class for applying a fixed delay to the samples in a signal partitioned using +// the audiobuffer band-splitting scheme. +class BlockDelayBuffer { + public: + BlockDelayBuffer(size_t num_channels, + size_t num_bands, + size_t frame_length, + size_t delay_samples); + ~BlockDelayBuffer(); + + // Delays the samples by the specified delay. + void DelaySignal(AudioBuffer* frame); + + private: + const size_t frame_length_; + const size_t delay_; + std::vector<std::vector<std::vector<float>>> buf_; + size_t last_insert_ = 0; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_BLOCK_DELAY_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer_unittest.cc new file mode 100644 index 0000000000..011ab49651 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_delay_buffer_unittest.cc @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_delay_buffer.h" + +#include <string> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/audio_buffer.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { + +float SampleValue(size_t sample_index) { + return sample_index % 32768; +} + +// Populates the frame with linearly increasing sample values for each band. +void PopulateInputFrame(size_t frame_length, + size_t num_bands, + size_t first_sample_index, + float* const* frame) { + for (size_t k = 0; k < num_bands; ++k) { + for (size_t i = 0; i < frame_length; ++i) { + frame[k][i] = SampleValue(first_sample_index + i); + } + } +} + +std::string ProduceDebugText(int sample_rate_hz, size_t delay) { + char log_stream_buffer[8 * 1024]; + rtc::SimpleStringBuilder ss(log_stream_buffer); + ss << "Sample rate: " << sample_rate_hz; + ss << ", Delay: " << delay; + return ss.str(); +} + +} // namespace + +class BlockDelayBufferTest + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, int, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P( + ParameterCombinations, + BlockDelayBufferTest, + ::testing::Combine(::testing::Values(0, 1, 27, 160, 4321, 7021), + ::testing::Values(16000, 32000, 48000), + ::testing::Values(1, 2, 4))); + +// Verifies that the correct signal delay is achived. +TEST_P(BlockDelayBufferTest, CorrectDelayApplied) { + const size_t delay = std::get<0>(GetParam()); + const int rate = std::get<1>(GetParam()); + const size_t num_channels = std::get<2>(GetParam()); + + SCOPED_TRACE(ProduceDebugText(rate, delay)); + size_t num_bands = NumBandsForRate(rate); + size_t subband_frame_length = 160; + + BlockDelayBuffer delay_buffer(num_channels, num_bands, subband_frame_length, + delay); + + static constexpr size_t kNumFramesToProcess = 20; + for (size_t frame_index = 0; frame_index < kNumFramesToProcess; + ++frame_index) { + AudioBuffer audio_buffer(rate, num_channels, rate, num_channels, rate, + num_channels); + if (rate > 16000) { + audio_buffer.SplitIntoFrequencyBands(); + } + size_t first_sample_index = frame_index * subband_frame_length; + for (size_t ch = 0; ch < num_channels; ++ch) { + PopulateInputFrame(subband_frame_length, num_bands, first_sample_index, + &audio_buffer.split_bands(ch)[0]); + } + delay_buffer.DelaySignal(&audio_buffer); + + for (size_t ch = 0; ch < num_channels; ++ch) { + for (size_t band = 0; band < num_bands; ++band) { + size_t sample_index = first_sample_index; + for (size_t i = 0; i < subband_frame_length; ++i, ++sample_index) { + if (sample_index < delay) { + EXPECT_EQ(0.f, audio_buffer.split_bands(ch)[band][i]); + } else { + EXPECT_EQ(SampleValue(sample_index - delay), + audio_buffer.split_bands(ch)[band][i]); + } + } + } + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.cc new file mode 100644 index 0000000000..4243ddeba0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_framer.h" + +#include <algorithm> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +BlockFramer::BlockFramer(size_t num_bands, size_t num_channels) + : num_bands_(num_bands), + num_channels_(num_channels), + buffer_(num_bands_, + std::vector<std::vector<float>>( + num_channels, + std::vector<float>(kBlockSize, 0.f))) { + RTC_DCHECK_LT(0, num_bands); + RTC_DCHECK_LT(0, num_channels); +} + +BlockFramer::~BlockFramer() = default; + +// All the constants are chosen so that the buffer is either empty or has enough +// samples for InsertBlockAndExtractSubFrame to produce a frame. In order to +// achieve this, the InsertBlockAndExtractSubFrame and InsertBlock methods need +// to be called in the correct order. +void BlockFramer::InsertBlock(const Block& block) { + RTC_DCHECK_EQ(num_bands_, block.NumBands()); + RTC_DCHECK_EQ(num_channels_, block.NumChannels()); + for (size_t band = 0; band < num_bands_; ++band) { + for (size_t channel = 0; channel < num_channels_; ++channel) { + RTC_DCHECK_EQ(0, buffer_[band][channel].size()); + + buffer_[band][channel].insert(buffer_[band][channel].begin(), + block.begin(band, channel), + block.end(band, channel)); + } + } +} + +void BlockFramer::InsertBlockAndExtractSubFrame( + const Block& block, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame) { + RTC_DCHECK(sub_frame); + RTC_DCHECK_EQ(num_bands_, block.NumBands()); + RTC_DCHECK_EQ(num_channels_, block.NumChannels()); + RTC_DCHECK_EQ(num_bands_, sub_frame->size()); + for (size_t band = 0; band < num_bands_; ++band) { + RTC_DCHECK_EQ(num_channels_, (*sub_frame)[0].size()); + for (size_t channel = 0; channel < num_channels_; ++channel) { + RTC_DCHECK_LE(kSubFrameLength, + buffer_[band][channel].size() + kBlockSize); + RTC_DCHECK_GE(kBlockSize, buffer_[band][channel].size()); + RTC_DCHECK_EQ(kSubFrameLength, (*sub_frame)[band][channel].size()); + + const int samples_to_frame = + kSubFrameLength - buffer_[band][channel].size(); + std::copy(buffer_[band][channel].begin(), buffer_[band][channel].end(), + (*sub_frame)[band][channel].begin()); + std::copy( + block.begin(band, channel), + block.begin(band, channel) + samples_to_frame, + (*sub_frame)[band][channel].begin() + buffer_[band][channel].size()); + buffer_[band][channel].clear(); + buffer_[band][channel].insert( + buffer_[band][channel].begin(), + block.begin(band, channel) + samples_to_frame, + block.end(band, channel)); + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.h b/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.h new file mode 100644 index 0000000000..e2cdd5a17c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_framer.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_BLOCK_FRAMER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_BLOCK_FRAMER_H_ + +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block.h" + +namespace webrtc { + +// Class for producing frames consisting of 2 subframes of 80 samples each +// from 64 sample blocks. The class is designed to work together with the +// FrameBlocker class which performs the reverse conversion. Used together with +// that, this class produces output frames are the same rate as frames are +// received by the FrameBlocker class. Note that the internal buffers will +// overrun if any other rate of packets insertion is used. +class BlockFramer { + public: + BlockFramer(size_t num_bands, size_t num_channels); + ~BlockFramer(); + BlockFramer(const BlockFramer&) = delete; + BlockFramer& operator=(const BlockFramer&) = delete; + + // Adds a 64 sample block into the data that will form the next output frame. + void InsertBlock(const Block& block); + // Adds a 64 sample block and extracts an 80 sample subframe. + void InsertBlockAndExtractSubFrame( + const Block& block, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame); + + private: + const size_t num_bands_; + const size_t num_channels_; + std::vector<std::vector<std::vector<float>>> buffer_; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_BLOCK_FRAMER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_framer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_framer_unittest.cc new file mode 100644 index 0000000000..9439623f72 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_framer_unittest.cc @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_framer.h" + +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +void SetupSubFrameView( + std::vector<std::vector<std::vector<float>>>* sub_frame, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame_view) { + for (size_t band = 0; band < sub_frame_view->size(); ++band) { + for (size_t channel = 0; channel < (*sub_frame_view)[band].size(); + ++channel) { + (*sub_frame_view)[band][channel] = + rtc::ArrayView<float>((*sub_frame)[band][channel].data(), + (*sub_frame)[band][channel].size()); + } + } +} + +float ComputeSampleValue(size_t chunk_counter, + size_t chunk_size, + size_t band, + size_t channel, + size_t sample_index, + int offset) { + float value = static_cast<int>(100 + chunk_counter * chunk_size + + sample_index + channel) + + offset; + return 5000 * band + value; +} + +bool VerifySubFrame( + size_t sub_frame_counter, + int offset, + const std::vector<std::vector<rtc::ArrayView<float>>>& sub_frame_view) { + for (size_t band = 0; band < sub_frame_view.size(); ++band) { + for (size_t channel = 0; channel < sub_frame_view[band].size(); ++channel) { + for (size_t sample = 0; sample < sub_frame_view[band][channel].size(); + ++sample) { + const float reference_value = ComputeSampleValue( + sub_frame_counter, kSubFrameLength, band, channel, sample, offset); + if (reference_value != sub_frame_view[band][channel][sample]) { + return false; + } + } + } + } + return true; +} + +void FillBlock(size_t block_counter, Block* block) { + for (int band = 0; band < block->NumBands(); ++band) { + for (int channel = 0; channel < block->NumChannels(); ++channel) { + auto b = block->View(band, channel); + for (size_t sample = 0; sample < kBlockSize; ++sample) { + b[sample] = ComputeSampleValue(block_counter, kBlockSize, band, channel, + sample, 0); + } + } + } +} + +// Verifies that the BlockFramer is able to produce the expected frame content. +void RunFramerTest(int sample_rate_hz, size_t num_channels) { + constexpr size_t kNumSubFramesToProcess = 10; + const size_t num_bands = NumBandsForRate(sample_rate_hz); + + Block block(num_bands, num_channels); + std::vector<std::vector<std::vector<float>>> output_sub_frame( + num_bands, std::vector<std::vector<float>>( + num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> output_sub_frame_view( + num_bands, std::vector<rtc::ArrayView<float>>(num_channels)); + SetupSubFrameView(&output_sub_frame, &output_sub_frame_view); + BlockFramer framer(num_bands, num_channels); + + size_t block_index = 0; + for (size_t sub_frame_index = 0; sub_frame_index < kNumSubFramesToProcess; + ++sub_frame_index) { + FillBlock(block_index++, &block); + framer.InsertBlockAndExtractSubFrame(block, &output_sub_frame_view); + if (sub_frame_index > 1) { + EXPECT_TRUE(VerifySubFrame(sub_frame_index, -64, output_sub_frame_view)); + } + + if ((sub_frame_index + 1) % 4 == 0) { + FillBlock(block_index++, &block); + framer.InsertBlock(block); + } + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies that the BlockFramer crashes if the InsertBlockAndExtractSubFrame +// method is called for inputs with the wrong number of bands or band lengths. +void RunWronglySizedInsertAndExtractParametersTest( + int sample_rate_hz, + size_t correct_num_channels, + size_t num_block_bands, + size_t num_block_channels, + size_t num_sub_frame_bands, + size_t num_sub_frame_channels, + size_t sub_frame_length) { + const size_t correct_num_bands = NumBandsForRate(sample_rate_hz); + + Block block(num_block_bands, num_block_channels); + std::vector<std::vector<std::vector<float>>> output_sub_frame( + num_sub_frame_bands, + std::vector<std::vector<float>>( + num_sub_frame_channels, std::vector<float>(sub_frame_length, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> output_sub_frame_view( + output_sub_frame.size(), + std::vector<rtc::ArrayView<float>>(num_sub_frame_channels)); + SetupSubFrameView(&output_sub_frame, &output_sub_frame_view); + BlockFramer framer(correct_num_bands, correct_num_channels); + EXPECT_DEATH( + framer.InsertBlockAndExtractSubFrame(block, &output_sub_frame_view), ""); +} + +// Verifies that the BlockFramer crashes if the InsertBlock method is called for +// inputs with the wrong number of bands or band lengths. +void RunWronglySizedInsertParameterTest(int sample_rate_hz, + size_t correct_num_channels, + size_t num_block_bands, + size_t num_block_channels) { + const size_t correct_num_bands = NumBandsForRate(sample_rate_hz); + + Block correct_block(correct_num_bands, correct_num_channels); + Block wrong_block(num_block_bands, num_block_channels); + std::vector<std::vector<std::vector<float>>> output_sub_frame( + correct_num_bands, + std::vector<std::vector<float>>( + correct_num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> output_sub_frame_view( + output_sub_frame.size(), + std::vector<rtc::ArrayView<float>>(correct_num_channels)); + SetupSubFrameView(&output_sub_frame, &output_sub_frame_view); + BlockFramer framer(correct_num_bands, correct_num_channels); + framer.InsertBlockAndExtractSubFrame(correct_block, &output_sub_frame_view); + framer.InsertBlockAndExtractSubFrame(correct_block, &output_sub_frame_view); + framer.InsertBlockAndExtractSubFrame(correct_block, &output_sub_frame_view); + framer.InsertBlockAndExtractSubFrame(correct_block, &output_sub_frame_view); + + EXPECT_DEATH(framer.InsertBlock(wrong_block), ""); +} + +// Verifies that the BlockFramer crashes if the InsertBlock method is called +// after a wrong number of previous InsertBlockAndExtractSubFrame method calls +// have been made. + +void RunWronglyInsertOrderTest(int sample_rate_hz, + size_t num_channels, + size_t num_preceeding_api_calls) { + const size_t correct_num_bands = NumBandsForRate(sample_rate_hz); + + Block block(correct_num_bands, num_channels); + std::vector<std::vector<std::vector<float>>> output_sub_frame( + correct_num_bands, + std::vector<std::vector<float>>( + num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> output_sub_frame_view( + output_sub_frame.size(), + std::vector<rtc::ArrayView<float>>(num_channels)); + SetupSubFrameView(&output_sub_frame, &output_sub_frame_view); + BlockFramer framer(correct_num_bands, num_channels); + for (size_t k = 0; k < num_preceeding_api_calls; ++k) { + framer.InsertBlockAndExtractSubFrame(block, &output_sub_frame_view); + } + + EXPECT_DEATH(framer.InsertBlock(block), ""); +} +#endif + +std::string ProduceDebugText(int sample_rate_hz, size_t num_channels) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + ss << ", number of channels: " << num_channels; + return ss.Release(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +TEST(BlockFramerDeathTest, + WrongNumberOfBandsInBlockForInsertBlockAndExtractSubFrame) { + for (auto rate : {16000, 32000, 48000}) { + for (auto correct_num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_bands = (correct_num_bands % 3) + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, wrong_num_bands, correct_num_channels, + correct_num_bands, correct_num_channels, kSubFrameLength); + } + } +} + +TEST(BlockFramerDeathTest, + WrongNumberOfChannelsInBlockForInsertBlockAndExtractSubFrame) { + for (auto rate : {16000, 32000, 48000}) { + for (auto correct_num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_channels = correct_num_channels + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, wrong_num_channels, + correct_num_bands, correct_num_channels, kSubFrameLength); + } + } +} + +TEST(BlockFramerDeathTest, + WrongNumberOfBandsInSubFrameForInsertBlockAndExtractSubFrame) { + for (auto rate : {16000, 32000, 48000}) { + for (auto correct_num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_bands = (correct_num_bands % 3) + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, correct_num_channels, + wrong_num_bands, correct_num_channels, kSubFrameLength); + } + } +} + +TEST(BlockFramerDeathTest, + WrongNumberOfChannelsInSubFrameForInsertBlockAndExtractSubFrame) { + for (auto rate : {16000, 32000, 48000}) { + for (auto correct_num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_channels = correct_num_channels + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, correct_num_channels, + correct_num_bands, wrong_num_channels, kSubFrameLength); + } + } +} + +TEST(BlockFramerDeathTest, + WrongNumberOfSamplesInSubFrameForInsertBlockAndExtractSubFrame) { + const size_t correct_num_channels = 1; + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, correct_num_channels, + correct_num_bands, correct_num_channels, kSubFrameLength - 1); + } +} + +TEST(BlockFramerDeathTest, WrongNumberOfBandsInBlockForInsertBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (auto correct_num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_bands = (correct_num_bands % 3) + 1; + RunWronglySizedInsertParameterTest(rate, correct_num_channels, + wrong_num_bands, correct_num_channels); + } + } +} + +TEST(BlockFramerDeathTest, WrongNumberOfChannelsInBlockForInsertBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (auto correct_num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_channels = correct_num_channels + 1; + RunWronglySizedInsertParameterTest(rate, correct_num_channels, + correct_num_bands, wrong_num_channels); + } + } +} + +TEST(BlockFramerDeathTest, WrongNumberOfPreceedingApiCallsForInsertBlock) { + for (size_t num_channels : {1, 2, 8}) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t num_calls = 0; num_calls < 4; ++num_calls) { + rtc::StringBuilder ss; + ss << "Sample rate: " << rate; + ss << ", Num channels: " << num_channels; + ss << ", Num preceeding InsertBlockAndExtractSubFrame calls: " + << num_calls; + + SCOPED_TRACE(ss.str()); + RunWronglyInsertOrderTest(rate, num_channels, num_calls); + } + } + } +} + +// Verifies that the verification for 0 number of channels works. +TEST(BlockFramerDeathTest, ZeroNumberOfChannelsParameter) { + EXPECT_DEATH(BlockFramer(16000, 0), ""); +} + +// Verifies that the verification for 0 number of bands works. +TEST(BlockFramerDeathTest, ZeroNumberOfBandsParameter) { + EXPECT_DEATH(BlockFramer(0, 1), ""); +} + +// Verifies that the verification for null sub_frame pointer works. +TEST(BlockFramerDeathTest, NullSubFrameParameter) { + EXPECT_DEATH( + BlockFramer(1, 1).InsertBlockAndExtractSubFrame(Block(1, 1), nullptr), + ""); +} + +#endif + +TEST(BlockFramer, FrameBitexactness) { + for (auto rate : {16000, 32000, 48000}) { + for (auto num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, num_channels)); + RunFramerTest(rate, num_channels); + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.cc new file mode 100644 index 0000000000..63e3d9cc7c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.cc @@ -0,0 +1,290 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/block_processor.h" + +#include <stddef.h> + +#include <atomic> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" +#include "api/audio/echo_control.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block_processor_metrics.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/echo_remover.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/aec3/render_delay_controller.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace webrtc { +namespace { + +enum class BlockProcessorApiCall { kCapture, kRender }; + +class BlockProcessorImpl final : public BlockProcessor { + public: + BlockProcessorImpl(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels, + std::unique_ptr<RenderDelayBuffer> render_buffer, + std::unique_ptr<RenderDelayController> delay_controller, + std::unique_ptr<EchoRemover> echo_remover); + + BlockProcessorImpl() = delete; + + ~BlockProcessorImpl() override; + + void ProcessCapture(bool echo_path_gain_change, + bool capture_signal_saturation, + Block* linear_output, + Block* capture_block) override; + + void BufferRender(const Block& block) override; + + void UpdateEchoLeakageStatus(bool leakage_detected) override; + + void GetMetrics(EchoControl::Metrics* metrics) const override; + + void SetAudioBufferDelay(int delay_ms) override; + void SetCaptureOutputUsage(bool capture_output_used) override; + + private: + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const EchoCanceller3Config config_; + bool capture_properly_started_ = false; + bool render_properly_started_ = false; + const size_t sample_rate_hz_; + std::unique_ptr<RenderDelayBuffer> render_buffer_; + std::unique_ptr<RenderDelayController> delay_controller_; + std::unique_ptr<EchoRemover> echo_remover_; + BlockProcessorMetrics metrics_; + RenderDelayBuffer::BufferingEvent render_event_; + size_t capture_call_counter_ = 0; + absl::optional<DelayEstimate> estimated_delay_; +}; + +std::atomic<int> BlockProcessorImpl::instance_count_(0); + +BlockProcessorImpl::BlockProcessorImpl( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels, + std::unique_ptr<RenderDelayBuffer> render_buffer, + std::unique_ptr<RenderDelayController> delay_controller, + std::unique_ptr<EchoRemover> echo_remover) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + config_(config), + sample_rate_hz_(sample_rate_hz), + render_buffer_(std::move(render_buffer)), + delay_controller_(std::move(delay_controller)), + echo_remover_(std::move(echo_remover)), + render_event_(RenderDelayBuffer::BufferingEvent::kNone) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); +} + +BlockProcessorImpl::~BlockProcessorImpl() = default; + +void BlockProcessorImpl::ProcessCapture(bool echo_path_gain_change, + bool capture_signal_saturation, + Block* linear_output, + Block* capture_block) { + RTC_DCHECK(capture_block); + RTC_DCHECK_EQ(NumBandsForRate(sample_rate_hz_), capture_block->NumBands()); + + capture_call_counter_++; + + data_dumper_->DumpRaw("aec3_processblock_call_order", + static_cast<int>(BlockProcessorApiCall::kCapture)); + data_dumper_->DumpWav("aec3_processblock_capture_input", + capture_block->View(/*band=*/0, /*channel=*/0), 16000, + 1); + + if (render_properly_started_) { + if (!capture_properly_started_) { + capture_properly_started_ = true; + render_buffer_->Reset(); + if (delay_controller_) + delay_controller_->Reset(true); + } + } else { + // If no render data has yet arrived, do not process the capture signal. + render_buffer_->HandleSkippedCaptureProcessing(); + return; + } + + EchoPathVariability echo_path_variability( + echo_path_gain_change, EchoPathVariability::DelayAdjustment::kNone, + false); + + if (render_event_ == RenderDelayBuffer::BufferingEvent::kRenderOverrun && + render_properly_started_) { + echo_path_variability.delay_change = + EchoPathVariability::DelayAdjustment::kBufferFlush; + if (delay_controller_) + delay_controller_->Reset(true); + RTC_LOG(LS_WARNING) << "Reset due to render buffer overrun at block " + << capture_call_counter_; + } + render_event_ = RenderDelayBuffer::BufferingEvent::kNone; + + // Update the render buffers with any newly arrived render blocks and prepare + // the render buffers for reading the render data corresponding to the current + // capture block. + RenderDelayBuffer::BufferingEvent buffer_event = + render_buffer_->PrepareCaptureProcessing(); + // Reset the delay controller at render buffer underrun. + if (buffer_event == RenderDelayBuffer::BufferingEvent::kRenderUnderrun) { + if (delay_controller_) + delay_controller_->Reset(false); + } + + data_dumper_->DumpWav("aec3_processblock_capture_input2", + capture_block->View(/*band=*/0, /*channel=*/0), 16000, + 1); + + bool has_delay_estimator = !config_.delay.use_external_delay_estimator; + if (has_delay_estimator) { + RTC_DCHECK(delay_controller_); + // Compute and apply the render delay required to achieve proper signal + // alignment. + estimated_delay_ = delay_controller_->GetDelay( + render_buffer_->GetDownsampledRenderBuffer(), render_buffer_->Delay(), + *capture_block); + + if (estimated_delay_) { + bool delay_change = + render_buffer_->AlignFromDelay(estimated_delay_->delay); + if (delay_change) { + rtc::LoggingSeverity log_level = + config_.delay.log_warning_on_delay_changes ? rtc::LS_WARNING + : rtc::LS_INFO; + RTC_LOG_V(log_level) << "Delay changed to " << estimated_delay_->delay + << " at block " << capture_call_counter_; + echo_path_variability.delay_change = + EchoPathVariability::DelayAdjustment::kNewDetectedDelay; + } + } + + echo_path_variability.clock_drift = delay_controller_->HasClockdrift(); + + } else { + render_buffer_->AlignFromExternalDelay(); + } + + // Remove the echo from the capture signal. + if (has_delay_estimator || render_buffer_->HasReceivedBufferDelay()) { + echo_remover_->ProcessCapture( + echo_path_variability, capture_signal_saturation, estimated_delay_, + render_buffer_->GetRenderBuffer(), linear_output, capture_block); + } + + // Update the metrics. + metrics_.UpdateCapture(false); +} + +void BlockProcessorImpl::BufferRender(const Block& block) { + RTC_DCHECK_EQ(NumBandsForRate(sample_rate_hz_), block.NumBands()); + data_dumper_->DumpRaw("aec3_processblock_call_order", + static_cast<int>(BlockProcessorApiCall::kRender)); + data_dumper_->DumpWav("aec3_processblock_render_input", + block.View(/*band=*/0, /*channel=*/0), 16000, 1); + + render_event_ = render_buffer_->Insert(block); + + metrics_.UpdateRender(render_event_ != + RenderDelayBuffer::BufferingEvent::kNone); + + render_properly_started_ = true; + if (delay_controller_) + delay_controller_->LogRenderCall(); +} + +void BlockProcessorImpl::UpdateEchoLeakageStatus(bool leakage_detected) { + echo_remover_->UpdateEchoLeakageStatus(leakage_detected); +} + +void BlockProcessorImpl::GetMetrics(EchoControl::Metrics* metrics) const { + echo_remover_->GetMetrics(metrics); + constexpr int block_size_ms = 4; + absl::optional<size_t> delay = render_buffer_->Delay(); + metrics->delay_ms = delay ? static_cast<int>(*delay) * block_size_ms : 0; +} + +void BlockProcessorImpl::SetAudioBufferDelay(int delay_ms) { + render_buffer_->SetAudioBufferDelay(delay_ms); +} + +void BlockProcessorImpl::SetCaptureOutputUsage(bool capture_output_used) { + echo_remover_->SetCaptureOutputUsage(capture_output_used); +} + +} // namespace + +BlockProcessor* BlockProcessor::Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels) { + std::unique_ptr<RenderDelayBuffer> render_buffer( + RenderDelayBuffer::Create(config, sample_rate_hz, num_render_channels)); + std::unique_ptr<RenderDelayController> delay_controller; + if (!config.delay.use_external_delay_estimator) { + delay_controller.reset(RenderDelayController::Create(config, sample_rate_hz, + num_capture_channels)); + } + std::unique_ptr<EchoRemover> echo_remover(EchoRemover::Create( + config, sample_rate_hz, num_render_channels, num_capture_channels)); + return Create(config, sample_rate_hz, num_render_channels, + num_capture_channels, std::move(render_buffer), + std::move(delay_controller), std::move(echo_remover)); +} + +BlockProcessor* BlockProcessor::Create( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels, + std::unique_ptr<RenderDelayBuffer> render_buffer) { + std::unique_ptr<RenderDelayController> delay_controller; + if (!config.delay.use_external_delay_estimator) { + delay_controller.reset(RenderDelayController::Create(config, sample_rate_hz, + num_capture_channels)); + } + std::unique_ptr<EchoRemover> echo_remover(EchoRemover::Create( + config, sample_rate_hz, num_render_channels, num_capture_channels)); + return Create(config, sample_rate_hz, num_render_channels, + num_capture_channels, std::move(render_buffer), + std::move(delay_controller), std::move(echo_remover)); +} + +BlockProcessor* BlockProcessor::Create( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels, + std::unique_ptr<RenderDelayBuffer> render_buffer, + std::unique_ptr<RenderDelayController> delay_controller, + std::unique_ptr<EchoRemover> echo_remover) { + return new BlockProcessorImpl(config, sample_rate_hz, num_render_channels, + num_capture_channels, std::move(render_buffer), + std::move(delay_controller), + std::move(echo_remover)); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.h b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.h new file mode 100644 index 0000000000..01a83ae5f7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_BLOCK_PROCESSOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_BLOCK_PROCESSOR_H_ + +#include <stddef.h> + +#include <memory> +#include <vector> + +#include "api/audio/echo_canceller3_config.h" +#include "api/audio/echo_control.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/echo_remover.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/aec3/render_delay_controller.h" + +namespace webrtc { + +// Class for performing echo cancellation on 64 sample blocks of audio data. +class BlockProcessor { + public: + static BlockProcessor* Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels); + // Only used for testing purposes. + static BlockProcessor* Create( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels, + std::unique_ptr<RenderDelayBuffer> render_buffer); + static BlockProcessor* Create( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels, + std::unique_ptr<RenderDelayBuffer> render_buffer, + std::unique_ptr<RenderDelayController> delay_controller, + std::unique_ptr<EchoRemover> echo_remover); + + virtual ~BlockProcessor() = default; + + // Get current metrics. + virtual void GetMetrics(EchoControl::Metrics* metrics) const = 0; + + // Provides an optional external estimate of the audio buffer delay. + virtual void SetAudioBufferDelay(int delay_ms) = 0; + + // Processes a block of capture data. + virtual void ProcessCapture(bool echo_path_gain_change, + bool capture_signal_saturation, + Block* linear_output, + Block* capture_block) = 0; + + // Buffers a block of render data supplied by a FrameBlocker object. + virtual void BufferRender(const Block& render_block) = 0; + + // Reports whether echo leakage has been detected in the echo canceller + // output. + virtual void UpdateEchoLeakageStatus(bool leakage_detected) = 0; + + // Specifies whether the capture output will be used. The purpose of this is + // to allow the block processor to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + virtual void SetCaptureOutputUsage(bool capture_output_used) = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_BLOCK_PROCESSOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.cc new file mode 100644 index 0000000000..deac1fcd22 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_processor_metrics.h" + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/metrics.h" + +namespace webrtc { + +namespace { + +enum class RenderUnderrunCategory { + kNone, + kFew, + kSeveral, + kMany, + kConstant, + kNumCategories +}; + +enum class RenderOverrunCategory { + kNone, + kFew, + kSeveral, + kMany, + kConstant, + kNumCategories +}; + +} // namespace + +void BlockProcessorMetrics::UpdateCapture(bool underrun) { + ++capture_block_counter_; + if (underrun) { + ++render_buffer_underruns_; + } + + if (capture_block_counter_ == kMetricsReportingIntervalBlocks) { + metrics_reported_ = true; + + RenderUnderrunCategory underrun_category; + if (render_buffer_underruns_ == 0) { + underrun_category = RenderUnderrunCategory::kNone; + } else if (render_buffer_underruns_ > (capture_block_counter_ >> 1)) { + underrun_category = RenderUnderrunCategory::kConstant; + } else if (render_buffer_underruns_ > 100) { + underrun_category = RenderUnderrunCategory::kMany; + } else if (render_buffer_underruns_ > 10) { + underrun_category = RenderUnderrunCategory::kSeveral; + } else { + underrun_category = RenderUnderrunCategory::kFew; + } + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.Audio.EchoCanceller.RenderUnderruns", + static_cast<int>(underrun_category), + static_cast<int>(RenderUnderrunCategory::kNumCategories)); + + RenderOverrunCategory overrun_category; + if (render_buffer_overruns_ == 0) { + overrun_category = RenderOverrunCategory::kNone; + } else if (render_buffer_overruns_ > (buffer_render_calls_ >> 1)) { + overrun_category = RenderOverrunCategory::kConstant; + } else if (render_buffer_overruns_ > 100) { + overrun_category = RenderOverrunCategory::kMany; + } else if (render_buffer_overruns_ > 10) { + overrun_category = RenderOverrunCategory::kSeveral; + } else { + overrun_category = RenderOverrunCategory::kFew; + } + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.Audio.EchoCanceller.RenderOverruns", + static_cast<int>(overrun_category), + static_cast<int>(RenderOverrunCategory::kNumCategories)); + + ResetMetrics(); + capture_block_counter_ = 0; + } else { + metrics_reported_ = false; + } +} + +void BlockProcessorMetrics::UpdateRender(bool overrun) { + ++buffer_render_calls_; + if (overrun) { + ++render_buffer_overruns_; + } +} + +void BlockProcessorMetrics::ResetMetrics() { + render_buffer_underruns_ = 0; + render_buffer_overruns_ = 0; + buffer_render_calls_ = 0; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.h b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.h new file mode 100644 index 0000000000..a70d0dac5b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_BLOCK_PROCESSOR_METRICS_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_BLOCK_PROCESSOR_METRICS_H_ + +namespace webrtc { + +// Handles the reporting of metrics for the block_processor. +class BlockProcessorMetrics { + public: + BlockProcessorMetrics() = default; + + BlockProcessorMetrics(const BlockProcessorMetrics&) = delete; + BlockProcessorMetrics& operator=(const BlockProcessorMetrics&) = delete; + + // Updates the metric with new capture data. + void UpdateCapture(bool underrun); + + // Updates the metric with new render data. + void UpdateRender(bool overrun); + + // Returns true if the metrics have just been reported, otherwise false. + bool MetricsReported() { return metrics_reported_; } + + private: + // Resets the metrics. + void ResetMetrics(); + + int capture_block_counter_ = 0; + bool metrics_reported_ = false; + int render_buffer_underruns_ = 0; + int render_buffer_overruns_ = 0; + int buffer_render_calls_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_BLOCK_PROCESSOR_METRICS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics_unittest.cc new file mode 100644 index 0000000000..3e23c2499d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_metrics_unittest.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_processor_metrics.h" + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "test/gtest.h" + +namespace webrtc { + +// Verify the general functionality of BlockProcessorMetrics. +TEST(BlockProcessorMetrics, NormalUsage) { + BlockProcessorMetrics metrics; + + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < kMetricsReportingIntervalBlocks - 1; ++k) { + metrics.UpdateRender(false); + metrics.UpdateRender(false); + metrics.UpdateCapture(false); + EXPECT_FALSE(metrics.MetricsReported()); + } + metrics.UpdateCapture(false); + EXPECT_TRUE(metrics.MetricsReported()); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_unittest.cc new file mode 100644 index 0000000000..aba5c4186d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/block_processor_unittest.cc @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/block_processor.h" + +#include <memory> +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/mock/mock_echo_remover.h" +#include "modules/audio_processing/aec3/mock/mock_render_delay_buffer.h" +#include "modules/audio_processing/aec3/mock/mock_render_delay_controller.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/checks.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using ::testing::_; +using ::testing::AtLeast; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::StrictMock; + +// Verifies that the basic BlockProcessor functionality works and that the API +// methods are callable. +void RunBasicSetupAndApiCallTest(int sample_rate_hz, int num_iterations) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + + std::unique_ptr<BlockProcessor> block_processor( + BlockProcessor::Create(EchoCanceller3Config(), sample_rate_hz, + kNumRenderChannels, kNumCaptureChannels)); + Block block(NumBandsForRate(sample_rate_hz), kNumRenderChannels, 1000.f); + for (int k = 0; k < num_iterations; ++k) { + block_processor->BufferRender(block); + block_processor->ProcessCapture(false, false, nullptr, &block); + block_processor->UpdateEchoLeakageStatus(false); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +void RunRenderBlockSizeVerificationTest(int sample_rate_hz) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + + std::unique_ptr<BlockProcessor> block_processor( + BlockProcessor::Create(EchoCanceller3Config(), sample_rate_hz, + kNumRenderChannels, kNumCaptureChannels)); + Block block(NumBandsForRate(sample_rate_hz), kNumRenderChannels); + + EXPECT_DEATH(block_processor->BufferRender(block), ""); +} + +void RunRenderNumBandsVerificationTest(int sample_rate_hz) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + + const size_t wrong_num_bands = NumBandsForRate(sample_rate_hz) < 3 + ? NumBandsForRate(sample_rate_hz) + 1 + : 1; + std::unique_ptr<BlockProcessor> block_processor( + BlockProcessor::Create(EchoCanceller3Config(), sample_rate_hz, + kNumRenderChannels, kNumCaptureChannels)); + Block block(wrong_num_bands, kNumRenderChannels); + + EXPECT_DEATH(block_processor->BufferRender(block), ""); +} + +void RunCaptureNumBandsVerificationTest(int sample_rate_hz) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + + const size_t wrong_num_bands = NumBandsForRate(sample_rate_hz) < 3 + ? NumBandsForRate(sample_rate_hz) + 1 + : 1; + std::unique_ptr<BlockProcessor> block_processor( + BlockProcessor::Create(EchoCanceller3Config(), sample_rate_hz, + kNumRenderChannels, kNumCaptureChannels)); + Block block(wrong_num_bands, kNumRenderChannels); + + EXPECT_DEATH(block_processor->ProcessCapture(false, false, nullptr, &block), + ""); +} +#endif + +std::string ProduceDebugText(int sample_rate_hz) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.Release(); +} + +void FillSampleVector(int call_counter, + int delay, + rtc::ArrayView<float> samples) { + for (size_t i = 0; i < samples.size(); ++i) { + samples[i] = (call_counter - delay) * 10000.0f + i; + } +} + +} // namespace + +// Verifies that the delay controller functionality is properly integrated with +// the render delay buffer inside block processor. +// TODO(peah): Activate the unittest once the required code has been landed. +TEST(BlockProcessor, DISABLED_DelayControllerIntegration) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + constexpr size_t kNumBlocks = 310; + constexpr size_t kDelayInSamples = 640; + constexpr size_t kDelayHeadroom = 1; + constexpr size_t kDelayInBlocks = + kDelayInSamples / kBlockSize - kDelayHeadroom; + Random random_generator(42U); + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<testing::StrictMock<webrtc::test::MockRenderDelayBuffer>> + render_delay_buffer_mock( + new StrictMock<webrtc::test::MockRenderDelayBuffer>(rate, 1)); + EXPECT_CALL(*render_delay_buffer_mock, Insert(_)) + .Times(kNumBlocks) + .WillRepeatedly(Return(RenderDelayBuffer::BufferingEvent::kNone)); + EXPECT_CALL(*render_delay_buffer_mock, AlignFromDelay(kDelayInBlocks)) + .Times(AtLeast(1)); + EXPECT_CALL(*render_delay_buffer_mock, MaxDelay()).WillOnce(Return(30)); + EXPECT_CALL(*render_delay_buffer_mock, Delay()) + .Times(kNumBlocks + 1) + .WillRepeatedly(Return(0)); + std::unique_ptr<BlockProcessor> block_processor(BlockProcessor::Create( + EchoCanceller3Config(), rate, kNumRenderChannels, kNumCaptureChannels, + std::move(render_delay_buffer_mock))); + + Block render_block(NumBandsForRate(rate), kNumRenderChannels); + Block capture_block(NumBandsForRate(rate), kNumCaptureChannels); + DelayBuffer<float> signal_delay_buffer(kDelayInSamples); + for (size_t k = 0; k < kNumBlocks; ++k) { + RandomizeSampleVector(&random_generator, + render_block.View(/*band=*/0, /*capture=*/0)); + signal_delay_buffer.Delay(render_block.View(/*band=*/0, /*capture=*/0), + capture_block.View(/*band=*/0, /*capture=*/0)); + block_processor->BufferRender(render_block); + block_processor->ProcessCapture(false, false, nullptr, &capture_block); + } + } +} + +// Verifies that BlockProcessor submodules are called in a proper manner. +TEST(BlockProcessor, DISABLED_SubmoduleIntegration) { + constexpr size_t kNumBlocks = 310; + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + + Random random_generator(42U); + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<testing::StrictMock<webrtc::test::MockRenderDelayBuffer>> + render_delay_buffer_mock( + new StrictMock<webrtc::test::MockRenderDelayBuffer>(rate, 1)); + std::unique_ptr< + ::testing::StrictMock<webrtc::test::MockRenderDelayController>> + render_delay_controller_mock( + new StrictMock<webrtc::test::MockRenderDelayController>()); + std::unique_ptr<testing::StrictMock<webrtc::test::MockEchoRemover>> + echo_remover_mock(new StrictMock<webrtc::test::MockEchoRemover>()); + + EXPECT_CALL(*render_delay_buffer_mock, Insert(_)) + .Times(kNumBlocks - 1) + .WillRepeatedly(Return(RenderDelayBuffer::BufferingEvent::kNone)); + EXPECT_CALL(*render_delay_buffer_mock, PrepareCaptureProcessing()) + .Times(kNumBlocks); + EXPECT_CALL(*render_delay_buffer_mock, AlignFromDelay(9)).Times(AtLeast(1)); + EXPECT_CALL(*render_delay_buffer_mock, Delay()) + .Times(kNumBlocks) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*render_delay_controller_mock, GetDelay(_, _, _)) + .Times(kNumBlocks); + EXPECT_CALL(*echo_remover_mock, ProcessCapture(_, _, _, _, _, _)) + .Times(kNumBlocks); + EXPECT_CALL(*echo_remover_mock, UpdateEchoLeakageStatus(_)) + .Times(kNumBlocks); + + std::unique_ptr<BlockProcessor> block_processor(BlockProcessor::Create( + EchoCanceller3Config(), rate, kNumRenderChannels, kNumCaptureChannels, + std::move(render_delay_buffer_mock), + std::move(render_delay_controller_mock), std::move(echo_remover_mock))); + + Block render_block(NumBandsForRate(rate), kNumRenderChannels); + Block capture_block(NumBandsForRate(rate), kNumCaptureChannels); + DelayBuffer<float> signal_delay_buffer(640); + for (size_t k = 0; k < kNumBlocks; ++k) { + RandomizeSampleVector(&random_generator, + render_block.View(/*band=*/0, /*capture=*/0)); + signal_delay_buffer.Delay(render_block.View(/*band=*/0, /*capture=*/0), + capture_block.View(/*band=*/0, /*capture=*/0)); + block_processor->BufferRender(render_block); + block_processor->ProcessCapture(false, false, nullptr, &capture_block); + block_processor->UpdateEchoLeakageStatus(false); + } + } +} + +TEST(BlockProcessor, BasicSetupAndApiCalls) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + RunBasicSetupAndApiCallTest(rate, 1); + } +} + +TEST(BlockProcessor, TestLongerCall) { + RunBasicSetupAndApiCallTest(16000, 20 * kNumBlocksPerSecond); +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// TODO(gustaf): Re-enable the test once the issue with memory leaks during +// DEATH tests on test bots has been fixed. +TEST(BlockProcessorDeathTest, DISABLED_VerifyRenderBlockSizeCheck) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + RunRenderBlockSizeVerificationTest(rate); + } +} + +TEST(BlockProcessorDeathTest, VerifyRenderNumBandsCheck) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + RunRenderNumBandsVerificationTest(rate); + } +} + +// TODO(peah): Verify the check for correct number of bands in the capture +// signal. +TEST(BlockProcessorDeathTest, VerifyCaptureNumBandsCheck) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + RunCaptureNumBandsVerificationTest(rate); + } +} + +// Verifiers that the verification for null ProcessCapture input works. +TEST(BlockProcessorDeathTest, NullProcessCaptureParameter) { + EXPECT_DEATH(std::unique_ptr<BlockProcessor>( + BlockProcessor::Create(EchoCanceller3Config(), 16000, 1, 1)) + ->ProcessCapture(false, false, nullptr, nullptr), + ""); +} + +// Verifies the check for correct sample rate. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(BlockProcessor, DISABLED_WrongSampleRate) { + EXPECT_DEATH(std::unique_ptr<BlockProcessor>( + BlockProcessor::Create(EchoCanceller3Config(), 8001, 1, 1)), + ""); +} + +#endif + +// Verifies that external delay estimator delays are applied correctly when a +// call begins with a sequence of capture blocks. +TEST(BlockProcessor, ExternalDelayAppliedCorrectlyWithInitialCaptureCalls) { + constexpr int kNumRenderChannels = 1; + constexpr int kNumCaptureChannels = 1; + constexpr int kSampleRateHz = 16000; + + EchoCanceller3Config config; + config.delay.use_external_delay_estimator = true; + + std::unique_ptr<RenderDelayBuffer> delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); + + std::unique_ptr<testing::NiceMock<webrtc::test::MockEchoRemover>> + echo_remover_mock(new NiceMock<webrtc::test::MockEchoRemover>()); + webrtc::test::MockEchoRemover* echo_remover_mock_pointer = + echo_remover_mock.get(); + + std::unique_ptr<BlockProcessor> block_processor(BlockProcessor::Create( + config, kSampleRateHz, kNumRenderChannels, kNumCaptureChannels, + std::move(delay_buffer), /*delay_controller=*/nullptr, + std::move(echo_remover_mock))); + + Block render_block(NumBandsForRate(kSampleRateHz), kNumRenderChannels); + Block capture_block(NumBandsForRate(kSampleRateHz), kNumCaptureChannels); + + // Process... + // - 10 capture calls, where no render data is available, + // - 10 render calls, populating the buffer, + // - 2 capture calls, verifying that the delay was applied correctly. + constexpr int kDelayInBlocks = 5; + constexpr int kDelayInMs = 20; + block_processor->SetAudioBufferDelay(kDelayInMs); + + int capture_call_counter = 0; + int render_call_counter = 0; + for (size_t k = 0; k < 10; ++k) { + FillSampleVector(++capture_call_counter, kDelayInBlocks, + capture_block.View(/*band=*/0, /*capture=*/0)); + block_processor->ProcessCapture(false, false, nullptr, &capture_block); + } + for (size_t k = 0; k < 10; ++k) { + FillSampleVector(++render_call_counter, 0, + render_block.View(/*band=*/0, /*capture=*/0)); + block_processor->BufferRender(render_block); + } + + EXPECT_CALL(*echo_remover_mock_pointer, ProcessCapture) + .WillRepeatedly( + [](EchoPathVariability /*echo_path_variability*/, + bool /*capture_signal_saturation*/, + const absl::optional<DelayEstimate>& /*external_delay*/, + RenderBuffer* render_buffer, Block* /*linear_output*/, + Block* capture) { + const auto& render = render_buffer->GetBlock(0); + const auto render_view = render.View(/*band=*/0, /*channel=*/0); + const auto capture_view = capture->View(/*band=*/0, /*channel=*/0); + for (size_t i = 0; i < kBlockSize; ++i) { + EXPECT_FLOAT_EQ(render_view[i], capture_view[i]); + } + }); + + FillSampleVector(++capture_call_counter, kDelayInBlocks, + capture_block.View(/*band=*/0, /*capture=*/0)); + block_processor->ProcessCapture(false, false, nullptr, &capture_block); + + FillSampleVector(++capture_call_counter, kDelayInBlocks, + capture_block.View(/*band=*/0, /*capture=*/0)); + block_processor->ProcessCapture(false, false, nullptr, &capture_block); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.cc b/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.cc new file mode 100644 index 0000000000..2c49b795c4 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/clockdrift_detector.h" + +namespace webrtc { + +ClockdriftDetector::ClockdriftDetector() + : level_(Level::kNone), stability_counter_(0) { + delay_history_.fill(0); +} + +ClockdriftDetector::~ClockdriftDetector() = default; + +void ClockdriftDetector::Update(int delay_estimate) { + if (delay_estimate == delay_history_[0]) { + // Reset clockdrift level if delay estimate is stable for 7500 blocks (30 + // seconds). + if (++stability_counter_ > 7500) + level_ = Level::kNone; + return; + } + + stability_counter_ = 0; + const int d1 = delay_history_[0] - delay_estimate; + const int d2 = delay_history_[1] - delay_estimate; + const int d3 = delay_history_[2] - delay_estimate; + + // Patterns recognized as positive clockdrift: + // [x-3], x-2, x-1, x. + // [x-3], x-1, x-2, x. + const bool probable_drift_up = + (d1 == -1 && d2 == -2) || (d1 == -2 && d2 == -1); + const bool drift_up = probable_drift_up && d3 == -3; + + // Patterns recognized as negative clockdrift: + // [x+3], x+2, x+1, x. + // [x+3], x+1, x+2, x. + const bool probable_drift_down = (d1 == 1 && d2 == 2) || (d1 == 2 && d2 == 1); + const bool drift_down = probable_drift_down && d3 == 3; + + // Set clockdrift level. + if (drift_up || drift_down) { + level_ = Level::kVerified; + } else if ((probable_drift_up || probable_drift_down) && + level_ == Level::kNone) { + level_ = Level::kProbable; + } + + // Shift delay history one step. + delay_history_[2] = delay_history_[1]; + delay_history_[1] = delay_history_[0]; + delay_history_[0] = delay_estimate; +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.h b/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.h new file mode 100644 index 0000000000..2ba90bb889 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_CLOCKDRIFT_DETECTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_CLOCKDRIFT_DETECTOR_H_ + +#include <stddef.h> + +#include <array> + +namespace webrtc { + +class ApmDataDumper; +struct DownsampledRenderBuffer; +struct EchoCanceller3Config; + +// Detects clockdrift by analyzing the estimated delay. +class ClockdriftDetector { + public: + enum class Level { kNone, kProbable, kVerified, kNumCategories }; + ClockdriftDetector(); + ~ClockdriftDetector(); + void Update(int delay_estimate); + Level ClockdriftLevel() const { return level_; } + + private: + std::array<int, 3> delay_history_; + Level level_; + size_t stability_counter_; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_CLOCKDRIFT_DETECTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector_unittest.cc new file mode 100644 index 0000000000..0f98b01d3a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/clockdrift_detector_unittest.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/clockdrift_detector.h" + +#include "test/gtest.h" + +namespace webrtc { +TEST(ClockdriftDetector, ClockdriftDetector) { + ClockdriftDetector c; + // No clockdrift at start. + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kNone); + + // Monotonically increasing delay. + for (int i = 0; i < 100; i++) + c.Update(1000); + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kNone); + for (int i = 0; i < 100; i++) + c.Update(1001); + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kNone); + for (int i = 0; i < 100; i++) + c.Update(1002); + // Probable clockdrift. + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kProbable); + for (int i = 0; i < 100; i++) + c.Update(1003); + // Verified clockdrift. + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kVerified); + + // Stable delay. + for (int i = 0; i < 10000; i++) + c.Update(1003); + // No clockdrift. + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kNone); + + // Decreasing delay. + for (int i = 0; i < 100; i++) + c.Update(1001); + for (int i = 0; i < 100; i++) + c.Update(999); + // Probable clockdrift. + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kProbable); + for (int i = 0; i < 100; i++) + c.Update(1000); + for (int i = 0; i < 100; i++) + c.Update(998); + // Verified clockdrift. + EXPECT_TRUE(c.ClockdriftLevel() == ClockdriftDetector::Level::kVerified); +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.cc b/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.cc new file mode 100644 index 0000000000..f4fb74d20d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.cc @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/coarse_filter_update_gain.h" + +#include <algorithm> +#include <functional> + +#include "rtc_base/checks.h" + +namespace webrtc { + +CoarseFilterUpdateGain::CoarseFilterUpdateGain( + const EchoCanceller3Config::Filter::CoarseConfiguration& config, + size_t config_change_duration_blocks) + : config_change_duration_blocks_( + static_cast<int>(config_change_duration_blocks)) { + SetConfig(config, true); + RTC_DCHECK_LT(0, config_change_duration_blocks_); + one_by_config_change_duration_blocks_ = 1.f / config_change_duration_blocks_; +} + +void CoarseFilterUpdateGain::HandleEchoPathChange() { + poor_signal_excitation_counter_ = 0; + call_counter_ = 0; +} + +void CoarseFilterUpdateGain::Compute( + const std::array<float, kFftLengthBy2Plus1>& render_power, + const RenderSignalAnalyzer& render_signal_analyzer, + const FftData& E_coarse, + size_t size_partitions, + bool saturated_capture_signal, + FftData* G) { + RTC_DCHECK(G); + ++call_counter_; + + UpdateCurrentConfig(); + + if (render_signal_analyzer.PoorSignalExcitation()) { + poor_signal_excitation_counter_ = 0; + } + + // Do not update the filter if the render is not sufficiently excited. + if (++poor_signal_excitation_counter_ < size_partitions || + saturated_capture_signal || call_counter_ <= size_partitions) { + G->re.fill(0.f); + G->im.fill(0.f); + return; + } + + // Compute mu. + std::array<float, kFftLengthBy2Plus1> mu; + const auto& X2 = render_power; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (X2[k] > current_config_.noise_gate) { + mu[k] = current_config_.rate / X2[k]; + } else { + mu[k] = 0.f; + } + } + + // Avoid updating the filter close to narrow bands in the render signals. + render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu); + + // G = mu * E * X2. + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + G->re[k] = mu[k] * E_coarse.re[k]; + G->im[k] = mu[k] * E_coarse.im[k]; + } +} + +void CoarseFilterUpdateGain::UpdateCurrentConfig() { + RTC_DCHECK_GE(config_change_duration_blocks_, config_change_counter_); + if (config_change_counter_ > 0) { + if (--config_change_counter_ > 0) { + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + float change_factor = + config_change_counter_ * one_by_config_change_duration_blocks_; + + current_config_.rate = + average(old_target_config_.rate, target_config_.rate, change_factor); + current_config_.noise_gate = + average(old_target_config_.noise_gate, target_config_.noise_gate, + change_factor); + } else { + current_config_ = old_target_config_ = target_config_; + } + } + RTC_DCHECK_LE(0, config_change_counter_); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.h b/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.h new file mode 100644 index 0000000000..a1a1399b2c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_COARSE_FILTER_UPDATE_GAIN_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_COARSE_FILTER_UPDATE_GAIN_H_ + +#include <stddef.h> + +#include <array> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" + +namespace webrtc { + +// Provides functionality for computing the fixed gain for the coarse filter. +class CoarseFilterUpdateGain { + public: + explicit CoarseFilterUpdateGain( + const EchoCanceller3Config::Filter::CoarseConfiguration& config, + size_t config_change_duration_blocks); + + // Takes action in the case of a known echo path change. + void HandleEchoPathChange(); + + // Computes the gain. + void Compute(const std::array<float, kFftLengthBy2Plus1>& render_power, + const RenderSignalAnalyzer& render_signal_analyzer, + const FftData& E_coarse, + size_t size_partitions, + bool saturated_capture_signal, + FftData* G); + + // Sets a new config. + void SetConfig( + const EchoCanceller3Config::Filter::CoarseConfiguration& config, + bool immediate_effect) { + if (immediate_effect) { + old_target_config_ = current_config_ = target_config_ = config; + config_change_counter_ = 0; + } else { + old_target_config_ = current_config_; + target_config_ = config; + config_change_counter_ = config_change_duration_blocks_; + } + } + + private: + EchoCanceller3Config::Filter::CoarseConfiguration current_config_; + EchoCanceller3Config::Filter::CoarseConfiguration target_config_; + EchoCanceller3Config::Filter::CoarseConfiguration old_target_config_; + const int config_change_duration_blocks_; + float one_by_config_change_duration_blocks_; + // TODO(peah): Check whether this counter should instead be initialized to a + // large value. + size_t poor_signal_excitation_counter_ = 0; + size_t call_counter_ = 0; + int config_change_counter_ = 0; + + void UpdateCurrentConfig(); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_COARSE_FILTER_UPDATE_GAIN_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain_unittest.cc new file mode 100644 index 0000000000..55b79bb812 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/coarse_filter_update_gain_unittest.cc @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/coarse_filter_update_gain.h" + +#include <algorithm> +#include <memory> +#include <numeric> +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/numerics/safe_minmax.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { +// Method for performing the simulations needed to test the refined filter +// update gain functionality. +void RunFilterUpdateTest(int num_blocks_to_process, + size_t delay_samples, + size_t num_render_channels, + int filter_length_blocks, + const std::vector<int>& blocks_with_saturation, + std::array<float, kBlockSize>* e_last_block, + std::array<float, kBlockSize>* y_last_block, + FftData* G_last_block) { + ApmDataDumper data_dumper(42); + EchoCanceller3Config config; + config.filter.refined.length_blocks = filter_length_blocks; + AdaptiveFirFilter refined_filter( + config.filter.refined.length_blocks, config.filter.refined.length_blocks, + config.filter.config_change_duration_blocks, num_render_channels, + DetectOptimization(), &data_dumper); + AdaptiveFirFilter coarse_filter( + config.filter.coarse.length_blocks, config.filter.coarse.length_blocks, + config.filter.config_change_duration_blocks, num_render_channels, + DetectOptimization(), &data_dumper); + Aec3Fft fft; + + constexpr int kSampleRateHz = 48000; + config.delay.default_delay = 1; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + + CoarseFilterUpdateGain coarse_gain( + config.filter.coarse, config.filter.config_change_duration_blocks); + Random random_generator(42U); + Block x(NumBandsForRate(kSampleRateHz), num_render_channels); + std::array<float, kBlockSize> y; + RenderSignalAnalyzer render_signal_analyzer(config); + std::array<float, kFftLength> s; + FftData S; + FftData G; + FftData E_coarse; + std::array<float, kBlockSize> e_coarse; + + constexpr float kScale = 1.0f / kFftLengthBy2; + + DelayBuffer<float> delay_buffer(delay_samples); + for (int k = 0; k < num_blocks_to_process; ++k) { + // Handle saturation. + bool saturation = + std::find(blocks_with_saturation.begin(), blocks_with_saturation.end(), + k) != blocks_with_saturation.end(); + + // Create the render signal. + for (int band = 0; band < x.NumBands(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + RandomizeSampleVector(&random_generator, x.View(band, channel)); + } + } + delay_buffer.Delay(x.View(/*band=*/0, /*channel*/ 0), y); + + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + + render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(), + delay_samples / kBlockSize); + + coarse_filter.Filter(*render_delay_buffer->GetRenderBuffer(), &S); + fft.Ifft(S, &s); + std::transform(y.begin(), y.end(), s.begin() + kFftLengthBy2, + e_coarse.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e_coarse.begin(), e_coarse.end(), + [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); + fft.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kRectangular, &E_coarse); + + std::array<float, kFftLengthBy2Plus1> render_power; + render_delay_buffer->GetRenderBuffer()->SpectralSum( + coarse_filter.SizePartitions(), &render_power); + coarse_gain.Compute(render_power, render_signal_analyzer, E_coarse, + coarse_filter.SizePartitions(), saturation, &G); + coarse_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G); + } + + std::copy(e_coarse.begin(), e_coarse.end(), e_last_block->begin()); + std::copy(y.begin(), y.end(), y_last_block->begin()); + std::copy(G.re.begin(), G.re.end(), G_last_block->re.begin()); + std::copy(G.im.begin(), G.im.end(), G_last_block->im.begin()); +} + +std::string ProduceDebugText(int filter_length_blocks) { + rtc::StringBuilder ss; + ss << "Length: " << filter_length_blocks; + return ss.Release(); +} + +std::string ProduceDebugText(size_t delay, int filter_length_blocks) { + rtc::StringBuilder ss; + ss << "Delay: " << delay << ", "; + ss << ProduceDebugText(filter_length_blocks); + return ss.Release(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gain parameter works. +TEST(CoarseFilterUpdateGainDeathTest, NullDataOutputGain) { + ApmDataDumper data_dumper(42); + FftBuffer fft_buffer(1, 1); + RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); + FftData E; + const EchoCanceller3Config::Filter::CoarseConfiguration& config = { + 12, 0.5f, 220075344.f}; + CoarseFilterUpdateGain gain(config, 250); + std::array<float, kFftLengthBy2Plus1> render_power; + render_power.fill(0.f); + EXPECT_DEATH(gain.Compute(render_power, analyzer, E, 1, false, nullptr), ""); +} + +#endif + +class CoarseFilterUpdateGainOneTwoEightRenderChannels + : public ::testing::Test, + public ::testing::WithParamInterface<size_t> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + CoarseFilterUpdateGainOneTwoEightRenderChannels, + ::testing::Values(1, 2, 8)); + +// Verifies that the gain formed causes the filter using it to converge. +TEST_P(CoarseFilterUpdateGainOneTwoEightRenderChannels, + GainCausesFilterToConverge) { + const size_t num_render_channels = GetParam(); + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + + for (size_t filter_length_blocks : {12, 20, 30}) { + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); + + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G; + + RunFilterUpdateTest(5000, delay_samples, num_render_channels, + filter_length_blocks, blocks_with_saturation, &e, &y, + &G); + + // Verify that the refined filter is able to perform well. + // Use different criteria to take overmodelling into account. + if (filter_length_blocks == 12) { + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } else { + EXPECT_LT(std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } + } + } +} + +// Verifies that the gain is zero when there is saturation. +TEST_P(CoarseFilterUpdateGainOneTwoEightRenderChannels, SaturationBehavior) { + const size_t num_render_channels = GetParam(); + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + for (int k = 99; k < 200; ++k) { + blocks_with_saturation.push_back(k); + } + for (size_t filter_length_blocks : {12, 20, 30}) { + SCOPED_TRACE(ProduceDebugText(filter_length_blocks)); + + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G_a; + FftData G_a_ref; + G_a_ref.re.fill(0.f); + G_a_ref.im.fill(0.f); + + RunFilterUpdateTest(100, 65, num_render_channels, filter_length_blocks, + blocks_with_saturation, &e, &y, &G_a); + + EXPECT_EQ(G_a_ref.re, G_a.re); + EXPECT_EQ(G_a_ref.im, G_a.im); + } +} + +class CoarseFilterUpdateGainOneTwoFourRenderChannels + : public ::testing::Test, + public ::testing::WithParamInterface<size_t> {}; + +INSTANTIATE_TEST_SUITE_P( + MultiChannel, + CoarseFilterUpdateGainOneTwoFourRenderChannels, + ::testing::Values(1, 2, 4), + [](const ::testing::TestParamInfo< + CoarseFilterUpdateGainOneTwoFourRenderChannels::ParamType>& info) { + return (rtc::StringBuilder() << "Render" << info.param).str(); + }); + +// Verifies that the magnitude of the gain on average decreases for a +// persistently exciting signal. +TEST_P(CoarseFilterUpdateGainOneTwoFourRenderChannels, DecreasingGain) { + const size_t num_render_channels = GetParam(); + for (size_t filter_length_blocks : {12, 20, 30}) { + SCOPED_TRACE(ProduceDebugText(filter_length_blocks)); + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G_a; + FftData G_b; + FftData G_c; + std::array<float, kFftLengthBy2Plus1> G_a_power; + std::array<float, kFftLengthBy2Plus1> G_b_power; + std::array<float, kFftLengthBy2Plus1> G_c_power; + + RunFilterUpdateTest(100, 65, num_render_channels, filter_length_blocks, + blocks_with_saturation, &e, &y, &G_a); + RunFilterUpdateTest(200, 65, num_render_channels, filter_length_blocks, + blocks_with_saturation, &e, &y, &G_b); + RunFilterUpdateTest(300, 65, num_render_channels, filter_length_blocks, + blocks_with_saturation, &e, &y, &G_c); + + G_a.Spectrum(Aec3Optimization::kNone, G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, G_b_power); + G_c.Spectrum(Aec3Optimization::kNone, G_c_power); + + EXPECT_GT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); + + EXPECT_GT(std::accumulate(G_b_power.begin(), G_b_power.end(), 0.), + std::accumulate(G_c_power.begin(), G_c_power.end(), 0.)); + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.cc new file mode 100644 index 0000000000..de5227c089 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.cc @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/comfort_noise_generator.h" + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <algorithm> +#include <array> +#include <cmath> +#include <cstdint> +#include <functional> +#include <numeric> + +#include "common_audio/signal_processing/include/signal_processing_library.h" +#include "modules/audio_processing/aec3/vector_math.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +// Computes the noise floor value that matches a WGN input of noise_floor_dbfs. +float GetNoiseFloorFactor(float noise_floor_dbfs) { + // kdBfsNormalization = 20.f*log10(32768.f). + constexpr float kdBfsNormalization = 90.30899869919436f; + return 64.f * powf(10.f, (kdBfsNormalization + noise_floor_dbfs) * 0.1f); +} + +// Table of sqrt(2) * sin(2*pi*i/32). +constexpr float kSqrt2Sin[32] = { + +0.0000000f, +0.2758994f, +0.5411961f, +0.7856950f, +1.0000000f, + +1.1758756f, +1.3065630f, +1.3870398f, +1.4142136f, +1.3870398f, + +1.3065630f, +1.1758756f, +1.0000000f, +0.7856950f, +0.5411961f, + +0.2758994f, +0.0000000f, -0.2758994f, -0.5411961f, -0.7856950f, + -1.0000000f, -1.1758756f, -1.3065630f, -1.3870398f, -1.4142136f, + -1.3870398f, -1.3065630f, -1.1758756f, -1.0000000f, -0.7856950f, + -0.5411961f, -0.2758994f}; + +void GenerateComfortNoise(Aec3Optimization optimization, + const std::array<float, kFftLengthBy2Plus1>& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise) { + FftData* N_low = lower_band_noise; + FftData* N_high = upper_band_noise; + + // Compute square root spectrum. + std::array<float, kFftLengthBy2Plus1> N; + std::copy(N2.begin(), N2.end(), N.begin()); + aec3::VectorMath(optimization).Sqrt(N); + + // Compute the noise level for the upper bands. + constexpr float kOneByNumBands = 1.f / (kFftLengthBy2Plus1 / 2 + 1); + constexpr int kFftLengthBy2Plus1By2 = kFftLengthBy2Plus1 / 2; + const float high_band_noise_level = + std::accumulate(N.begin() + kFftLengthBy2Plus1By2, N.end(), 0.f) * + kOneByNumBands; + + // The analysis and synthesis windowing cause loss of power when + // cross-fading the noise where frames are completely uncorrelated + // (generated with random phase), hence the factor sqrt(2). + // This is not the case for the speech signal where the input is overlapping + // (strong correlation). + N_low->re[0] = N_low->re[kFftLengthBy2] = N_high->re[0] = + N_high->re[kFftLengthBy2] = 0.f; + for (size_t k = 1; k < kFftLengthBy2; k++) { + constexpr int kIndexMask = 32 - 1; + // Generate a random 31-bit integer. + seed[0] = (seed[0] * 69069 + 1) & (0x80000000 - 1); + // Convert to a 5-bit index. + int i = seed[0] >> 26; + + // y = sqrt(2) * sin(a) + const float x = kSqrt2Sin[i]; + // x = sqrt(2) * cos(a) = sqrt(2) * sin(a + pi/2) + const float y = kSqrt2Sin[(i + 8) & kIndexMask]; + + // Form low-frequency noise via spectral shaping. + N_low->re[k] = N[k] * x; + N_low->im[k] = N[k] * y; + + // Form the high-frequency noise via simple levelling. + N_high->re[k] = high_band_noise_level * x; + N_high->im[k] = high_band_noise_level * y; + } +} + +} // namespace + +ComfortNoiseGenerator::ComfortNoiseGenerator(const EchoCanceller3Config& config, + Aec3Optimization optimization, + size_t num_capture_channels) + : optimization_(optimization), + seed_(42), + num_capture_channels_(num_capture_channels), + noise_floor_(GetNoiseFloorFactor(config.comfort_noise.noise_floor_dbfs)), + N2_initial_( + std::make_unique<std::vector<std::array<float, kFftLengthBy2Plus1>>>( + num_capture_channels_)), + Y2_smoothed_(num_capture_channels_), + N2_(num_capture_channels_) { + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + (*N2_initial_)[ch].fill(0.f); + Y2_smoothed_[ch].fill(0.f); + N2_[ch].fill(1.0e6f); + } +} + +ComfortNoiseGenerator::~ComfortNoiseGenerator() = default; + +void ComfortNoiseGenerator::Compute( + bool saturated_capture, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + capture_spectrum, + rtc::ArrayView<FftData> lower_band_noise, + rtc::ArrayView<FftData> upper_band_noise) { + const auto& Y2 = capture_spectrum; + + if (!saturated_capture) { + // Smooth Y2. + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::transform(Y2_smoothed_[ch].begin(), Y2_smoothed_[ch].end(), + Y2[ch].begin(), Y2_smoothed_[ch].begin(), + [](float a, float b) { return a + 0.1f * (b - a); }); + } + + if (N2_counter_ > 50) { + // Update N2 from Y2_smoothed. + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::transform(N2_[ch].begin(), N2_[ch].end(), Y2_smoothed_[ch].begin(), + N2_[ch].begin(), [](float a, float b) { + return b < a ? (0.9f * b + 0.1f * a) * 1.0002f + : a * 1.0002f; + }); + } + } + + if (N2_initial_) { + if (++N2_counter_ == 1000) { + N2_initial_.reset(); + } else { + // Compute the N2_initial from N2. + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::transform(N2_[ch].begin(), N2_[ch].end(), + (*N2_initial_)[ch].begin(), (*N2_initial_)[ch].begin(), + [](float a, float b) { + return a > b ? b + 0.001f * (a - b) : a; + }); + } + } + } + + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + for (auto& n : N2_[ch]) { + n = std::max(n, noise_floor_); + } + if (N2_initial_) { + for (auto& n : (*N2_initial_)[ch]) { + n = std::max(n, noise_floor_); + } + } + } + } + + // Choose N2 estimate to use. + const auto& N2 = N2_initial_ ? (*N2_initial_) : N2_; + + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + GenerateComfortNoise(optimization_, N2[ch], &seed_, &lower_band_noise[ch], + &upper_band_noise[ch]); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.h b/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.h new file mode 100644 index 0000000000..2785b765c5 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_COMFORT_NOISE_GENERATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_COMFORT_NOISE_GENERATOR_H_ + +#include <stdint.h> + +#include <array> +#include <memory> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { +namespace aec3 { +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void EstimateComfortNoise_SSE2(const std::array<float, kFftLengthBy2Plus1>& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise); +#endif +void EstimateComfortNoise(const std::array<float, kFftLengthBy2Plus1>& N2, + uint32_t* seed, + FftData* lower_band_noise, + FftData* upper_band_noise); + +} // namespace aec3 + +// Generates the comfort noise. +class ComfortNoiseGenerator { + public: + ComfortNoiseGenerator(const EchoCanceller3Config& config, + Aec3Optimization optimization, + size_t num_capture_channels); + ComfortNoiseGenerator() = delete; + ~ComfortNoiseGenerator(); + ComfortNoiseGenerator(const ComfortNoiseGenerator&) = delete; + + // Computes the comfort noise. + void Compute(bool saturated_capture, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + capture_spectrum, + rtc::ArrayView<FftData> lower_band_noise, + rtc::ArrayView<FftData> upper_band_noise); + + // Returns the estimate of the background noise spectrum. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> NoiseSpectrum() + const { + return N2_; + } + + private: + const Aec3Optimization optimization_; + uint32_t seed_; + const size_t num_capture_channels_; + const float noise_floor_; + std::unique_ptr<std::vector<std::array<float, kFftLengthBy2Plus1>>> + N2_initial_; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_smoothed_; + std::vector<std::array<float, kFftLengthBy2Plus1>> N2_; + int N2_counter_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_COMFORT_NOISE_GENERATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc new file mode 100644 index 0000000000..a9da17559a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/comfort_noise_generator.h" + +#include <algorithm> +#include <numeric> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "rtc_base/random.h" +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { +namespace aec3 { +namespace { + +float Power(const FftData& N) { + std::array<float, kFftLengthBy2Plus1> N2; + N.Spectrum(Aec3Optimization::kNone, N2); + return std::accumulate(N2.begin(), N2.end(), 0.f) / N2.size(); +} + +} // namespace + +TEST(ComfortNoiseGenerator, CorrectLevel) { + constexpr size_t kNumChannels = 5; + EchoCanceller3Config config; + ComfortNoiseGenerator cng(config, DetectOptimization(), kNumChannels); + AecState aec_state(config, kNumChannels); + + std::vector<std::array<float, kFftLengthBy2Plus1>> N2(kNumChannels); + std::vector<FftData> n_lower(kNumChannels); + std::vector<FftData> n_upper(kNumChannels); + + for (size_t ch = 0; ch < kNumChannels; ++ch) { + N2[ch].fill(1000.f * 1000.f / (ch + 1)); + n_lower[ch].re.fill(0.f); + n_lower[ch].im.fill(0.f); + n_upper[ch].re.fill(0.f); + n_upper[ch].im.fill(0.f); + } + + // Ensure instantaneous updata to nonzero noise. + cng.Compute(false, N2, n_lower, n_upper); + + for (size_t ch = 0; ch < kNumChannels; ++ch) { + EXPECT_LT(0.f, Power(n_lower[ch])); + EXPECT_LT(0.f, Power(n_upper[ch])); + } + + for (int k = 0; k < 10000; ++k) { + cng.Compute(false, N2, n_lower, n_upper); + } + + for (size_t ch = 0; ch < kNumChannels; ++ch) { + EXPECT_NEAR(2.f * N2[ch][0], Power(n_lower[ch]), N2[ch][0] / 10.f); + EXPECT_NEAR(2.f * N2[ch][0], Power(n_upper[ch]), N2[ch][0] / 10.f); + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.cc b/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.cc new file mode 100644 index 0000000000..c55344da79 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.cc @@ -0,0 +1,71 @@ + +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/config_selector.h" + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +// Validates that the mono and the multichannel configs have compatible fields. +bool CompatibleConfigs(const EchoCanceller3Config& mono_config, + const EchoCanceller3Config& multichannel_config) { + if (mono_config.delay.fixed_capture_delay_samples != + multichannel_config.delay.fixed_capture_delay_samples) { + return false; + } + if (mono_config.filter.export_linear_aec_output != + multichannel_config.filter.export_linear_aec_output) { + return false; + } + if (mono_config.filter.high_pass_filter_echo_reference != + multichannel_config.filter.high_pass_filter_echo_reference) { + return false; + } + if (mono_config.multi_channel.detect_stereo_content != + multichannel_config.multi_channel.detect_stereo_content) { + return false; + } + if (mono_config.multi_channel.stereo_detection_timeout_threshold_seconds != + multichannel_config.multi_channel + .stereo_detection_timeout_threshold_seconds) { + return false; + } + return true; +} + +} // namespace + +ConfigSelector::ConfigSelector( + const EchoCanceller3Config& config, + const absl::optional<EchoCanceller3Config>& multichannel_config, + int num_render_input_channels) + : config_(config), multichannel_config_(multichannel_config) { + if (multichannel_config_.has_value()) { + RTC_DCHECK(CompatibleConfigs(config_, *multichannel_config_)); + } + + Update(!config_.multi_channel.detect_stereo_content && + num_render_input_channels > 1); + + RTC_DCHECK(active_config_); +} + +void ConfigSelector::Update(bool multichannel_content) { + if (multichannel_content && multichannel_config_.has_value()) { + active_config_ = &(*multichannel_config_); + } else { + active_config_ = &config_; + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.h b/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.h new file mode 100644 index 0000000000..3b3f94e5ac --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/config_selector.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_CONFIG_SELECTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_CONFIG_SELECTOR_H_ + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" + +namespace webrtc { + +// Selects the config to use. +class ConfigSelector { + public: + ConfigSelector( + const EchoCanceller3Config& config, + const absl::optional<EchoCanceller3Config>& multichannel_config, + int num_render_input_channels); + + // Updates the config selection based on the detection of multichannel + // content. + void Update(bool multichannel_content); + + const EchoCanceller3Config& active_config() const { return *active_config_; } + + private: + const EchoCanceller3Config config_; + const absl::optional<EchoCanceller3Config> multichannel_config_; + const EchoCanceller3Config* active_config_ = nullptr; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_CONFIG_SELECTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/config_selector_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/config_selector_unittest.cc new file mode 100644 index 0000000000..1826bfcace --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/config_selector_unittest.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/config_selector.h" + +#include <tuple> + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" +#include "test/gtest.h" + +namespace webrtc { + +class ConfigSelectorChannelsAndContentDetection + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<int, bool>> {}; + +INSTANTIATE_TEST_SUITE_P(ConfigSelectorMultiParameters, + ConfigSelectorChannelsAndContentDetection, + ::testing::Combine(::testing::Values(1, 2, 8), + ::testing::Values(false, true))); + +class ConfigSelectorChannels : public ::testing::Test, + public ::testing::WithParamInterface<int> {}; + +INSTANTIATE_TEST_SUITE_P(ConfigSelectorMultiParameters, + ConfigSelectorChannels, + ::testing::Values(1, 2, 8)); + +TEST_P(ConfigSelectorChannelsAndContentDetection, + MonoConfigIsSelectedWhenNoMultiChannelConfigPresent) { + const auto [num_channels, detect_stereo_content] = GetParam(); + EchoCanceller3Config config; + config.multi_channel.detect_stereo_content = detect_stereo_content; + absl::optional<EchoCanceller3Config> multichannel_config; + + config.delay.default_delay = config.delay.default_delay + 1; + const size_t custom_delay_value_in_config = config.delay.default_delay; + + ConfigSelector cs(config, multichannel_config, + /*num_render_input_channels=*/num_channels); + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_config); + + cs.Update(/*multichannel_content=*/false); + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_config); + + cs.Update(/*multichannel_content=*/true); + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_config); +} + +TEST_P(ConfigSelectorChannelsAndContentDetection, + CorrectInitialConfigIsSelected) { + const auto [num_channels, detect_stereo_content] = GetParam(); + EchoCanceller3Config config; + config.multi_channel.detect_stereo_content = detect_stereo_content; + absl::optional<EchoCanceller3Config> multichannel_config = config; + + config.delay.default_delay += 1; + const size_t custom_delay_value_in_config = config.delay.default_delay; + multichannel_config->delay.default_delay += 2; + const size_t custom_delay_value_in_multichannel_config = + multichannel_config->delay.default_delay; + + ConfigSelector cs(config, multichannel_config, + /*num_render_input_channels=*/num_channels); + + if (num_channels == 1 || detect_stereo_content) { + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_config); + } else { + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_multichannel_config); + } +} + +TEST_P(ConfigSelectorChannels, CorrectConfigUpdateBehavior) { + const int num_channels = GetParam(); + EchoCanceller3Config config; + config.multi_channel.detect_stereo_content = true; + absl::optional<EchoCanceller3Config> multichannel_config = config; + + config.delay.default_delay += 1; + const size_t custom_delay_value_in_config = config.delay.default_delay; + multichannel_config->delay.default_delay += 2; + const size_t custom_delay_value_in_multichannel_config = + multichannel_config->delay.default_delay; + + ConfigSelector cs(config, multichannel_config, + /*num_render_input_channels=*/num_channels); + + cs.Update(/*multichannel_content=*/false); + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_config); + + if (num_channels == 1) { + cs.Update(/*multichannel_content=*/false); + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_config); + } else { + cs.Update(/*multichannel_content=*/true); + EXPECT_EQ(cs.active_config().delay.default_delay, + custom_delay_value_in_multichannel_config); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/decimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/decimator.cc new file mode 100644 index 0000000000..bd03237ca0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/decimator.cc @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/decimator.h" + +#include <array> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +// signal.butter(2, 3400/8000.0, 'lowpass', analog=False) +const std::vector<CascadedBiQuadFilter::BiQuadParam> GetLowPassFilterDS2() { + return std::vector<CascadedBiQuadFilter::BiQuadParam>{ + {{-1.f, 0.f}, {0.13833231f, 0.40743176f}, 0.22711796393486466f}, + {{-1.f, 0.f}, {0.13833231f, 0.40743176f}, 0.22711796393486466f}, + {{-1.f, 0.f}, {0.13833231f, 0.40743176f}, 0.22711796393486466f}}; +} + +// signal.ellip(6, 1, 40, 1800/8000, btype='lowpass', analog=False) +const std::vector<CascadedBiQuadFilter::BiQuadParam> GetLowPassFilterDS4() { + return std::vector<CascadedBiQuadFilter::BiQuadParam>{ + {{-0.08873842f, 0.99605496f}, {0.75916227f, 0.23841065f}, 0.26250696827f}, + {{0.62273832f, 0.78243018f}, {0.74892112f, 0.5410152f}, 0.26250696827f}, + {{0.71107693f, 0.70311421f}, {0.74895534f, 0.63924616f}, 0.26250696827f}}; +} + +// signal.cheby1(1, 6, [1000/8000, 2000/8000], btype='bandpass', analog=False) +const std::vector<CascadedBiQuadFilter::BiQuadParam> GetBandPassFilterDS8() { + return std::vector<CascadedBiQuadFilter::BiQuadParam>{ + {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true}, + {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true}, + {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true}, + {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true}, + {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true}}; +} + +// signal.butter(2, 1000/8000.0, 'highpass', analog=False) +const std::vector<CascadedBiQuadFilter::BiQuadParam> GetHighPassFilter() { + return std::vector<CascadedBiQuadFilter::BiQuadParam>{ + {{1.f, 0.f}, {0.72712179f, 0.21296904f}, 0.7570763753338849f}}; +} + +const std::vector<CascadedBiQuadFilter::BiQuadParam> GetPassThroughFilter() { + return std::vector<CascadedBiQuadFilter::BiQuadParam>{}; +} +} // namespace + +Decimator::Decimator(size_t down_sampling_factor) + : down_sampling_factor_(down_sampling_factor), + anti_aliasing_filter_(down_sampling_factor_ == 4 + ? GetLowPassFilterDS4() + : (down_sampling_factor_ == 8 + ? GetBandPassFilterDS8() + : GetLowPassFilterDS2())), + noise_reduction_filter_(down_sampling_factor_ == 8 + ? GetPassThroughFilter() + : GetHighPassFilter()) { + RTC_DCHECK(down_sampling_factor_ == 2 || down_sampling_factor_ == 4 || + down_sampling_factor_ == 8); +} + +void Decimator::Decimate(rtc::ArrayView<const float> in, + rtc::ArrayView<float> out) { + RTC_DCHECK_EQ(kBlockSize, in.size()); + RTC_DCHECK_EQ(kBlockSize / down_sampling_factor_, out.size()); + std::array<float, kBlockSize> x; + + // Limit the frequency content of the signal to avoid aliasing. + anti_aliasing_filter_.Process(in, x); + + // Reduce the impact of near-end noise. + noise_reduction_filter_.Process(x); + + // Downsample the signal. + for (size_t j = 0, k = 0; j < out.size(); ++j, k += down_sampling_factor_) { + RTC_DCHECK_GT(kBlockSize, k); + out[j] = x[k]; + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/decimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/decimator.h new file mode 100644 index 0000000000..dbff3d9fff --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/decimator.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_H_ + +#include <array> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/utility/cascaded_biquad_filter.h" + +namespace webrtc { + +// Provides functionality for decimating a signal. +class Decimator { + public: + explicit Decimator(size_t down_sampling_factor); + + Decimator(const Decimator&) = delete; + Decimator& operator=(const Decimator&) = delete; + + // Downsamples the signal. + void Decimate(rtc::ArrayView<const float> in, rtc::ArrayView<float> out); + + private: + const size_t down_sampling_factor_; + CascadedBiQuadFilter anti_aliasing_filter_; + CascadedBiQuadFilter noise_reduction_filter_; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_DECIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/decimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/decimator_unittest.cc new file mode 100644 index 0000000000..e6f5ea0403 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/decimator_unittest.cc @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/decimator.h" + +#include <math.h> + +#include <algorithm> +#include <array> +#include <cmath> +#include <cstring> +#include <numeric> +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { + +std::string ProduceDebugText(int sample_rate_hz) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.Release(); +} + +constexpr size_t kDownSamplingFactors[] = {2, 4, 8}; +constexpr float kPi = 3.141592f; +constexpr size_t kNumStartupBlocks = 50; +constexpr size_t kNumBlocks = 1000; + +void ProduceDecimatedSinusoidalOutputPower(int sample_rate_hz, + size_t down_sampling_factor, + float sinusoidal_frequency_hz, + float* input_power, + float* output_power) { + float input[kBlockSize * kNumBlocks]; + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + // Produce a sinusoid of the specified frequency. + for (size_t k = 0; k < kBlockSize * kNumBlocks; ++k) { + input[k] = 32767.f * std::sin(2.f * kPi * sinusoidal_frequency_hz * k / + sample_rate_hz); + } + + Decimator decimator(down_sampling_factor); + std::vector<float> output(sub_block_size * kNumBlocks); + + for (size_t k = 0; k < kNumBlocks; ++k) { + std::vector<float> sub_block(sub_block_size); + decimator.Decimate( + rtc::ArrayView<const float>(&input[k * kBlockSize], kBlockSize), + sub_block); + + std::copy(sub_block.begin(), sub_block.end(), + output.begin() + k * sub_block_size); + } + + ASSERT_GT(kNumBlocks, kNumStartupBlocks); + rtc::ArrayView<const float> input_to_evaluate( + &input[kNumStartupBlocks * kBlockSize], + (kNumBlocks - kNumStartupBlocks) * kBlockSize); + rtc::ArrayView<const float> output_to_evaluate( + &output[kNumStartupBlocks * sub_block_size], + (kNumBlocks - kNumStartupBlocks) * sub_block_size); + *input_power = + std::inner_product(input_to_evaluate.begin(), input_to_evaluate.end(), + input_to_evaluate.begin(), 0.f) / + input_to_evaluate.size(); + *output_power = + std::inner_product(output_to_evaluate.begin(), output_to_evaluate.end(), + output_to_evaluate.begin(), 0.f) / + output_to_evaluate.size(); +} + +} // namespace + +// Verifies that there is little aliasing from upper frequencies in the +// downsampling. +TEST(Decimator, NoLeakageFromUpperFrequencies) { + float input_power; + float output_power; + for (auto rate : {16000, 32000, 48000}) { + for (auto down_sampling_factor : kDownSamplingFactors) { + ProduceDebugText(rate); + ProduceDecimatedSinusoidalOutputPower(rate, down_sampling_factor, + 3.f / 8.f * rate, &input_power, + &output_power); + EXPECT_GT(0.0001f * input_power, output_power); + } + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies the check for the input size. +TEST(DecimatorDeathTest, WrongInputSize) { + Decimator decimator(4); + std::vector<float> x(kBlockSize - 1, 0.f); + std::array<float, kBlockSize / 4> x_downsampled; + EXPECT_DEATH(decimator.Decimate(x, x_downsampled), ""); +} + +// Verifies the check for non-null output parameter. +TEST(DecimatorDeathTest, NullOutput) { + Decimator decimator(4); + std::vector<float> x(kBlockSize, 0.f); + EXPECT_DEATH(decimator.Decimate(x, nullptr), ""); +} + +// Verifies the check for the output size. +TEST(DecimatorDeathTest, WrongOutputSize) { + Decimator decimator(4); + std::vector<float> x(kBlockSize, 0.f); + std::array<float, kBlockSize / 4 - 1> x_downsampled; + EXPECT_DEATH(decimator.Decimate(x, x_downsampled), ""); +} + +// Verifies the check for the correct downsampling factor. +TEST(DecimatorDeathTest, CorrectDownSamplingFactor) { + EXPECT_DEATH(Decimator(3), ""); +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/delay_estimate.h b/third_party/libwebrtc/modules/audio_processing/aec3/delay_estimate.h new file mode 100644 index 0000000000..7838a0c255 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/delay_estimate.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_DELAY_ESTIMATE_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_DELAY_ESTIMATE_H_ + +#include <stddef.h> + +namespace webrtc { + +// Stores delay_estimates. +struct DelayEstimate { + enum class Quality { kCoarse, kRefined }; + + DelayEstimate(Quality quality, size_t delay) + : quality(quality), delay(delay) {} + + Quality quality; + size_t delay; + size_t blocks_since_last_change = 0; + size_t blocks_since_last_update = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_DELAY_ESTIMATE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.cc b/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.cc new file mode 100644 index 0000000000..40073cf615 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.cc @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/dominant_nearend_detector.h" + +#include <numeric> + +namespace webrtc { +DominantNearendDetector::DominantNearendDetector( + const EchoCanceller3Config::Suppressor::DominantNearendDetection& config, + size_t num_capture_channels) + : enr_threshold_(config.enr_threshold), + enr_exit_threshold_(config.enr_exit_threshold), + snr_threshold_(config.snr_threshold), + hold_duration_(config.hold_duration), + trigger_threshold_(config.trigger_threshold), + use_during_initial_phase_(config.use_during_initial_phase), + num_capture_channels_(num_capture_channels), + trigger_counters_(num_capture_channels_), + hold_counters_(num_capture_channels_) {} + +void DominantNearendDetector::Update( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + bool initial_state) { + nearend_state_ = false; + + auto low_frequency_energy = [](rtc::ArrayView<const float> spectrum) { + RTC_DCHECK_LE(16, spectrum.size()); + return std::accumulate(spectrum.begin() + 1, spectrum.begin() + 16, 0.f); + }; + + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + const float ne_sum = low_frequency_energy(nearend_spectrum[ch]); + const float echo_sum = low_frequency_energy(residual_echo_spectrum[ch]); + const float noise_sum = low_frequency_energy(comfort_noise_spectrum[ch]); + + // Detect strong active nearend if the nearend is sufficiently stronger than + // the echo and the nearend noise. + if ((!initial_state || use_during_initial_phase_) && + echo_sum < enr_threshold_ * ne_sum && + ne_sum > snr_threshold_ * noise_sum) { + if (++trigger_counters_[ch] >= trigger_threshold_) { + // After a period of strong active nearend activity, flag nearend mode. + hold_counters_[ch] = hold_duration_; + trigger_counters_[ch] = trigger_threshold_; + } + } else { + // Forget previously detected strong active nearend activity. + trigger_counters_[ch] = std::max(0, trigger_counters_[ch] - 1); + } + + // Exit nearend-state early at strong echo. + if (echo_sum > enr_exit_threshold_ * ne_sum && + echo_sum > snr_threshold_ * noise_sum) { + hold_counters_[ch] = 0; + } + + // Remain in any nearend mode for a certain duration. + hold_counters_[ch] = std::max(0, hold_counters_[ch] - 1); + nearend_state_ = nearend_state_ || hold_counters_[ch] > 0; + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.h b/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.h new file mode 100644 index 0000000000..046d1488d6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/dominant_nearend_detector.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_DOMINANT_NEAREND_DETECTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_DOMINANT_NEAREND_DETECTOR_H_ + +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/nearend_detector.h" + +namespace webrtc { +// Class for selecting whether the suppressor is in the nearend or echo state. +class DominantNearendDetector : public NearendDetector { + public: + DominantNearendDetector( + const EchoCanceller3Config::Suppressor::DominantNearendDetection& config, + size_t num_capture_channels); + + // Returns whether the current state is the nearend state. + bool IsNearendState() const override { return nearend_state_; } + + // Updates the state selection based on latest spectral estimates. + void Update(rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + bool initial_state) override; + + private: + const float enr_threshold_; + const float enr_exit_threshold_; + const float snr_threshold_; + const int hold_duration_; + const int trigger_threshold_; + const bool use_during_initial_phase_; + const size_t num_capture_channels_; + + bool nearend_state_ = false; + std::vector<int> trigger_counters_; + std::vector<int> hold_counters_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_DOMINANT_NEAREND_DETECTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.cc new file mode 100644 index 0000000000..c105911aa8 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.cc @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" + +#include <algorithm> + +namespace webrtc { + +DownsampledRenderBuffer::DownsampledRenderBuffer(size_t downsampled_buffer_size) + : size(static_cast<int>(downsampled_buffer_size)), + buffer(downsampled_buffer_size, 0.f) { + std::fill(buffer.begin(), buffer.end(), 0.f); +} + +DownsampledRenderBuffer::~DownsampledRenderBuffer() = default; + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.h new file mode 100644 index 0000000000..fbdc9b4e93 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/downsampled_render_buffer.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_DOWNSAMPLED_RENDER_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_DOWNSAMPLED_RENDER_BUFFER_H_ + +#include <stddef.h> + +#include <vector> + +#include "rtc_base/checks.h" + +namespace webrtc { + +// Holds the circular buffer of the downsampled render data. +struct DownsampledRenderBuffer { + explicit DownsampledRenderBuffer(size_t downsampled_buffer_size); + ~DownsampledRenderBuffer(); + + int IncIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index < size - 1 ? index + 1 : 0; + } + + int DecIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index > 0 ? index - 1 : size - 1; + } + + int OffsetIndex(int index, int offset) const { + RTC_DCHECK_GE(buffer.size(), offset); + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return (size + index + offset) % size; + } + + void UpdateWriteIndex(int offset) { write = OffsetIndex(write, offset); } + void IncWriteIndex() { write = IncIndex(write); } + void DecWriteIndex() { write = DecIndex(write); } + void UpdateReadIndex(int offset) { read = OffsetIndex(read, offset); } + void IncReadIndex() { read = IncIndex(read); } + void DecReadIndex() { read = DecIndex(read); } + + const int size; + std::vector<float> buffer; + int write = 0; + int read = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_DOWNSAMPLED_RENDER_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.cc new file mode 100644 index 0000000000..142a33d5e0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.cc @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_audibility.h" + +#include <algorithm> +#include <cmath> +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/block_buffer.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "modules/audio_processing/aec3/stationarity_estimator.h" + +namespace webrtc { + +EchoAudibility::EchoAudibility(bool use_render_stationarity_at_init) + : use_render_stationarity_at_init_(use_render_stationarity_at_init) { + Reset(); +} + +EchoAudibility::~EchoAudibility() = default; + +void EchoAudibility::Update(const RenderBuffer& render_buffer, + rtc::ArrayView<const float> average_reverb, + int delay_blocks, + bool external_delay_seen) { + UpdateRenderNoiseEstimator(render_buffer.GetSpectrumBuffer(), + render_buffer.GetBlockBuffer(), + external_delay_seen); + + if (external_delay_seen || use_render_stationarity_at_init_) { + UpdateRenderStationarityFlags(render_buffer, average_reverb, delay_blocks); + } +} + +void EchoAudibility::Reset() { + render_stationarity_.Reset(); + non_zero_render_seen_ = false; + render_spectrum_write_prev_ = absl::nullopt; +} + +void EchoAudibility::UpdateRenderStationarityFlags( + const RenderBuffer& render_buffer, + rtc::ArrayView<const float> average_reverb, + int min_channel_delay_blocks) { + const SpectrumBuffer& spectrum_buffer = render_buffer.GetSpectrumBuffer(); + int idx_at_delay = spectrum_buffer.OffsetIndex(spectrum_buffer.read, + min_channel_delay_blocks); + + int num_lookahead = render_buffer.Headroom() - min_channel_delay_blocks + 1; + num_lookahead = std::max(0, num_lookahead); + + render_stationarity_.UpdateStationarityFlags(spectrum_buffer, average_reverb, + idx_at_delay, num_lookahead); +} + +void EchoAudibility::UpdateRenderNoiseEstimator( + const SpectrumBuffer& spectrum_buffer, + const BlockBuffer& block_buffer, + bool external_delay_seen) { + if (!render_spectrum_write_prev_) { + render_spectrum_write_prev_ = spectrum_buffer.write; + render_block_write_prev_ = block_buffer.write; + return; + } + int render_spectrum_write_current = spectrum_buffer.write; + if (!non_zero_render_seen_ && !external_delay_seen) { + non_zero_render_seen_ = !IsRenderTooLow(block_buffer); + } + if (non_zero_render_seen_) { + for (int idx = render_spectrum_write_prev_.value(); + idx != render_spectrum_write_current; + idx = spectrum_buffer.DecIndex(idx)) { + render_stationarity_.UpdateNoiseEstimator(spectrum_buffer.buffer[idx]); + } + } + render_spectrum_write_prev_ = render_spectrum_write_current; +} + +bool EchoAudibility::IsRenderTooLow(const BlockBuffer& block_buffer) { + const int num_render_channels = + static_cast<int>(block_buffer.buffer[0].NumChannels()); + bool too_low = false; + const int render_block_write_current = block_buffer.write; + if (render_block_write_current == render_block_write_prev_) { + too_low = true; + } else { + for (int idx = render_block_write_prev_; idx != render_block_write_current; + idx = block_buffer.IncIndex(idx)) { + float max_abs_over_channels = 0.f; + for (int ch = 0; ch < num_render_channels; ++ch) { + rtc::ArrayView<const float, kBlockSize> block = + block_buffer.buffer[idx].View(/*band=*/0, /*channel=*/ch); + auto r = std::minmax_element(block.cbegin(), block.cend()); + float max_abs_channel = + std::max(std::fabs(*r.first), std::fabs(*r.second)); + max_abs_over_channels = + std::max(max_abs_over_channels, max_abs_channel); + } + if (max_abs_over_channels < 10.f) { + too_low = true; // Discards all blocks if one of them is too low. + break; + } + } + } + render_block_write_prev_ = render_block_write_current; + return too_low; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.h b/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.h new file mode 100644 index 0000000000..b9d6f87d2a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_audibility.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ECHO_AUDIBILITY_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ECHO_AUDIBILITY_H_ + +#include <stddef.h> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/block_buffer.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "modules/audio_processing/aec3/stationarity_estimator.h" + +namespace webrtc { + +class EchoAudibility { + public: + explicit EchoAudibility(bool use_render_stationarity_at_init); + ~EchoAudibility(); + + EchoAudibility(const EchoAudibility&) = delete; + EchoAudibility& operator=(const EchoAudibility&) = delete; + + // Feed new render data to the echo audibility estimator. + void Update(const RenderBuffer& render_buffer, + rtc::ArrayView<const float> average_reverb, + int min_channel_delay_blocks, + bool external_delay_seen); + // Get the residual echo scaling. + void GetResidualEchoScaling(bool filter_has_had_time_to_converge, + rtc::ArrayView<float> residual_scaling) const { + for (size_t band = 0; band < residual_scaling.size(); ++band) { + if (render_stationarity_.IsBandStationary(band) && + (filter_has_had_time_to_converge || + use_render_stationarity_at_init_)) { + residual_scaling[band] = 0.f; + } else { + residual_scaling[band] = 1.0f; + } + } + } + + // Returns true if the current render block is estimated as stationary. + bool IsBlockStationary() const { + return render_stationarity_.IsBlockStationary(); + } + + private: + // Reset the EchoAudibility class. + void Reset(); + + // Updates the render stationarity flags for the current frame. + void UpdateRenderStationarityFlags(const RenderBuffer& render_buffer, + rtc::ArrayView<const float> average_reverb, + int delay_blocks); + + // Updates the noise estimator with the new render data since the previous + // call to this method. + void UpdateRenderNoiseEstimator(const SpectrumBuffer& spectrum_buffer, + const BlockBuffer& block_buffer, + bool external_delay_seen); + + // Returns a bool being true if the render signal contains just close to zero + // values. + bool IsRenderTooLow(const BlockBuffer& block_buffer); + + absl::optional<int> render_spectrum_write_prev_; + int render_block_write_prev_; + bool non_zero_render_seen_; + const bool use_render_stationarity_at_init_; + StationarityEstimator render_stationarity_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ECHO_AUDIBILITY_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.cc new file mode 100644 index 0000000000..e8e2175994 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.cc @@ -0,0 +1,992 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/echo_canceller3.h" + +#include <algorithm> +#include <utility> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/high_pass_filter.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/experiments/field_trial_parser.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { + +namespace { + +enum class EchoCanceller3ApiCall { kCapture, kRender }; + +bool DetectSaturation(rtc::ArrayView<const float> y) { + for (size_t k = 0; k < y.size(); ++k) { + if (y[k] >= 32700.0f || y[k] <= -32700.0f) { + return true; + } + } + return false; +} + +// Retrieves a value from a field trial if it is available. If no value is +// present, the default value is returned. If the retrieved value is beyond the +// specified limits, the default value is returned instead. +void RetrieveFieldTrialValue(absl::string_view trial_name, + float min, + float max, + float* value_to_update) { + const std::string field_trial_str = field_trial::FindFullName(trial_name); + + FieldTrialParameter<double> field_trial_param(/*key=*/"", *value_to_update); + + ParseFieldTrial({&field_trial_param}, field_trial_str); + float field_trial_value = static_cast<float>(field_trial_param.Get()); + + if (field_trial_value >= min && field_trial_value <= max && + field_trial_value != *value_to_update) { + RTC_LOG(LS_INFO) << "Key " << trial_name + << " changing AEC3 parameter value from " + << *value_to_update << " to " << field_trial_value; + *value_to_update = field_trial_value; + } +} + +void RetrieveFieldTrialValue(absl::string_view trial_name, + int min, + int max, + int* value_to_update) { + const std::string field_trial_str = field_trial::FindFullName(trial_name); + + FieldTrialParameter<int> field_trial_param(/*key=*/"", *value_to_update); + + ParseFieldTrial({&field_trial_param}, field_trial_str); + float field_trial_value = field_trial_param.Get(); + + if (field_trial_value >= min && field_trial_value <= max && + field_trial_value != *value_to_update) { + RTC_LOG(LS_INFO) << "Key " << trial_name + << " changing AEC3 parameter value from " + << *value_to_update << " to " << field_trial_value; + *value_to_update = field_trial_value; + } +} + +void FillSubFrameView( + AudioBuffer* frame, + size_t sub_frame_index, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame_view) { + RTC_DCHECK_GE(1, sub_frame_index); + RTC_DCHECK_LE(0, sub_frame_index); + RTC_DCHECK_EQ(frame->num_bands(), sub_frame_view->size()); + RTC_DCHECK_EQ(frame->num_channels(), (*sub_frame_view)[0].size()); + for (size_t band = 0; band < sub_frame_view->size(); ++band) { + for (size_t channel = 0; channel < (*sub_frame_view)[0].size(); ++channel) { + (*sub_frame_view)[band][channel] = rtc::ArrayView<float>( + &frame->split_bands(channel)[band][sub_frame_index * kSubFrameLength], + kSubFrameLength); + } + } +} + +void FillSubFrameView( + bool proper_downmix_needed, + std::vector<std::vector<std::vector<float>>>* frame, + size_t sub_frame_index, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame_view) { + RTC_DCHECK_GE(1, sub_frame_index); + RTC_DCHECK_EQ(frame->size(), sub_frame_view->size()); + const size_t frame_num_channels = (*frame)[0].size(); + const size_t sub_frame_num_channels = (*sub_frame_view)[0].size(); + if (frame_num_channels > sub_frame_num_channels) { + RTC_DCHECK_EQ(sub_frame_num_channels, 1u); + if (proper_downmix_needed) { + // When a proper downmix is needed (which is the case when proper stereo + // is present in the echo reference signal but the echo canceller does the + // processing in mono) downmix the echo reference by averaging the channel + // content (otherwise downmixing is done by selecting channel 0). + for (size_t band = 0; band < frame->size(); ++band) { + for (size_t ch = 1; ch < frame_num_channels; ++ch) { + for (size_t k = 0; k < kSubFrameLength; ++k) { + (*frame)[band][/*channel=*/0] + [sub_frame_index * kSubFrameLength + k] += + (*frame)[band][ch][sub_frame_index * kSubFrameLength + k]; + } + } + const float one_by_num_channels = 1.0f / frame_num_channels; + for (size_t k = 0; k < kSubFrameLength; ++k) { + (*frame)[band][/*channel=*/0][sub_frame_index * kSubFrameLength + + k] *= one_by_num_channels; + } + } + } + for (size_t band = 0; band < frame->size(); ++band) { + (*sub_frame_view)[band][/*channel=*/0] = rtc::ArrayView<float>( + &(*frame)[band][/*channel=*/0][sub_frame_index * kSubFrameLength], + kSubFrameLength); + } + } else { + RTC_DCHECK_EQ(frame_num_channels, sub_frame_num_channels); + for (size_t band = 0; band < frame->size(); ++band) { + for (size_t channel = 0; channel < (*frame)[band].size(); ++channel) { + (*sub_frame_view)[band][channel] = rtc::ArrayView<float>( + &(*frame)[band][channel][sub_frame_index * kSubFrameLength], + kSubFrameLength); + } + } + } +} + +void ProcessCaptureFrameContent( + AudioBuffer* linear_output, + AudioBuffer* capture, + bool level_change, + bool aec_reference_is_downmixed_stereo, + bool saturated_microphone_signal, + size_t sub_frame_index, + FrameBlocker* capture_blocker, + BlockFramer* linear_output_framer, + BlockFramer* output_framer, + BlockProcessor* block_processor, + Block* linear_output_block, + std::vector<std::vector<rtc::ArrayView<float>>>* + linear_output_sub_frame_view, + Block* capture_block, + std::vector<std::vector<rtc::ArrayView<float>>>* capture_sub_frame_view) { + FillSubFrameView(capture, sub_frame_index, capture_sub_frame_view); + + if (linear_output) { + RTC_DCHECK(linear_output_framer); + RTC_DCHECK(linear_output_block); + RTC_DCHECK(linear_output_sub_frame_view); + FillSubFrameView(linear_output, sub_frame_index, + linear_output_sub_frame_view); + } + + capture_blocker->InsertSubFrameAndExtractBlock(*capture_sub_frame_view, + capture_block); + block_processor->ProcessCapture( + /*echo_path_gain_change=*/level_change || + aec_reference_is_downmixed_stereo, + saturated_microphone_signal, linear_output_block, capture_block); + output_framer->InsertBlockAndExtractSubFrame(*capture_block, + capture_sub_frame_view); + + if (linear_output) { + RTC_DCHECK(linear_output_framer); + linear_output_framer->InsertBlockAndExtractSubFrame( + *linear_output_block, linear_output_sub_frame_view); + } +} + +void ProcessRemainingCaptureFrameContent(bool level_change, + bool aec_reference_is_downmixed_stereo, + bool saturated_microphone_signal, + FrameBlocker* capture_blocker, + BlockFramer* linear_output_framer, + BlockFramer* output_framer, + BlockProcessor* block_processor, + Block* linear_output_block, + Block* block) { + if (!capture_blocker->IsBlockAvailable()) { + return; + } + + capture_blocker->ExtractBlock(block); + block_processor->ProcessCapture( + /*echo_path_gain_change=*/level_change || + aec_reference_is_downmixed_stereo, + saturated_microphone_signal, linear_output_block, block); + output_framer->InsertBlock(*block); + + if (linear_output_framer) { + RTC_DCHECK(linear_output_block); + linear_output_framer->InsertBlock(*linear_output_block); + } +} + +void BufferRenderFrameContent( + bool proper_downmix_needed, + std::vector<std::vector<std::vector<float>>>* render_frame, + size_t sub_frame_index, + FrameBlocker* render_blocker, + BlockProcessor* block_processor, + Block* block, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame_view) { + FillSubFrameView(proper_downmix_needed, render_frame, sub_frame_index, + sub_frame_view); + render_blocker->InsertSubFrameAndExtractBlock(*sub_frame_view, block); + block_processor->BufferRender(*block); +} + +void BufferRemainingRenderFrameContent(FrameBlocker* render_blocker, + BlockProcessor* block_processor, + Block* block) { + if (!render_blocker->IsBlockAvailable()) { + return; + } + render_blocker->ExtractBlock(block); + block_processor->BufferRender(*block); +} + +void CopyBufferIntoFrame(const AudioBuffer& buffer, + size_t num_bands, + size_t num_channels, + std::vector<std::vector<std::vector<float>>>* frame) { + RTC_DCHECK_EQ(num_bands, frame->size()); + RTC_DCHECK_EQ(num_channels, (*frame)[0].size()); + RTC_DCHECK_EQ(AudioBuffer::kSplitBandSize, (*frame)[0][0].size()); + for (size_t band = 0; band < num_bands; ++band) { + for (size_t channel = 0; channel < num_channels; ++channel) { + rtc::ArrayView<const float> buffer_view( + &buffer.split_bands_const(channel)[band][0], + AudioBuffer::kSplitBandSize); + std::copy(buffer_view.begin(), buffer_view.end(), + (*frame)[band][channel].begin()); + } + } +} + +} // namespace + +// TODO(webrtc:5298): Move this to a separate file. +EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) { + EchoCanceller3Config adjusted_cfg = config; + + if (field_trial::IsEnabled("WebRTC-Aec3StereoContentDetectionKillSwitch")) { + adjusted_cfg.multi_channel.detect_stereo_content = false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3AntiHowlingMinimizationKillSwitch")) { + adjusted_cfg.suppressor.high_bands_suppression + .anti_howling_activation_threshold = 25.f; + adjusted_cfg.suppressor.high_bands_suppression.anti_howling_gain = 0.01f; + } + + if (field_trial::IsEnabled("WebRTC-Aec3UseShortConfigChangeDuration")) { + adjusted_cfg.filter.config_change_duration_blocks = 10; + } + + if (field_trial::IsEnabled("WebRTC-Aec3UseZeroInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = 0.f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3UseDot1SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = .1f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3UseDot2SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = .2f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3UseDot3SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = .3f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3UseDot6SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = .6f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3UseDot9SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = .9f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3Use1Dot2SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = 1.2f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3Use1Dot6SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = 1.6f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3Use2Dot0SecondsInitialStateDuration")) { + adjusted_cfg.filter.initial_state_seconds = 2.0f; + } + + if (field_trial::IsEnabled("WebRTC-Aec3HighPassFilterEchoReference")) { + adjusted_cfg.filter.high_pass_filter_echo_reference = true; + } + + if (field_trial::IsEnabled("WebRTC-Aec3EchoSaturationDetectionKillSwitch")) { + adjusted_cfg.ep_strength.echo_can_saturate = false; + } + + const std::string use_nearend_reverb_len_tunings = + field_trial::FindFullName("WebRTC-Aec3UseNearendReverbLen"); + FieldTrialParameter<double> nearend_reverb_default_len( + "default_len", adjusted_cfg.ep_strength.default_len); + FieldTrialParameter<double> nearend_reverb_nearend_len( + "nearend_len", adjusted_cfg.ep_strength.nearend_len); + + ParseFieldTrial({&nearend_reverb_default_len, &nearend_reverb_nearend_len}, + use_nearend_reverb_len_tunings); + float default_len = static_cast<float>(nearend_reverb_default_len.Get()); + float nearend_len = static_cast<float>(nearend_reverb_nearend_len.Get()); + if (default_len > -1 && default_len < 1 && nearend_len > -1 && + nearend_len < 1) { + adjusted_cfg.ep_strength.default_len = + static_cast<float>(nearend_reverb_default_len.Get()); + adjusted_cfg.ep_strength.nearend_len = + static_cast<float>(nearend_reverb_nearend_len.Get()); + } + + if (field_trial::IsEnabled("WebRTC-Aec3ConservativeTailFreqResponse")) { + adjusted_cfg.ep_strength.use_conservative_tail_frequency_response = true; + } + + if (field_trial::IsDisabled("WebRTC-Aec3ConservativeTailFreqResponse")) { + adjusted_cfg.ep_strength.use_conservative_tail_frequency_response = false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3ShortHeadroomKillSwitch")) { + // Two blocks headroom. + adjusted_cfg.delay.delay_headroom_samples = kBlockSize * 2; + } + + if (field_trial::IsEnabled("WebRTC-Aec3ClampInstQualityToZeroKillSwitch")) { + adjusted_cfg.erle.clamp_quality_estimate_to_zero = false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3ClampInstQualityToOneKillSwitch")) { + adjusted_cfg.erle.clamp_quality_estimate_to_one = false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3OnsetDetectionKillSwitch")) { + adjusted_cfg.erle.onset_detection = false; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceRenderDelayEstimationDownmixing")) { + adjusted_cfg.delay.render_alignment_mixing.downmix = true; + adjusted_cfg.delay.render_alignment_mixing.adaptive_selection = false; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceCaptureDelayEstimationDownmixing")) { + adjusted_cfg.delay.capture_alignment_mixing.downmix = true; + adjusted_cfg.delay.capture_alignment_mixing.adaptive_selection = false; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceCaptureDelayEstimationLeftRightPrioritization")) { + adjusted_cfg.delay.capture_alignment_mixing.prefer_first_two_channels = + true; + } + + if (field_trial::IsEnabled( + "WebRTC-" + "Aec3RenderDelayEstimationLeftRightPrioritizationKillSwitch")) { + adjusted_cfg.delay.capture_alignment_mixing.prefer_first_two_channels = + false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) { + adjusted_cfg.delay.detect_pre_echo = true; + } + + if (field_trial::IsDisabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) { + adjusted_cfg.delay.detect_pre_echo = false; + } + + if (field_trial::IsEnabled("WebRTC-Aec3SensitiveDominantNearendActivation")) { + adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = 0.5f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3VerySensitiveDominantNearendActivation")) { + adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = 0.75f; + } + + if (field_trial::IsEnabled("WebRTC-Aec3TransparentAntiHowlingGain")) { + adjusted_cfg.suppressor.high_bands_suppression.anti_howling_gain = 1.f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceMoreTransparentNormalSuppressorTuning")) { + adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_transparent = 0.4f; + adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_suppress = 0.5f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceMoreTransparentNearendSuppressorTuning")) { + adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_transparent = 1.29f; + adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_suppress = 1.3f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceMoreTransparentNormalSuppressorHfTuning")) { + adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_transparent = 0.3f; + adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_suppress = 0.4f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceMoreTransparentNearendSuppressorHfTuning")) { + adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_transparent = 1.09f; + adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_suppress = 1.1f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceRapidlyAdjustingNormalSuppressorTunings")) { + adjusted_cfg.suppressor.normal_tuning.max_inc_factor = 2.5f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceRapidlyAdjustingNearendSuppressorTunings")) { + adjusted_cfg.suppressor.nearend_tuning.max_inc_factor = 2.5f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceSlowlyAdjustingNormalSuppressorTunings")) { + adjusted_cfg.suppressor.normal_tuning.max_dec_factor_lf = .2f; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceSlowlyAdjustingNearendSuppressorTunings")) { + adjusted_cfg.suppressor.nearend_tuning.max_dec_factor_lf = .2f; + } + + if (field_trial::IsEnabled("WebRTC-Aec3EnforceConservativeHfSuppression")) { + adjusted_cfg.suppressor.conservative_hf_suppression = true; + } + + if (field_trial::IsEnabled("WebRTC-Aec3EnforceStationarityProperties")) { + adjusted_cfg.echo_audibility.use_stationarity_properties = true; + } + + if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceStationarityPropertiesAtInit")) { + adjusted_cfg.echo_audibility.use_stationarity_properties_at_init = true; + } + + if (field_trial::IsEnabled("WebRTC-Aec3EnforceLowActiveRenderLimit")) { + adjusted_cfg.render_levels.active_render_limit = 50.f; + } else if (field_trial::IsEnabled( + "WebRTC-Aec3EnforceVeryLowActiveRenderLimit")) { + adjusted_cfg.render_levels.active_render_limit = 30.f; + } + + if (field_trial::IsEnabled("WebRTC-Aec3NonlinearModeReverbKillSwitch")) { + adjusted_cfg.echo_model.model_reverb_in_nonlinear_mode = false; + } + + // Field-trial based override for the whole suppressor tuning. + const std::string suppressor_tuning_override_trial_name = + field_trial::FindFullName("WebRTC-Aec3SuppressorTuningOverride"); + + FieldTrialParameter<double> nearend_tuning_mask_lf_enr_transparent( + "nearend_tuning_mask_lf_enr_transparent", + adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_transparent); + FieldTrialParameter<double> nearend_tuning_mask_lf_enr_suppress( + "nearend_tuning_mask_lf_enr_suppress", + adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_suppress); + FieldTrialParameter<double> nearend_tuning_mask_hf_enr_transparent( + "nearend_tuning_mask_hf_enr_transparent", + adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_transparent); + FieldTrialParameter<double> nearend_tuning_mask_hf_enr_suppress( + "nearend_tuning_mask_hf_enr_suppress", + adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_suppress); + FieldTrialParameter<double> nearend_tuning_max_inc_factor( + "nearend_tuning_max_inc_factor", + adjusted_cfg.suppressor.nearend_tuning.max_inc_factor); + FieldTrialParameter<double> nearend_tuning_max_dec_factor_lf( + "nearend_tuning_max_dec_factor_lf", + adjusted_cfg.suppressor.nearend_tuning.max_dec_factor_lf); + FieldTrialParameter<double> normal_tuning_mask_lf_enr_transparent( + "normal_tuning_mask_lf_enr_transparent", + adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_transparent); + FieldTrialParameter<double> normal_tuning_mask_lf_enr_suppress( + "normal_tuning_mask_lf_enr_suppress", + adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_suppress); + FieldTrialParameter<double> normal_tuning_mask_hf_enr_transparent( + "normal_tuning_mask_hf_enr_transparent", + adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_transparent); + FieldTrialParameter<double> normal_tuning_mask_hf_enr_suppress( + "normal_tuning_mask_hf_enr_suppress", + adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_suppress); + FieldTrialParameter<double> normal_tuning_max_inc_factor( + "normal_tuning_max_inc_factor", + adjusted_cfg.suppressor.normal_tuning.max_inc_factor); + FieldTrialParameter<double> normal_tuning_max_dec_factor_lf( + "normal_tuning_max_dec_factor_lf", + adjusted_cfg.suppressor.normal_tuning.max_dec_factor_lf); + FieldTrialParameter<double> dominant_nearend_detection_enr_threshold( + "dominant_nearend_detection_enr_threshold", + adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold); + FieldTrialParameter<double> dominant_nearend_detection_enr_exit_threshold( + "dominant_nearend_detection_enr_exit_threshold", + adjusted_cfg.suppressor.dominant_nearend_detection.enr_exit_threshold); + FieldTrialParameter<double> dominant_nearend_detection_snr_threshold( + "dominant_nearend_detection_snr_threshold", + adjusted_cfg.suppressor.dominant_nearend_detection.snr_threshold); + FieldTrialParameter<int> dominant_nearend_detection_hold_duration( + "dominant_nearend_detection_hold_duration", + adjusted_cfg.suppressor.dominant_nearend_detection.hold_duration); + FieldTrialParameter<int> dominant_nearend_detection_trigger_threshold( + "dominant_nearend_detection_trigger_threshold", + adjusted_cfg.suppressor.dominant_nearend_detection.trigger_threshold); + + ParseFieldTrial( + {&nearend_tuning_mask_lf_enr_transparent, + &nearend_tuning_mask_lf_enr_suppress, + &nearend_tuning_mask_hf_enr_transparent, + &nearend_tuning_mask_hf_enr_suppress, &nearend_tuning_max_inc_factor, + &nearend_tuning_max_dec_factor_lf, + &normal_tuning_mask_lf_enr_transparent, + &normal_tuning_mask_lf_enr_suppress, + &normal_tuning_mask_hf_enr_transparent, + &normal_tuning_mask_hf_enr_suppress, &normal_tuning_max_inc_factor, + &normal_tuning_max_dec_factor_lf, + &dominant_nearend_detection_enr_threshold, + &dominant_nearend_detection_enr_exit_threshold, + &dominant_nearend_detection_snr_threshold, + &dominant_nearend_detection_hold_duration, + &dominant_nearend_detection_trigger_threshold}, + suppressor_tuning_override_trial_name); + + adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_transparent = + static_cast<float>(nearend_tuning_mask_lf_enr_transparent.Get()); + adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_suppress = + static_cast<float>(nearend_tuning_mask_lf_enr_suppress.Get()); + adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_transparent = + static_cast<float>(nearend_tuning_mask_hf_enr_transparent.Get()); + adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_suppress = + static_cast<float>(nearend_tuning_mask_hf_enr_suppress.Get()); + adjusted_cfg.suppressor.nearend_tuning.max_inc_factor = + static_cast<float>(nearend_tuning_max_inc_factor.Get()); + adjusted_cfg.suppressor.nearend_tuning.max_dec_factor_lf = + static_cast<float>(nearend_tuning_max_dec_factor_lf.Get()); + adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_transparent = + static_cast<float>(normal_tuning_mask_lf_enr_transparent.Get()); + adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_suppress = + static_cast<float>(normal_tuning_mask_lf_enr_suppress.Get()); + adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_transparent = + static_cast<float>(normal_tuning_mask_hf_enr_transparent.Get()); + adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_suppress = + static_cast<float>(normal_tuning_mask_hf_enr_suppress.Get()); + adjusted_cfg.suppressor.normal_tuning.max_inc_factor = + static_cast<float>(normal_tuning_max_inc_factor.Get()); + adjusted_cfg.suppressor.normal_tuning.max_dec_factor_lf = + static_cast<float>(normal_tuning_max_dec_factor_lf.Get()); + adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = + static_cast<float>(dominant_nearend_detection_enr_threshold.Get()); + adjusted_cfg.suppressor.dominant_nearend_detection.enr_exit_threshold = + static_cast<float>(dominant_nearend_detection_enr_exit_threshold.Get()); + adjusted_cfg.suppressor.dominant_nearend_detection.snr_threshold = + static_cast<float>(dominant_nearend_detection_snr_threshold.Get()); + adjusted_cfg.suppressor.dominant_nearend_detection.hold_duration = + dominant_nearend_detection_hold_duration.Get(); + adjusted_cfg.suppressor.dominant_nearend_detection.trigger_threshold = + dominant_nearend_detection_trigger_threshold.Get(); + + // Field trial-based overrides of individual suppressor parameters. + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNearendLfMaskTransparentOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_transparent); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNearendLfMaskSuppressOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.nearend_tuning.mask_lf.enr_suppress); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNearendHfMaskTransparentOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_transparent); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNearendHfMaskSuppressOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.nearend_tuning.mask_hf.enr_suppress); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNearendMaxIncFactorOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.nearend_tuning.max_inc_factor); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNearendMaxDecFactorLfOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.nearend_tuning.max_dec_factor_lf); + + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNormalLfMaskTransparentOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_transparent); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNormalLfMaskSuppressOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.normal_tuning.mask_lf.enr_suppress); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNormalHfMaskTransparentOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_transparent); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNormalHfMaskSuppressOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.normal_tuning.mask_hf.enr_suppress); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNormalMaxIncFactorOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.normal_tuning.max_inc_factor); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorNormalMaxDecFactorLfOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.normal_tuning.max_dec_factor_lf); + + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorDominantNearendEnrThresholdOverride", 0.f, 100.f, + &adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorDominantNearendEnrExitThresholdOverride", 0.f, + 100.f, + &adjusted_cfg.suppressor.dominant_nearend_detection.enr_exit_threshold); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorDominantNearendSnrThresholdOverride", 0.f, 100.f, + &adjusted_cfg.suppressor.dominant_nearend_detection.snr_threshold); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorDominantNearendHoldDurationOverride", 0, 1000, + &adjusted_cfg.suppressor.dominant_nearend_detection.hold_duration); + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorDominantNearendTriggerThresholdOverride", 0, 1000, + &adjusted_cfg.suppressor.dominant_nearend_detection.trigger_threshold); + + RetrieveFieldTrialValue( + "WebRTC-Aec3SuppressorAntiHowlingGainOverride", 0.f, 10.f, + &adjusted_cfg.suppressor.high_bands_suppression.anti_howling_gain); + + // Field trial-based overrides of individual delay estimator parameters. + RetrieveFieldTrialValue("WebRTC-Aec3DelayEstimateSmoothingOverride", 0.f, 1.f, + &adjusted_cfg.delay.delay_estimate_smoothing); + RetrieveFieldTrialValue( + "WebRTC-Aec3DelayEstimateSmoothingDelayFoundOverride", 0.f, 1.f, + &adjusted_cfg.delay.delay_estimate_smoothing_delay_found); + + return adjusted_cfg; +} + +class EchoCanceller3::RenderWriter { + public: + RenderWriter(ApmDataDumper* data_dumper, + const EchoCanceller3Config& config, + SwapQueue<std::vector<std::vector<std::vector<float>>>, + Aec3RenderQueueItemVerifier>* render_transfer_queue, + size_t num_bands, + size_t num_channels); + + RenderWriter() = delete; + RenderWriter(const RenderWriter&) = delete; + RenderWriter& operator=(const RenderWriter&) = delete; + + ~RenderWriter(); + void Insert(const AudioBuffer& input); + + private: + ApmDataDumper* data_dumper_; + const size_t num_bands_; + const size_t num_channels_; + std::unique_ptr<HighPassFilter> high_pass_filter_; + std::vector<std::vector<std::vector<float>>> render_queue_input_frame_; + SwapQueue<std::vector<std::vector<std::vector<float>>>, + Aec3RenderQueueItemVerifier>* render_transfer_queue_; +}; + +EchoCanceller3::RenderWriter::RenderWriter( + ApmDataDumper* data_dumper, + const EchoCanceller3Config& config, + SwapQueue<std::vector<std::vector<std::vector<float>>>, + Aec3RenderQueueItemVerifier>* render_transfer_queue, + size_t num_bands, + size_t num_channels) + : data_dumper_(data_dumper), + num_bands_(num_bands), + num_channels_(num_channels), + render_queue_input_frame_( + num_bands_, + std::vector<std::vector<float>>( + num_channels_, + std::vector<float>(AudioBuffer::kSplitBandSize, 0.f))), + render_transfer_queue_(render_transfer_queue) { + RTC_DCHECK(data_dumper); + if (config.filter.high_pass_filter_echo_reference) { + high_pass_filter_ = std::make_unique<HighPassFilter>(16000, num_channels); + } +} + +EchoCanceller3::RenderWriter::~RenderWriter() = default; + +void EchoCanceller3::RenderWriter::Insert(const AudioBuffer& input) { + RTC_DCHECK_EQ(AudioBuffer::kSplitBandSize, input.num_frames_per_band()); + RTC_DCHECK_EQ(num_bands_, input.num_bands()); + RTC_DCHECK_EQ(num_channels_, input.num_channels()); + + // TODO(bugs.webrtc.org/8759) Temporary work-around. + if (num_bands_ != input.num_bands()) + return; + + data_dumper_->DumpWav("aec3_render_input", AudioBuffer::kSplitBandSize, + &input.split_bands_const(0)[0][0], 16000, 1); + + CopyBufferIntoFrame(input, num_bands_, num_channels_, + &render_queue_input_frame_); + if (high_pass_filter_) { + high_pass_filter_->Process(&render_queue_input_frame_[0]); + } + + static_cast<void>(render_transfer_queue_->Insert(&render_queue_input_frame_)); +} + +std::atomic<int> EchoCanceller3::instance_count_(0); + +EchoCanceller3::EchoCanceller3( + const EchoCanceller3Config& config, + const absl::optional<EchoCanceller3Config>& multichannel_config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + config_(AdjustConfig(config)), + sample_rate_hz_(sample_rate_hz), + num_bands_(NumBandsForRate(sample_rate_hz_)), + num_render_input_channels_(num_render_channels), + num_capture_channels_(num_capture_channels), + config_selector_(AdjustConfig(config), + multichannel_config, + num_render_input_channels_), + multichannel_content_detector_( + config_selector_.active_config().multi_channel.detect_stereo_content, + num_render_input_channels_, + config_selector_.active_config() + .multi_channel.stereo_detection_threshold, + config_selector_.active_config() + .multi_channel.stereo_detection_timeout_threshold_seconds, + config_selector_.active_config() + .multi_channel.stereo_detection_hysteresis_seconds), + output_framer_(num_bands_, num_capture_channels_), + capture_blocker_(num_bands_, num_capture_channels_), + render_transfer_queue_( + kRenderTransferQueueSizeFrames, + std::vector<std::vector<std::vector<float>>>( + num_bands_, + std::vector<std::vector<float>>( + num_render_input_channels_, + std::vector<float>(AudioBuffer::kSplitBandSize, 0.f))), + Aec3RenderQueueItemVerifier(num_bands_, + num_render_input_channels_, + AudioBuffer::kSplitBandSize)), + render_queue_output_frame_( + num_bands_, + std::vector<std::vector<float>>( + num_render_input_channels_, + std::vector<float>(AudioBuffer::kSplitBandSize, 0.f))), + render_block_(num_bands_, num_render_input_channels_), + capture_block_(num_bands_, num_capture_channels_), + capture_sub_frame_view_( + num_bands_, + std::vector<rtc::ArrayView<float>>(num_capture_channels_)) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); + + if (config_selector_.active_config().delay.fixed_capture_delay_samples > 0) { + block_delay_buffer_.reset(new BlockDelayBuffer( + num_capture_channels_, num_bands_, AudioBuffer::kSplitBandSize, + config_.delay.fixed_capture_delay_samples)); + } + + render_writer_.reset(new RenderWriter( + data_dumper_.get(), config_selector_.active_config(), + &render_transfer_queue_, num_bands_, num_render_input_channels_)); + + RTC_DCHECK_EQ(num_bands_, std::max(sample_rate_hz_, 16000) / 16000); + RTC_DCHECK_GE(kMaxNumBands, num_bands_); + + if (config_selector_.active_config().filter.export_linear_aec_output) { + linear_output_framer_.reset( + new BlockFramer(/*num_bands=*/1, num_capture_channels_)); + linear_output_block_ = + std::make_unique<Block>(/*num_bands=*/1, num_capture_channels_), + linear_output_sub_frame_view_ = + std::vector<std::vector<rtc::ArrayView<float>>>( + 1, std::vector<rtc::ArrayView<float>>(num_capture_channels_)); + } + + Initialize(); + + RTC_LOG(LS_INFO) << "AEC3 created with sample rate: " << sample_rate_hz_ + << " Hz, num render channels: " << num_render_input_channels_ + << ", num capture channels: " << num_capture_channels_; +} + +EchoCanceller3::~EchoCanceller3() = default; + +void EchoCanceller3::Initialize() { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + + num_render_channels_to_aec_ = + multichannel_content_detector_.IsProperMultiChannelContentDetected() + ? num_render_input_channels_ + : 1; + + config_selector_.Update( + multichannel_content_detector_.IsProperMultiChannelContentDetected()); + + render_block_.SetNumChannels(num_render_channels_to_aec_); + + render_blocker_.reset( + new FrameBlocker(num_bands_, num_render_channels_to_aec_)); + + block_processor_.reset(BlockProcessor::Create( + config_selector_.active_config(), sample_rate_hz_, + num_render_channels_to_aec_, num_capture_channels_)); + + render_sub_frame_view_ = std::vector<std::vector<rtc::ArrayView<float>>>( + num_bands_, + std::vector<rtc::ArrayView<float>>(num_render_channels_to_aec_)); +} + +void EchoCanceller3::AnalyzeRender(const AudioBuffer& render) { + RTC_DCHECK_RUNS_SERIALIZED(&render_race_checker_); + + RTC_DCHECK_EQ(render.num_channels(), num_render_input_channels_); + data_dumper_->DumpRaw("aec3_call_order", + static_cast<int>(EchoCanceller3ApiCall::kRender)); + + return render_writer_->Insert(render); +} + +void EchoCanceller3::AnalyzeCapture(const AudioBuffer& capture) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + data_dumper_->DumpWav("aec3_capture_analyze_input", capture.num_frames(), + capture.channels_const()[0], sample_rate_hz_, 1); + saturated_microphone_signal_ = false; + for (size_t channel = 0; channel < capture.num_channels(); ++channel) { + saturated_microphone_signal_ |= + DetectSaturation(rtc::ArrayView<const float>( + capture.channels_const()[channel], capture.num_frames())); + if (saturated_microphone_signal_) { + break; + } + } +} + +void EchoCanceller3::ProcessCapture(AudioBuffer* capture, bool level_change) { + ProcessCapture(capture, nullptr, level_change); +} + +void EchoCanceller3::ProcessCapture(AudioBuffer* capture, + AudioBuffer* linear_output, + bool level_change) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + RTC_DCHECK(capture); + RTC_DCHECK_EQ(num_bands_, capture->num_bands()); + RTC_DCHECK_EQ(AudioBuffer::kSplitBandSize, capture->num_frames_per_band()); + RTC_DCHECK_EQ(capture->num_channels(), num_capture_channels_); + data_dumper_->DumpRaw("aec3_call_order", + static_cast<int>(EchoCanceller3ApiCall::kCapture)); + + if (linear_output && !linear_output_framer_) { + RTC_LOG(LS_ERROR) << "Trying to retrieve the linear AEC output without " + "properly configuring AEC3."; + RTC_DCHECK_NOTREACHED(); + } + + // Report capture call in the metrics and periodically update API call + // metrics. + api_call_metrics_.ReportCaptureCall(); + + // Optionally delay the capture signal. + if (config_selector_.active_config().delay.fixed_capture_delay_samples > 0) { + RTC_DCHECK(block_delay_buffer_); + block_delay_buffer_->DelaySignal(capture); + } + + rtc::ArrayView<float> capture_lower_band = rtc::ArrayView<float>( + &capture->split_bands(0)[0][0], AudioBuffer::kSplitBandSize); + + data_dumper_->DumpWav("aec3_capture_input", capture_lower_band, 16000, 1); + + EmptyRenderQueue(); + + ProcessCaptureFrameContent( + linear_output, capture, level_change, + multichannel_content_detector_.IsTemporaryMultiChannelContentDetected(), + saturated_microphone_signal_, 0, &capture_blocker_, + linear_output_framer_.get(), &output_framer_, block_processor_.get(), + linear_output_block_.get(), &linear_output_sub_frame_view_, + &capture_block_, &capture_sub_frame_view_); + + ProcessCaptureFrameContent( + linear_output, capture, level_change, + multichannel_content_detector_.IsTemporaryMultiChannelContentDetected(), + saturated_microphone_signal_, 1, &capture_blocker_, + linear_output_framer_.get(), &output_framer_, block_processor_.get(), + linear_output_block_.get(), &linear_output_sub_frame_view_, + &capture_block_, &capture_sub_frame_view_); + + ProcessRemainingCaptureFrameContent( + level_change, + multichannel_content_detector_.IsTemporaryMultiChannelContentDetected(), + saturated_microphone_signal_, &capture_blocker_, + linear_output_framer_.get(), &output_framer_, block_processor_.get(), + linear_output_block_.get(), &capture_block_); + + data_dumper_->DumpWav("aec3_capture_output", AudioBuffer::kSplitBandSize, + &capture->split_bands(0)[0][0], 16000, 1); +} + +EchoControl::Metrics EchoCanceller3::GetMetrics() const { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + Metrics metrics; + block_processor_->GetMetrics(&metrics); + return metrics; +} + +void EchoCanceller3::SetAudioBufferDelay(int delay_ms) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + block_processor_->SetAudioBufferDelay(delay_ms); +} + +void EchoCanceller3::SetCaptureOutputUsage(bool capture_output_used) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + block_processor_->SetCaptureOutputUsage(capture_output_used); +} + +bool EchoCanceller3::ActiveProcessing() const { + return true; +} + +EchoCanceller3Config EchoCanceller3::CreateDefaultMultichannelConfig() { + EchoCanceller3Config cfg; + // Use shorter and more rapidly adapting coarse filter to compensate for + // thge increased number of total filter parameters to adapt. + cfg.filter.coarse.length_blocks = 11; + cfg.filter.coarse.rate = 0.95f; + cfg.filter.coarse_initial.length_blocks = 11; + cfg.filter.coarse_initial.rate = 0.95f; + + // Use more concervative suppressor behavior for non-nearend speech. + cfg.suppressor.normal_tuning.max_dec_factor_lf = 0.35f; + cfg.suppressor.normal_tuning.max_inc_factor = 1.5f; + return cfg; +} + +void EchoCanceller3::SetBlockProcessorForTesting( + std::unique_ptr<BlockProcessor> block_processor) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + RTC_DCHECK(block_processor); + block_processor_ = std::move(block_processor); +} + +void EchoCanceller3::EmptyRenderQueue() { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + bool frame_to_buffer = + render_transfer_queue_.Remove(&render_queue_output_frame_); + while (frame_to_buffer) { + // Report render call in the metrics. + api_call_metrics_.ReportRenderCall(); + + if (multichannel_content_detector_.UpdateDetection( + render_queue_output_frame_)) { + // Reinitialize the AEC when proper stereo is detected. + Initialize(); + } + + // Buffer frame content. + BufferRenderFrameContent( + /*proper_downmix_needed=*/multichannel_content_detector_ + .IsTemporaryMultiChannelContentDetected(), + &render_queue_output_frame_, 0, render_blocker_.get(), + block_processor_.get(), &render_block_, &render_sub_frame_view_); + + BufferRenderFrameContent( + /*proper_downmix_needed=*/multichannel_content_detector_ + .IsTemporaryMultiChannelContentDetected(), + &render_queue_output_frame_, 1, render_blocker_.get(), + block_processor_.get(), &render_block_, &render_sub_frame_view_); + + BufferRemainingRenderFrameContent(render_blocker_.get(), + block_processor_.get(), &render_block_); + + frame_to_buffer = + render_transfer_queue_.Remove(&render_queue_output_frame_); + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.h b/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.h new file mode 100644 index 0000000000..7bf8e51a4b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3.h @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ECHO_CANCELLER3_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ECHO_CANCELLER3_H_ + +#include <stddef.h> + +#include <atomic> +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "api/audio/echo_control.h" +#include "modules/audio_processing/aec3/api_call_jitter_metrics.h" +#include "modules/audio_processing/aec3/block_delay_buffer.h" +#include "modules/audio_processing/aec3/block_framer.h" +#include "modules/audio_processing/aec3/block_processor.h" +#include "modules/audio_processing/aec3/config_selector.h" +#include "modules/audio_processing/aec3/frame_blocker.h" +#include "modules/audio_processing/aec3/multi_channel_content_detector.h" +#include "modules/audio_processing/audio_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/race_checker.h" +#include "rtc_base/swap_queue.h" +#include "rtc_base/thread_annotations.h" + +namespace webrtc { + +// Method for adjusting config parameter dependencies. +// Only to be used externally to AEC3 for testing purposes. +// TODO(webrtc:5298): Move this to a separate file. +EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config); + +// Functor for verifying the invariance of the frames being put into the render +// queue. +class Aec3RenderQueueItemVerifier { + public: + Aec3RenderQueueItemVerifier(size_t num_bands, + size_t num_channels, + size_t frame_length) + : num_bands_(num_bands), + num_channels_(num_channels), + frame_length_(frame_length) {} + + bool operator()(const std::vector<std::vector<std::vector<float>>>& v) const { + if (v.size() != num_bands_) { + return false; + } + for (const auto& band : v) { + if (band.size() != num_channels_) { + return false; + } + for (const auto& channel : band) { + if (channel.size() != frame_length_) { + return false; + } + } + } + return true; + } + + private: + const size_t num_bands_; + const size_t num_channels_; + const size_t frame_length_; +}; + +// Main class for the echo canceller3. +// It does 4 things: +// -Receives 10 ms frames of band-split audio. +// -Provides the lower level echo canceller functionality with +// blocks of 64 samples of audio data. +// -Partially handles the jitter in the render and capture API +// call sequence. +// +// The class is supposed to be used in a non-concurrent manner apart from the +// AnalyzeRender call which can be called concurrently with the other methods. +class EchoCanceller3 : public EchoControl { + public: + EchoCanceller3( + const EchoCanceller3Config& config, + const absl::optional<EchoCanceller3Config>& multichannel_config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels); + + ~EchoCanceller3() override; + + EchoCanceller3(const EchoCanceller3&) = delete; + EchoCanceller3& operator=(const EchoCanceller3&) = delete; + + // Analyzes and stores an internal copy of the split-band domain render + // signal. + void AnalyzeRender(AudioBuffer* render) override { AnalyzeRender(*render); } + // Analyzes the full-band domain capture signal to detect signal saturation. + void AnalyzeCapture(AudioBuffer* capture) override { + AnalyzeCapture(*capture); + } + // Processes the split-band domain capture signal in order to remove any echo + // present in the signal. + void ProcessCapture(AudioBuffer* capture, bool level_change) override; + // As above, but also returns the linear filter output. + void ProcessCapture(AudioBuffer* capture, + AudioBuffer* linear_output, + bool level_change) override; + // Collect current metrics from the echo canceller. + Metrics GetMetrics() const override; + // Provides an optional external estimate of the audio buffer delay. + void SetAudioBufferDelay(int delay_ms) override; + + // Specifies whether the capture output will be used. The purpose of this is + // to allow the echo controller to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + void SetCaptureOutputUsage(bool capture_output_used) override; + + bool ActiveProcessing() const override; + + // Signals whether an external detector has detected echo leakage from the + // echo canceller. + // Note that in the case echo leakage has been flagged, it should be unflagged + // once it is no longer occurring. + void UpdateEchoLeakageStatus(bool leakage_detected) { + RTC_DCHECK_RUNS_SERIALIZED(&capture_race_checker_); + block_processor_->UpdateEchoLeakageStatus(leakage_detected); + } + + // Produces a default configuration for multichannel. + static EchoCanceller3Config CreateDefaultMultichannelConfig(); + + private: + friend class EchoCanceller3Tester; + FRIEND_TEST_ALL_PREFIXES(EchoCanceller3, DetectionOfProperStereo); + FRIEND_TEST_ALL_PREFIXES(EchoCanceller3, + DetectionOfProperStereoUsingThreshold); + FRIEND_TEST_ALL_PREFIXES(EchoCanceller3, + DetectionOfProperStereoUsingHysteresis); + FRIEND_TEST_ALL_PREFIXES(EchoCanceller3, + StereoContentDetectionForMonoSignals); + + class RenderWriter; + + // (Re-)Initializes the selected subset of the EchoCanceller3 fields, at + // creation as well as during reconfiguration. + void Initialize(); + + // Only for testing. Replaces the internal block processor. + void SetBlockProcessorForTesting( + std::unique_ptr<BlockProcessor> block_processor); + + // Only for testing. Returns whether stereo processing is active. + bool StereoRenderProcessingActiveForTesting() const { + return multichannel_content_detector_.IsProperMultiChannelContentDetected(); + } + + // Only for testing. + const EchoCanceller3Config& GetActiveConfigForTesting() const { + return config_selector_.active_config(); + } + + // Empties the render SwapQueue. + void EmptyRenderQueue(); + + // Analyzes and stores an internal copy of the split-band domain render + // signal. + void AnalyzeRender(const AudioBuffer& render); + // Analyzes the full-band domain capture signal to detect signal saturation. + void AnalyzeCapture(const AudioBuffer& capture); + + rtc::RaceChecker capture_race_checker_; + rtc::RaceChecker render_race_checker_; + + // State that is accessed by the AnalyzeRender call. + std::unique_ptr<RenderWriter> render_writer_ + RTC_GUARDED_BY(render_race_checker_); + + // State that may be accessed by the capture thread. + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const EchoCanceller3Config config_; + const int sample_rate_hz_; + const int num_bands_; + const size_t num_render_input_channels_; + size_t num_render_channels_to_aec_; + const size_t num_capture_channels_; + ConfigSelector config_selector_; + MultiChannelContentDetector multichannel_content_detector_; + std::unique_ptr<BlockFramer> linear_output_framer_ + RTC_GUARDED_BY(capture_race_checker_); + BlockFramer output_framer_ RTC_GUARDED_BY(capture_race_checker_); + FrameBlocker capture_blocker_ RTC_GUARDED_BY(capture_race_checker_); + std::unique_ptr<FrameBlocker> render_blocker_ + RTC_GUARDED_BY(capture_race_checker_); + SwapQueue<std::vector<std::vector<std::vector<float>>>, + Aec3RenderQueueItemVerifier> + render_transfer_queue_; + std::unique_ptr<BlockProcessor> block_processor_ + RTC_GUARDED_BY(capture_race_checker_); + std::vector<std::vector<std::vector<float>>> render_queue_output_frame_ + RTC_GUARDED_BY(capture_race_checker_); + bool saturated_microphone_signal_ RTC_GUARDED_BY(capture_race_checker_) = + false; + Block render_block_ RTC_GUARDED_BY(capture_race_checker_); + std::unique_ptr<Block> linear_output_block_ + RTC_GUARDED_BY(capture_race_checker_); + Block capture_block_ RTC_GUARDED_BY(capture_race_checker_); + std::vector<std::vector<rtc::ArrayView<float>>> render_sub_frame_view_ + RTC_GUARDED_BY(capture_race_checker_); + std::vector<std::vector<rtc::ArrayView<float>>> linear_output_sub_frame_view_ + RTC_GUARDED_BY(capture_race_checker_); + std::vector<std::vector<rtc::ArrayView<float>>> capture_sub_frame_view_ + RTC_GUARDED_BY(capture_race_checker_); + std::unique_ptr<BlockDelayBuffer> block_delay_buffer_ + RTC_GUARDED_BY(capture_race_checker_); + ApiCallJitterMetrics api_call_metrics_ RTC_GUARDED_BY(capture_race_checker_); +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ECHO_CANCELLER3_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc new file mode 100644 index 0000000000..ad126af4d3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_canceller3_unittest.cc @@ -0,0 +1,1160 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_canceller3.h" + +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block_processor.h" +#include "modules/audio_processing/aec3/frame_blocker.h" +#include "modules/audio_processing/aec3/mock/mock_block_processor.h" +#include "modules/audio_processing/audio_buffer.h" +#include "modules/audio_processing/high_pass_filter.h" +#include "modules/audio_processing/utility/cascaded_biquad_filter.h" +#include "rtc_base/strings/string_builder.h" +#include "test/field_trial.h" +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +using ::testing::_; +using ::testing::StrictMock; + +// Populates the frame with linearly increasing sample values for each band, +// with a band-specific offset, in order to allow simple bitexactness +// verification for each band. +void PopulateInputFrame(size_t frame_length, + size_t num_bands, + size_t frame_index, + float* const* frame, + int offset) { + for (size_t k = 0; k < num_bands; ++k) { + for (size_t i = 0; i < frame_length; ++i) { + float value = static_cast<int>(frame_index * frame_length + i) + offset; + frame[k][i] = (value > 0 ? 5000 * k + value : 0); + } + } +} + +// Populates the frame with linearly increasing sample values. +void PopulateInputFrame(size_t frame_length, + size_t frame_index, + float* frame, + int offset) { + for (size_t i = 0; i < frame_length; ++i) { + float value = static_cast<int>(frame_index * frame_length + i) + offset; + frame[i] = std::max(value, 0.f); + } +} + +// Verifies the that samples in the output frame are identical to the samples +// that were produced for the input frame, with an offset in order to compensate +// for buffering delays. +bool VerifyOutputFrameBitexactness(size_t frame_length, + size_t num_bands, + size_t frame_index, + const float* const* frame, + int offset) { + float reference_frame_data[kMaxNumBands][2 * kSubFrameLength]; + float* reference_frame[kMaxNumBands]; + for (size_t k = 0; k < num_bands; ++k) { + reference_frame[k] = &reference_frame_data[k][0]; + } + + PopulateInputFrame(frame_length, num_bands, frame_index, reference_frame, + offset); + for (size_t k = 0; k < num_bands; ++k) { + for (size_t i = 0; i < frame_length; ++i) { + if (reference_frame[k][i] != frame[k][i]) { + return false; + } + } + } + + return true; +} + +bool VerifyOutputFrameBitexactness(rtc::ArrayView<const float> reference, + rtc::ArrayView<const float> frame, + int offset) { + for (size_t k = 0; k < frame.size(); ++k) { + int reference_index = static_cast<int>(k) + offset; + if (reference_index >= 0) { + if (reference[reference_index] != frame[k]) { + return false; + } + } + } + return true; +} + +// Class for testing that the capture data is properly received by the block +// processor and that the processor data is properly passed to the +// EchoCanceller3 output. +class CaptureTransportVerificationProcessor : public BlockProcessor { + public: + explicit CaptureTransportVerificationProcessor(size_t num_bands) {} + + CaptureTransportVerificationProcessor() = delete; + CaptureTransportVerificationProcessor( + const CaptureTransportVerificationProcessor&) = delete; + CaptureTransportVerificationProcessor& operator=( + const CaptureTransportVerificationProcessor&) = delete; + + ~CaptureTransportVerificationProcessor() override = default; + + void ProcessCapture(bool level_change, + bool saturated_microphone_signal, + Block* linear_output, + Block* capture_block) override {} + + void BufferRender(const Block& block) override {} + + void UpdateEchoLeakageStatus(bool leakage_detected) override {} + + void GetMetrics(EchoControl::Metrics* metrics) const override {} + + void SetAudioBufferDelay(int delay_ms) override {} + + void SetCaptureOutputUsage(bool capture_output_used) {} +}; + +// Class for testing that the render data is properly received by the block +// processor. +class RenderTransportVerificationProcessor : public BlockProcessor { + public: + explicit RenderTransportVerificationProcessor(size_t num_bands) {} + + RenderTransportVerificationProcessor() = delete; + RenderTransportVerificationProcessor( + const RenderTransportVerificationProcessor&) = delete; + RenderTransportVerificationProcessor& operator=( + const RenderTransportVerificationProcessor&) = delete; + + ~RenderTransportVerificationProcessor() override = default; + + void ProcessCapture(bool level_change, + bool saturated_microphone_signal, + Block* linear_output, + Block* capture_block) override { + Block render_block = received_render_blocks_.front(); + received_render_blocks_.pop_front(); + capture_block->Swap(render_block); + } + + void BufferRender(const Block& block) override { + received_render_blocks_.push_back(block); + } + + void UpdateEchoLeakageStatus(bool leakage_detected) override {} + + void GetMetrics(EchoControl::Metrics* metrics) const override {} + + void SetAudioBufferDelay(int delay_ms) override {} + + void SetCaptureOutputUsage(bool capture_output_used) {} + + private: + std::deque<Block> received_render_blocks_; +}; + +std::string ProduceDebugText(int sample_rate_hz) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.Release(); +} + +std::string ProduceDebugText(int sample_rate_hz, int variant) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz << ", variant: " << variant; + return ss.Release(); +} + +void RunAecInStereo(AudioBuffer& buffer, + EchoCanceller3& aec3, + float channel_0_value, + float channel_1_value) { + rtc::ArrayView<float> data_channel_0(&buffer.channels()[0][0], + buffer.num_frames()); + std::fill(data_channel_0.begin(), data_channel_0.end(), channel_0_value); + rtc::ArrayView<float> data_channel_1(&buffer.channels()[1][0], + buffer.num_frames()); + std::fill(data_channel_1.begin(), data_channel_1.end(), channel_1_value); + aec3.AnalyzeRender(&buffer); + aec3.AnalyzeCapture(&buffer); + aec3.ProcessCapture(&buffer, /*level_change=*/false); +} + +void RunAecInSMono(AudioBuffer& buffer, + EchoCanceller3& aec3, + float channel_0_value) { + rtc::ArrayView<float> data_channel_0(&buffer.channels()[0][0], + buffer.num_frames()); + std::fill(data_channel_0.begin(), data_channel_0.end(), channel_0_value); + aec3.AnalyzeRender(&buffer); + aec3.AnalyzeCapture(&buffer); + aec3.ProcessCapture(&buffer, /*level_change=*/false); +} + +} // namespace + +class EchoCanceller3Tester { + public: + explicit EchoCanceller3Tester(int sample_rate_hz) + : sample_rate_hz_(sample_rate_hz), + num_bands_(NumBandsForRate(sample_rate_hz_)), + frame_length_(160), + fullband_frame_length_(rtc::CheckedDivExact(sample_rate_hz_, 100)), + capture_buffer_(fullband_frame_length_ * 100, + 1, + fullband_frame_length_ * 100, + 1, + fullband_frame_length_ * 100, + 1), + render_buffer_(fullband_frame_length_ * 100, + 1, + fullband_frame_length_ * 100, + 1, + fullband_frame_length_ * 100, + 1) {} + + EchoCanceller3Tester() = delete; + EchoCanceller3Tester(const EchoCanceller3Tester&) = delete; + EchoCanceller3Tester& operator=(const EchoCanceller3Tester&) = delete; + + // Verifies that the capture data is properly received by the block processor + // and that the processor data is properly passed to the EchoCanceller3 + // output. + void RunCaptureTransportVerificationTest() { + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, sample_rate_hz_, + 1, 1); + aec3.SetBlockProcessorForTesting( + std::make_unique<CaptureTransportVerificationProcessor>(num_bands_)); + + for (size_t frame_index = 0; frame_index < kNumFramesToProcess; + ++frame_index) { + aec3.AnalyzeCapture(&capture_buffer_); + OptionalBandSplit(); + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels()[0][0], 0); + + aec3.AnalyzeRender(&render_buffer_); + aec3.ProcessCapture(&capture_buffer_, false); + EXPECT_TRUE(VerifyOutputFrameBitexactness( + frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], -64)); + } + } + + // Test method for testing that the render data is properly received by the + // block processor. + void RunRenderTransportVerificationTest() { + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, sample_rate_hz_, + 1, 1); + aec3.SetBlockProcessorForTesting( + std::make_unique<RenderTransportVerificationProcessor>(num_bands_)); + + std::vector<std::vector<float>> render_input(1); + std::vector<float> capture_output; + for (size_t frame_index = 0; frame_index < kNumFramesToProcess; + ++frame_index) { + aec3.AnalyzeCapture(&capture_buffer_); + OptionalBandSplit(); + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], 100); + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &render_buffer_.split_bands(0)[0], 0); + + for (size_t k = 0; k < frame_length_; ++k) { + render_input[0].push_back(render_buffer_.split_bands(0)[0][k]); + } + aec3.AnalyzeRender(&render_buffer_); + aec3.ProcessCapture(&capture_buffer_, false); + for (size_t k = 0; k < frame_length_; ++k) { + capture_output.push_back(capture_buffer_.split_bands(0)[0][k]); + } + } + + EXPECT_TRUE( + VerifyOutputFrameBitexactness(render_input[0], capture_output, -64)); + } + + // Verifies that information about echo path changes are properly propagated + // to the block processor. + // The cases tested are: + // -That no set echo path change flags are received when there is no echo path + // change. + // -That set echo path change flags are received and continues to be received + // as long as echo path changes are flagged. + // -That set echo path change flags are no longer received when echo path + // change events stop being flagged. + enum class EchoPathChangeTestVariant { kNone, kOneSticky, kOneNonSticky }; + + void RunEchoPathChangeVerificationTest( + EchoPathChangeTestVariant echo_path_change_test_variant) { + constexpr size_t kNumFullBlocksPerFrame = 160 / kBlockSize; + constexpr size_t kExpectedNumBlocksToProcess = + (kNumFramesToProcess * 160) / kBlockSize; + std::unique_ptr<testing::StrictMock<webrtc::test::MockBlockProcessor>> + block_processor_mock( + new StrictMock<webrtc::test::MockBlockProcessor>()); + EXPECT_CALL(*block_processor_mock, BufferRender(_)) + .Times(kExpectedNumBlocksToProcess); + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(_)).Times(0); + + switch (echo_path_change_test_variant) { + case EchoPathChangeTestVariant::kNone: + EXPECT_CALL(*block_processor_mock, ProcessCapture(false, _, _, _)) + .Times(kExpectedNumBlocksToProcess); + break; + case EchoPathChangeTestVariant::kOneSticky: + EXPECT_CALL(*block_processor_mock, ProcessCapture(true, _, _, _)) + .Times(kExpectedNumBlocksToProcess); + break; + case EchoPathChangeTestVariant::kOneNonSticky: + EXPECT_CALL(*block_processor_mock, ProcessCapture(true, _, _, _)) + .Times(kNumFullBlocksPerFrame); + EXPECT_CALL(*block_processor_mock, ProcessCapture(false, _, _, _)) + .Times(kExpectedNumBlocksToProcess - kNumFullBlocksPerFrame); + break; + } + + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, sample_rate_hz_, + 1, 1); + aec3.SetBlockProcessorForTesting(std::move(block_processor_mock)); + + for (size_t frame_index = 0; frame_index < kNumFramesToProcess; + ++frame_index) { + bool echo_path_change = false; + switch (echo_path_change_test_variant) { + case EchoPathChangeTestVariant::kNone: + break; + case EchoPathChangeTestVariant::kOneSticky: + echo_path_change = true; + break; + case EchoPathChangeTestVariant::kOneNonSticky: + if (frame_index == 0) { + echo_path_change = true; + } + break; + } + + aec3.AnalyzeCapture(&capture_buffer_); + OptionalBandSplit(); + + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels()[0][0], 0); + + aec3.AnalyzeRender(&render_buffer_); + aec3.ProcessCapture(&capture_buffer_, echo_path_change); + } + } + + // Test for verifying that echo leakage information is being properly passed + // to the processor. + // The cases tested are: + // -That no method calls are received when they should not. + // -That false values are received each time they are flagged. + // -That true values are received each time they are flagged. + // -That a false value is received when flagged after a true value has been + // flagged. + enum class EchoLeakageTestVariant { + kNone, + kFalseSticky, + kTrueSticky, + kTrueNonSticky + }; + + void RunEchoLeakageVerificationTest( + EchoLeakageTestVariant leakage_report_variant) { + constexpr size_t kExpectedNumBlocksToProcess = + (kNumFramesToProcess * 160) / kBlockSize; + std::unique_ptr<testing::StrictMock<webrtc::test::MockBlockProcessor>> + block_processor_mock( + new StrictMock<webrtc::test::MockBlockProcessor>()); + EXPECT_CALL(*block_processor_mock, BufferRender(_)) + .Times(kExpectedNumBlocksToProcess); + EXPECT_CALL(*block_processor_mock, ProcessCapture(_, _, _, _)) + .Times(kExpectedNumBlocksToProcess); + + switch (leakage_report_variant) { + case EchoLeakageTestVariant::kNone: + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(_)).Times(0); + break; + case EchoLeakageTestVariant::kFalseSticky: + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(false)) + .Times(1); + break; + case EchoLeakageTestVariant::kTrueSticky: + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(true)) + .Times(1); + break; + case EchoLeakageTestVariant::kTrueNonSticky: { + ::testing::InSequence s; + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(true)) + .Times(1); + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(false)) + .Times(kNumFramesToProcess - 1); + } break; + } + + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, sample_rate_hz_, + 1, 1); + aec3.SetBlockProcessorForTesting(std::move(block_processor_mock)); + + for (size_t frame_index = 0; frame_index < kNumFramesToProcess; + ++frame_index) { + switch (leakage_report_variant) { + case EchoLeakageTestVariant::kNone: + break; + case EchoLeakageTestVariant::kFalseSticky: + if (frame_index == 0) { + aec3.UpdateEchoLeakageStatus(false); + } + break; + case EchoLeakageTestVariant::kTrueSticky: + if (frame_index == 0) { + aec3.UpdateEchoLeakageStatus(true); + } + break; + case EchoLeakageTestVariant::kTrueNonSticky: + if (frame_index == 0) { + aec3.UpdateEchoLeakageStatus(true); + } else { + aec3.UpdateEchoLeakageStatus(false); + } + break; + } + + aec3.AnalyzeCapture(&capture_buffer_); + OptionalBandSplit(); + + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], 0); + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels()[0][0], 0); + + aec3.AnalyzeRender(&render_buffer_); + aec3.ProcessCapture(&capture_buffer_, false); + } + } + + // This verifies that saturation information is properly passed to the + // BlockProcessor. + // The cases tested are: + // -That no saturation event is passed to the processor if there is no + // saturation. + // -That one frame with one negative saturated sample value is reported to be + // saturated and that following non-saturated frames are properly reported as + // not being saturated. + // -That one frame with one positive saturated sample value is reported to be + // saturated and that following non-saturated frames are properly reported as + // not being saturated. + enum class SaturationTestVariant { kNone, kOneNegative, kOnePositive }; + + void RunCaptureSaturationVerificationTest( + SaturationTestVariant saturation_variant) { + const size_t kNumFullBlocksPerFrame = 160 / kBlockSize; + const size_t kExpectedNumBlocksToProcess = + (kNumFramesToProcess * 160) / kBlockSize; + std::unique_ptr<testing::StrictMock<webrtc::test::MockBlockProcessor>> + block_processor_mock( + new StrictMock<webrtc::test::MockBlockProcessor>()); + EXPECT_CALL(*block_processor_mock, BufferRender(_)) + .Times(kExpectedNumBlocksToProcess); + EXPECT_CALL(*block_processor_mock, UpdateEchoLeakageStatus(_)).Times(0); + + switch (saturation_variant) { + case SaturationTestVariant::kNone: + EXPECT_CALL(*block_processor_mock, ProcessCapture(_, false, _, _)) + .Times(kExpectedNumBlocksToProcess); + break; + case SaturationTestVariant::kOneNegative: { + ::testing::InSequence s; + EXPECT_CALL(*block_processor_mock, ProcessCapture(_, true, _, _)) + .Times(kNumFullBlocksPerFrame); + EXPECT_CALL(*block_processor_mock, ProcessCapture(_, false, _, _)) + .Times(kExpectedNumBlocksToProcess - kNumFullBlocksPerFrame); + } break; + case SaturationTestVariant::kOnePositive: { + ::testing::InSequence s; + EXPECT_CALL(*block_processor_mock, ProcessCapture(_, true, _, _)) + .Times(kNumFullBlocksPerFrame); + EXPECT_CALL(*block_processor_mock, ProcessCapture(_, false, _, _)) + .Times(kExpectedNumBlocksToProcess - kNumFullBlocksPerFrame); + } break; + } + + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, sample_rate_hz_, + 1, 1); + aec3.SetBlockProcessorForTesting(std::move(block_processor_mock)); + for (size_t frame_index = 0; frame_index < kNumFramesToProcess; + ++frame_index) { + for (int k = 0; k < fullband_frame_length_; ++k) { + capture_buffer_.channels()[0][k] = 0.f; + } + switch (saturation_variant) { + case SaturationTestVariant::kNone: + break; + case SaturationTestVariant::kOneNegative: + if (frame_index == 0) { + capture_buffer_.channels()[0][10] = -32768.f; + } + break; + case SaturationTestVariant::kOnePositive: + if (frame_index == 0) { + capture_buffer_.channels()[0][10] = 32767.f; + } + break; + } + + aec3.AnalyzeCapture(&capture_buffer_); + OptionalBandSplit(); + + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], 0); + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &render_buffer_.split_bands(0)[0], 0); + + aec3.AnalyzeRender(&render_buffer_); + aec3.ProcessCapture(&capture_buffer_, false); + } + } + + // This test verifies that the swapqueue is able to handle jitter in the + // capture and render API calls. + void RunRenderSwapQueueVerificationTest() { + const EchoCanceller3Config config; + EchoCanceller3 aec3(config, /*multichannel_config=*/absl::nullopt, + sample_rate_hz_, 1, 1); + aec3.SetBlockProcessorForTesting( + std::make_unique<RenderTransportVerificationProcessor>(num_bands_)); + + std::vector<std::vector<float>> render_input(1); + std::vector<float> capture_output; + + for (size_t frame_index = 0; frame_index < kRenderTransferQueueSizeFrames; + ++frame_index) { + if (sample_rate_hz_ > 16000) { + render_buffer_.SplitIntoFrequencyBands(); + } + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &render_buffer_.split_bands(0)[0], 0); + + if (sample_rate_hz_ > 16000) { + render_buffer_.SplitIntoFrequencyBands(); + } + + for (size_t k = 0; k < frame_length_; ++k) { + render_input[0].push_back(render_buffer_.split_bands(0)[0][k]); + } + aec3.AnalyzeRender(&render_buffer_); + } + + for (size_t frame_index = 0; frame_index < kRenderTransferQueueSizeFrames; + ++frame_index) { + aec3.AnalyzeCapture(&capture_buffer_); + if (sample_rate_hz_ > 16000) { + capture_buffer_.SplitIntoFrequencyBands(); + } + + PopulateInputFrame(frame_length_, num_bands_, frame_index, + &capture_buffer_.split_bands(0)[0], 0); + + aec3.ProcessCapture(&capture_buffer_, false); + for (size_t k = 0; k < frame_length_; ++k) { + capture_output.push_back(capture_buffer_.split_bands(0)[0][k]); + } + } + + EXPECT_TRUE( + VerifyOutputFrameBitexactness(render_input[0], capture_output, -64)); + } + + // This test verifies that a buffer overrun in the render swapqueue is + // properly reported. + void RunRenderPipelineSwapQueueOverrunReturnValueTest() { + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, sample_rate_hz_, + 1, 1); + + constexpr size_t kRenderTransferQueueSize = 30; + for (size_t k = 0; k < 2; ++k) { + for (size_t frame_index = 0; frame_index < kRenderTransferQueueSize; + ++frame_index) { + if (sample_rate_hz_ > 16000) { + render_buffer_.SplitIntoFrequencyBands(); + } + PopulateInputFrame(frame_length_, frame_index, + &render_buffer_.channels()[0][0], 0); + + aec3.AnalyzeRender(&render_buffer_); + } + } + } + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + // Verifies the that the check for the number of bands in the AnalyzeRender + // input is correct by adjusting the sample rates of EchoCanceller3 and the + // input AudioBuffer to have a different number of bands. + void RunAnalyzeRenderNumBandsCheckVerification() { + // Set aec3_sample_rate_hz to be different from sample_rate_hz_ in such a + // way that the number of bands for the rates are different. + const int aec3_sample_rate_hz = sample_rate_hz_ == 48000 ? 32000 : 48000; + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, + aec3_sample_rate_hz, 1, 1); + PopulateInputFrame(frame_length_, 0, &render_buffer_.channels_f()[0][0], 0); + + EXPECT_DEATH(aec3.AnalyzeRender(&render_buffer_), ""); + } + + // Verifies the that the check for the number of bands in the ProcessCapture + // input is correct by adjusting the sample rates of EchoCanceller3 and the + // input AudioBuffer to have a different number of bands. + void RunProcessCaptureNumBandsCheckVerification() { + // Set aec3_sample_rate_hz to be different from sample_rate_hz_ in such a + // way that the number of bands for the rates are different. + const int aec3_sample_rate_hz = sample_rate_hz_ == 48000 ? 32000 : 48000; + EchoCanceller3 aec3(EchoCanceller3Config(), + /*multichannel_config=*/absl::nullopt, + aec3_sample_rate_hz, 1, 1); + PopulateInputFrame(frame_length_, num_bands_, 0, + &capture_buffer_.split_bands_f(0)[0], 100); + EXPECT_DEATH(aec3.ProcessCapture(&capture_buffer_, false), ""); + } + +#endif + + private: + void OptionalBandSplit() { + if (sample_rate_hz_ > 16000) { + capture_buffer_.SplitIntoFrequencyBands(); + render_buffer_.SplitIntoFrequencyBands(); + } + } + + static constexpr size_t kNumFramesToProcess = 20; + const int sample_rate_hz_; + const size_t num_bands_; + const size_t frame_length_; + const int fullband_frame_length_; + AudioBuffer capture_buffer_; + AudioBuffer render_buffer_; +}; + +TEST(EchoCanceller3Buffering, CaptureBitexactness) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + EchoCanceller3Tester(rate).RunCaptureTransportVerificationTest(); + } +} + +TEST(EchoCanceller3Buffering, RenderBitexactness) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + EchoCanceller3Tester(rate).RunRenderTransportVerificationTest(); + } +} + +TEST(EchoCanceller3Buffering, RenderSwapQueue) { + EchoCanceller3Tester(16000).RunRenderSwapQueueVerificationTest(); +} + +TEST(EchoCanceller3Buffering, RenderSwapQueueOverrunReturnValue) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + EchoCanceller3Tester(rate) + .RunRenderPipelineSwapQueueOverrunReturnValueTest(); + } +} + +TEST(EchoCanceller3Messaging, CaptureSaturation) { + auto variants = {EchoCanceller3Tester::SaturationTestVariant::kNone, + EchoCanceller3Tester::SaturationTestVariant::kOneNegative, + EchoCanceller3Tester::SaturationTestVariant::kOnePositive}; + for (auto rate : {16000, 32000, 48000}) { + for (auto variant : variants) { + SCOPED_TRACE(ProduceDebugText(rate, static_cast<int>(variant))); + EchoCanceller3Tester(rate).RunCaptureSaturationVerificationTest(variant); + } + } +} + +TEST(EchoCanceller3Messaging, EchoPathChange) { + auto variants = { + EchoCanceller3Tester::EchoPathChangeTestVariant::kNone, + EchoCanceller3Tester::EchoPathChangeTestVariant::kOneSticky, + EchoCanceller3Tester::EchoPathChangeTestVariant::kOneNonSticky}; + for (auto rate : {16000, 32000, 48000}) { + for (auto variant : variants) { + SCOPED_TRACE(ProduceDebugText(rate, static_cast<int>(variant))); + EchoCanceller3Tester(rate).RunEchoPathChangeVerificationTest(variant); + } + } +} + +TEST(EchoCanceller3Messaging, EchoLeakage) { + auto variants = { + EchoCanceller3Tester::EchoLeakageTestVariant::kNone, + EchoCanceller3Tester::EchoLeakageTestVariant::kFalseSticky, + EchoCanceller3Tester::EchoLeakageTestVariant::kTrueSticky, + EchoCanceller3Tester::EchoLeakageTestVariant::kTrueNonSticky}; + for (auto rate : {16000, 32000, 48000}) { + for (auto variant : variants) { + SCOPED_TRACE(ProduceDebugText(rate, static_cast<int>(variant))); + EchoCanceller3Tester(rate).RunEchoLeakageVerificationTest(variant); + } + } +} + +// Tests the parameter functionality for the field trial override for the +// anti-howling gain. +TEST(EchoCanceller3FieldTrials, Aec3SuppressorAntiHowlingGainOverride) { + EchoCanceller3Config default_config; + EchoCanceller3Config adjusted_config = AdjustConfig(default_config); + ASSERT_EQ( + default_config.suppressor.high_bands_suppression.anti_howling_gain, + adjusted_config.suppressor.high_bands_suppression.anti_howling_gain); + + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-Aec3SuppressorAntiHowlingGainOverride/0.02/"); + adjusted_config = AdjustConfig(default_config); + + ASSERT_NE( + default_config.suppressor.high_bands_suppression.anti_howling_gain, + adjusted_config.suppressor.high_bands_suppression.anti_howling_gain); + EXPECT_FLOAT_EQ( + 0.02f, + adjusted_config.suppressor.high_bands_suppression.anti_howling_gain); +} + +// Tests the field trial override for the enforcement of a low active render +// limit. +TEST(EchoCanceller3FieldTrials, Aec3EnforceLowActiveRenderLimit) { + EchoCanceller3Config default_config; + EchoCanceller3Config adjusted_config = AdjustConfig(default_config); + ASSERT_EQ(default_config.render_levels.active_render_limit, + adjusted_config.render_levels.active_render_limit); + + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-Aec3EnforceLowActiveRenderLimit/Enabled/"); + adjusted_config = AdjustConfig(default_config); + + ASSERT_NE(default_config.render_levels.active_render_limit, + adjusted_config.render_levels.active_render_limit); + EXPECT_FLOAT_EQ(50.f, adjusted_config.render_levels.active_render_limit); +} + +// Testing the field trial-based override of the suppressor parameters for a +// joint passing of all parameters. +TEST(EchoCanceller3FieldTrials, Aec3SuppressorTuningOverrideAllParams) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-Aec3SuppressorTuningOverride/" + "nearend_tuning_mask_lf_enr_transparent:0.1,nearend_tuning_mask_lf_enr_" + "suppress:0.2,nearend_tuning_mask_hf_enr_transparent:0.3,nearend_tuning_" + "mask_hf_enr_suppress:0.4,nearend_tuning_max_inc_factor:0.5,nearend_" + "tuning_max_dec_factor_lf:0.6,normal_tuning_mask_lf_enr_transparent:0.7," + "normal_tuning_mask_lf_enr_suppress:0.8,normal_tuning_mask_hf_enr_" + "transparent:0.9,normal_tuning_mask_hf_enr_suppress:1.0,normal_tuning_" + "max_inc_factor:1.1,normal_tuning_max_dec_factor_lf:1.2,dominant_nearend_" + "detection_enr_threshold:1.3,dominant_nearend_detection_enr_exit_" + "threshold:1.4,dominant_nearend_detection_snr_threshold:1.5,dominant_" + "nearend_detection_hold_duration:10,dominant_nearend_detection_trigger_" + "threshold:11/"); + + EchoCanceller3Config default_config; + EchoCanceller3Config adjusted_config = AdjustConfig(default_config); + + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.mask_lf.enr_transparent, + default_config.suppressor.nearend_tuning.mask_lf.enr_transparent); + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.mask_lf.enr_suppress, + default_config.suppressor.nearend_tuning.mask_lf.enr_suppress); + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.mask_hf.enr_transparent, + default_config.suppressor.nearend_tuning.mask_hf.enr_transparent); + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.mask_hf.enr_suppress, + default_config.suppressor.nearend_tuning.mask_hf.enr_suppress); + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.max_inc_factor, + default_config.suppressor.nearend_tuning.max_inc_factor); + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.max_dec_factor_lf, + default_config.suppressor.nearend_tuning.max_dec_factor_lf); + ASSERT_NE(adjusted_config.suppressor.normal_tuning.mask_lf.enr_transparent, + default_config.suppressor.normal_tuning.mask_lf.enr_transparent); + ASSERT_NE(adjusted_config.suppressor.normal_tuning.mask_lf.enr_suppress, + default_config.suppressor.normal_tuning.mask_lf.enr_suppress); + ASSERT_NE(adjusted_config.suppressor.normal_tuning.mask_hf.enr_transparent, + default_config.suppressor.normal_tuning.mask_hf.enr_transparent); + ASSERT_NE(adjusted_config.suppressor.normal_tuning.mask_hf.enr_suppress, + default_config.suppressor.normal_tuning.mask_hf.enr_suppress); + ASSERT_NE(adjusted_config.suppressor.normal_tuning.max_inc_factor, + default_config.suppressor.normal_tuning.max_inc_factor); + ASSERT_NE(adjusted_config.suppressor.normal_tuning.max_dec_factor_lf, + default_config.suppressor.normal_tuning.max_dec_factor_lf); + ASSERT_NE(adjusted_config.suppressor.dominant_nearend_detection.enr_threshold, + default_config.suppressor.dominant_nearend_detection.enr_threshold); + ASSERT_NE( + adjusted_config.suppressor.dominant_nearend_detection.enr_exit_threshold, + default_config.suppressor.dominant_nearend_detection.enr_exit_threshold); + ASSERT_NE(adjusted_config.suppressor.dominant_nearend_detection.snr_threshold, + default_config.suppressor.dominant_nearend_detection.snr_threshold); + ASSERT_NE(adjusted_config.suppressor.dominant_nearend_detection.hold_duration, + default_config.suppressor.dominant_nearend_detection.hold_duration); + ASSERT_NE( + adjusted_config.suppressor.dominant_nearend_detection.trigger_threshold, + default_config.suppressor.dominant_nearend_detection.trigger_threshold); + + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.nearend_tuning.mask_lf.enr_transparent, 0.1); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.nearend_tuning.mask_lf.enr_suppress, 0.2); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.nearend_tuning.mask_hf.enr_transparent, 0.3); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.nearend_tuning.mask_hf.enr_suppress, 0.4); + EXPECT_FLOAT_EQ(adjusted_config.suppressor.nearend_tuning.max_inc_factor, + 0.5); + EXPECT_FLOAT_EQ(adjusted_config.suppressor.nearend_tuning.max_dec_factor_lf, + 0.6); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.normal_tuning.mask_lf.enr_transparent, 0.7); + EXPECT_FLOAT_EQ(adjusted_config.suppressor.normal_tuning.mask_lf.enr_suppress, + 0.8); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.normal_tuning.mask_hf.enr_transparent, 0.9); + EXPECT_FLOAT_EQ(adjusted_config.suppressor.normal_tuning.mask_hf.enr_suppress, + 1.0); + EXPECT_FLOAT_EQ(adjusted_config.suppressor.normal_tuning.max_inc_factor, 1.1); + EXPECT_FLOAT_EQ(adjusted_config.suppressor.normal_tuning.max_dec_factor_lf, + 1.2); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.dominant_nearend_detection.enr_threshold, 1.3); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.dominant_nearend_detection.enr_exit_threshold, + 1.4); + EXPECT_FLOAT_EQ( + adjusted_config.suppressor.dominant_nearend_detection.snr_threshold, 1.5); + EXPECT_EQ(adjusted_config.suppressor.dominant_nearend_detection.hold_duration, + 10); + EXPECT_EQ( + adjusted_config.suppressor.dominant_nearend_detection.trigger_threshold, + 11); +} + +// Testing the field trial-based override of the suppressor parameters for +// passing one parameter. +TEST(EchoCanceller3FieldTrials, Aec3SuppressorTuningOverrideOneParam) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-Aec3SuppressorTuningOverride/nearend_tuning_max_inc_factor:0.5/"); + + EchoCanceller3Config default_config; + EchoCanceller3Config adjusted_config = AdjustConfig(default_config); + + ASSERT_EQ(adjusted_config.suppressor.nearend_tuning.mask_lf.enr_transparent, + default_config.suppressor.nearend_tuning.mask_lf.enr_transparent); + ASSERT_EQ(adjusted_config.suppressor.nearend_tuning.mask_lf.enr_suppress, + default_config.suppressor.nearend_tuning.mask_lf.enr_suppress); + ASSERT_EQ(adjusted_config.suppressor.nearend_tuning.mask_hf.enr_transparent, + default_config.suppressor.nearend_tuning.mask_hf.enr_transparent); + ASSERT_EQ(adjusted_config.suppressor.nearend_tuning.mask_hf.enr_suppress, + default_config.suppressor.nearend_tuning.mask_hf.enr_suppress); + ASSERT_EQ(adjusted_config.suppressor.nearend_tuning.max_dec_factor_lf, + default_config.suppressor.nearend_tuning.max_dec_factor_lf); + ASSERT_EQ(adjusted_config.suppressor.normal_tuning.mask_lf.enr_transparent, + default_config.suppressor.normal_tuning.mask_lf.enr_transparent); + ASSERT_EQ(adjusted_config.suppressor.normal_tuning.mask_lf.enr_suppress, + default_config.suppressor.normal_tuning.mask_lf.enr_suppress); + ASSERT_EQ(adjusted_config.suppressor.normal_tuning.mask_hf.enr_transparent, + default_config.suppressor.normal_tuning.mask_hf.enr_transparent); + ASSERT_EQ(adjusted_config.suppressor.normal_tuning.mask_hf.enr_suppress, + default_config.suppressor.normal_tuning.mask_hf.enr_suppress); + ASSERT_EQ(adjusted_config.suppressor.normal_tuning.max_inc_factor, + default_config.suppressor.normal_tuning.max_inc_factor); + ASSERT_EQ(adjusted_config.suppressor.normal_tuning.max_dec_factor_lf, + default_config.suppressor.normal_tuning.max_dec_factor_lf); + ASSERT_EQ(adjusted_config.suppressor.dominant_nearend_detection.enr_threshold, + default_config.suppressor.dominant_nearend_detection.enr_threshold); + ASSERT_EQ( + adjusted_config.suppressor.dominant_nearend_detection.enr_exit_threshold, + default_config.suppressor.dominant_nearend_detection.enr_exit_threshold); + ASSERT_EQ(adjusted_config.suppressor.dominant_nearend_detection.snr_threshold, + default_config.suppressor.dominant_nearend_detection.snr_threshold); + ASSERT_EQ(adjusted_config.suppressor.dominant_nearend_detection.hold_duration, + default_config.suppressor.dominant_nearend_detection.hold_duration); + ASSERT_EQ( + adjusted_config.suppressor.dominant_nearend_detection.trigger_threshold, + default_config.suppressor.dominant_nearend_detection.trigger_threshold); + + ASSERT_NE(adjusted_config.suppressor.nearend_tuning.max_inc_factor, + default_config.suppressor.nearend_tuning.max_inc_factor); + + EXPECT_FLOAT_EQ(adjusted_config.suppressor.nearend_tuning.max_inc_factor, + 0.5); +} + +// Testing the field trial-based that override the exponential decay parameters. +TEST(EchoCanceller3FieldTrials, Aec3UseNearendReverb) { + webrtc::test::ScopedFieldTrials field_trials( + "WebRTC-Aec3UseNearendReverbLen/default_len:0.9,nearend_len:0.8/"); + EchoCanceller3Config default_config; + EchoCanceller3Config adjusted_config = AdjustConfig(default_config); + EXPECT_FLOAT_EQ(adjusted_config.ep_strength.default_len, 0.9); + EXPECT_FLOAT_EQ(adjusted_config.ep_strength.nearend_len, 0.8); +} + +TEST(EchoCanceller3, DetectionOfProperStereo) { + constexpr int kSampleRateHz = 16000; + constexpr int kNumChannels = 2; + AudioBuffer buffer(/*input_rate=*/kSampleRateHz, + /*input_num_channels=*/kNumChannels, + /*input_rate=*/kSampleRateHz, + /*buffer_num_channels=*/kNumChannels, + /*output_rate=*/kSampleRateHz, + /*output_num_channels=*/kNumChannels); + + constexpr size_t kNumBlocksForMonoConfig = 1; + constexpr size_t kNumBlocksForSurroundConfig = 2; + EchoCanceller3Config mono_config; + absl::optional<EchoCanceller3Config> multichannel_config; + + mono_config.multi_channel.detect_stereo_content = true; + mono_config.multi_channel.stereo_detection_threshold = 0.0f; + mono_config.multi_channel.stereo_detection_hysteresis_seconds = 0.0f; + multichannel_config = mono_config; + mono_config.filter.coarse_initial.length_blocks = kNumBlocksForMonoConfig; + multichannel_config->filter.coarse_initial.length_blocks = + kNumBlocksForSurroundConfig; + + EchoCanceller3 aec3(mono_config, multichannel_config, + /*sample_rate_hz=*/kSampleRateHz, + /*num_render_channels=*/kNumChannels, + /*num_capture_input_channels=*/kNumChannels); + + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + RunAecInStereo(buffer, aec3, 100.0f, 100.0f); + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + RunAecInStereo(buffer, aec3, 100.0f, 101.0f); + EXPECT_TRUE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForSurroundConfig); +} + +TEST(EchoCanceller3, DetectionOfProperStereoUsingThreshold) { + constexpr int kSampleRateHz = 16000; + constexpr int kNumChannels = 2; + AudioBuffer buffer(/*input_rate=*/kSampleRateHz, + /*input_num_channels=*/kNumChannels, + /*input_rate=*/kSampleRateHz, + /*buffer_num_channels=*/kNumChannels, + /*output_rate=*/kSampleRateHz, + /*output_num_channels=*/kNumChannels); + + constexpr size_t kNumBlocksForMonoConfig = 1; + constexpr size_t kNumBlocksForSurroundConfig = 2; + EchoCanceller3Config mono_config; + absl::optional<EchoCanceller3Config> multichannel_config; + + constexpr float kStereoDetectionThreshold = 2.0f; + mono_config.multi_channel.detect_stereo_content = true; + mono_config.multi_channel.stereo_detection_threshold = + kStereoDetectionThreshold; + mono_config.multi_channel.stereo_detection_hysteresis_seconds = 0.0f; + multichannel_config = mono_config; + mono_config.filter.coarse_initial.length_blocks = kNumBlocksForMonoConfig; + multichannel_config->filter.coarse_initial.length_blocks = + kNumBlocksForSurroundConfig; + + EchoCanceller3 aec3(mono_config, multichannel_config, + /*sample_rate_hz=*/kSampleRateHz, + /*num_render_channels=*/kNumChannels, + /*num_capture_input_channels=*/kNumChannels); + + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + RunAecInStereo(buffer, aec3, 100.0f, + 100.0f + kStereoDetectionThreshold - 1.0f); + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + RunAecInStereo(buffer, aec3, 100.0f, + 100.0f + kStereoDetectionThreshold + 10.0f); + EXPECT_TRUE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForSurroundConfig); +} + +TEST(EchoCanceller3, DetectionOfProperStereoUsingHysteresis) { + constexpr int kSampleRateHz = 16000; + constexpr int kNumChannels = 2; + AudioBuffer buffer(/*input_rate=*/kSampleRateHz, + /*input_num_channels=*/kNumChannels, + /*input_rate=*/kSampleRateHz, + /*buffer_num_channels=*/kNumChannels, + /*output_rate=*/kSampleRateHz, + /*output_num_channels=*/kNumChannels); + + constexpr size_t kNumBlocksForMonoConfig = 1; + constexpr size_t kNumBlocksForSurroundConfig = 2; + EchoCanceller3Config mono_config; + absl::optional<EchoCanceller3Config> surround_config; + + mono_config.multi_channel.detect_stereo_content = true; + mono_config.multi_channel.stereo_detection_hysteresis_seconds = 0.5f; + surround_config = mono_config; + mono_config.filter.coarse_initial.length_blocks = kNumBlocksForMonoConfig; + surround_config->filter.coarse_initial.length_blocks = + kNumBlocksForSurroundConfig; + + EchoCanceller3 aec3(mono_config, surround_config, + /*sample_rate_hz=*/kSampleRateHz, + /*num_render_channels=*/kNumChannels, + /*num_capture_input_channels=*/kNumChannels); + + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + RunAecInStereo(buffer, aec3, 100.0f, 100.0f); + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + constexpr int kNumFramesPerSecond = 100; + for (int k = 0; + k < static_cast<int>( + kNumFramesPerSecond * + mono_config.multi_channel.stereo_detection_hysteresis_seconds); + ++k) { + RunAecInStereo(buffer, aec3, 100.0f, 101.0f); + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + } + + RunAecInStereo(buffer, aec3, 100.0f, 101.0f); + EXPECT_TRUE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForSurroundConfig); +} + +TEST(EchoCanceller3, StereoContentDetectionForMonoSignals) { + constexpr int kSampleRateHz = 16000; + constexpr int kNumChannels = 2; + AudioBuffer buffer(/*input_rate=*/kSampleRateHz, + /*input_num_channels=*/kNumChannels, + /*input_rate=*/kSampleRateHz, + /*buffer_num_channels=*/kNumChannels, + /*output_rate=*/kSampleRateHz, + /*output_num_channels=*/kNumChannels); + + constexpr size_t kNumBlocksForMonoConfig = 1; + constexpr size_t kNumBlocksForSurroundConfig = 2; + EchoCanceller3Config mono_config; + absl::optional<EchoCanceller3Config> multichannel_config; + + for (bool detect_stereo_content : {false, true}) { + mono_config.multi_channel.detect_stereo_content = detect_stereo_content; + multichannel_config = mono_config; + mono_config.filter.coarse_initial.length_blocks = kNumBlocksForMonoConfig; + multichannel_config->filter.coarse_initial.length_blocks = + kNumBlocksForSurroundConfig; + + AudioBuffer mono_buffer(/*input_rate=*/kSampleRateHz, + /*input_num_channels=*/1, + /*input_rate=*/kSampleRateHz, + /*buffer_num_channels=*/1, + /*output_rate=*/kSampleRateHz, + /*output_num_channels=*/1); + + EchoCanceller3 aec3(mono_config, multichannel_config, + /*sample_rate_hz=*/kSampleRateHz, + /*num_render_channels=*/1, + /*num_capture_input_channels=*/1); + + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + + RunAecInSMono(mono_buffer, aec3, 100.0f); + EXPECT_FALSE(aec3.StereoRenderProcessingActiveForTesting()); + EXPECT_EQ( + aec3.GetActiveConfigForTesting().filter.coarse_initial.length_blocks, + kNumBlocksForMonoConfig); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +TEST(EchoCanceller3InputCheckDeathTest, WrongCaptureNumBandsCheckVerification) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + EchoCanceller3Tester(rate).RunProcessCaptureNumBandsCheckVerification(); + } +} + +// Verifiers that the verification for null input to the capture processing api +// call works. +TEST(EchoCanceller3InputCheckDeathTest, NullCaptureProcessingParameter) { + EXPECT_DEATH( + EchoCanceller3(EchoCanceller3Config(), + /*multichannel_config_=*/absl::nullopt, 16000, 1, 1) + .ProcessCapture(nullptr, false), + ""); +} + +// Verifies the check for correct sample rate. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(EchoCanceller3InputCheckDeathTest, DISABLED_WrongSampleRate) { + ApmDataDumper data_dumper(0); + EXPECT_DEATH( + EchoCanceller3(EchoCanceller3Config(), + /*multichannel_config_=*/absl::nullopt, 8001, 1, 1), + ""); +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc new file mode 100644 index 0000000000..510e4b8a8d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/echo_path_delay_estimator.h" + +#include <array> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +EchoPathDelayEstimator::EchoPathDelayEstimator( + ApmDataDumper* data_dumper, + const EchoCanceller3Config& config, + size_t num_capture_channels) + : data_dumper_(data_dumper), + down_sampling_factor_(config.delay.down_sampling_factor), + sub_block_size_(down_sampling_factor_ != 0 + ? kBlockSize / down_sampling_factor_ + : kBlockSize), + capture_mixer_(num_capture_channels, + config.delay.capture_alignment_mixing), + capture_decimator_(down_sampling_factor_), + matched_filter_( + data_dumper_, + DetectOptimization(), + sub_block_size_, + kMatchedFilterWindowSizeSubBlocks, + config.delay.num_filters, + kMatchedFilterAlignmentShiftSizeSubBlocks, + config.delay.down_sampling_factor == 8 + ? config.render_levels.poor_excitation_render_limit_ds8 + : config.render_levels.poor_excitation_render_limit, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + config.delay.detect_pre_echo), + matched_filter_lag_aggregator_(data_dumper_, + matched_filter_.GetMaxFilterLag(), + config.delay) { + RTC_DCHECK(data_dumper); + RTC_DCHECK(down_sampling_factor_ > 0); +} + +EchoPathDelayEstimator::~EchoPathDelayEstimator() = default; + +void EchoPathDelayEstimator::Reset(bool reset_delay_confidence) { + Reset(true, reset_delay_confidence); +} + +absl::optional<DelayEstimate> EchoPathDelayEstimator::EstimateDelay( + const DownsampledRenderBuffer& render_buffer, + const Block& capture) { + std::array<float, kBlockSize> downsampled_capture_data; + rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(), + sub_block_size_); + + std::array<float, kBlockSize> downmixed_capture; + capture_mixer_.ProduceOutput(capture, downmixed_capture); + capture_decimator_.Decimate(downmixed_capture, downsampled_capture); + data_dumper_->DumpWav("aec3_capture_decimator_output", + downsampled_capture.size(), downsampled_capture.data(), + 16000 / down_sampling_factor_, 1); + matched_filter_.Update(render_buffer, downsampled_capture, + matched_filter_lag_aggregator_.ReliableDelayFound()); + + absl::optional<DelayEstimate> aggregated_matched_filter_lag = + matched_filter_lag_aggregator_.Aggregate( + matched_filter_.GetBestLagEstimate()); + + // Run clockdrift detection. + if (aggregated_matched_filter_lag && + (*aggregated_matched_filter_lag).quality == + DelayEstimate::Quality::kRefined) + clockdrift_detector_.Update( + matched_filter_lag_aggregator_.GetDelayAtHighestPeak()); + + // TODO(peah): Move this logging outside of this class once EchoCanceller3 + // development is done. + data_dumper_->DumpRaw( + "aec3_echo_path_delay_estimator_delay", + aggregated_matched_filter_lag + ? static_cast<int>(aggregated_matched_filter_lag->delay * + down_sampling_factor_) + : -1); + + // Return the detected delay in samples as the aggregated matched filter lag + // compensated by the down sampling factor for the signal being correlated. + if (aggregated_matched_filter_lag) { + aggregated_matched_filter_lag->delay *= down_sampling_factor_; + } + + if (old_aggregated_lag_ && aggregated_matched_filter_lag && + old_aggregated_lag_->delay == aggregated_matched_filter_lag->delay) { + ++consistent_estimate_counter_; + } else { + consistent_estimate_counter_ = 0; + } + old_aggregated_lag_ = aggregated_matched_filter_lag; + constexpr size_t kNumBlocksPerSecondBy2 = kNumBlocksPerSecond / 2; + if (consistent_estimate_counter_ > kNumBlocksPerSecondBy2) { + Reset(false, false); + } + + return aggregated_matched_filter_lag; +} + +void EchoPathDelayEstimator::Reset(bool reset_lag_aggregator, + bool reset_delay_confidence) { + if (reset_lag_aggregator) { + matched_filter_lag_aggregator_.Reset(reset_delay_confidence); + } + matched_filter_.Reset(/*full_reset=*/reset_lag_aggregator); + old_aggregated_lag_ = absl::nullopt; + consistent_estimate_counter_ = 0; +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h new file mode 100644 index 0000000000..b24d0a29ec --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ECHO_PATH_DELAY_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ECHO_PATH_DELAY_ESTIMATOR_H_ + +#include <stddef.h> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/alignment_mixer.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/clockdrift_detector.h" +#include "modules/audio_processing/aec3/decimator.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/matched_filter.h" +#include "modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +namespace webrtc { + +class ApmDataDumper; +struct DownsampledRenderBuffer; +struct EchoCanceller3Config; + +// Estimates the delay of the echo path. +class EchoPathDelayEstimator { + public: + EchoPathDelayEstimator(ApmDataDumper* data_dumper, + const EchoCanceller3Config& config, + size_t num_capture_channels); + ~EchoPathDelayEstimator(); + + EchoPathDelayEstimator(const EchoPathDelayEstimator&) = delete; + EchoPathDelayEstimator& operator=(const EchoPathDelayEstimator&) = delete; + + // Resets the estimation. If the delay confidence is reset, the reset behavior + // is as if the call is restarted. + void Reset(bool reset_delay_confidence); + + // Produce a delay estimate if such is avaliable. + absl::optional<DelayEstimate> EstimateDelay( + const DownsampledRenderBuffer& render_buffer, + const Block& capture); + + // Log delay estimator properties. + void LogDelayEstimationProperties(int sample_rate_hz, size_t shift) const { + matched_filter_.LogFilterProperties(sample_rate_hz, shift, + down_sampling_factor_); + } + + // Returns the level of detected clockdrift. + ClockdriftDetector::Level Clockdrift() const { + return clockdrift_detector_.ClockdriftLevel(); + } + + private: + ApmDataDumper* const data_dumper_; + const size_t down_sampling_factor_; + const size_t sub_block_size_; + AlignmentMixer capture_mixer_; + Decimator capture_decimator_; + MatchedFilter matched_filter_; + MatchedFilterLagAggregator matched_filter_lag_aggregator_; + absl::optional<DelayEstimate> old_aggregated_lag_; + size_t consistent_estimate_counter_ = 0; + ClockdriftDetector clockdrift_detector_; + + // Internal reset method with more granularity. + void Reset(bool reset_lag_aggregator, bool reset_delay_confidence); +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ECHO_PATH_DELAY_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc new file mode 100644 index 0000000000..e2c101fb04 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_delay_estimator_unittest.cc @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_path_delay_estimator.h" + +#include <algorithm> +#include <string> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +std::string ProduceDebugText(size_t delay, size_t down_sampling_factor) { + rtc::StringBuilder ss; + ss << "Delay: " << delay; + ss << ", Down sampling factor: " << down_sampling_factor; + return ss.Release(); +} + +} // namespace + +class EchoPathDelayEstimatorMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + EchoPathDelayEstimatorMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 3, 6, 8), + ::testing::Values(1, 2, 4))); + +// Verifies that the basic API calls work. +TEST_P(EchoPathDelayEstimatorMultiChannel, BasicApiCalls) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + EchoPathDelayEstimator estimator(&data_dumper, config, num_capture_channels); + Block render(kNumBands, num_render_channels); + Block capture(/*num_bands=*/1, num_capture_channels); + for (size_t k = 0; k < 100; ++k) { + render_delay_buffer->Insert(render); + estimator.EstimateDelay(render_delay_buffer->GetDownsampledRenderBuffer(), + capture); + } +} + +// Verifies that the delay estimator produces correct delay for artificially +// delayed signals. +TEST(EchoPathDelayEstimator, DelayEstimation) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + Random random_generator(42U); + Block render(kNumBands, kNumRenderChannels); + Block capture(/*num_bands=*/1, kNumCaptureChannels); + ApmDataDumper data_dumper(0); + constexpr size_t kDownSamplingFactors[] = {2, 4, 8}; + for (auto down_sampling_factor : kDownSamplingFactors) { + EchoCanceller3Config config; + config.delay.delay_headroom_samples = 0; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = 10; + for (size_t delay_samples : {30, 64, 150, 200, 800, 4000}) { + SCOPED_TRACE(ProduceDebugText(delay_samples, down_sampling_factor)); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); + DelayBuffer<float> signal_delay_buffer(delay_samples); + EchoPathDelayEstimator estimator(&data_dumper, config, + kNumCaptureChannels); + + absl::optional<DelayEstimate> estimated_delay_samples; + for (size_t k = 0; k < (500 + (delay_samples) / kBlockSize); ++k) { + RandomizeSampleVector(&random_generator, + render.View(/*band=*/0, /*channel=*/0)); + signal_delay_buffer.Delay(render.View(/*band=*/0, /*channel=*/0), + capture.View(/*band=*/0, /*channel=*/0)); + render_delay_buffer->Insert(render); + + if (k == 0) { + render_delay_buffer->Reset(); + } + + render_delay_buffer->PrepareCaptureProcessing(); + + auto estimate = estimator.EstimateDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), capture); + + if (estimate) { + estimated_delay_samples = estimate; + } + } + + if (estimated_delay_samples) { + // Allow estimated delay to be off by a block as internally the delay is + // quantized with an error up to a block. + size_t delay_ds = delay_samples / down_sampling_factor; + size_t estimated_delay_ds = + estimated_delay_samples->delay / down_sampling_factor; + EXPECT_NEAR(delay_ds, estimated_delay_ds, + kBlockSize / down_sampling_factor); + } else { + ADD_FAILURE(); + } + } + } +} + +// Verifies that the delay estimator does not produce delay estimates for render +// signals of low level. +TEST(EchoPathDelayEstimator, NoDelayEstimatesForLowLevelRenderSignals) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + Random random_generator(42U); + EchoCanceller3Config config; + Block render(kNumBands, kNumRenderChannels); + Block capture(/*num_bands=*/1, kNumCaptureChannels); + ApmDataDumper data_dumper(0); + EchoPathDelayEstimator estimator(&data_dumper, config, kNumCaptureChannels); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + kNumRenderChannels)); + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, + render.View(/*band=*/0, /*channel=*/0)); + for (auto& render_k : render.View(/*band=*/0, /*channel=*/0)) { + render_k *= 100.f / 32767.f; + } + std::copy(render.begin(/*band=*/0, /*channel=*/0), + render.end(/*band=*/0, /*channel=*/0), + capture.begin(/*band*/ 0, /*channel=*/0)); + render_delay_buffer->Insert(render); + render_delay_buffer->PrepareCaptureProcessing(); + EXPECT_FALSE(estimator.EstimateDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), capture)); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for the render blocksize. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(EchoPathDelayEstimatorDeathTest, DISABLED_WrongRenderBlockSize) { + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + EchoPathDelayEstimator estimator(&data_dumper, config, 1); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, 48000, 1)); + Block capture(/*num_bands=*/1, /*num_channels=*/1); + EXPECT_DEATH(estimator.EstimateDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), capture), + ""); +} + +// Verifies the check for non-null data dumper. +TEST(EchoPathDelayEstimatorDeathTest, NullDataDumper) { + EXPECT_DEATH(EchoPathDelayEstimator(nullptr, EchoCanceller3Config(), 1), ""); +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.cc new file mode 100644 index 0000000000..0ae9cff98e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.cc @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_path_variability.h" + +namespace webrtc { + +EchoPathVariability::EchoPathVariability(bool gain_change, + DelayAdjustment delay_change, + bool clock_drift) + : gain_change(gain_change), + delay_change(delay_change), + clock_drift(clock_drift) {} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.h b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.h new file mode 100644 index 0000000000..78e4f64b2b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ECHO_PATH_VARIABILITY_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ECHO_PATH_VARIABILITY_H_ + +namespace webrtc { + +struct EchoPathVariability { + enum class DelayAdjustment { + kNone, + kBufferFlush, + kNewDetectedDelay + }; + + EchoPathVariability(bool gain_change, + DelayAdjustment delay_change, + bool clock_drift); + + bool AudioPathChanged() const { + return gain_change || delay_change != DelayAdjustment::kNone; + } + bool gain_change; + DelayAdjustment delay_change; + bool clock_drift; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ECHO_PATH_VARIABILITY_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc new file mode 100644 index 0000000000..0f10f95f72 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_path_variability_unittest.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_path_variability.h" + +#include "test/gtest.h" + +namespace webrtc { + +TEST(EchoPathVariability, CorrectBehavior) { + // Test correct passing and reporting of the gain change information. + EchoPathVariability v( + true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false); + EXPECT_TRUE(v.gain_change); + EXPECT_TRUE(v.delay_change == + EchoPathVariability::DelayAdjustment::kNewDetectedDelay); + EXPECT_TRUE(v.AudioPathChanged()); + EXPECT_FALSE(v.clock_drift); + + v = EchoPathVariability(true, EchoPathVariability::DelayAdjustment::kNone, + false); + EXPECT_TRUE(v.gain_change); + EXPECT_TRUE(v.delay_change == EchoPathVariability::DelayAdjustment::kNone); + EXPECT_TRUE(v.AudioPathChanged()); + EXPECT_FALSE(v.clock_drift); + + v = EchoPathVariability( + false, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false); + EXPECT_FALSE(v.gain_change); + EXPECT_TRUE(v.delay_change == + EchoPathVariability::DelayAdjustment::kNewDetectedDelay); + EXPECT_TRUE(v.AudioPathChanged()); + EXPECT_FALSE(v.clock_drift); + + v = EchoPathVariability(false, EchoPathVariability::DelayAdjustment::kNone, + false); + EXPECT_FALSE(v.gain_change); + EXPECT_TRUE(v.delay_change == EchoPathVariability::DelayAdjustment::kNone); + EXPECT_FALSE(v.AudioPathChanged()); + EXPECT_FALSE(v.clock_drift); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.cc new file mode 100644 index 0000000000..673d88af03 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.cc @@ -0,0 +1,521 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/echo_remover.h" + +#include <math.h> +#include <stddef.h> + +#include <algorithm> +#include <array> +#include <atomic> +#include <cmath> +#include <memory> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/comfort_noise_generator.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/echo_remover_metrics.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" +#include "modules/audio_processing/aec3/residual_echo_estimator.h" +#include "modules/audio_processing/aec3/subtractor.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/aec3/suppression_filter.h" +#include "modules/audio_processing/aec3/suppression_gain.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace webrtc { + +namespace { + +// Maximum number of channels for which the capture channel data is stored on +// the stack. If the number of channels are larger than this, they are stored +// using scratch memory that is pre-allocated on the heap. The reason for this +// partitioning is not to waste heap space for handling the more common numbers +// of channels, while at the same time not limiting the support for higher +// numbers of channels by enforcing the capture channel data to be stored on the +// stack using a fixed maximum value. +constexpr size_t kMaxNumChannelsOnStack = 2; + +// Chooses the number of channels to store on the heap when that is required due +// to the number of capture channels being larger than the pre-defined number +// of channels to store on the stack. +size_t NumChannelsOnHeap(size_t num_capture_channels) { + return num_capture_channels > kMaxNumChannelsOnStack ? num_capture_channels + : 0; +} + +void LinearEchoPower(const FftData& E, + const FftData& Y, + std::array<float, kFftLengthBy2Plus1>* S2) { + for (size_t k = 0; k < E.re.size(); ++k) { + (*S2)[k] = (Y.re[k] - E.re[k]) * (Y.re[k] - E.re[k]) + + (Y.im[k] - E.im[k]) * (Y.im[k] - E.im[k]); + } +} + +// Fades between two input signals using a fix-sized transition. +void SignalTransition(rtc::ArrayView<const float> from, + rtc::ArrayView<const float> to, + rtc::ArrayView<float> out) { + if (from == to) { + RTC_DCHECK_EQ(to.size(), out.size()); + std::copy(to.begin(), to.end(), out.begin()); + } else { + constexpr size_t kTransitionSize = 30; + constexpr float kOneByTransitionSizePlusOne = 1.f / (kTransitionSize + 1); + + RTC_DCHECK_EQ(from.size(), to.size()); + RTC_DCHECK_EQ(from.size(), out.size()); + RTC_DCHECK_LE(kTransitionSize, out.size()); + + for (size_t k = 0; k < kTransitionSize; ++k) { + float a = (k + 1) * kOneByTransitionSizePlusOne; + out[k] = a * to[k] + (1.f - a) * from[k]; + } + + std::copy(to.begin() + kTransitionSize, to.end(), + out.begin() + kTransitionSize); + } +} + +// Computes a windowed (square root Hanning) padded FFT and updates the related +// memory. +void WindowedPaddedFft(const Aec3Fft& fft, + rtc::ArrayView<const float> v, + rtc::ArrayView<float> v_old, + FftData* V) { + fft.PaddedFft(v, v_old, Aec3Fft::Window::kSqrtHanning, V); + std::copy(v.begin(), v.end(), v_old.begin()); +} + +// Class for removing the echo from the capture signal. +class EchoRemoverImpl final : public EchoRemover { + public: + EchoRemoverImpl(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels); + ~EchoRemoverImpl() override; + EchoRemoverImpl(const EchoRemoverImpl&) = delete; + EchoRemoverImpl& operator=(const EchoRemoverImpl&) = delete; + + void GetMetrics(EchoControl::Metrics* metrics) const override; + + // Removes the echo from a block of samples from the capture signal. The + // supplied render signal is assumed to be pre-aligned with the capture + // signal. + void ProcessCapture(EchoPathVariability echo_path_variability, + bool capture_signal_saturation, + const absl::optional<DelayEstimate>& external_delay, + RenderBuffer* render_buffer, + Block* linear_output, + Block* capture) override; + + // Updates the status on whether echo leakage is detected in the output of the + // echo remover. + void UpdateEchoLeakageStatus(bool leakage_detected) override { + echo_leakage_detected_ = leakage_detected; + } + + void SetCaptureOutputUsage(bool capture_output_used) override { + capture_output_used_ = capture_output_used; + } + + private: + // Selects which of the coarse and refined linear filter outputs that is most + // appropriate to pass to the suppressor and forms the linear filter output by + // smoothly transition between those. + void FormLinearFilterOutput(const SubtractorOutput& subtractor_output, + rtc::ArrayView<float> output); + + static std::atomic<int> instance_count_; + const EchoCanceller3Config config_; + const Aec3Fft fft_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const Aec3Optimization optimization_; + const int sample_rate_hz_; + const size_t num_render_channels_; + const size_t num_capture_channels_; + const bool use_coarse_filter_output_; + Subtractor subtractor_; + SuppressionGain suppression_gain_; + ComfortNoiseGenerator cng_; + SuppressionFilter suppression_filter_; + RenderSignalAnalyzer render_signal_analyzer_; + ResidualEchoEstimator residual_echo_estimator_; + bool echo_leakage_detected_ = false; + bool capture_output_used_ = true; + AecState aec_state_; + EchoRemoverMetrics metrics_; + std::vector<std::array<float, kFftLengthBy2>> e_old_; + std::vector<std::array<float, kFftLengthBy2>> y_old_; + size_t block_counter_ = 0; + int gain_change_hangover_ = 0; + bool refined_filter_output_last_selected_ = true; + + std::vector<std::array<float, kFftLengthBy2>> e_heap_; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_heap_; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_heap_; + std::vector<std::array<float, kFftLengthBy2Plus1>> R2_heap_; + std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded_heap_; + std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear_heap_; + std::vector<FftData> Y_heap_; + std::vector<FftData> E_heap_; + std::vector<FftData> comfort_noise_heap_; + std::vector<FftData> high_band_comfort_noise_heap_; + std::vector<SubtractorOutput> subtractor_output_heap_; +}; + +std::atomic<int> EchoRemoverImpl::instance_count_(0); + +EchoRemoverImpl::EchoRemoverImpl(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels) + : config_(config), + fft_(), + data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + optimization_(DetectOptimization()), + sample_rate_hz_(sample_rate_hz), + num_render_channels_(num_render_channels), + num_capture_channels_(num_capture_channels), + use_coarse_filter_output_( + config_.filter.enable_coarse_filter_output_usage), + subtractor_(config, + num_render_channels_, + num_capture_channels_, + data_dumper_.get(), + optimization_), + suppression_gain_(config_, + optimization_, + sample_rate_hz, + num_capture_channels), + cng_(config_, optimization_, num_capture_channels_), + suppression_filter_(optimization_, + sample_rate_hz_, + num_capture_channels_), + render_signal_analyzer_(config_), + residual_echo_estimator_(config_, num_render_channels), + aec_state_(config_, num_capture_channels_), + e_old_(num_capture_channels_, {0.f}), + y_old_(num_capture_channels_, {0.f}), + e_heap_(NumChannelsOnHeap(num_capture_channels_), {0.f}), + Y2_heap_(NumChannelsOnHeap(num_capture_channels_)), + E2_heap_(NumChannelsOnHeap(num_capture_channels_)), + R2_heap_(NumChannelsOnHeap(num_capture_channels_)), + R2_unbounded_heap_(NumChannelsOnHeap(num_capture_channels_)), + S2_linear_heap_(NumChannelsOnHeap(num_capture_channels_)), + Y_heap_(NumChannelsOnHeap(num_capture_channels_)), + E_heap_(NumChannelsOnHeap(num_capture_channels_)), + comfort_noise_heap_(NumChannelsOnHeap(num_capture_channels_)), + high_band_comfort_noise_heap_(NumChannelsOnHeap(num_capture_channels_)), + subtractor_output_heap_(NumChannelsOnHeap(num_capture_channels_)) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); +} + +EchoRemoverImpl::~EchoRemoverImpl() = default; + +void EchoRemoverImpl::GetMetrics(EchoControl::Metrics* metrics) const { + // Echo return loss (ERL) is inverted to go from gain to attenuation. + metrics->echo_return_loss = -10.0 * std::log10(aec_state_.ErlTimeDomain()); + metrics->echo_return_loss_enhancement = + Log2TodB(aec_state_.FullBandErleLog2()); +} + +void EchoRemoverImpl::ProcessCapture( + EchoPathVariability echo_path_variability, + bool capture_signal_saturation, + const absl::optional<DelayEstimate>& external_delay, + RenderBuffer* render_buffer, + Block* linear_output, + Block* capture) { + ++block_counter_; + const Block& x = render_buffer->GetBlock(0); + Block* y = capture; + RTC_DCHECK(render_buffer); + RTC_DCHECK(y); + RTC_DCHECK_EQ(x.NumBands(), NumBandsForRate(sample_rate_hz_)); + RTC_DCHECK_EQ(y->NumBands(), NumBandsForRate(sample_rate_hz_)); + RTC_DCHECK_EQ(x.NumChannels(), num_render_channels_); + RTC_DCHECK_EQ(y->NumChannels(), num_capture_channels_); + + // Stack allocated data to use when the number of channels is low. + std::array<std::array<float, kFftLengthBy2>, kMaxNumChannelsOnStack> e_stack; + std::array<std::array<float, kFftLengthBy2Plus1>, kMaxNumChannelsOnStack> + Y2_stack; + std::array<std::array<float, kFftLengthBy2Plus1>, kMaxNumChannelsOnStack> + E2_stack; + std::array<std::array<float, kFftLengthBy2Plus1>, kMaxNumChannelsOnStack> + R2_stack; + std::array<std::array<float, kFftLengthBy2Plus1>, kMaxNumChannelsOnStack> + R2_unbounded_stack; + std::array<std::array<float, kFftLengthBy2Plus1>, kMaxNumChannelsOnStack> + S2_linear_stack; + std::array<FftData, kMaxNumChannelsOnStack> Y_stack; + std::array<FftData, kMaxNumChannelsOnStack> E_stack; + std::array<FftData, kMaxNumChannelsOnStack> comfort_noise_stack; + std::array<FftData, kMaxNumChannelsOnStack> high_band_comfort_noise_stack; + std::array<SubtractorOutput, kMaxNumChannelsOnStack> subtractor_output_stack; + + rtc::ArrayView<std::array<float, kFftLengthBy2>> e(e_stack.data(), + num_capture_channels_); + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> Y2( + Y2_stack.data(), num_capture_channels_); + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> E2( + E2_stack.data(), num_capture_channels_); + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2( + R2_stack.data(), num_capture_channels_); + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2_unbounded( + R2_unbounded_stack.data(), num_capture_channels_); + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> S2_linear( + S2_linear_stack.data(), num_capture_channels_); + rtc::ArrayView<FftData> Y(Y_stack.data(), num_capture_channels_); + rtc::ArrayView<FftData> E(E_stack.data(), num_capture_channels_); + rtc::ArrayView<FftData> comfort_noise(comfort_noise_stack.data(), + num_capture_channels_); + rtc::ArrayView<FftData> high_band_comfort_noise( + high_band_comfort_noise_stack.data(), num_capture_channels_); + rtc::ArrayView<SubtractorOutput> subtractor_output( + subtractor_output_stack.data(), num_capture_channels_); + if (NumChannelsOnHeap(num_capture_channels_) > 0) { + // If the stack-allocated space is too small, use the heap for storing the + // microphone data. + e = rtc::ArrayView<std::array<float, kFftLengthBy2>>(e_heap_.data(), + num_capture_channels_); + Y2 = rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>>( + Y2_heap_.data(), num_capture_channels_); + E2 = rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>>( + E2_heap_.data(), num_capture_channels_); + R2 = rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>>( + R2_heap_.data(), num_capture_channels_); + R2_unbounded = rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>>( + R2_unbounded_heap_.data(), num_capture_channels_); + S2_linear = rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>>( + S2_linear_heap_.data(), num_capture_channels_); + Y = rtc::ArrayView<FftData>(Y_heap_.data(), num_capture_channels_); + E = rtc::ArrayView<FftData>(E_heap_.data(), num_capture_channels_); + comfort_noise = rtc::ArrayView<FftData>(comfort_noise_heap_.data(), + num_capture_channels_); + high_band_comfort_noise = rtc::ArrayView<FftData>( + high_band_comfort_noise_heap_.data(), num_capture_channels_); + subtractor_output = rtc::ArrayView<SubtractorOutput>( + subtractor_output_heap_.data(), num_capture_channels_); + } + + data_dumper_->DumpWav("aec3_echo_remover_capture_input", + y->View(/*band=*/0, /*channel=*/0), 16000, 1); + data_dumper_->DumpWav("aec3_echo_remover_render_input", + x.View(/*band=*/0, /*channel=*/0), 16000, 1); + data_dumper_->DumpRaw("aec3_echo_remover_capture_input", + y->View(/*band=*/0, /*channel=*/0)); + data_dumper_->DumpRaw("aec3_echo_remover_render_input", + x.View(/*band=*/0, /*channel=*/0)); + + aec_state_.UpdateCaptureSaturation(capture_signal_saturation); + + if (echo_path_variability.AudioPathChanged()) { + // Ensure that the gain change is only acted on once per frame. + if (echo_path_variability.gain_change) { + if (gain_change_hangover_ == 0) { + constexpr int kMaxBlocksPerFrame = 3; + gain_change_hangover_ = kMaxBlocksPerFrame; + rtc::LoggingSeverity log_level = + config_.delay.log_warning_on_delay_changes ? rtc::LS_WARNING + : rtc::LS_VERBOSE; + RTC_LOG_V(log_level) + << "Gain change detected at block " << block_counter_; + } else { + echo_path_variability.gain_change = false; + } + } + + subtractor_.HandleEchoPathChange(echo_path_variability); + aec_state_.HandleEchoPathChange(echo_path_variability); + + if (echo_path_variability.delay_change != + EchoPathVariability::DelayAdjustment::kNone) { + suppression_gain_.SetInitialState(true); + } + } + if (gain_change_hangover_ > 0) { + --gain_change_hangover_; + } + + // Analyze the render signal. + render_signal_analyzer_.Update(*render_buffer, + aec_state_.MinDirectPathFilterDelay()); + + // State transition. + if (aec_state_.TransitionTriggered()) { + subtractor_.ExitInitialState(); + suppression_gain_.SetInitialState(false); + } + + // Perform linear echo cancellation. + subtractor_.Process(*render_buffer, *y, render_signal_analyzer_, aec_state_, + subtractor_output); + + // Compute spectra. + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + FormLinearFilterOutput(subtractor_output[ch], e[ch]); + WindowedPaddedFft(fft_, y->View(/*band=*/0, ch), y_old_[ch], &Y[ch]); + WindowedPaddedFft(fft_, e[ch], e_old_[ch], &E[ch]); + LinearEchoPower(E[ch], Y[ch], &S2_linear[ch]); + Y[ch].Spectrum(optimization_, Y2[ch]); + E[ch].Spectrum(optimization_, E2[ch]); + } + + // Optionally return the linear filter output. + if (linear_output) { + RTC_DCHECK_GE(1, linear_output->NumBands()); + RTC_DCHECK_EQ(num_capture_channels_, linear_output->NumChannels()); + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::copy(e[ch].begin(), e[ch].end(), + linear_output->begin(/*band=*/0, ch)); + } + } + + // Update the AEC state information. + aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponses(), + subtractor_.FilterImpulseResponses(), *render_buffer, E2, + Y2, subtractor_output); + + // Choose the linear output. + const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y; + + data_dumper_->DumpWav("aec3_output_linear", + y->View(/*band=*/0, /*channel=*/0), 16000, 1); + data_dumper_->DumpWav("aec3_output_linear2", kBlockSize, &e[0][0], 16000, 1); + + // Estimate the comfort noise. + cng_.Compute(aec_state_.SaturatedCapture(), Y2, comfort_noise, + high_band_comfort_noise); + + // Only do the below processing if the output of the audio processing module + // is used. + std::array<float, kFftLengthBy2Plus1> G; + if (capture_output_used_) { + // Estimate the residual echo power. + residual_echo_estimator_.Estimate(aec_state_, *render_buffer, S2_linear, Y2, + suppression_gain_.IsDominantNearend(), R2, + R2_unbounded); + + // Suppressor nearend estimate. + if (aec_state_.UsableLinearEstimate()) { + // E2 is bound by Y2. + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::transform(E2[ch].begin(), E2[ch].end(), Y2[ch].begin(), + E2[ch].begin(), + [](float a, float b) { return std::min(a, b); }); + } + } + const auto& nearend_spectrum = aec_state_.UsableLinearEstimate() ? E2 : Y2; + + // Suppressor echo estimate. + const auto& echo_spectrum = + aec_state_.UsableLinearEstimate() ? S2_linear : R2; + + // Determine if the suppressor should assume clock drift. + const bool clock_drift = config_.echo_removal_control.has_clock_drift || + echo_path_variability.clock_drift; + + // Compute preferred gains. + float high_bands_gain; + suppression_gain_.GetGain(nearend_spectrum, echo_spectrum, R2, R2_unbounded, + cng_.NoiseSpectrum(), render_signal_analyzer_, + aec_state_, x, clock_drift, &high_bands_gain, &G); + + suppression_filter_.ApplyGain(comfort_noise, high_band_comfort_noise, G, + high_bands_gain, Y_fft, y); + + } else { + G.fill(0.f); + } + + // Update the metrics. + metrics_.Update(aec_state_, cng_.NoiseSpectrum()[0], G); + + // Debug outputs for the purpose of development and analysis. + data_dumper_->DumpWav("aec3_echo_estimate", kBlockSize, + &subtractor_output[0].s_refined[0], 16000, 1); + data_dumper_->DumpRaw("aec3_output", y->View(/*band=*/0, /*channel=*/0)); + data_dumper_->DumpRaw("aec3_narrow_render", + render_signal_analyzer_.NarrowPeakBand() ? 1 : 0); + data_dumper_->DumpRaw("aec3_N2", cng_.NoiseSpectrum()[0]); + data_dumper_->DumpRaw("aec3_suppressor_gain", G); + data_dumper_->DumpWav("aec3_output", y->View(/*band=*/0, /*channel=*/0), + 16000, 1); + data_dumper_->DumpRaw("aec3_using_subtractor_output[0]", + aec_state_.UseLinearFilterOutput() ? 1 : 0); + data_dumper_->DumpRaw("aec3_E2", E2[0]); + data_dumper_->DumpRaw("aec3_S2_linear", S2_linear[0]); + data_dumper_->DumpRaw("aec3_Y2", Y2[0]); + data_dumper_->DumpRaw( + "aec3_X2", render_buffer->Spectrum( + aec_state_.MinDirectPathFilterDelay())[/*channel=*/0]); + data_dumper_->DumpRaw("aec3_R2", R2[0]); + data_dumper_->DumpRaw("aec3_filter_delay", + aec_state_.MinDirectPathFilterDelay()); + data_dumper_->DumpRaw("aec3_capture_saturation", + aec_state_.SaturatedCapture() ? 1 : 0); +} + +void EchoRemoverImpl::FormLinearFilterOutput( + const SubtractorOutput& subtractor_output, + rtc::ArrayView<float> output) { + RTC_DCHECK_EQ(subtractor_output.e_refined.size(), output.size()); + RTC_DCHECK_EQ(subtractor_output.e_coarse.size(), output.size()); + bool use_refined_output = true; + if (use_coarse_filter_output_) { + // As the output of the refined adaptive filter generally should be better + // than the coarse filter output, add a margin and threshold for when + // choosing the coarse filter output. + if (subtractor_output.e2_coarse < 0.9f * subtractor_output.e2_refined && + subtractor_output.y2 > 30.f * 30.f * kBlockSize && + (subtractor_output.s2_refined > 60.f * 60.f * kBlockSize || + subtractor_output.s2_coarse > 60.f * 60.f * kBlockSize)) { + use_refined_output = false; + } else { + // If the refined filter is diverged, choose the filter output that has + // the lowest power. + if (subtractor_output.e2_coarse < subtractor_output.e2_refined && + subtractor_output.y2 < subtractor_output.e2_refined) { + use_refined_output = false; + } + } + } + + SignalTransition(refined_filter_output_last_selected_ + ? subtractor_output.e_refined + : subtractor_output.e_coarse, + use_refined_output ? subtractor_output.e_refined + : subtractor_output.e_coarse, + output); + refined_filter_output_last_selected_ = use_refined_output; +} + +} // namespace + +EchoRemover* EchoRemover::Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels) { + return new EchoRemoverImpl(config, sample_rate_hz, num_render_channels, + num_capture_channels); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.h b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.h new file mode 100644 index 0000000000..f2f4f5e64d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ECHO_REMOVER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ECHO_REMOVER_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" +#include "api/audio/echo_control.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/render_buffer.h" + +namespace webrtc { + +// Class for removing the echo from the capture signal. +class EchoRemover { + public: + static EchoRemover* Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels, + size_t num_capture_channels); + virtual ~EchoRemover() = default; + + // Get current metrics. + virtual void GetMetrics(EchoControl::Metrics* metrics) const = 0; + + // Removes the echo from a block of samples from the capture signal. The + // supplied render signal is assumed to be pre-aligned with the capture + // signal. + virtual void ProcessCapture( + EchoPathVariability echo_path_variability, + bool capture_signal_saturation, + const absl::optional<DelayEstimate>& external_delay, + RenderBuffer* render_buffer, + Block* linear_output, + Block* capture) = 0; + + // Updates the status on whether echo leakage is detected in the output of the + // echo remover. + virtual void UpdateEchoLeakageStatus(bool leakage_detected) = 0; + + // Specifies whether the capture output will be used. The purpose of this is + // to allow the echo remover to deactivate some of the processing when the + // resulting output is anyway not used, for instance when the endpoint is + // muted. + virtual void SetCaptureOutputUsage(bool capture_output_used) = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ECHO_REMOVER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.cc new file mode 100644 index 0000000000..c3fc80773a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.cc @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_remover_metrics.h" + +#include <math.h> +#include <stddef.h> + +#include <algorithm> +#include <cmath> +#include <numeric> + +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" +#include "system_wrappers/include/metrics.h" + +namespace webrtc { + +EchoRemoverMetrics::DbMetric::DbMetric() : DbMetric(0.f, 0.f, 0.f) {} +EchoRemoverMetrics::DbMetric::DbMetric(float sum_value, + float floor_value, + float ceil_value) + : sum_value(sum_value), floor_value(floor_value), ceil_value(ceil_value) {} + +void EchoRemoverMetrics::DbMetric::Update(float value) { + sum_value += value; + floor_value = std::min(floor_value, value); + ceil_value = std::max(ceil_value, value); +} + +void EchoRemoverMetrics::DbMetric::UpdateInstant(float value) { + sum_value = value; + floor_value = std::min(floor_value, value); + ceil_value = std::max(ceil_value, value); +} + +EchoRemoverMetrics::EchoRemoverMetrics() { + ResetMetrics(); +} + +void EchoRemoverMetrics::ResetMetrics() { + erl_time_domain_ = DbMetric(0.f, 10000.f, 0.000f); + erle_time_domain_ = DbMetric(0.f, 0.f, 1000.f); + saturated_capture_ = false; +} + +void EchoRemoverMetrics::Update( + const AecState& aec_state, + const std::array<float, kFftLengthBy2Plus1>& comfort_noise_spectrum, + const std::array<float, kFftLengthBy2Plus1>& suppressor_gain) { + metrics_reported_ = false; + if (++block_counter_ <= kMetricsCollectionBlocks) { + erl_time_domain_.UpdateInstant(aec_state.ErlTimeDomain()); + erle_time_domain_.UpdateInstant(aec_state.FullBandErleLog2()); + saturated_capture_ = saturated_capture_ || aec_state.SaturatedCapture(); + } else { + // Report the metrics over several frames in order to lower the impact of + // the logarithms involved on the computational complexity. + switch (block_counter_) { + case kMetricsCollectionBlocks + 1: + RTC_HISTOGRAM_BOOLEAN( + "WebRTC.Audio.EchoCanceller.UsableLinearEstimate", + static_cast<int>(aec_state.UsableLinearEstimate() ? 1 : 0)); + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.EchoCanceller.FilterDelay", + aec_state.MinDirectPathFilterDelay(), 0, 30, + 31); + RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.EchoCanceller.CaptureSaturation", + static_cast<int>(saturated_capture_ ? 1 : 0)); + break; + case kMetricsCollectionBlocks + 2: + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.Erl.Value", + aec3::TransformDbMetricForReporting(true, 0.f, 59.f, 30.f, 1.f, + erl_time_domain_.sum_value), + 0, 59, 30); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.Erl.Max", + aec3::TransformDbMetricForReporting(true, 0.f, 59.f, 30.f, 1.f, + erl_time_domain_.ceil_value), + 0, 59, 30); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.Erl.Min", + aec3::TransformDbMetricForReporting(true, 0.f, 59.f, 30.f, 1.f, + erl_time_domain_.floor_value), + 0, 59, 30); + break; + case kMetricsCollectionBlocks + 3: + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.Erle.Value", + aec3::TransformDbMetricForReporting(false, 0.f, 19.f, 0.f, 1.f, + erle_time_domain_.sum_value), + 0, 19, 20); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.Erle.Max", + aec3::TransformDbMetricForReporting(false, 0.f, 19.f, 0.f, 1.f, + erle_time_domain_.ceil_value), + 0, 19, 20); + RTC_HISTOGRAM_COUNTS_LINEAR( + "WebRTC.Audio.EchoCanceller.Erle.Min", + aec3::TransformDbMetricForReporting(false, 0.f, 19.f, 0.f, 1.f, + erle_time_domain_.floor_value), + 0, 19, 20); + metrics_reported_ = true; + RTC_DCHECK_EQ(kMetricsReportingIntervalBlocks, block_counter_); + block_counter_ = 0; + ResetMetrics(); + break; + default: + RTC_DCHECK_NOTREACHED(); + break; + } + } +} + +namespace aec3 { + +void UpdateDbMetric(const std::array<float, kFftLengthBy2Plus1>& value, + std::array<EchoRemoverMetrics::DbMetric, 2>* statistic) { + RTC_DCHECK(statistic); + // Truncation is intended in the band width computation. + constexpr int kNumBands = 2; + constexpr int kBandWidth = 65 / kNumBands; + constexpr float kOneByBandWidth = 1.f / kBandWidth; + RTC_DCHECK_EQ(kNumBands, statistic->size()); + RTC_DCHECK_EQ(65, value.size()); + for (size_t k = 0; k < statistic->size(); ++k) { + float average_band = + std::accumulate(value.begin() + kBandWidth * k, + value.begin() + kBandWidth * (k + 1), 0.f) * + kOneByBandWidth; + (*statistic)[k].Update(average_band); + } +} + +int TransformDbMetricForReporting(bool negate, + float min_value, + float max_value, + float offset, + float scaling, + float value) { + float new_value = 10.f * std::log10(value * scaling + 1e-10f) + offset; + if (negate) { + new_value = -new_value; + } + return static_cast<int>(rtc::SafeClamp(new_value, min_value, max_value)); +} + +} // namespace aec3 + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.h b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.h new file mode 100644 index 0000000000..aec8084d78 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ECHO_REMOVER_METRICS_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ECHO_REMOVER_METRICS_H_ + +#include <array> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec_state.h" + +namespace webrtc { + +// Handles the reporting of metrics for the echo remover. +class EchoRemoverMetrics { + public: + struct DbMetric { + DbMetric(); + DbMetric(float sum_value, float floor_value, float ceil_value); + void Update(float value); + void UpdateInstant(float value); + float sum_value; + float floor_value; + float ceil_value; + }; + + EchoRemoverMetrics(); + + EchoRemoverMetrics(const EchoRemoverMetrics&) = delete; + EchoRemoverMetrics& operator=(const EchoRemoverMetrics&) = delete; + + // Updates the metric with new data. + void Update( + const AecState& aec_state, + const std::array<float, kFftLengthBy2Plus1>& comfort_noise_spectrum, + const std::array<float, kFftLengthBy2Plus1>& suppressor_gain); + + // Returns true if the metrics have just been reported, otherwise false. + bool MetricsReported() { return metrics_reported_; } + + private: + // Resets the metrics. + void ResetMetrics(); + + int block_counter_ = 0; + DbMetric erl_time_domain_; + DbMetric erle_time_domain_; + bool saturated_capture_ = false; + bool metrics_reported_ = false; +}; + +namespace aec3 { + +// Updates a banded metric of type DbMetric with the values in the supplied +// array. +void UpdateDbMetric(const std::array<float, kFftLengthBy2Plus1>& value, + std::array<EchoRemoverMetrics::DbMetric, 2>* statistic); + +// Transforms a DbMetric from the linear domain into the logarithmic domain. +int TransformDbMetricForReporting(bool negate, + float min_value, + float max_value, + float offset, + float scaling, + float value); + +} // namespace aec3 + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ECHO_REMOVER_METRICS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc new file mode 100644 index 0000000000..45b30a9c74 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_metrics_unittest.cc @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_remover_metrics.h" + +#include <math.h> + +#include <cmath> + +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "test/gtest.h" + +namespace webrtc { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for non-null input. +TEST(UpdateDbMetricDeathTest, NullValue) { + std::array<float, kFftLengthBy2Plus1> value; + value.fill(0.f); + EXPECT_DEATH(aec3::UpdateDbMetric(value, nullptr), ""); +} + +#endif + +// Verifies the updating functionality of UpdateDbMetric. +TEST(UpdateDbMetric, Updating) { + std::array<float, kFftLengthBy2Plus1> value; + std::array<EchoRemoverMetrics::DbMetric, 2> statistic; + statistic.fill(EchoRemoverMetrics::DbMetric(0.f, 100.f, -100.f)); + constexpr float kValue0 = 10.f; + constexpr float kValue1 = 20.f; + std::fill(value.begin(), value.begin() + 32, kValue0); + std::fill(value.begin() + 32, value.begin() + 64, kValue1); + + aec3::UpdateDbMetric(value, &statistic); + EXPECT_FLOAT_EQ(kValue0, statistic[0].sum_value); + EXPECT_FLOAT_EQ(kValue0, statistic[0].ceil_value); + EXPECT_FLOAT_EQ(kValue0, statistic[0].floor_value); + EXPECT_FLOAT_EQ(kValue1, statistic[1].sum_value); + EXPECT_FLOAT_EQ(kValue1, statistic[1].ceil_value); + EXPECT_FLOAT_EQ(kValue1, statistic[1].floor_value); + + aec3::UpdateDbMetric(value, &statistic); + EXPECT_FLOAT_EQ(2.f * kValue0, statistic[0].sum_value); + EXPECT_FLOAT_EQ(kValue0, statistic[0].ceil_value); + EXPECT_FLOAT_EQ(kValue0, statistic[0].floor_value); + EXPECT_FLOAT_EQ(2.f * kValue1, statistic[1].sum_value); + EXPECT_FLOAT_EQ(kValue1, statistic[1].ceil_value); + EXPECT_FLOAT_EQ(kValue1, statistic[1].floor_value); +} + +// Verifies that the TransformDbMetricForReporting method produces the desired +// output for values for dBFS. +TEST(TransformDbMetricForReporting, DbFsScaling) { + std::array<float, kBlockSize> x; + FftData X; + std::array<float, kFftLengthBy2Plus1> X2; + Aec3Fft fft; + x.fill(1000.f); + fft.ZeroPaddedFft(x, Aec3Fft::Window::kRectangular, &X); + X.Spectrum(Aec3Optimization::kNone, X2); + + float offset = -10.f * std::log10(32768.f * 32768.f); + EXPECT_NEAR(offset, -90.3f, 0.1f); + EXPECT_EQ( + static_cast<int>(30.3f), + aec3::TransformDbMetricForReporting( + true, 0.f, 90.f, offset, 1.f / (kBlockSize * kBlockSize), X2[0])); +} + +// Verifies that the TransformDbMetricForReporting method is able to properly +// limit the output. +TEST(TransformDbMetricForReporting, Limits) { + EXPECT_EQ(0, aec3::TransformDbMetricForReporting(false, 0.f, 10.f, 0.f, 1.f, + 0.001f)); + EXPECT_EQ(10, aec3::TransformDbMetricForReporting(false, 0.f, 10.f, 0.f, 1.f, + 100.f)); +} + +// Verifies that the TransformDbMetricForReporting method is able to properly +// negate output. +TEST(TransformDbMetricForReporting, Negate) { + EXPECT_EQ(10, aec3::TransformDbMetricForReporting(true, -20.f, 20.f, 0.f, 1.f, + 0.1f)); + EXPECT_EQ(-10, aec3::TransformDbMetricForReporting(true, -20.f, 20.f, 0.f, + 1.f, 10.f)); +} + +// Verify the Update functionality of DbMetric. +TEST(DbMetric, Update) { + EchoRemoverMetrics::DbMetric metric(0.f, 20.f, -20.f); + constexpr int kNumValues = 100; + constexpr float kValue = 10.f; + for (int k = 0; k < kNumValues; ++k) { + metric.Update(kValue); + } + EXPECT_FLOAT_EQ(kValue * kNumValues, metric.sum_value); + EXPECT_FLOAT_EQ(kValue, metric.ceil_value); + EXPECT_FLOAT_EQ(kValue, metric.floor_value); +} + +// Verify the Update functionality of DbMetric. +TEST(DbMetric, UpdateInstant) { + EchoRemoverMetrics::DbMetric metric(0.f, 20.f, -20.f); + constexpr float kMinValue = -77.f; + constexpr float kMaxValue = 33.f; + constexpr float kLastValue = (kMinValue + kMaxValue) / 2.0f; + for (float value = kMinValue; value <= kMaxValue; value++) + metric.UpdateInstant(value); + metric.UpdateInstant(kLastValue); + EXPECT_FLOAT_EQ(kLastValue, metric.sum_value); + EXPECT_FLOAT_EQ(kMaxValue, metric.ceil_value); + EXPECT_FLOAT_EQ(kMinValue, metric.floor_value); +} + +// Verify the constructor functionality of DbMetric. +TEST(DbMetric, Constructor) { + EchoRemoverMetrics::DbMetric metric; + EXPECT_FLOAT_EQ(0.f, metric.sum_value); + EXPECT_FLOAT_EQ(0.f, metric.ceil_value); + EXPECT_FLOAT_EQ(0.f, metric.floor_value); + + metric = EchoRemoverMetrics::DbMetric(1.f, 2.f, 3.f); + EXPECT_FLOAT_EQ(1.f, metric.sum_value); + EXPECT_FLOAT_EQ(2.f, metric.floor_value); + EXPECT_FLOAT_EQ(3.f, metric.ceil_value); +} + +// Verify the general functionality of EchoRemoverMetrics. +TEST(EchoRemoverMetrics, NormalUsage) { + EchoRemoverMetrics metrics; + AecState aec_state(EchoCanceller3Config{}, 1); + std::array<float, kFftLengthBy2Plus1> comfort_noise_spectrum; + std::array<float, kFftLengthBy2Plus1> suppressor_gain; + comfort_noise_spectrum.fill(10.f); + suppressor_gain.fill(1.f); + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < kMetricsReportingIntervalBlocks - 1; ++k) { + metrics.Update(aec_state, comfort_noise_spectrum, suppressor_gain); + EXPECT_FALSE(metrics.MetricsReported()); + } + metrics.Update(aec_state, comfort_noise_spectrum, suppressor_gain); + EXPECT_TRUE(metrics.MetricsReported()); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_unittest.cc new file mode 100644 index 0000000000..66168ab08d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/echo_remover_unittest.cc @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/echo_remover.h" + +#include <algorithm> +#include <memory> +#include <numeric> +#include <string> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { +std::string ProduceDebugText(int sample_rate_hz) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.Release(); +} + +std::string ProduceDebugText(int sample_rate_hz, int delay) { + rtc::StringBuilder ss(ProduceDebugText(sample_rate_hz)); + ss << ", Delay: " << delay; + return ss.Release(); +} + +} // namespace + +class EchoRemoverMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + EchoRemoverMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 8), + ::testing::Values(1, 2, 8))); + +// Verifies the basic API call sequence +TEST_P(EchoRemoverMultiChannel, BasicApiCalls) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + absl::optional<DelayEstimate> delay_estimate; + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<EchoRemover> remover( + EchoRemover::Create(EchoCanceller3Config(), rate, num_render_channels, + num_capture_channels)); + std::unique_ptr<RenderDelayBuffer> render_buffer(RenderDelayBuffer::Create( + EchoCanceller3Config(), rate, num_render_channels)); + + Block render(NumBandsForRate(rate), num_render_channels); + Block capture(NumBandsForRate(rate), num_capture_channels); + for (size_t k = 0; k < 100; ++k) { + EchoPathVariability echo_path_variability( + k % 3 == 0 ? true : false, + k % 5 == 0 ? EchoPathVariability::DelayAdjustment::kNewDetectedDelay + : EchoPathVariability::DelayAdjustment::kNone, + false); + render_buffer->Insert(render); + render_buffer->PrepareCaptureProcessing(); + + remover->ProcessCapture(echo_path_variability, k % 2 == 0 ? true : false, + delay_estimate, render_buffer->GetRenderBuffer(), + nullptr, &capture); + } + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for the samplerate. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(EchoRemoverDeathTest, DISABLED_WrongSampleRate) { + EXPECT_DEATH(std::unique_ptr<EchoRemover>( + EchoRemover::Create(EchoCanceller3Config(), 8001, 1, 1)), + ""); +} + +// Verifies the check for the number of capture bands. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed.c +TEST(EchoRemoverDeathTest, DISABLED_WrongCaptureNumBands) { + absl::optional<DelayEstimate> delay_estimate; + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<EchoRemover> remover( + EchoRemover::Create(EchoCanceller3Config(), rate, 1, 1)); + std::unique_ptr<RenderDelayBuffer> render_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), rate, 1)); + Block capture(NumBandsForRate(rate == 48000 ? 16000 : rate + 16000), 1); + EchoPathVariability echo_path_variability( + false, EchoPathVariability::DelayAdjustment::kNone, false); + EXPECT_DEATH(remover->ProcessCapture( + echo_path_variability, false, delay_estimate, + render_buffer->GetRenderBuffer(), nullptr, &capture), + ""); + } +} + +// Verifies the check for non-null capture block. +TEST(EchoRemoverDeathTest, NullCapture) { + absl::optional<DelayEstimate> delay_estimate; + std::unique_ptr<EchoRemover> remover( + EchoRemover::Create(EchoCanceller3Config(), 16000, 1, 1)); + std::unique_ptr<RenderDelayBuffer> render_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), 16000, 1)); + EchoPathVariability echo_path_variability( + false, EchoPathVariability::DelayAdjustment::kNone, false); + EXPECT_DEATH(remover->ProcessCapture( + echo_path_variability, false, delay_estimate, + render_buffer->GetRenderBuffer(), nullptr, nullptr), + ""); +} + +#endif + +// Performs a sanity check that the echo_remover is able to properly +// remove echoes. +TEST(EchoRemover, BasicEchoRemoval) { + constexpr int kNumBlocksToProcess = 500; + Random random_generator(42U); + absl::optional<DelayEstimate> delay_estimate; + for (size_t num_channels : {1, 2, 4}) { + for (auto rate : {16000, 32000, 48000}) { + Block x(NumBandsForRate(rate), num_channels); + Block y(NumBandsForRate(rate), num_channels); + EchoPathVariability echo_path_variability( + false, EchoPathVariability::DelayAdjustment::kNone, false); + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(rate, delay_samples)); + EchoCanceller3Config config; + std::unique_ptr<EchoRemover> remover( + EchoRemover::Create(config, rate, num_channels, num_channels)); + std::unique_ptr<RenderDelayBuffer> render_buffer( + RenderDelayBuffer::Create(config, rate, num_channels)); + render_buffer->AlignFromDelay(delay_samples / kBlockSize); + + std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> + delay_buffers(x.NumBands()); + for (size_t band = 0; band < delay_buffers.size(); ++band) { + delay_buffers[band].resize(x.NumChannels()); + } + + for (int band = 0; band < x.NumBands(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + delay_buffers[band][channel].reset( + new DelayBuffer<float>(delay_samples)); + } + } + + float input_energy = 0.f; + float output_energy = 0.f; + for (int k = 0; k < kNumBlocksToProcess; ++k) { + const bool silence = k < 100 || (k % 100 >= 10); + + for (int band = 0; band < x.NumBands(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + if (silence) { + std::fill(x.begin(band, channel), x.end(band, channel), 0.f); + } else { + RandomizeSampleVector(&random_generator, x.View(band, channel)); + } + delay_buffers[band][channel]->Delay(x.View(band, channel), + y.View(band, channel)); + } + } + + if (k > kNumBlocksToProcess / 2) { + input_energy = std::inner_product( + y.begin(/*band=*/0, /*channel=*/0), + y.end(/*band=*/0, /*channel=*/0), + y.begin(/*band=*/0, /*channel=*/0), input_energy); + } + + render_buffer->Insert(x); + render_buffer->PrepareCaptureProcessing(); + + remover->ProcessCapture(echo_path_variability, false, delay_estimate, + render_buffer->GetRenderBuffer(), nullptr, + &y); + + if (k > kNumBlocksToProcess / 2) { + output_energy = std::inner_product( + y.begin(/*band=*/0, /*channel=*/0), + y.end(/*band=*/0, /*channel=*/0), + y.begin(/*band=*/0, /*channel=*/0), output_energy); + } + } + EXPECT_GT(input_energy, 10.f * output_energy); + } + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.cc new file mode 100644 index 0000000000..01cc33cb80 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.cc @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/erl_estimator.h" + +#include <algorithm> +#include <numeric> + +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +constexpr float kMinErl = 0.01f; +constexpr float kMaxErl = 1000.f; + +} // namespace + +ErlEstimator::ErlEstimator(size_t startup_phase_length_blocks_) + : startup_phase_length_blocks__(startup_phase_length_blocks_) { + erl_.fill(kMaxErl); + hold_counters_.fill(0); + erl_time_domain_ = kMaxErl; + hold_counter_time_domain_ = 0; +} + +ErlEstimator::~ErlEstimator() = default; + +void ErlEstimator::Reset() { + blocks_since_reset_ = 0; +} + +void ErlEstimator::Update( + const std::vector<bool>& converged_filters, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> render_spectra, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + capture_spectra) { + const size_t num_capture_channels = converged_filters.size(); + RTC_DCHECK_EQ(capture_spectra.size(), num_capture_channels); + + // Corresponds to WGN of power -46 dBFS. + constexpr float kX2Min = 44015068.0f; + + const auto first_converged_iter = + std::find(converged_filters.begin(), converged_filters.end(), true); + const bool any_filter_converged = + first_converged_iter != converged_filters.end(); + + if (++blocks_since_reset_ < startup_phase_length_blocks__ || + !any_filter_converged) { + return; + } + + // Use the maximum spectrum across capture and the maximum across render. + std::array<float, kFftLengthBy2Plus1> max_capture_spectrum_data; + std::array<float, kFftLengthBy2Plus1> max_capture_spectrum = + capture_spectra[/*channel=*/0]; + if (num_capture_channels > 1) { + // Initialize using the first channel with a converged filter. + const size_t first_converged = + std::distance(converged_filters.begin(), first_converged_iter); + RTC_DCHECK_GE(first_converged, 0); + RTC_DCHECK_LT(first_converged, num_capture_channels); + max_capture_spectrum_data = capture_spectra[first_converged]; + + for (size_t ch = first_converged + 1; ch < num_capture_channels; ++ch) { + if (!converged_filters[ch]) { + continue; + } + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + max_capture_spectrum_data[k] = + std::max(max_capture_spectrum_data[k], capture_spectra[ch][k]); + } + } + max_capture_spectrum = max_capture_spectrum_data; + } + + const size_t num_render_channels = render_spectra.size(); + std::array<float, kFftLengthBy2Plus1> max_render_spectrum_data; + rtc::ArrayView<const float, kFftLengthBy2Plus1> max_render_spectrum = + render_spectra[/*channel=*/0]; + if (num_render_channels > 1) { + std::copy(render_spectra[0].begin(), render_spectra[0].end(), + max_render_spectrum_data.begin()); + for (size_t ch = 1; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + max_render_spectrum_data[k] = + std::max(max_render_spectrum_data[k], render_spectra[ch][k]); + } + } + max_render_spectrum = max_render_spectrum_data; + } + + const auto& X2 = max_render_spectrum; + const auto& Y2 = max_capture_spectrum; + + // Update the estimates in a maximum statistics manner. + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (X2[k] > kX2Min) { + const float new_erl = Y2[k] / X2[k]; + if (new_erl < erl_[k]) { + hold_counters_[k - 1] = 1000; + erl_[k] += 0.1f * (new_erl - erl_[k]); + erl_[k] = std::max(erl_[k], kMinErl); + } + } + } + + std::for_each(hold_counters_.begin(), hold_counters_.end(), + [](int& a) { --a; }); + std::transform(hold_counters_.begin(), hold_counters_.end(), erl_.begin() + 1, + erl_.begin() + 1, [](int a, float b) { + return a > 0 ? b : std::min(kMaxErl, 2.f * b); + }); + + erl_[0] = erl_[1]; + erl_[kFftLengthBy2] = erl_[kFftLengthBy2 - 1]; + + // Compute ERL over all frequency bins. + const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f); + + if (X2_sum > kX2Min * X2.size()) { + const float Y2_sum = std::accumulate(Y2.begin(), Y2.end(), 0.0f); + const float new_erl = Y2_sum / X2_sum; + if (new_erl < erl_time_domain_) { + hold_counter_time_domain_ = 1000; + erl_time_domain_ += 0.1f * (new_erl - erl_time_domain_); + erl_time_domain_ = std::max(erl_time_domain_, kMinErl); + } + } + + --hold_counter_time_domain_; + erl_time_domain_ = (hold_counter_time_domain_ > 0) + ? erl_time_domain_ + : std::min(kMaxErl, 2.f * erl_time_domain_); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.h new file mode 100644 index 0000000000..639a52c561 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ERL_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ERL_ESTIMATOR_H_ + +#include <stddef.h> + +#include <array> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Estimates the echo return loss based on the signal spectra. +class ErlEstimator { + public: + explicit ErlEstimator(size_t startup_phase_length_blocks_); + ~ErlEstimator(); + + ErlEstimator(const ErlEstimator&) = delete; + ErlEstimator& operator=(const ErlEstimator&) = delete; + + // Resets the ERL estimation. + void Reset(); + + // Updates the ERL estimate. + void Update(const std::vector<bool>& converged_filters, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + render_spectra, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + capture_spectra); + + // Returns the most recent ERL estimate. + const std::array<float, kFftLengthBy2Plus1>& Erl() const { return erl_; } + float ErlTimeDomain() const { return erl_time_domain_; } + + private: + const size_t startup_phase_length_blocks__; + std::array<float, kFftLengthBy2Plus1> erl_; + std::array<int, kFftLengthBy2Minus1> hold_counters_; + float erl_time_domain_; + int hold_counter_time_domain_; + size_t blocks_since_reset_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ERL_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc new file mode 100644 index 0000000000..79e5465e3c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/erl_estimator_unittest.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/erl_estimator.h" + +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { +std::string ProduceDebugText(size_t num_render_channels, + size_t num_capture_channels) { + rtc::StringBuilder ss; + ss << "Render channels: " << num_render_channels; + ss << ", Capture channels: " << num_capture_channels; + return ss.Release(); +} + +void VerifyErl(const std::array<float, kFftLengthBy2Plus1>& erl, + float erl_time_domain, + float reference) { + std::for_each(erl.begin(), erl.end(), + [reference](float a) { EXPECT_NEAR(reference, a, 0.001); }); + EXPECT_NEAR(reference, erl_time_domain, 0.001); +} + +} // namespace + +class ErlEstimatorMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + ErlEstimatorMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 8), + ::testing::Values(1, 2, 8))); + +// Verifies that the correct ERL estimates are achieved. +TEST_P(ErlEstimatorMultiChannel, Estimates) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels)); + std::vector<std::array<float, kFftLengthBy2Plus1>> X2(num_render_channels); + for (auto& X2_ch : X2) { + X2_ch.fill(0.f); + } + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); + for (auto& Y2_ch : Y2) { + Y2_ch.fill(0.f); + } + std::vector<bool> converged_filters(num_capture_channels, false); + const size_t converged_idx = num_capture_channels - 1; + converged_filters[converged_idx] = true; + + ErlEstimator estimator(0); + + // Verifies that the ERL estimate is properly reduced to lower values. + for (auto& X2_ch : X2) { + X2_ch.fill(500 * 1000.f * 1000.f); + } + Y2[converged_idx].fill(10 * X2[0][0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); + + // Verifies that the ERL is not immediately increased when the ERL in the + // data increases. + Y2[converged_idx].fill(10000 * X2[0][0]); + for (size_t k = 0; k < 998; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); + + // Verifies that the rate of increase is 3 dB. + estimator.Update(converged_filters, X2, Y2); + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f); + + // Verifies that the maximum ERL is achieved when there are no low RLE + // estimates. + for (size_t k = 0; k < 1000; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); + + // Verifies that the ERL estimate is is not updated for low-level signals + for (auto& X2_ch : X2) { + X2_ch.fill(1000.f * 1000.f); + } + Y2[converged_idx].fill(10 * X2[0][0]); + for (size_t k = 0; k < 200; ++k) { + estimator.Update(converged_filters, X2, Y2); + } + VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.cc new file mode 100644 index 0000000000..0e3d715c59 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/erle_estimator.h" + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +ErleEstimator::ErleEstimator(size_t startup_phase_length_blocks, + const EchoCanceller3Config& config, + size_t num_capture_channels) + : startup_phase_length_blocks_(startup_phase_length_blocks), + fullband_erle_estimator_(config.erle, num_capture_channels), + subband_erle_estimator_(config, num_capture_channels) { + if (config.erle.num_sections > 1) { + signal_dependent_erle_estimator_ = + std::make_unique<SignalDependentErleEstimator>(config, + num_capture_channels); + } + Reset(true); +} + +ErleEstimator::~ErleEstimator() = default; + +void ErleEstimator::Reset(bool delay_change) { + fullband_erle_estimator_.Reset(); + subband_erle_estimator_.Reset(); + if (signal_dependent_erle_estimator_) { + signal_dependent_erle_estimator_->Reset(); + } + if (delay_change) { + blocks_since_reset_ = 0; + } +} + +void ErleEstimator::Update( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses, + rtc::ArrayView<const float, kFftLengthBy2Plus1> + avg_render_spectrum_with_reverb, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> capture_spectra, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + subtractor_spectra, + const std::vector<bool>& converged_filters) { + RTC_DCHECK_EQ(subband_erle_estimator_.Erle(/*onset_compensated=*/true).size(), + capture_spectra.size()); + RTC_DCHECK_EQ(subband_erle_estimator_.Erle(/*onset_compensated=*/true).size(), + subtractor_spectra.size()); + const auto& X2_reverb = avg_render_spectrum_with_reverb; + const auto& Y2 = capture_spectra; + const auto& E2 = subtractor_spectra; + + if (++blocks_since_reset_ < startup_phase_length_blocks_) { + return; + } + + subband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters); + + if (signal_dependent_erle_estimator_) { + signal_dependent_erle_estimator_->Update( + render_buffer, filter_frequency_responses, X2_reverb, Y2, E2, + subband_erle_estimator_.Erle(/*onset_compensated=*/false), + subband_erle_estimator_.Erle(/*onset_compensated=*/true), + converged_filters); + } + + fullband_erle_estimator_.Update(X2_reverb, Y2, E2, converged_filters); +} + +void ErleEstimator::Dump( + const std::unique_ptr<ApmDataDumper>& data_dumper) const { + fullband_erle_estimator_.Dump(data_dumper); + subband_erle_estimator_.Dump(data_dumper); + if (signal_dependent_erle_estimator_) { + signal_dependent_erle_estimator_->Dump(data_dumper); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.h new file mode 100644 index 0000000000..55797592a9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ERLE_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ERLE_ESTIMATOR_H_ + +#include <stddef.h> + +#include <array> +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/fullband_erle_estimator.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h" +#include "modules/audio_processing/aec3/subband_erle_estimator.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +// Estimates the echo return loss enhancement. One estimate is done per subband +// and another one is done using the aggreation of energy over all the subbands. +class ErleEstimator { + public: + ErleEstimator(size_t startup_phase_length_blocks, + const EchoCanceller3Config& config, + size_t num_capture_channels); + ~ErleEstimator(); + + // Resets the fullband ERLE estimator and the subbands ERLE estimators. + void Reset(bool delay_change); + + // Updates the ERLE estimates. + void Update( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses, + rtc::ArrayView<const float, kFftLengthBy2Plus1> + avg_render_spectrum_with_reverb, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + capture_spectra, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + subtractor_spectra, + const std::vector<bool>& converged_filters); + + // Returns the most recent subband ERLE estimates. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle( + bool onset_compensated) const { + return signal_dependent_erle_estimator_ + ? signal_dependent_erle_estimator_->Erle(onset_compensated) + : subband_erle_estimator_.Erle(onset_compensated); + } + + // Returns the non-capped subband ERLE. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleUnbounded() + const { + // Unbounded ERLE is only used with the subband erle estimator where the + // ERLE is often capped at low values. When the signal dependent ERLE + // estimator is used the capped ERLE is returned. + return !signal_dependent_erle_estimator_ + ? subband_erle_estimator_.ErleUnbounded() + : signal_dependent_erle_estimator_->Erle( + /*onset_compensated=*/false); + } + + // Returns the subband ERLE that are estimated during onsets (only used for + // testing). + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleDuringOnsets() + const { + return subband_erle_estimator_.ErleDuringOnsets(); + } + + // Returns the fullband ERLE estimate. + float FullbandErleLog2() const { + return fullband_erle_estimator_.FullbandErleLog2(); + } + + // Returns an estimation of the current linear filter quality based on the + // current and past fullband ERLE estimates. The returned value is a float + // vector with content between 0 and 1 where 1 indicates that, at this current + // time instant, the linear filter is reaching its maximum subtraction + // performance. + rtc::ArrayView<const absl::optional<float>> GetInstLinearQualityEstimates() + const { + return fullband_erle_estimator_.GetInstLinearQualityEstimates(); + } + + void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const; + + private: + const size_t startup_phase_length_blocks_; + FullBandErleEstimator fullband_erle_estimator_; + SubbandErleEstimator subband_erle_estimator_; + std::unique_ptr<SignalDependentErleEstimator> + signal_dependent_erle_estimator_; + size_t blocks_since_reset_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ERLE_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc new file mode 100644 index 0000000000..42be7d9c7d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/erle_estimator_unittest.cc @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/erle_estimator.h" + +#include <cmath> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { +constexpr int kLowFrequencyLimit = kFftLengthBy2 / 2; +constexpr float kTrueErle = 10.f; +constexpr float kTrueErleOnsets = 1.0f; +constexpr float kEchoPathGain = 3.f; + +void VerifyErleBands( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle, + float reference_lf, + float reference_hf) { + for (size_t ch = 0; ch < erle.size(); ++ch) { + std::for_each( + erle[ch].begin(), erle[ch].begin() + kLowFrequencyLimit, + [reference_lf](float a) { EXPECT_NEAR(reference_lf, a, 0.001); }); + std::for_each( + erle[ch].begin() + kLowFrequencyLimit, erle[ch].end(), + [reference_hf](float a) { EXPECT_NEAR(reference_hf, a, 0.001); }); + } +} + +void VerifyErle( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle, + float erle_time_domain, + float reference_lf, + float reference_hf) { + VerifyErleBands(erle, reference_lf, reference_hf); + EXPECT_NEAR(kTrueErle, erle_time_domain, 0.5); +} + +void VerifyErleGreaterOrEqual( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle1, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle2) { + for (size_t ch = 0; ch < erle1.size(); ++ch) { + for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) { + EXPECT_GE(erle1[ch][i], erle2[ch][i]); + } + } +} + +void FormFarendTimeFrame(Block* x) { + const std::array<float, kBlockSize> frame = { + 7459.88, 17209.6, 17383, 20768.9, 16816.7, 18386.3, 4492.83, 9675.85, + 6665.52, 14808.6, 9342.3, 7483.28, 19261.7, 4145.98, 1622.18, 13475.2, + 7166.32, 6856.61, 21937, 7263.14, 9569.07, 14919, 8413.32, 7551.89, + 7848.65, 6011.27, 13080.6, 15865.2, 12656, 17459.6, 4263.93, 4503.03, + 9311.79, 21095.8, 12657.9, 13906.6, 19267.2, 11338.1, 16828.9, 11501.6, + 11405, 15031.4, 14541.6, 19765.5, 18346.3, 19350.2, 3157.47, 18095.8, + 1743.68, 21328.2, 19727.5, 7295.16, 10332.4, 11055.5, 20107.4, 14708.4, + 12416.2, 16434, 2454.69, 9840.8, 6867.23, 1615.75, 6059.9, 8394.19}; + for (int band = 0; band < x->NumBands(); ++band) { + for (int channel = 0; channel < x->NumChannels(); ++channel) { + RTC_DCHECK_GE(kBlockSize, frame.size()); + std::copy(frame.begin(), frame.end(), x->begin(band, channel)); + } + } +} + +void FormFarendFrame(const RenderBuffer& render_buffer, + float erle, + std::array<float, kFftLengthBy2Plus1>* X2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> E2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> Y2) { + const auto& spectrum_buffer = render_buffer.GetSpectrumBuffer(); + const int num_render_channels = spectrum_buffer.buffer[0].size(); + const int num_capture_channels = Y2.size(); + + X2->fill(0.f); + for (int ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + (*X2)[k] += spectrum_buffer.buffer[spectrum_buffer.write][ch][k] / + num_render_channels; + } + } + + for (int ch = 0; ch < num_capture_channels; ++ch) { + std::transform(X2->begin(), X2->end(), Y2[ch].begin(), + [](float a) { return a * kEchoPathGain * kEchoPathGain; }); + std::transform(Y2[ch].begin(), Y2[ch].end(), E2[ch].begin(), + [erle](float a) { return a / erle; }); + } +} + +void FormNearendFrame( + Block* x, + std::array<float, kFftLengthBy2Plus1>* X2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> E2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> Y2) { + for (int band = 0; band < x->NumBands(); ++band) { + for (int ch = 0; ch < x->NumChannels(); ++ch) { + std::fill(x->begin(band, ch), x->end(band, ch), 0.f); + } + } + + X2->fill(0.f); + for (size_t ch = 0; ch < Y2.size(); ++ch) { + Y2[ch].fill(500.f * 1000.f * 1000.f); + E2[ch].fill(Y2[ch][0]); + } +} + +void GetFilterFreq( + size_t delay_headroom_samples, + rtc::ArrayView<std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_response) { + const size_t delay_headroom_blocks = delay_headroom_samples / kBlockSize; + for (size_t ch = 0; ch < filter_frequency_response[0].size(); ++ch) { + for (auto& block_freq_resp : filter_frequency_response) { + block_freq_resp[ch].fill(0.f); + } + + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + filter_frequency_response[delay_headroom_blocks][ch][k] = kEchoPathGain; + } + } +} + +} // namespace + +class ErleEstimatorMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + ErleEstimatorMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 4, 8), + ::testing::Values(1, 2, 8))); + +TEST_P(ErleEstimatorMultiChannel, VerifyErleIncreaseAndHold) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + std::array<float, kFftLengthBy2Plus1> X2; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2(num_capture_channels); + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); + std::vector<bool> converged_filters(num_capture_channels, true); + + EchoCanceller3Config config; + config.erle.onset_detection = true; + + Block x(kNumBands, num_render_channels); + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_response( + config.filter.refined.length_blocks, + std::vector<std::array<float, kFftLengthBy2Plus1>>( + num_capture_channels)); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + + GetFilterFreq(config.delay.delay_headroom_samples, filter_frequency_response); + + ErleEstimator estimator(0, config, num_capture_channels); + + FormFarendTimeFrame(&x); + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + // Verifies that the ERLE estimate is properly increased to higher values. + FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2, E2, + Y2); + for (size_t k = 0; k < 1000; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, converged_filters); + } + VerifyErle(estimator.Erle(/*onset_compensated=*/true), + std::pow(2.f, estimator.FullbandErleLog2()), config.erle.max_l, + config.erle.max_h); + VerifyErleGreaterOrEqual(estimator.Erle(/*onset_compensated=*/false), + estimator.Erle(/*onset_compensated=*/true)); + VerifyErleGreaterOrEqual(estimator.ErleUnbounded(), + estimator.Erle(/*onset_compensated=*/false)); + + FormNearendFrame(&x, &X2, E2, Y2); + // Verifies that the ERLE is not immediately decreased during nearend + // activity. + for (size_t k = 0; k < 50; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, converged_filters); + } + VerifyErle(estimator.Erle(/*onset_compensated=*/true), + std::pow(2.f, estimator.FullbandErleLog2()), config.erle.max_l, + config.erle.max_h); + VerifyErleGreaterOrEqual(estimator.Erle(/*onset_compensated=*/false), + estimator.Erle(/*onset_compensated=*/true)); + VerifyErleGreaterOrEqual(estimator.ErleUnbounded(), + estimator.Erle(/*onset_compensated=*/false)); +} + +TEST_P(ErleEstimatorMultiChannel, VerifyErleTrackingOnOnsets) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + std::array<float, kFftLengthBy2Plus1> X2; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2(num_capture_channels); + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); + std::vector<bool> converged_filters(num_capture_channels, true); + EchoCanceller3Config config; + config.erle.onset_detection = true; + Block x(kNumBands, num_render_channels); + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_response( + config.filter.refined.length_blocks, + std::vector<std::array<float, kFftLengthBy2Plus1>>( + num_capture_channels)); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + + GetFilterFreq(config.delay.delay_headroom_samples, filter_frequency_response); + + ErleEstimator estimator(/*startup_phase_length_blocks=*/0, config, + num_capture_channels); + + FormFarendTimeFrame(&x); + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + + for (size_t burst = 0; burst < 20; ++burst) { + FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErleOnsets, + &X2, E2, Y2); + for (size_t k = 0; k < 10; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + FormFarendFrame(*render_delay_buffer->GetRenderBuffer(), kTrueErle, &X2, E2, + Y2); + for (size_t k = 0; k < 1000; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + FormNearendFrame(&x, &X2, E2, Y2); + for (size_t k = 0; k < 300; ++k) { + render_delay_buffer->Insert(x); + render_delay_buffer->PrepareCaptureProcessing(); + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, + converged_filters); + } + } + VerifyErleBands(estimator.ErleDuringOnsets(), config.erle.min, + config.erle.min); + FormNearendFrame(&x, &X2, E2, Y2); + for (size_t k = 0; k < 1000; k++) { + estimator.Update(*render_delay_buffer->GetRenderBuffer(), + filter_frequency_response, X2, Y2, E2, converged_filters); + } + // Verifies that during ne activity, Erle converges to the Erle for + // onsets. + VerifyErle(estimator.Erle(/*onset_compensated=*/true), + std::pow(2.f, estimator.FullbandErleLog2()), config.erle.min, + config.erle.min); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.cc new file mode 100644 index 0000000000..1ce2d31d8f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.cc @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/fft_buffer.h" + +namespace webrtc { + +FftBuffer::FftBuffer(size_t size, size_t num_channels) + : size(static_cast<int>(size)), + buffer(size, std::vector<FftData>(num_channels)) { + for (auto& block : buffer) { + for (auto& channel_fft_data : block) { + channel_fft_data.Clear(); + } + } +} + +FftBuffer::~FftBuffer() = default; + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.h new file mode 100644 index 0000000000..4187315863 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fft_buffer.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_FFT_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_FFT_BUFFER_H_ + +#include <stddef.h> + +#include <vector> + +#include "modules/audio_processing/aec3/fft_data.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Struct for bundling a circular buffer of FftData objects together with the +// read and write indices. +struct FftBuffer { + FftBuffer(size_t size, size_t num_channels); + ~FftBuffer(); + + int IncIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index < size - 1 ? index + 1 : 0; + } + + int DecIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index > 0 ? index - 1 : size - 1; + } + + int OffsetIndex(int index, int offset) const { + RTC_DCHECK_GE(buffer.size(), offset); + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return (size + index + offset) % size; + } + + void UpdateWriteIndex(int offset) { write = OffsetIndex(write, offset); } + void IncWriteIndex() { write = IncIndex(write); } + void DecWriteIndex() { write = DecIndex(write); } + void UpdateReadIndex(int offset) { read = OffsetIndex(read, offset); } + void IncReadIndex() { read = IncIndex(read); } + void DecReadIndex() { read = DecIndex(read); } + + const int size; + std::vector<std::vector<FftData>> buffer; + int write = 0; + int read = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_FFT_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fft_data.h b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data.h new file mode 100644 index 0000000000..9c25e784aa --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_FFT_DATA_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_FFT_DATA_H_ + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <algorithm> +#include <array> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Struct that holds imaginary data produced from 128 point real-valued FFTs. +struct FftData { + // Copies the data in src. + void Assign(const FftData& src) { + std::copy(src.re.begin(), src.re.end(), re.begin()); + std::copy(src.im.begin(), src.im.end(), im.begin()); + im[0] = im[kFftLengthBy2] = 0; + } + + // Clears all the imaginary. + void Clear() { + re.fill(0.f); + im.fill(0.f); + } + + // Computes the power spectrum of the data. + void SpectrumAVX2(rtc::ArrayView<float> power_spectrum) const; + + // Computes the power spectrum of the data. + void Spectrum(Aec3Optimization optimization, + rtc::ArrayView<float> power_spectrum) const { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, power_spectrum.size()); + switch (optimization) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + constexpr int kNumFourBinBands = kFftLengthBy2 / 4; + constexpr int kLimit = kNumFourBinBands * 4; + for (size_t k = 0; k < kLimit; k += 4) { + const __m128 r = _mm_loadu_ps(&re[k]); + const __m128 i = _mm_loadu_ps(&im[k]); + const __m128 ii = _mm_mul_ps(i, i); + const __m128 rr = _mm_mul_ps(r, r); + const __m128 rrii = _mm_add_ps(rr, ii); + _mm_storeu_ps(&power_spectrum[k], rrii); + } + power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] + + im[kFftLengthBy2] * im[kFftLengthBy2]; + } break; + case Aec3Optimization::kAvx2: + SpectrumAVX2(power_spectrum); + break; +#endif + default: + std::transform(re.begin(), re.end(), im.begin(), power_spectrum.begin(), + [](float a, float b) { return a * a + b * b; }); + } + } + + // Copy the data from an interleaved array. + void CopyFromPackedArray(const std::array<float, kFftLength>& v) { + re[0] = v[0]; + re[kFftLengthBy2] = v[1]; + im[0] = im[kFftLengthBy2] = 0; + for (size_t k = 1, j = 2; k < kFftLengthBy2; ++k) { + re[k] = v[j++]; + im[k] = v[j++]; + } + } + + // Copies the data into an interleaved array. + void CopyToPackedArray(std::array<float, kFftLength>* v) const { + RTC_DCHECK(v); + (*v)[0] = re[0]; + (*v)[1] = re[kFftLengthBy2]; + for (size_t k = 1, j = 2; k < kFftLengthBy2; ++k) { + (*v)[j++] = re[k]; + (*v)[j++] = im[k]; + } + } + + std::array<float, kFftLengthBy2Plus1> re; + std::array<float, kFftLengthBy2Plus1> im; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_FFT_DATA_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_avx2.cc b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_avx2.cc new file mode 100644 index 0000000000..1fe4bd69c6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_avx2.cc @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/fft_data.h" + +#include <immintrin.h> + +#include "api/array_view.h" + +namespace webrtc { + +// Computes the power spectrum of the data. +void FftData::SpectrumAVX2(rtc::ArrayView<float> power_spectrum) const { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, power_spectrum.size()); + for (size_t k = 0; k < kFftLengthBy2; k += 8) { + __m256 r = _mm256_loadu_ps(&re[k]); + __m256 i = _mm256_loadu_ps(&im[k]); + __m256 ii = _mm256_mul_ps(i, i); + ii = _mm256_fmadd_ps(r, r, ii); + _mm256_storeu_ps(&power_spectrum[k], ii); + } + power_spectrum[kFftLengthBy2] = re[kFftLengthBy2] * re[kFftLengthBy2] + + im[kFftLengthBy2] * im[kFftLengthBy2]; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_gn/moz.build new file mode 100644 index 0000000000..d77163999b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_gn/moz.build @@ -0,0 +1,205 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("fft_data_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_unittest.cc new file mode 100644 index 0000000000..d76fabdbd6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fft_data_unittest.cc @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/fft_data.h" + +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(FftData, TestSse2Optimizations) { + if (GetCPUInfo(kSSE2) != 0) { + FftData x; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + std::array<float, kFftLengthBy2Plus1> spectrum; + std::array<float, kFftLengthBy2Plus1> spectrum_sse2; + x.Spectrum(Aec3Optimization::kNone, spectrum); + x.Spectrum(Aec3Optimization::kSse2, spectrum_sse2); + EXPECT_EQ(spectrum, spectrum_sse2); + } +} + +// Verifies that the optimized methods are bitexact to their reference +// counterparts. +TEST(FftData, TestAvx2Optimizations) { + if (GetCPUInfo(kAVX2) != 0) { + FftData x; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + std::array<float, kFftLengthBy2Plus1> spectrum; + std::array<float, kFftLengthBy2Plus1> spectrum_avx2; + x.Spectrum(Aec3Optimization::kNone, spectrum); + x.Spectrum(Aec3Optimization::kAvx2, spectrum_avx2); + EXPECT_EQ(spectrum, spectrum_avx2); + } +} +#endif + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for null output in CopyToPackedArray. +TEST(FftDataDeathTest, NonNullCopyToPackedArrayOutput) { + EXPECT_DEATH(FftData().CopyToPackedArray(nullptr), ""); +} + +// Verifies the check for null output in Spectrum. +TEST(FftDataDeathTest, NonNullSpectrumOutput) { + EXPECT_DEATH(FftData().Spectrum(Aec3Optimization::kNone, nullptr), ""); +} + +#endif + +// Verifies that the Assign method properly copies the data from the source and +// ensures that the imaginary components for the DC and Nyquist bins are 0. +TEST(FftData, Assign) { + FftData x; + FftData y; + + x.re.fill(1.f); + x.im.fill(2.f); + y.Assign(x); + EXPECT_EQ(x.re, y.re); + EXPECT_EQ(0.f, y.im[0]); + EXPECT_EQ(0.f, y.im[x.im.size() - 1]); + for (size_t k = 1; k < x.im.size() - 1; ++k) { + EXPECT_EQ(x.im[k], y.im[k]); + } +} + +// Verifies that the Clear method properly clears all the data. +TEST(FftData, Clear) { + FftData x_ref; + FftData x; + + x_ref.re.fill(0.f); + x_ref.im.fill(0.f); + + x.re.fill(1.f); + x.im.fill(2.f); + x.Clear(); + + EXPECT_EQ(x_ref.re, x.re); + EXPECT_EQ(x_ref.im, x.im); +} + +// Verifies that the spectrum is correctly computed. +TEST(FftData, Spectrum) { + FftData x; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + std::array<float, kFftLengthBy2Plus1> spectrum; + x.Spectrum(Aec3Optimization::kNone, spectrum); + + EXPECT_EQ(x.re[0] * x.re[0], spectrum[0]); + EXPECT_EQ(x.re[spectrum.size() - 1] * x.re[spectrum.size() - 1], + spectrum[spectrum.size() - 1]); + for (size_t k = 1; k < spectrum.size() - 1; ++k) { + EXPECT_EQ(x.re[k] * x.re[k] + x.im[k] * x.im[k], spectrum[k]); + } +} + +// Verifies that the functionality in CopyToPackedArray works as intended. +TEST(FftData, CopyToPackedArray) { + FftData x; + std::array<float, kFftLength> x_packed; + + for (size_t k = 0; k < x.re.size(); ++k) { + x.re[k] = k + 1; + } + + x.im[0] = x.im[x.im.size() - 1] = 0.f; + for (size_t k = 1; k < x.im.size() - 1; ++k) { + x.im[k] = 2.f * (k + 1); + } + + x.CopyToPackedArray(&x_packed); + + EXPECT_EQ(x.re[0], x_packed[0]); + EXPECT_EQ(x.re[x.re.size() - 1], x_packed[1]); + for (size_t k = 1; k < x_packed.size() / 2; ++k) { + EXPECT_EQ(x.re[k], x_packed[2 * k]); + EXPECT_EQ(x.im[k], x_packed[2 * k + 1]); + } +} + +// Verifies that the functionality in CopyFromPackedArray works as intended +// (relies on that the functionality in CopyToPackedArray has been verified in +// the test above). +TEST(FftData, CopyFromPackedArray) { + FftData x_ref; + FftData x; + std::array<float, kFftLength> x_packed; + + for (size_t k = 0; k < x_ref.re.size(); ++k) { + x_ref.re[k] = k + 1; + } + + x_ref.im[0] = x_ref.im[x_ref.im.size() - 1] = 0.f; + for (size_t k = 1; k < x_ref.im.size() - 1; ++k) { + x_ref.im[k] = 2.f * (k + 1); + } + + x_ref.CopyToPackedArray(&x_packed); + x.CopyFromPackedArray(x_packed); + + EXPECT_EQ(x_ref.re, x.re); + EXPECT_EQ(x_ref.im, x.im); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.cc new file mode 100644 index 0000000000..d8fd3aa275 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.cc @@ -0,0 +1,289 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/filter_analyzer.h" + +#include <math.h> + +#include <algorithm> +#include <array> +#include <numeric> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +size_t FindPeakIndex(rtc::ArrayView<const float> filter_time_domain, + size_t peak_index_in, + size_t start_sample, + size_t end_sample) { + size_t peak_index_out = peak_index_in; + float max_h2 = + filter_time_domain[peak_index_out] * filter_time_domain[peak_index_out]; + for (size_t k = start_sample; k <= end_sample; ++k) { + float tmp = filter_time_domain[k] * filter_time_domain[k]; + if (tmp > max_h2) { + peak_index_out = k; + max_h2 = tmp; + } + } + + return peak_index_out; +} + +} // namespace + +std::atomic<int> FilterAnalyzer::instance_count_(0); + +FilterAnalyzer::FilterAnalyzer(const EchoCanceller3Config& config, + size_t num_capture_channels) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + bounded_erl_(config.ep_strength.bounded_erl), + default_gain_(config.ep_strength.default_gain), + h_highpass_(num_capture_channels, + std::vector<float>( + GetTimeDomainLength(config.filter.refined.length_blocks), + 0.f)), + filter_analysis_states_(num_capture_channels, + FilterAnalysisState(config)), + filter_delays_blocks_(num_capture_channels, 0) { + Reset(); +} + +FilterAnalyzer::~FilterAnalyzer() = default; + +void FilterAnalyzer::Reset() { + blocks_since_reset_ = 0; + ResetRegion(); + for (auto& state : filter_analysis_states_) { + state.Reset(default_gain_); + } + std::fill(filter_delays_blocks_.begin(), filter_delays_blocks_.end(), 0); +} + +void FilterAnalyzer::Update( + rtc::ArrayView<const std::vector<float>> filters_time_domain, + const RenderBuffer& render_buffer, + bool* any_filter_consistent, + float* max_echo_path_gain) { + RTC_DCHECK(any_filter_consistent); + RTC_DCHECK(max_echo_path_gain); + RTC_DCHECK_EQ(filters_time_domain.size(), filter_analysis_states_.size()); + RTC_DCHECK_EQ(filters_time_domain.size(), h_highpass_.size()); + + ++blocks_since_reset_; + SetRegionToAnalyze(filters_time_domain[0].size()); + AnalyzeRegion(filters_time_domain, render_buffer); + + // Aggregate the results for all capture channels. + auto& st_ch0 = filter_analysis_states_[0]; + *any_filter_consistent = st_ch0.consistent_estimate; + *max_echo_path_gain = st_ch0.gain; + min_filter_delay_blocks_ = filter_delays_blocks_[0]; + for (size_t ch = 1; ch < filters_time_domain.size(); ++ch) { + auto& st_ch = filter_analysis_states_[ch]; + *any_filter_consistent = + *any_filter_consistent || st_ch.consistent_estimate; + *max_echo_path_gain = std::max(*max_echo_path_gain, st_ch.gain); + min_filter_delay_blocks_ = + std::min(min_filter_delay_blocks_, filter_delays_blocks_[ch]); + } +} + +void FilterAnalyzer::AnalyzeRegion( + rtc::ArrayView<const std::vector<float>> filters_time_domain, + const RenderBuffer& render_buffer) { + // Preprocess the filter to avoid issues with low-frequency components in the + // filter. + PreProcessFilters(filters_time_domain); + data_dumper_->DumpRaw("aec3_linear_filter_processed_td", h_highpass_[0]); + + constexpr float kOneByBlockSize = 1.f / kBlockSize; + for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) { + RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size()); + RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size()); + + auto& st_ch = filter_analysis_states_[ch]; + RTC_DCHECK_EQ(h_highpass_[ch].size(), filters_time_domain[ch].size()); + RTC_DCHECK_GT(h_highpass_[ch].size(), 0); + st_ch.peak_index = std::min(st_ch.peak_index, h_highpass_[ch].size() - 1); + + st_ch.peak_index = + FindPeakIndex(h_highpass_[ch], st_ch.peak_index, region_.start_sample_, + region_.end_sample_); + filter_delays_blocks_[ch] = st_ch.peak_index >> kBlockSizeLog2; + UpdateFilterGain(h_highpass_[ch], &st_ch); + st_ch.filter_length_blocks = + filters_time_domain[ch].size() * kOneByBlockSize; + + st_ch.consistent_estimate = st_ch.consistent_filter_detector.Detect( + h_highpass_[ch], region_, + render_buffer.GetBlock(-filter_delays_blocks_[ch]), st_ch.peak_index, + filter_delays_blocks_[ch]); + } +} + +void FilterAnalyzer::UpdateFilterGain( + rtc::ArrayView<const float> filter_time_domain, + FilterAnalysisState* st) { + bool sufficient_time_to_converge = + blocks_since_reset_ > 5 * kNumBlocksPerSecond; + + if (sufficient_time_to_converge && st->consistent_estimate) { + st->gain = fabsf(filter_time_domain[st->peak_index]); + } else { + // TODO(peah): Verify whether this check against a float is ok. + if (st->gain) { + st->gain = std::max(st->gain, fabsf(filter_time_domain[st->peak_index])); + } + } + + if (bounded_erl_ && st->gain) { + st->gain = std::max(st->gain, 0.01f); + } +} + +void FilterAnalyzer::PreProcessFilters( + rtc::ArrayView<const std::vector<float>> filters_time_domain) { + for (size_t ch = 0; ch < filters_time_domain.size(); ++ch) { + RTC_DCHECK_LT(region_.start_sample_, filters_time_domain[ch].size()); + RTC_DCHECK_LT(region_.end_sample_, filters_time_domain[ch].size()); + + RTC_DCHECK_GE(h_highpass_[ch].capacity(), filters_time_domain[ch].size()); + h_highpass_[ch].resize(filters_time_domain[ch].size()); + // Minimum phase high-pass filter with cutoff frequency at about 600 Hz. + constexpr std::array<float, 3> h = { + {0.7929742f, -0.36072128f, -0.47047766f}}; + + std::fill(h_highpass_[ch].begin() + region_.start_sample_, + h_highpass_[ch].begin() + region_.end_sample_ + 1, 0.f); + float* h_highpass_ch = h_highpass_[ch].data(); + const float* filters_time_domain_ch = filters_time_domain[ch].data(); + const size_t region_end = region_.end_sample_; + for (size_t k = std::max(h.size() - 1, region_.start_sample_); + k <= region_end; ++k) { + float tmp = h_highpass_ch[k]; + for (size_t j = 0; j < h.size(); ++j) { + tmp += filters_time_domain_ch[k - j] * h[j]; + } + h_highpass_ch[k] = tmp; + } + } +} + +void FilterAnalyzer::ResetRegion() { + region_.start_sample_ = 0; + region_.end_sample_ = 0; +} + +void FilterAnalyzer::SetRegionToAnalyze(size_t filter_size) { + constexpr size_t kNumberBlocksToUpdate = 1; + auto& r = region_; + r.start_sample_ = r.end_sample_ >= filter_size - 1 ? 0 : r.end_sample_ + 1; + r.end_sample_ = + std::min(r.start_sample_ + kNumberBlocksToUpdate * kBlockSize - 1, + filter_size - 1); + + // Check range. + RTC_DCHECK_LT(r.start_sample_, filter_size); + RTC_DCHECK_LT(r.end_sample_, filter_size); + RTC_DCHECK_LE(r.start_sample_, r.end_sample_); +} + +FilterAnalyzer::ConsistentFilterDetector::ConsistentFilterDetector( + const EchoCanceller3Config& config) + : active_render_threshold_(config.render_levels.active_render_limit * + config.render_levels.active_render_limit * + kFftLengthBy2) { + Reset(); +} + +void FilterAnalyzer::ConsistentFilterDetector::Reset() { + significant_peak_ = false; + filter_floor_accum_ = 0.f; + filter_secondary_peak_ = 0.f; + filter_floor_low_limit_ = 0; + filter_floor_high_limit_ = 0; + consistent_estimate_counter_ = 0; + consistent_delay_reference_ = -10; +} + +bool FilterAnalyzer::ConsistentFilterDetector::Detect( + rtc::ArrayView<const float> filter_to_analyze, + const FilterRegion& region, + const Block& x_block, + size_t peak_index, + int delay_blocks) { + if (region.start_sample_ == 0) { + filter_floor_accum_ = 0.f; + filter_secondary_peak_ = 0.f; + filter_floor_low_limit_ = peak_index < 64 ? 0 : peak_index - 64; + filter_floor_high_limit_ = + peak_index > filter_to_analyze.size() - 129 ? 0 : peak_index + 128; + } + + float filter_floor_accum = filter_floor_accum_; + float filter_secondary_peak = filter_secondary_peak_; + for (size_t k = region.start_sample_; + k < std::min(region.end_sample_ + 1, filter_floor_low_limit_); ++k) { + float abs_h = fabsf(filter_to_analyze[k]); + filter_floor_accum += abs_h; + filter_secondary_peak = std::max(filter_secondary_peak, abs_h); + } + + for (size_t k = std::max(filter_floor_high_limit_, region.start_sample_); + k <= region.end_sample_; ++k) { + float abs_h = fabsf(filter_to_analyze[k]); + filter_floor_accum += abs_h; + filter_secondary_peak = std::max(filter_secondary_peak, abs_h); + } + filter_floor_accum_ = filter_floor_accum; + filter_secondary_peak_ = filter_secondary_peak; + + if (region.end_sample_ == filter_to_analyze.size() - 1) { + float filter_floor = filter_floor_accum_ / + (filter_floor_low_limit_ + filter_to_analyze.size() - + filter_floor_high_limit_); + + float abs_peak = fabsf(filter_to_analyze[peak_index]); + significant_peak_ = abs_peak > 10.f * filter_floor && + abs_peak > 2.f * filter_secondary_peak_; + } + + if (significant_peak_) { + bool active_render_block = false; + for (int ch = 0; ch < x_block.NumChannels(); ++ch) { + rtc::ArrayView<const float, kBlockSize> x_channel = + x_block.View(/*band=*/0, ch); + const float x_energy = std::inner_product( + x_channel.begin(), x_channel.end(), x_channel.begin(), 0.f); + if (x_energy > active_render_threshold_) { + active_render_block = true; + break; + } + } + + if (consistent_delay_reference_ == delay_blocks) { + if (active_render_block) { + ++consistent_estimate_counter_; + } + } else { + consistent_estimate_counter_ = 0; + consistent_delay_reference_ = delay_blocks; + } + } + return consistent_estimate_counter_ > 1.5f * kNumBlocksPerSecond; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.h b/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.h new file mode 100644 index 0000000000..9aec8b14d7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_FILTER_ANALYZER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_FILTER_ANALYZER_H_ + +#include <stddef.h> + +#include <array> +#include <atomic> +#include <memory> +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block.h" + +namespace webrtc { + +class ApmDataDumper; +class RenderBuffer; + +// Class for analyzing the properties of an adaptive filter. +class FilterAnalyzer { + public: + FilterAnalyzer(const EchoCanceller3Config& config, + size_t num_capture_channels); + ~FilterAnalyzer(); + + FilterAnalyzer(const FilterAnalyzer&) = delete; + FilterAnalyzer& operator=(const FilterAnalyzer&) = delete; + + // Resets the analysis. + void Reset(); + + // Updates the estimates with new input data. + void Update(rtc::ArrayView<const std::vector<float>> filters_time_domain, + const RenderBuffer& render_buffer, + bool* any_filter_consistent, + float* max_echo_path_gain); + + // Returns the delay in blocks for each filter. + rtc::ArrayView<const int> FilterDelaysBlocks() const { + return filter_delays_blocks_; + } + + // Returns the minimum delay of all filters in terms of blocks. + int MinFilterDelayBlocks() const { return min_filter_delay_blocks_; } + + // Returns the number of blocks for the current used filter. + int FilterLengthBlocks() const { + return filter_analysis_states_[0].filter_length_blocks; + } + + // Returns the preprocessed filter. + rtc::ArrayView<const std::vector<float>> GetAdjustedFilters() const { + return h_highpass_; + } + + // Public for testing purposes only. + void SetRegionToAnalyze(size_t filter_size); + + private: + struct FilterAnalysisState; + + void AnalyzeRegion( + rtc::ArrayView<const std::vector<float>> filters_time_domain, + const RenderBuffer& render_buffer); + + void UpdateFilterGain(rtc::ArrayView<const float> filters_time_domain, + FilterAnalysisState* st); + void PreProcessFilters( + rtc::ArrayView<const std::vector<float>> filters_time_domain); + + void ResetRegion(); + + struct FilterRegion { + size_t start_sample_; + size_t end_sample_; + }; + + // This class checks whether the shape of the impulse response has been + // consistent over time. + class ConsistentFilterDetector { + public: + explicit ConsistentFilterDetector(const EchoCanceller3Config& config); + void Reset(); + bool Detect(rtc::ArrayView<const float> filter_to_analyze, + const FilterRegion& region, + const Block& x_block, + size_t peak_index, + int delay_blocks); + + private: + bool significant_peak_; + float filter_floor_accum_; + float filter_secondary_peak_; + size_t filter_floor_low_limit_; + size_t filter_floor_high_limit_; + const float active_render_threshold_; + size_t consistent_estimate_counter_ = 0; + int consistent_delay_reference_ = -10; + }; + + struct FilterAnalysisState { + explicit FilterAnalysisState(const EchoCanceller3Config& config) + : filter_length_blocks(config.filter.refined_initial.length_blocks), + consistent_filter_detector(config) { + Reset(config.ep_strength.default_gain); + } + + void Reset(float default_gain) { + peak_index = 0; + gain = default_gain; + consistent_filter_detector.Reset(); + } + + float gain; + size_t peak_index; + int filter_length_blocks; + bool consistent_estimate = false; + ConsistentFilterDetector consistent_filter_detector; + }; + + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const bool bounded_erl_; + const float default_gain_; + std::vector<std::vector<float>> h_highpass_; + + size_t blocks_since_reset_ = 0; + FilterRegion region_; + + std::vector<FilterAnalysisState> filter_analysis_states_; + std::vector<int> filter_delays_blocks_; + + int min_filter_delay_blocks_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_FILTER_ANALYZER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer_unittest.cc new file mode 100644 index 0000000000..f1e2e4c188 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/filter_analyzer_unittest.cc @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/filter_analyzer.h" + +#include <algorithm> + +#include "test/gmock.h" +#include "test/gtest.h" + +namespace webrtc { + +// Verifies that the filter analyzer handles filter resizes properly. +TEST(FilterAnalyzer, FilterResize) { + EchoCanceller3Config c; + std::vector<float> filter(65, 0.f); + for (size_t num_capture_channels : {1, 2, 4}) { + FilterAnalyzer fa(c, num_capture_channels); + fa.SetRegionToAnalyze(filter.size()); + fa.SetRegionToAnalyze(filter.size()); + filter.resize(32); + fa.SetRegionToAnalyze(filter.size()); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.cc b/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.cc new file mode 100644 index 0000000000..3039dcf7f1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.cc @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/frame_blocker.h" + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +FrameBlocker::FrameBlocker(size_t num_bands, size_t num_channels) + : num_bands_(num_bands), + num_channels_(num_channels), + buffer_(num_bands_, std::vector<std::vector<float>>(num_channels)) { + RTC_DCHECK_LT(0, num_bands); + RTC_DCHECK_LT(0, num_channels); + for (auto& band : buffer_) { + for (auto& channel : band) { + channel.reserve(kBlockSize); + RTC_DCHECK(channel.empty()); + } + } +} + +FrameBlocker::~FrameBlocker() = default; + +void FrameBlocker::InsertSubFrameAndExtractBlock( + const std::vector<std::vector<rtc::ArrayView<float>>>& sub_frame, + Block* block) { + RTC_DCHECK(block); + RTC_DCHECK_EQ(num_bands_, block->NumBands()); + RTC_DCHECK_EQ(num_bands_, sub_frame.size()); + for (size_t band = 0; band < num_bands_; ++band) { + RTC_DCHECK_EQ(num_channels_, block->NumChannels()); + RTC_DCHECK_EQ(num_channels_, sub_frame[band].size()); + for (size_t channel = 0; channel < num_channels_; ++channel) { + RTC_DCHECK_GE(kBlockSize - 16, buffer_[band][channel].size()); + RTC_DCHECK_EQ(kSubFrameLength, sub_frame[band][channel].size()); + const int samples_to_block = kBlockSize - buffer_[band][channel].size(); + std::copy(buffer_[band][channel].begin(), buffer_[band][channel].end(), + block->begin(band, channel)); + std::copy(sub_frame[band][channel].begin(), + sub_frame[band][channel].begin() + samples_to_block, + block->begin(band, channel) + kBlockSize - samples_to_block); + buffer_[band][channel].clear(); + buffer_[band][channel].insert( + buffer_[band][channel].begin(), + sub_frame[band][channel].begin() + samples_to_block, + sub_frame[band][channel].end()); + } + } +} + +bool FrameBlocker::IsBlockAvailable() const { + return kBlockSize == buffer_[0][0].size(); +} + +void FrameBlocker::ExtractBlock(Block* block) { + RTC_DCHECK(block); + RTC_DCHECK_EQ(num_bands_, block->NumBands()); + RTC_DCHECK_EQ(num_channels_, block->NumChannels()); + RTC_DCHECK(IsBlockAvailable()); + for (size_t band = 0; band < num_bands_; ++band) { + for (size_t channel = 0; channel < num_channels_; ++channel) { + RTC_DCHECK_EQ(kBlockSize, buffer_[band][channel].size()); + std::copy(buffer_[band][channel].begin(), buffer_[band][channel].end(), + block->begin(band, channel)); + buffer_[band][channel].clear(); + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.h b/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.h new file mode 100644 index 0000000000..623c812157 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_FRAME_BLOCKER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_FRAME_BLOCKER_H_ + +#include <stddef.h> + +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block.h" + +namespace webrtc { + +// Class for producing 64 sample multiband blocks from frames consisting of 2 +// subframes of 80 samples. +class FrameBlocker { + public: + FrameBlocker(size_t num_bands, size_t num_channels); + ~FrameBlocker(); + FrameBlocker(const FrameBlocker&) = delete; + FrameBlocker& operator=(const FrameBlocker&) = delete; + + // Inserts one 80 sample multiband subframe from the multiband frame and + // extracts one 64 sample multiband block. + void InsertSubFrameAndExtractBlock( + const std::vector<std::vector<rtc::ArrayView<float>>>& sub_frame, + Block* block); + // Reports whether a multiband block of 64 samples is available for + // extraction. + bool IsBlockAvailable() const; + // Extracts a multiband block of 64 samples. + void ExtractBlock(Block* block); + + private: + const size_t num_bands_; + const size_t num_channels_; + std::vector<std::vector<std::vector<float>>> buffer_; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_FRAME_BLOCKER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc new file mode 100644 index 0000000000..92e393023a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/frame_blocker_unittest.cc @@ -0,0 +1,425 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/frame_blocker.h" + +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block_framer.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +float ComputeSampleValue(size_t chunk_counter, + size_t chunk_size, + size_t band, + size_t channel, + size_t sample_index, + int offset) { + float value = + static_cast<int>(chunk_counter * chunk_size + sample_index + channel) + + offset; + return value > 0 ? 5000 * band + value : 0; +} + +void FillSubFrame(size_t sub_frame_counter, + int offset, + std::vector<std::vector<std::vector<float>>>* sub_frame) { + for (size_t band = 0; band < sub_frame->size(); ++band) { + for (size_t channel = 0; channel < (*sub_frame)[band].size(); ++channel) { + for (size_t sample = 0; sample < (*sub_frame)[band][channel].size(); + ++sample) { + (*sub_frame)[band][channel][sample] = ComputeSampleValue( + sub_frame_counter, kSubFrameLength, band, channel, sample, offset); + } + } + } +} + +void FillSubFrameView( + size_t sub_frame_counter, + int offset, + std::vector<std::vector<std::vector<float>>>* sub_frame, + std::vector<std::vector<rtc::ArrayView<float>>>* sub_frame_view) { + FillSubFrame(sub_frame_counter, offset, sub_frame); + for (size_t band = 0; band < sub_frame_view->size(); ++band) { + for (size_t channel = 0; channel < (*sub_frame_view)[band].size(); + ++channel) { + (*sub_frame_view)[band][channel] = rtc::ArrayView<float>( + &(*sub_frame)[band][channel][0], (*sub_frame)[band][channel].size()); + } + } +} + +bool VerifySubFrame( + size_t sub_frame_counter, + int offset, + const std::vector<std::vector<rtc::ArrayView<float>>>& sub_frame_view) { + std::vector<std::vector<std::vector<float>>> reference_sub_frame( + sub_frame_view.size(), + std::vector<std::vector<float>>( + sub_frame_view[0].size(), + std::vector<float>(sub_frame_view[0][0].size(), 0.f))); + FillSubFrame(sub_frame_counter, offset, &reference_sub_frame); + for (size_t band = 0; band < sub_frame_view.size(); ++band) { + for (size_t channel = 0; channel < sub_frame_view[band].size(); ++channel) { + for (size_t sample = 0; sample < sub_frame_view[band][channel].size(); + ++sample) { + if (reference_sub_frame[band][channel][sample] != + sub_frame_view[band][channel][sample]) { + return false; + } + } + } + } + return true; +} + +bool VerifyBlock(size_t block_counter, int offset, const Block& block) { + for (int band = 0; band < block.NumBands(); ++band) { + for (int channel = 0; channel < block.NumChannels(); ++channel) { + for (size_t sample = 0; sample < kBlockSize; ++sample) { + auto it = block.begin(band, channel) + sample; + const float reference_value = ComputeSampleValue( + block_counter, kBlockSize, band, channel, sample, offset); + if (reference_value != *it) { + return false; + } + } + } + } + return true; +} + +// Verifies that the FrameBlocker properly forms blocks out of the frames. +void RunBlockerTest(int sample_rate_hz, size_t num_channels) { + constexpr size_t kNumSubFramesToProcess = 20; + const size_t num_bands = NumBandsForRate(sample_rate_hz); + + Block block(num_bands, num_channels); + std::vector<std::vector<std::vector<float>>> input_sub_frame( + num_bands, std::vector<std::vector<float>>( + num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> input_sub_frame_view( + num_bands, std::vector<rtc::ArrayView<float>>(num_channels)); + FrameBlocker blocker(num_bands, num_channels); + + size_t block_counter = 0; + for (size_t sub_frame_index = 0; sub_frame_index < kNumSubFramesToProcess; + ++sub_frame_index) { + FillSubFrameView(sub_frame_index, 0, &input_sub_frame, + &input_sub_frame_view); + + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &block); + VerifyBlock(block_counter++, 0, block); + + if ((sub_frame_index + 1) % 4 == 0) { + EXPECT_TRUE(blocker.IsBlockAvailable()); + } else { + EXPECT_FALSE(blocker.IsBlockAvailable()); + } + if (blocker.IsBlockAvailable()) { + blocker.ExtractBlock(&block); + VerifyBlock(block_counter++, 0, block); + } + } +} + +// Verifies that the FrameBlocker and BlockFramer work well together and produce +// the expected output. +void RunBlockerAndFramerTest(int sample_rate_hz, size_t num_channels) { + const size_t kNumSubFramesToProcess = 20; + const size_t num_bands = NumBandsForRate(sample_rate_hz); + + Block block(num_bands, num_channels); + std::vector<std::vector<std::vector<float>>> input_sub_frame( + num_bands, std::vector<std::vector<float>>( + num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<std::vector<float>>> output_sub_frame( + num_bands, std::vector<std::vector<float>>( + num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> output_sub_frame_view( + num_bands, std::vector<rtc::ArrayView<float>>(num_channels)); + std::vector<std::vector<rtc::ArrayView<float>>> input_sub_frame_view( + num_bands, std::vector<rtc::ArrayView<float>>(num_channels)); + FrameBlocker blocker(num_bands, num_channels); + BlockFramer framer(num_bands, num_channels); + + for (size_t sub_frame_index = 0; sub_frame_index < kNumSubFramesToProcess; + ++sub_frame_index) { + FillSubFrameView(sub_frame_index, 0, &input_sub_frame, + &input_sub_frame_view); + FillSubFrameView(sub_frame_index, 0, &output_sub_frame, + &output_sub_frame_view); + + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &block); + framer.InsertBlockAndExtractSubFrame(block, &output_sub_frame_view); + + if ((sub_frame_index + 1) % 4 == 0) { + EXPECT_TRUE(blocker.IsBlockAvailable()); + } else { + EXPECT_FALSE(blocker.IsBlockAvailable()); + } + if (blocker.IsBlockAvailable()) { + blocker.ExtractBlock(&block); + framer.InsertBlock(block); + } + if (sub_frame_index > 1) { + EXPECT_TRUE(VerifySubFrame(sub_frame_index, -64, output_sub_frame_view)); + } + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies that the FrameBlocker crashes if the InsertSubFrameAndExtractBlock +// method is called for inputs with the wrong number of bands or band lengths. +void RunWronglySizedInsertAndExtractParametersTest( + int sample_rate_hz, + size_t correct_num_channels, + size_t num_block_bands, + size_t num_block_channels, + size_t num_sub_frame_bands, + size_t num_sub_frame_channels, + size_t sub_frame_length) { + const size_t correct_num_bands = NumBandsForRate(sample_rate_hz); + + Block block(num_block_bands, num_block_channels); + std::vector<std::vector<std::vector<float>>> input_sub_frame( + num_sub_frame_bands, + std::vector<std::vector<float>>( + num_sub_frame_channels, std::vector<float>(sub_frame_length, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> input_sub_frame_view( + input_sub_frame.size(), + std::vector<rtc::ArrayView<float>>(num_sub_frame_channels)); + FillSubFrameView(0, 0, &input_sub_frame, &input_sub_frame_view); + FrameBlocker blocker(correct_num_bands, correct_num_channels); + EXPECT_DEATH( + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &block), ""); +} + +// Verifies that the FrameBlocker crashes if the ExtractBlock method is called +// for inputs with the wrong number of bands or band lengths. +void RunWronglySizedExtractParameterTest(int sample_rate_hz, + size_t correct_num_channels, + size_t num_block_bands, + size_t num_block_channels) { + const size_t correct_num_bands = NumBandsForRate(sample_rate_hz); + + Block correct_block(correct_num_bands, correct_num_channels); + Block wrong_block(num_block_bands, num_block_channels); + std::vector<std::vector<std::vector<float>>> input_sub_frame( + correct_num_bands, + std::vector<std::vector<float>>( + correct_num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> input_sub_frame_view( + input_sub_frame.size(), + std::vector<rtc::ArrayView<float>>(correct_num_channels)); + FillSubFrameView(0, 0, &input_sub_frame, &input_sub_frame_view); + FrameBlocker blocker(correct_num_bands, correct_num_channels); + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &correct_block); + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &correct_block); + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &correct_block); + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &correct_block); + + EXPECT_DEATH(blocker.ExtractBlock(&wrong_block), ""); +} + +// Verifies that the FrameBlocker crashes if the ExtractBlock method is called +// after a wrong number of previous InsertSubFrameAndExtractBlock method calls +// have been made. +void RunWrongExtractOrderTest(int sample_rate_hz, + size_t num_channels, + size_t num_preceeding_api_calls) { + const size_t num_bands = NumBandsForRate(sample_rate_hz); + + Block block(num_bands, num_channels); + std::vector<std::vector<std::vector<float>>> input_sub_frame( + num_bands, std::vector<std::vector<float>>( + num_channels, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> input_sub_frame_view( + input_sub_frame.size(), std::vector<rtc::ArrayView<float>>(num_channels)); + FillSubFrameView(0, 0, &input_sub_frame, &input_sub_frame_view); + FrameBlocker blocker(num_bands, num_channels); + for (size_t k = 0; k < num_preceeding_api_calls; ++k) { + blocker.InsertSubFrameAndExtractBlock(input_sub_frame_view, &block); + } + + EXPECT_DEATH(blocker.ExtractBlock(&block), ""); +} +#endif + +std::string ProduceDebugText(int sample_rate_hz, size_t num_channels) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + ss << ", number of channels: " << num_channels; + return ss.Release(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +TEST(FrameBlockerDeathTest, + WrongNumberOfBandsInBlockForInsertSubFrameAndExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_bands = (correct_num_bands % 3) + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, wrong_num_bands, correct_num_channels, + correct_num_bands, correct_num_channels, kSubFrameLength); + } + } +} + +TEST(FrameBlockerDeathTest, + WrongNumberOfChannelsInBlockForInsertSubFrameAndExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_channels = correct_num_channels + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, wrong_num_channels, + correct_num_bands, correct_num_channels, kSubFrameLength); + } + } +} + +TEST(FrameBlockerDeathTest, + WrongNumberOfBandsInSubFrameForInsertSubFrameAndExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_bands = (correct_num_bands % 3) + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, correct_num_channels, + wrong_num_bands, correct_num_channels, kSubFrameLength); + } + } +} + +TEST(FrameBlockerDeathTest, + WrongNumberOfChannelsInSubFrameForInsertSubFrameAndExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_channels = correct_num_channels + 1; + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, wrong_num_channels, + correct_num_bands, wrong_num_channels, kSubFrameLength); + } + } +} + +TEST(FrameBlockerDeathTest, + WrongNumberOfSamplesInSubFrameForInsertSubFrameAndExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + RunWronglySizedInsertAndExtractParametersTest( + rate, correct_num_channels, correct_num_bands, correct_num_channels, + correct_num_bands, correct_num_channels, kSubFrameLength - 1); + } + } +} + +TEST(FrameBlockerDeathTest, WrongNumberOfBandsInBlockForExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_bands = (correct_num_bands % 3) + 1; + RunWronglySizedExtractParameterTest( + rate, correct_num_channels, wrong_num_bands, correct_num_channels); + } + } +} + +TEST(FrameBlockerDeathTest, WrongNumberOfChannelsInBlockForExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t correct_num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, correct_num_channels)); + const size_t correct_num_bands = NumBandsForRate(rate); + const size_t wrong_num_channels = correct_num_channels + 1; + RunWronglySizedExtractParameterTest( + rate, correct_num_channels, correct_num_bands, wrong_num_channels); + } + } +} + +TEST(FrameBlockerDeathTest, WrongNumberOfPreceedingApiCallsForExtractBlock) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t num_channels : {1, 2, 4, 8}) { + for (size_t num_calls = 0; num_calls < 4; ++num_calls) { + rtc::StringBuilder ss; + ss << "Sample rate: " << rate; + ss << "Num channels: " << num_channels; + ss << ", Num preceeding InsertSubFrameAndExtractBlock calls: " + << num_calls; + + SCOPED_TRACE(ss.str()); + RunWrongExtractOrderTest(rate, num_channels, num_calls); + } + } + } +} + +// Verifies that the verification for 0 number of channels works. +TEST(FrameBlockerDeathTest, ZeroNumberOfChannelsParameter) { + EXPECT_DEATH(FrameBlocker(16000, 0), ""); +} + +// Verifies that the verification for 0 number of bands works. +TEST(FrameBlockerDeathTest, ZeroNumberOfBandsParameter) { + EXPECT_DEATH(FrameBlocker(0, 1), ""); +} + +// Verifiers that the verification for null sub_frame pointer works. +TEST(FrameBlockerDeathTest, NullBlockParameter) { + std::vector<std::vector<std::vector<float>>> sub_frame( + 1, std::vector<std::vector<float>>( + 1, std::vector<float>(kSubFrameLength, 0.f))); + std::vector<std::vector<rtc::ArrayView<float>>> sub_frame_view( + sub_frame.size()); + FillSubFrameView(0, 0, &sub_frame, &sub_frame_view); + EXPECT_DEATH( + FrameBlocker(1, 1).InsertSubFrameAndExtractBlock(sub_frame_view, nullptr), + ""); +} + +#endif + +TEST(FrameBlocker, BlockBitexactness) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, num_channels)); + RunBlockerTest(rate, num_channels); + } + } +} + +TEST(FrameBlocker, BlockerAndFramer) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t num_channels : {1, 2, 4, 8}) { + SCOPED_TRACE(ProduceDebugText(rate, num_channels)); + RunBlockerAndFramerTest(rate, num_channels); + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.cc new file mode 100644 index 0000000000..e56674e4c9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.cc @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/fullband_erle_estimator.h" + +#include <algorithm> +#include <memory> +#include <numeric> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { + +namespace { +constexpr float kEpsilon = 1e-3f; +constexpr float kX2BandEnergyThreshold = 44015068.0f; +constexpr int kBlocksToHoldErle = 100; +constexpr int kPointsToAccumulate = 6; +} // namespace + +FullBandErleEstimator::FullBandErleEstimator( + const EchoCanceller3Config::Erle& config, + size_t num_capture_channels) + : min_erle_log2_(FastApproxLog2f(config.min + kEpsilon)), + max_erle_lf_log2_(FastApproxLog2f(config.max_l + kEpsilon)), + hold_counters_instantaneous_erle_(num_capture_channels, 0), + erle_time_domain_log2_(num_capture_channels, min_erle_log2_), + instantaneous_erle_(num_capture_channels, ErleInstantaneous(config)), + linear_filters_qualities_(num_capture_channels) { + Reset(); +} + +FullBandErleEstimator::~FullBandErleEstimator() = default; + +void FullBandErleEstimator::Reset() { + for (auto& instantaneous_erle_ch : instantaneous_erle_) { + instantaneous_erle_ch.Reset(); + } + + UpdateQualityEstimates(); + std::fill(erle_time_domain_log2_.begin(), erle_time_domain_log2_.end(), + min_erle_log2_); + std::fill(hold_counters_instantaneous_erle_.begin(), + hold_counters_instantaneous_erle_.end(), 0); +} + +void FullBandErleEstimator::Update( + rtc::ArrayView<const float> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters) { + for (size_t ch = 0; ch < Y2.size(); ++ch) { + if (converged_filters[ch]) { + // Computes the fullband ERLE. + const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f); + if (X2_sum > kX2BandEnergyThreshold * X2.size()) { + const float Y2_sum = + std::accumulate(Y2[ch].begin(), Y2[ch].end(), 0.0f); + const float E2_sum = + std::accumulate(E2[ch].begin(), E2[ch].end(), 0.0f); + if (instantaneous_erle_[ch].Update(Y2_sum, E2_sum)) { + hold_counters_instantaneous_erle_[ch] = kBlocksToHoldErle; + erle_time_domain_log2_[ch] += + 0.05f * ((instantaneous_erle_[ch].GetInstErleLog2().value()) - + erle_time_domain_log2_[ch]); + erle_time_domain_log2_[ch] = + std::max(erle_time_domain_log2_[ch], min_erle_log2_); + } + } + } + --hold_counters_instantaneous_erle_[ch]; + if (hold_counters_instantaneous_erle_[ch] == 0) { + instantaneous_erle_[ch].ResetAccumulators(); + } + } + + UpdateQualityEstimates(); +} + +void FullBandErleEstimator::Dump( + const std::unique_ptr<ApmDataDumper>& data_dumper) const { + data_dumper->DumpRaw("aec3_fullband_erle_log2", FullbandErleLog2()); + instantaneous_erle_[0].Dump(data_dumper); +} + +void FullBandErleEstimator::UpdateQualityEstimates() { + for (size_t ch = 0; ch < instantaneous_erle_.size(); ++ch) { + linear_filters_qualities_[ch] = + instantaneous_erle_[ch].GetQualityEstimate(); + } +} + +FullBandErleEstimator::ErleInstantaneous::ErleInstantaneous( + const EchoCanceller3Config::Erle& config) + : clamp_inst_quality_to_zero_(config.clamp_quality_estimate_to_zero), + clamp_inst_quality_to_one_(config.clamp_quality_estimate_to_one) { + Reset(); +} + +FullBandErleEstimator::ErleInstantaneous::~ErleInstantaneous() = default; + +bool FullBandErleEstimator::ErleInstantaneous::Update(const float Y2_sum, + const float E2_sum) { + bool update_estimates = false; + E2_acum_ += E2_sum; + Y2_acum_ += Y2_sum; + num_points_++; + if (num_points_ == kPointsToAccumulate) { + if (E2_acum_ > 0.f) { + update_estimates = true; + erle_log2_ = FastApproxLog2f(Y2_acum_ / E2_acum_ + kEpsilon); + } + num_points_ = 0; + E2_acum_ = 0.f; + Y2_acum_ = 0.f; + } + + if (update_estimates) { + UpdateMaxMin(); + UpdateQualityEstimate(); + } + return update_estimates; +} + +void FullBandErleEstimator::ErleInstantaneous::Reset() { + ResetAccumulators(); + max_erle_log2_ = -10.f; // -30 dB. + min_erle_log2_ = 33.f; // 100 dB. + inst_quality_estimate_ = 0.f; +} + +void FullBandErleEstimator::ErleInstantaneous::ResetAccumulators() { + erle_log2_ = absl::nullopt; + inst_quality_estimate_ = 0.f; + num_points_ = 0; + E2_acum_ = 0.f; + Y2_acum_ = 0.f; +} + +void FullBandErleEstimator::ErleInstantaneous::Dump( + const std::unique_ptr<ApmDataDumper>& data_dumper) const { + data_dumper->DumpRaw("aec3_fullband_erle_inst_log2", + erle_log2_ ? *erle_log2_ : -10.f); + data_dumper->DumpRaw( + "aec3_erle_instantaneous_quality", + GetQualityEstimate() ? GetQualityEstimate().value() : 0.f); + data_dumper->DumpRaw("aec3_fullband_erle_max_log2", max_erle_log2_); + data_dumper->DumpRaw("aec3_fullband_erle_min_log2", min_erle_log2_); +} + +void FullBandErleEstimator::ErleInstantaneous::UpdateMaxMin() { + RTC_DCHECK(erle_log2_); + // Adding the forgetting factors for the maximum and minimum and capping the + // result to the incoming value. + max_erle_log2_ -= 0.0004f; // Forget factor, approx 1dB every 3 sec. + max_erle_log2_ = std::max(max_erle_log2_, erle_log2_.value()); + min_erle_log2_ += 0.0004f; // Forget factor, approx 1dB every 3 sec. + min_erle_log2_ = std::min(min_erle_log2_, erle_log2_.value()); +} + +void FullBandErleEstimator::ErleInstantaneous::UpdateQualityEstimate() { + const float alpha = 0.07f; + float quality_estimate = 0.f; + RTC_DCHECK(erle_log2_); + // TODO(peah): Currently, the estimate can become be less than 0; this should + // be corrected. + if (max_erle_log2_ > min_erle_log2_) { + quality_estimate = (erle_log2_.value() - min_erle_log2_) / + (max_erle_log2_ - min_erle_log2_); + } + if (quality_estimate > inst_quality_estimate_) { + inst_quality_estimate_ = quality_estimate; + } else { + inst_quality_estimate_ += + alpha * (quality_estimate - inst_quality_estimate_); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.h new file mode 100644 index 0000000000..7a082176d6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/fullband_erle_estimator.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_FULLBAND_ERLE_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_FULLBAND_ERLE_ESTIMATOR_H_ + +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +// Estimates the echo return loss enhancement using the energy of all the +// freuquency bands. +class FullBandErleEstimator { + public: + FullBandErleEstimator(const EchoCanceller3Config::Erle& config, + size_t num_capture_channels); + ~FullBandErleEstimator(); + // Resets the ERLE estimator. + void Reset(); + + // Updates the ERLE estimator. + void Update(rtc::ArrayView<const float> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters); + + // Returns the fullband ERLE estimates in log2 units. + float FullbandErleLog2() const { + float min_erle = erle_time_domain_log2_[0]; + for (size_t ch = 1; ch < erle_time_domain_log2_.size(); ++ch) { + min_erle = std::min(min_erle, erle_time_domain_log2_[ch]); + } + return min_erle; + } + + // Returns an estimation of the current linear filter quality. It returns a + // float number between 0 and 1 mapping 1 to the highest possible quality. + rtc::ArrayView<const absl::optional<float>> GetInstLinearQualityEstimates() + const { + return linear_filters_qualities_; + } + + void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const; + + private: + void UpdateQualityEstimates(); + + class ErleInstantaneous { + public: + explicit ErleInstantaneous(const EchoCanceller3Config::Erle& config); + ~ErleInstantaneous(); + + // Updates the estimator with a new point, returns true + // if the instantaneous ERLE was updated due to having enough + // points for performing the estimate. + bool Update(float Y2_sum, float E2_sum); + // Resets the instantaneous ERLE estimator to its initial state. + void Reset(); + // Resets the members related with an instantaneous estimate. + void ResetAccumulators(); + // Returns the instantaneous ERLE in log2 units. + absl::optional<float> GetInstErleLog2() const { return erle_log2_; } + // Gets an indication between 0 and 1 of the performance of the linear + // filter for the current time instant. + absl::optional<float> GetQualityEstimate() const { + if (erle_log2_) { + float value = inst_quality_estimate_; + if (clamp_inst_quality_to_zero_) { + value = std::max(0.f, value); + } + if (clamp_inst_quality_to_one_) { + value = std::min(1.f, value); + } + return absl::optional<float>(value); + } + return absl::nullopt; + } + void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const; + + private: + void UpdateMaxMin(); + void UpdateQualityEstimate(); + const bool clamp_inst_quality_to_zero_; + const bool clamp_inst_quality_to_one_; + absl::optional<float> erle_log2_; + float inst_quality_estimate_; + float max_erle_log2_; + float min_erle_log2_; + float Y2_acum_; + float E2_acum_; + int num_points_; + }; + + const float min_erle_log2_; + const float max_erle_lf_log2_; + std::vector<int> hold_counters_instantaneous_erle_; + std::vector<float> erle_time_domain_log2_; + std::vector<ErleInstantaneous> instantaneous_erle_; + std::vector<absl::optional<float>> linear_filters_qualities_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_FULLBAND_ERLE_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc new file mode 100644 index 0000000000..af30ff1b9f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.cc @@ -0,0 +1,900 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/matched_filter.h" + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_HAS_NEON) +#include <arm_neon.h> +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <algorithm> +#include <cstddef> +#include <initializer_list> +#include <iterator> +#include <numeric> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/experiments/field_trial_parser.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" + +namespace { + +// Subsample rate used for computing the accumulated error. +// The implementation of some core functions depends on this constant being +// equal to 4. +constexpr int kAccumulatedErrorSubSampleRate = 4; + +void UpdateAccumulatedError( + const rtc::ArrayView<const float> instantaneous_accumulated_error, + const rtc::ArrayView<float> accumulated_error, + float one_over_error_sum_anchor, + float smooth_constant_increases) { + for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) { + float error_norm = + instantaneous_accumulated_error[k] * one_over_error_sum_anchor; + if (error_norm < accumulated_error[k]) { + accumulated_error[k] = error_norm; + } else { + accumulated_error[k] += + smooth_constant_increases * (error_norm - accumulated_error[k]); + } + } +} + +size_t ComputePreEchoLag( + const webrtc::MatchedFilter::PreEchoConfiguration& pre_echo_configuration, + const rtc::ArrayView<const float> accumulated_error, + size_t lag, + size_t alignment_shift_winner) { + RTC_DCHECK_GE(lag, alignment_shift_winner); + size_t pre_echo_lag_estimate = lag - alignment_shift_winner; + size_t maximum_pre_echo_lag = + std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate, + accumulated_error.size()); + switch (pre_echo_configuration.mode) { + case 0: + // Mode 0: Pre echo lag is defined as the first coefficient with an error + // lower than a threshold with a certain decrease slope. + for (size_t k = 1; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < + pre_echo_configuration.threshold * accumulated_error[k - 1] && + accumulated_error[k] < pre_echo_configuration.threshold) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + break; + case 1: + // Mode 1: Pre echo lag is defined as the first coefficient with an error + // lower than a certain threshold. + for (size_t k = 0; k < maximum_pre_echo_lag; ++k) { + if (accumulated_error[k] < pre_echo_configuration.threshold) { + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + break; + } + } + break; + case 2: + case 3: + // Mode 2,3: Pre echo lag is defined as the closest coefficient to the lag + // with an error lower than a certain threshold. + for (int k = static_cast<int>(maximum_pre_echo_lag) - 1; k >= 0; --k) { + if (accumulated_error[k] > pre_echo_configuration.threshold) { + break; + } + pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1; + } + break; + default: + RTC_DCHECK_NOTREACHED(); + break; + } + return pre_echo_lag_estimate + alignment_shift_winner; +} + +webrtc::MatchedFilter::PreEchoConfiguration FetchPreEchoConfiguration() { + float threshold = 0.5f; + int mode = 0; + const std::string pre_echo_configuration_field_trial = + webrtc::field_trial::FindFullName("WebRTC-Aec3PreEchoConfiguration"); + webrtc::FieldTrialParameter<double> threshold_field_trial_parameter( + /*key=*/"threshold", /*default_value=*/threshold); + webrtc::FieldTrialParameter<int> mode_field_trial_parameter( + /*key=*/"mode", /*default_value=*/mode); + webrtc::ParseFieldTrial( + {&threshold_field_trial_parameter, &mode_field_trial_parameter}, + pre_echo_configuration_field_trial); + float threshold_read = + static_cast<float>(threshold_field_trial_parameter.Get()); + int mode_read = mode_field_trial_parameter.Get(); + if (threshold_read < 1.0f && threshold_read > 0.0f) { + threshold = threshold_read; + } else { + RTC_LOG(LS_ERROR) + << "AEC3: Pre echo configuration: wrong input, threshold = " + << threshold_read << "."; + } + if (mode_read >= 0 && mode_read <= 3) { + mode = mode_read; + } else { + RTC_LOG(LS_ERROR) << "AEC3: Pre echo configuration: wrong input, mode = " + << mode_read << "."; + } + RTC_LOG(LS_INFO) << "AEC3: Pre echo configuration: threshold = " << threshold + << ", mode = " << mode << "."; + return {.threshold = threshold, .mode = mode}; +} + +} // namespace + +namespace webrtc { +namespace aec3 { + +#if defined(WEBRTC_HAS_NEON) + +inline float SumAllElements(float32x4_t elements) { + float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements)); + sum = vpadd_f32(sum, sum); + return vget_lane_f32(sum, 0); +} + +void MatchedFilterCoreWithAccumulatedError_NEON( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* accumulated_error_p = &accumulated_error[0]; + // Initialize values for the accumulation. + float32x4_t x2_sum_128 = vdupq_n_f32(0); + float x2_sum = 0.f; + float s = 0; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; + --k, h_p += 4, x_p += 4, accumulated_error_p++) { + // Load the data into 128 bit vectors. + const float32x4_t x_k = vld1q_f32(x_p); + const float32x4_t h_k = vld1q_f32(h_p); + // Compute and accumulate x * x. + x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); + // Compute x * h + float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k); + s += SumAllElements(hk_xk_128); + const float e = s - y[i]; + accumulated_error_p[0] += e * e; + } + // Combine the accumulated vector and scalar values. + x2_sum += SumAllElements(x2_sum_128); + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const float32x4_t alpha_128 = vmovq_n_f32(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + float32x4_t h_k = vld1q_f32(h_p); + const float32x4_t x_k = vld1q_f32(x_p); + // Compute h = h + alpha * x. + h_k = vmlaq_f32(h_k, alpha_128, x_k); + // Store the result. + vst1q_f32(h_p, h_k); + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +void MatchedFilterCore_NEON(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + + if (compute_accumulated_error) { + return MatchedFilterCoreWithAccumulatedError_NEON( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + + // Initialize values for the accumulation. + float32x4_t s_128 = vdupq_n_f32(0); + float32x4_t x2_sum_128 = vdupq_n_f32(0); + float x2_sum = 0.f; + float s = 0; + + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + const float32x4_t x_k = vld1q_f32(x_p); + const float32x4_t h_k = vld1q_f32(h_p); + // Compute and accumulate x * x and h * x. + x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k); + s_128 = vmlaq_f32(s_128, h_k, x_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + + x_p = &x[0]; + } + + // Combine the accumulated vector and scalar values. + s += SumAllElements(s_128); + x2_sum += SumAllElements(x2_sum_128); + + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const float32x4_t alpha_128 = vmovq_n_f32(alpha); + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + float32x4_t h_k = vld1q_f32(h_p); + const float32x4_t x_k = vld1q_f32(x_p); + // Compute h = h + alpha * x. + h_k = vmlaq_f32(h_k, alpha_128, x_k); + + // Store the result. + vst1q_f32(h_p, h_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + + x_p = &x[0]; + } + + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +void MatchedFilterCore_AccumulatedError_SSE2( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 8); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* a_p = &accumulated_error[0]; + __m128 s_inst_128; + __m128 s_inst_128_4; + __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); + __m128 e_128; + float* const s_p = reinterpret_cast<float*>(&s_inst_128); + float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4); + float* const e_p = reinterpret_cast<float*>(&e_128); + float x2_sum = 0.0f; + float s_acum = 0; + // Perform 128 bit vector operations. + const int limit_by_8 = h_size >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) { + // Load the data into 128 bit vectors. + const __m128 x_k = _mm_loadu_ps(x_p); + const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); + const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); + // Compute and accumulate x * x and h * x. + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); + s_inst_128 = _mm_mul_ps(h_k, x_k); + s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4); + s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3]; + e_p[0] = s_acum - y[i]; + s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3]; + e_p[1] = s_acum - y[i]; + a_p[0] += e_p[0] * e_p[0]; + a_p[1] += e_p[1] * e_p[1]; + } + // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + // Compute the matched filter error. + float e = y[i] - s_acum; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m128 alpha_128 = _mm_set1_ps(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 128 bit vector operations. + const int limit_by_4 = h_size >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k = _mm_loadu_ps(x_p); + // Compute h = h + alpha * x. + const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); + h_k = _mm_add_ps(h_k, alpha_x); + // Store the result. + _mm_storeu_ps(h_p, h_k); + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +void MatchedFilterCore_SSE2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + if (compute_accumulated_error) { + return MatchedFilterCore_AccumulatedError_SSE2( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 4); + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + // Initialize values for the accumulation. + __m128 s_128 = _mm_set1_ps(0); + __m128 s_128_4 = _mm_set1_ps(0); + __m128 x2_sum_128 = _mm_set1_ps(0); + __m128 x2_sum_128_4 = _mm_set1_ps(0); + float x2_sum = 0.f; + float s = 0; + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 128 bit vectors. + const __m128 x_k = _mm_loadu_ps(x_p); + const __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k_4 = _mm_loadu_ps(x_p + 4); + const __m128 h_k_4 = _mm_loadu_ps(h_p + 4); + const __m128 xx = _mm_mul_ps(x_k, x_k); + const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4); + // Compute and accumulate x * x and h * x. + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); + x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4); + const __m128 hx = _mm_mul_ps(h_k, x_k); + const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4); + s_128 = _mm_add_ps(s_128, hx); + s_128_4 = _mm_add_ps(s_128_4, hx_4); + } + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + x_p = &x[0]; + } + // Combine the accumulated vector and scalar values. + x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4); + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + s_128 = _mm_add_ps(s_128, s_128_4); + v = reinterpret_cast<float*>(&s_128); + s += v[0] + v[1] + v[2] + v[3]; + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m128 alpha_128 = _mm_set1_ps(alpha); + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 128 bit vector operations. + const int limit_by_4 = limit >> 2; + for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) { + // Load the data into 128 bit vectors. + __m128 h_k = _mm_loadu_ps(h_p); + const __m128 x_k = _mm_loadu_ps(x_p); + + // Compute h = h + alpha * x. + const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k); + h_k = _mm_add_ps(h_k, alpha_x); + // Store the result. + _mm_storeu_ps(h_p, h_k); + } + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + x_p = &x[0]; + } + *filters_updated = true; + } + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} +#endif + +void MatchedFilterCore(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error) { + if (compute_accumulated_error) { + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + } + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + float x2_sum = 0.f; + float s = 0; + size_t x_index = x_start_index; + if (compute_accumulated_error) { + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + if ((k + 1 & 0b11) == 0) { + int idx = k >> 2; + accumulated_error[idx] += (y[i] - s) * (y[i] - s); + } + } + } else { + for (size_t k = 0; k < h.size(); ++k) { + x2_sum += x[x_index] * x[x_index]; + s += h[k] * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + } + + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + size_t x_index = x_start_index; + for (size_t k = 0; k < h.size(); ++k) { + h[k] += alpha * x[x_index]; + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; + } + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; + } +} + +size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) { + if (h.size() < 2) { + return 0; + } + float max_element1 = h[0] * h[0]; + float max_element2 = h[1] * h[1]; + size_t lag_estimate1 = 0; + size_t lag_estimate2 = 1; + const size_t last_index = h.size() - 1; + // Keeping track of even & odd max elements separately typically allows the + // compiler to produce more efficient code. + for (size_t k = 2; k < last_index; k += 2) { + float element1 = h[k] * h[k]; + float element2 = h[k + 1] * h[k + 1]; + if (element1 > max_element1) { + max_element1 = element1; + lag_estimate1 = k; + } + if (element2 > max_element2) { + max_element2 = element2; + lag_estimate2 = k + 1; + } + } + if (max_element2 > max_element1) { + max_element1 = max_element2; + lag_estimate1 = lag_estimate2; + } + // In case of odd h size, we have not yet checked the last element. + float last_element = h[last_index] * h[last_index]; + if (last_element > max_element1) { + return last_index; + } + return lag_estimate1; +} + +} // namespace aec3 + +MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, + Aec3Optimization optimization, + size_t sub_block_size, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks, + float excitation_limit, + float smoothing_fast, + float smoothing_slow, + float matching_filter_threshold, + bool detect_pre_echo) + : data_dumper_(data_dumper), + optimization_(optimization), + sub_block_size_(sub_block_size), + filter_intra_lag_shift_(alignment_shift_sub_blocks * sub_block_size_), + filters_( + num_matched_filters, + std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)), + filters_offsets_(num_matched_filters, 0), + excitation_limit_(excitation_limit), + smoothing_fast_(smoothing_fast), + smoothing_slow_(smoothing_slow), + matching_filter_threshold_(matching_filter_threshold), + detect_pre_echo_(detect_pre_echo), + pre_echo_config_(FetchPreEchoConfiguration()) { + RTC_DCHECK(data_dumper); + RTC_DCHECK_LT(0, window_size_sub_blocks); + RTC_DCHECK((kBlockSize % sub_block_size) == 0); + RTC_DCHECK((sub_block_size % 4) == 0); + static_assert(kAccumulatedErrorSubSampleRate == 4); + if (detect_pre_echo_) { + accumulated_error_ = std::vector<std::vector<float>>( + num_matched_filters, + std::vector<float>(window_size_sub_blocks * sub_block_size_ / + kAccumulatedErrorSubSampleRate, + 1.0f)); + + instantaneous_accumulated_error_ = + std::vector<float>(window_size_sub_blocks * sub_block_size_ / + kAccumulatedErrorSubSampleRate, + 0.0f); + scratch_memory_ = + std::vector<float>(window_size_sub_blocks * sub_block_size_); + } +} + +MatchedFilter::~MatchedFilter() = default; + +void MatchedFilter::Reset(bool full_reset) { + for (auto& f : filters_) { + std::fill(f.begin(), f.end(), 0.f); + } + + winner_lag_ = absl::nullopt; + reported_lag_estimate_ = absl::nullopt; + if (pre_echo_config_.mode != 3 || full_reset) { + for (auto& e : accumulated_error_) { + std::fill(e.begin(), e.end(), 1.0f); + } + number_pre_echo_updates_ = 0; + } +} + +void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer, + rtc::ArrayView<const float> capture, + bool use_slow_smoothing) { + RTC_DCHECK_EQ(sub_block_size_, capture.size()); + auto& y = capture; + + const float smoothing = + use_slow_smoothing ? smoothing_slow_ : smoothing_fast_; + + const float x2_sum_threshold = + filters_[0].size() * excitation_limit_ * excitation_limit_; + + // Compute anchor for the matched filter error. + float error_sum_anchor = 0.0f; + for (size_t k = 0; k < y.size(); ++k) { + error_sum_anchor += y[k] * y[k]; + } + + // Apply all matched filters. + float winner_error_sum = error_sum_anchor; + winner_lag_ = absl::nullopt; + reported_lag_estimate_ = absl::nullopt; + size_t alignment_shift = 0; + absl::optional<size_t> previous_lag_estimate; + const int num_filters = static_cast<int>(filters_.size()); + int winner_index = -1; + for (int n = 0; n < num_filters; ++n) { + float error_sum = 0.f; + bool filters_updated = false; + const bool compute_pre_echo = + detect_pre_echo_ && n == last_detected_best_lag_filter_; + + size_t x_start_index = + (render_buffer.read + alignment_shift + sub_block_size_ - 1) % + render_buffer.buffer.size(); + + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::MatchedFilterCore_SSE2( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); + break; + case Aec3Optimization::kAvx2: + aec3::MatchedFilterCore_AVX2( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::MatchedFilterCore_NEON( + x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y, + filters_[n], &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_, scratch_memory_); + break; +#endif + default: + aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing, + render_buffer.buffer, y, filters_[n], + &filters_updated, &error_sum, compute_pre_echo, + instantaneous_accumulated_error_); + } + + // Estimate the lag in the matched filter as the distance to the portion in + // the filter that contributes the most to the matched filter output. This + // is detected as the peak of the matched filter. + const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); + const bool reliable = + lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && + error_sum < matching_filter_threshold_ * error_sum_anchor; + + // Find the best estimate + const size_t lag = lag_estimate + alignment_shift; + if (filters_updated && reliable && error_sum < winner_error_sum) { + winner_error_sum = error_sum; + winner_index = n; + // In case that 2 matched filters return the same winner candidate + // (overlap region), the one with the smaller index is chosen in order + // to search for pre-echoes. + if (previous_lag_estimate && previous_lag_estimate == lag) { + winner_lag_ = previous_lag_estimate; + winner_index = n - 1; + } else { + winner_lag_ = lag; + } + } + previous_lag_estimate = lag; + alignment_shift += filter_intra_lag_shift_; + } + + if (winner_index != -1) { + RTC_DCHECK(winner_lag_.has_value()); + reported_lag_estimate_ = + LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value()); + if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) { + const float energy_threshold = + pre_echo_config_.mode == 3 ? 1.0f : 30.0f * 30.0f * y.size(); + + if (error_sum_anchor > energy_threshold) { + const float smooth_constant_increases = + pre_echo_config_.mode != 3 ? 0.01f : 0.015f; + + UpdateAccumulatedError( + instantaneous_accumulated_error_, accumulated_error_[winner_index], + 1.0f / error_sum_anchor, smooth_constant_increases); + number_pre_echo_updates_++; + } + if (pre_echo_config_.mode != 3 || number_pre_echo_updates_ >= 50) { + reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag( + pre_echo_config_, accumulated_error_[winner_index], + winner_lag_.value(), + winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/); + } else { + reported_lag_estimate_->pre_echo_lag = winner_lag_.value(); + } + } + last_detected_best_lag_filter_ = winner_index; + } + if (ApmDataDumper::IsAvailable()) { + Dump(); + data_dumper_->DumpRaw("error_sum_anchor", error_sum_anchor / y.size()); + data_dumper_->DumpRaw("number_pre_echo_updates", number_pre_echo_updates_); + data_dumper_->DumpRaw("filter_smoothing", smoothing); + } +} + +void MatchedFilter::LogFilterProperties(int sample_rate_hz, + size_t shift, + size_t downsampling_factor) const { + size_t alignment_shift = 0; + constexpr int kFsBy1000 = 16; + for (size_t k = 0; k < filters_.size(); ++k) { + int start = static_cast<int>(alignment_shift * downsampling_factor); + int end = static_cast<int>((alignment_shift + filters_[k].size()) * + downsampling_factor); + RTC_LOG(LS_VERBOSE) << "Filter " << k << ": start: " + << (start - static_cast<int>(shift)) / kFsBy1000 + << " ms, end: " + << (end - static_cast<int>(shift)) / kFsBy1000 + << " ms."; + alignment_shift += filter_intra_lag_shift_; + } +} + +void MatchedFilter::Dump() { + for (size_t n = 0; n < filters_.size(); ++n) { + const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]); + std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h"; + data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]); + std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n); + data_dumper_->DumpRaw(dumper_lag.c_str(), + lag_estimate + n * filter_intra_lag_shift_); + if (detect_pre_echo_) { + std::string dumper_error = + "aec3_correlator_error_" + std::to_string(n) + "_h"; + data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]); + + size_t pre_echo_lag = + ComputePreEchoLag(pre_echo_config_, accumulated_error_[n], + lag_estimate + n * filter_intra_lag_shift_, + n * filter_intra_lag_shift_); + std::string dumper_pre_lag = + "aec3_correlator_pre_echo_lag_" + std::to_string(n); + data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag); + if (static_cast<int>(n) == last_detected_best_lag_filter_) { + data_dumper_->DumpRaw("aec3_pre_echo_delay_winner_inst", pre_echo_lag); + } + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.h b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.h new file mode 100644 index 0000000000..bb54fba2b4 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ + +#include <stddef.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/gtest_prod_util.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { + +class ApmDataDumper; +struct DownsampledRenderBuffer; + +namespace aec3 { + +#if defined(WEBRTC_HAS_NEON) + +// Filter core for the matched filter that is optimized for NEON. +void MatchedFilterCore_NEON(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulation_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +// Filter core for the matched filter that is optimized for SSE2. +void MatchedFilterCore_SSE2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); + +// Filter core for the matched filter that is optimized for AVX2. +void MatchedFilterCore_AVX2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory); + +#endif + +// Filter core for the matched filter. +void MatchedFilterCore(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulation_error, + rtc::ArrayView<float> accumulated_error); + +// Find largest peak of squared values in array. +size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h); + +} // namespace aec3 + +// Produces recursively updated cross-correlation estimates for several signal +// shifts where the intra-shift spacing is uniform. +class MatchedFilter { + public: + // Stores properties for the lag estimate corresponding to a particular signal + // shift. + struct LagEstimate { + LagEstimate() = default; + LagEstimate(size_t lag, size_t pre_echo_lag) + : lag(lag), pre_echo_lag(pre_echo_lag) {} + size_t lag = 0; + size_t pre_echo_lag = 0; + }; + + struct PreEchoConfiguration { + const float threshold; + const int mode; + }; + + MatchedFilter(ApmDataDumper* data_dumper, + Aec3Optimization optimization, + size_t sub_block_size, + size_t window_size_sub_blocks, + int num_matched_filters, + size_t alignment_shift_sub_blocks, + float excitation_limit, + float smoothing_fast, + float smoothing_slow, + float matching_filter_threshold, + bool detect_pre_echo); + + MatchedFilter() = delete; + MatchedFilter(const MatchedFilter&) = delete; + MatchedFilter& operator=(const MatchedFilter&) = delete; + + ~MatchedFilter(); + + // Updates the correlation with the values in the capture buffer. + void Update(const DownsampledRenderBuffer& render_buffer, + rtc::ArrayView<const float> capture, + bool use_slow_smoothing); + + // Resets the matched filter. + void Reset(bool full_reset); + + // Returns the current lag estimates. + absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const { + return reported_lag_estimate_; + } + + // Returns the maximum filter lag. + size_t GetMaxFilterLag() const { + return filters_.size() * filter_intra_lag_shift_ + filters_[0].size(); + } + + // Log matched filter properties. + void LogFilterProperties(int sample_rate_hz, + size_t shift, + size_t downsampling_factor) const; + + private: + FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, + PreEchoConfigurationTest); + FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, + WrongPreEchoConfigurationTest); + + // Only for testing. Gets the pre echo detection configuration. + const PreEchoConfiguration& GetPreEchoConfiguration() const { + return pre_echo_config_; + } + void Dump(); + + ApmDataDumper* const data_dumper_; + const Aec3Optimization optimization_; + const size_t sub_block_size_; + const size_t filter_intra_lag_shift_; + std::vector<std::vector<float>> filters_; + std::vector<std::vector<float>> accumulated_error_; + std::vector<float> instantaneous_accumulated_error_; + std::vector<float> scratch_memory_; + absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_; + absl::optional<size_t> winner_lag_; + int last_detected_best_lag_filter_ = -1; + std::vector<size_t> filters_offsets_; + int number_pre_echo_updates_ = 0; + const float excitation_limit_; + const float smoothing_fast_; + const float smoothing_slow_; + const float matching_filter_threshold_; + const bool detect_pre_echo_; + const PreEchoConfiguration pre_echo_config_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_avx2.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_avx2.cc new file mode 100644 index 0000000000..8c2ffcbd1e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_avx2.cc @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <immintrin.h> + +#include "modules/audio_processing/aec3/matched_filter.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace aec3 { + +// Let ha denote the horizontal of a, and hb the horizontal sum of b +// returns [ha, hb, ha, hb] +inline __m128 hsum_ab(__m256 a, __m256 b) { + __m256 s_256 = _mm256_hadd_ps(a, b); + const __m256i mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + s_256 = _mm256_permutevar8x32_ps(s_256, mask); + __m128 s = _mm_hadd_ps(_mm256_extractf128_ps(s_256, 0), + _mm256_extractf128_ps(s_256, 1)); + s = _mm_hadd_ps(s, s); + return s; +} + +void MatchedFilterCore_AccumulatedError_AVX2( + size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 16); + std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + RTC_DCHECK_GT(x_size, x_start_index); + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + if (chunk1 != h_size) { + const int chunk2 = h_size - chunk1; + std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin()); + std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1); + } + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + const float* h_p = &h[0]; + float* a_p = &accumulated_error[0]; + __m256 s_inst_hadd_256; + __m256 s_inst_256; + __m256 s_inst_256_8; + __m256 x2_sum_256 = _mm256_set1_ps(0); + __m256 x2_sum_256_8 = _mm256_set1_ps(0); + __m128 e_128; + float x2_sum = 0.0f; + float s_acum = 0; + const int limit_by_16 = h_size >> 4; + for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16, a_p += 4) { + // Load the data into 256 bit vectors. + __m256 x_k = _mm256_loadu_ps(x_p); + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k_8 = _mm256_loadu_ps(x_p + 8); + __m256 h_k_8 = _mm256_loadu_ps(h_p + 8); + // Compute and accumulate x * x and h * x. + x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256); + x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8); + s_inst_256 = _mm256_mul_ps(h_k, x_k); + s_inst_256_8 = _mm256_mul_ps(h_k_8, x_k_8); + s_inst_hadd_256 = _mm256_hadd_ps(s_inst_256, s_inst_256_8); + s_inst_hadd_256 = _mm256_hadd_ps(s_inst_hadd_256, s_inst_hadd_256); + s_acum += s_inst_hadd_256[0]; + e_128[0] = s_acum - y[i]; + s_acum += s_inst_hadd_256[4]; + e_128[1] = s_acum - y[i]; + s_acum += s_inst_hadd_256[1]; + e_128[2] = s_acum - y[i]; + s_acum += s_inst_hadd_256[5]; + e_128[3] = s_acum - y[i]; + + __m128 accumulated_error = _mm_load_ps(a_p); + accumulated_error = _mm_fmadd_ps(e_128, e_128, accumulated_error); + _mm_storeu_ps(a_p, accumulated_error); + } + // Sum components together. + x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8); + __m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0), + _mm256_extractf128_ps(x2_sum_256, 1)); + // Combine the accumulated vector and scalar values. + float* v = reinterpret_cast<float*>(&x2_sum_128); + x2_sum += v[0] + v[1] + v[2] + v[3]; + + // Compute the matched filter error. + float e = y[i] - s_acum; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m256 alpha_256 = _mm256_set1_ps(alpha); + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + const float* x_p = + chunk1 != h_size ? scratch_memory.data() : &x[x_start_index]; + // Perform 256 bit vector operations. + const int limit_by_8 = h_size >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 256 bit vectors. + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k = _mm256_loadu_ps(x_p); + // Compute h = h + alpha * x. + h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k); + + // Store the result. + _mm256_storeu_ps(h_p, h_k); + } + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +void MatchedFilterCore_AVX2(size_t x_start_index, + float x2_sum_threshold, + float smoothing, + rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> h, + bool* filters_updated, + float* error_sum, + bool compute_accumulated_error, + rtc::ArrayView<float> accumulated_error, + rtc::ArrayView<float> scratch_memory) { + if (compute_accumulated_error) { + return MatchedFilterCore_AccumulatedError_AVX2( + x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated, + error_sum, accumulated_error, scratch_memory); + } + const int h_size = static_cast<int>(h.size()); + const int x_size = static_cast<int>(x.size()); + RTC_DCHECK_EQ(0, h_size % 8); + + // Process for all samples in the sub-block. + for (size_t i = 0; i < y.size(); ++i) { + // Apply the matched filter as filter * x, and compute x * x. + + RTC_DCHECK_GT(x_size, x_start_index); + const float* x_p = &x[x_start_index]; + const float* h_p = &h[0]; + + // Initialize values for the accumulation. + __m256 s_256 = _mm256_set1_ps(0); + __m256 s_256_8 = _mm256_set1_ps(0); + __m256 x2_sum_256 = _mm256_set1_ps(0); + __m256 x2_sum_256_8 = _mm256_set1_ps(0); + float x2_sum = 0.f; + float s = 0; + + // Compute loop chunk sizes until, and after, the wraparound of the circular + // buffer for x. + const int chunk1 = + std::min(h_size, static_cast<int>(x_size - x_start_index)); + + // Perform the loop in two chunks. + const int chunk2 = h_size - chunk1; + for (int limit : {chunk1, chunk2}) { + // Perform 256 bit vector operations. + const int limit_by_16 = limit >> 4; + for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16) { + // Load the data into 256 bit vectors. + __m256 x_k = _mm256_loadu_ps(x_p); + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k_8 = _mm256_loadu_ps(x_p + 8); + __m256 h_k_8 = _mm256_loadu_ps(h_p + 8); + // Compute and accumulate x * x and h * x. + x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256); + x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8); + s_256 = _mm256_fmadd_ps(h_k, x_k, s_256); + s_256_8 = _mm256_fmadd_ps(h_k_8, x_k_8, s_256_8); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_16 * 16; k > 0; --k, ++h_p, ++x_p) { + const float x_k = *x_p; + x2_sum += x_k * x_k; + s += *h_p * x_k; + } + + x_p = &x[0]; + } + + // Sum components together. + x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8); + s_256 = _mm256_add_ps(s_256, s_256_8); + __m128 sum = hsum_ab(x2_sum_256, s_256); + x2_sum += sum[0]; + s += sum[1]; + + // Compute the matched filter error. + float e = y[i] - s; + const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f; + (*error_sum) += e * e; + + // Update the matched filter estimate in an NLMS manner. + if (x2_sum > x2_sum_threshold && !saturation) { + RTC_DCHECK_LT(0.f, x2_sum); + const float alpha = smoothing * e / x2_sum; + const __m256 alpha_256 = _mm256_set1_ps(alpha); + + // filter = filter + smoothing * (y - filter * x) * x / x * x. + float* h_p = &h[0]; + x_p = &x[x_start_index]; + + // Perform the loop in two chunks. + for (int limit : {chunk1, chunk2}) { + // Perform 256 bit vector operations. + const int limit_by_8 = limit >> 3; + for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) { + // Load the data into 256 bit vectors. + __m256 h_k = _mm256_loadu_ps(h_p); + __m256 x_k = _mm256_loadu_ps(x_p); + // Compute h = h + alpha * x. + h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k); + + // Store the result. + _mm256_storeu_ps(h_p, h_k); + } + + // Perform non-vector operations for any remaining items. + for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) { + *h_p += alpha * *x_p; + } + + x_p = &x[0]; + } + + *filters_updated = true; + } + + x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_gn/moz.build new file mode 100644 index 0000000000..bae4fa2972 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_gn/moz.build @@ -0,0 +1,205 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("matched_filter_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc new file mode 100644 index 0000000000..bea7868a91 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.cc @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +#include <algorithm> +#include <iterator> + +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { +namespace { +constexpr int kPreEchoHistogramDataNotUpdated = -1; + +int GetDownSamplingBlockSizeLog2(int down_sampling_factor) { + int down_sampling_factor_log2 = 0; + down_sampling_factor >>= 1; + while (down_sampling_factor > 0) { + down_sampling_factor_log2++; + down_sampling_factor >>= 1; + } + return static_cast<int>(kBlockSizeLog2) > down_sampling_factor_log2 + ? static_cast<int>(kBlockSizeLog2) - down_sampling_factor_log2 + : 0; +} +} // namespace + +MatchedFilterLagAggregator::MatchedFilterLagAggregator( + ApmDataDumper* data_dumper, + size_t max_filter_lag, + const EchoCanceller3Config::Delay& delay_config) + : data_dumper_(data_dumper), + thresholds_(delay_config.delay_selection_thresholds), + headroom_(static_cast<int>(delay_config.delay_headroom_samples / + delay_config.down_sampling_factor)), + highest_peak_aggregator_(max_filter_lag) { + if (delay_config.detect_pre_echo) { + pre_echo_lag_aggregator_ = std::make_unique<PreEchoLagAggregator>( + max_filter_lag, delay_config.down_sampling_factor); + } + RTC_DCHECK(data_dumper); + RTC_DCHECK_LE(thresholds_.initial, thresholds_.converged); +} + +MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default; + +void MatchedFilterLagAggregator::Reset(bool hard_reset) { + highest_peak_aggregator_.Reset(); + if (pre_echo_lag_aggregator_ != nullptr) { + pre_echo_lag_aggregator_->Reset(); + } + if (hard_reset) { + significant_candidate_found_ = false; + } +} + +absl::optional<DelayEstimate> MatchedFilterLagAggregator::Aggregate( + const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate) { + if (lag_estimate && pre_echo_lag_aggregator_) { + pre_echo_lag_aggregator_->Dump(data_dumper_); + pre_echo_lag_aggregator_->Aggregate( + std::max(0, static_cast<int>(lag_estimate->pre_echo_lag) - headroom_)); + } + + if (lag_estimate) { + highest_peak_aggregator_.Aggregate( + std::max(0, static_cast<int>(lag_estimate->lag) - headroom_)); + rtc::ArrayView<const int> histogram = highest_peak_aggregator_.histogram(); + int candidate = highest_peak_aggregator_.candidate(); + significant_candidate_found_ = significant_candidate_found_ || + histogram[candidate] > thresholds_.converged; + if (histogram[candidate] > thresholds_.converged || + (histogram[candidate] > thresholds_.initial && + !significant_candidate_found_)) { + DelayEstimate::Quality quality = significant_candidate_found_ + ? DelayEstimate::Quality::kRefined + : DelayEstimate::Quality::kCoarse; + int reported_delay = pre_echo_lag_aggregator_ != nullptr + ? pre_echo_lag_aggregator_->pre_echo_candidate() + : candidate; + return DelayEstimate(quality, reported_delay); + } + } + + return absl::nullopt; +} + +MatchedFilterLagAggregator::HighestPeakAggregator::HighestPeakAggregator( + size_t max_filter_lag) + : histogram_(max_filter_lag + 1, 0) { + histogram_data_.fill(0); +} + +void MatchedFilterLagAggregator::HighestPeakAggregator::Reset() { + std::fill(histogram_.begin(), histogram_.end(), 0); + histogram_data_.fill(0); + histogram_data_index_ = 0; +} + +void MatchedFilterLagAggregator::HighestPeakAggregator::Aggregate(int lag) { + RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]); + RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]); + --histogram_[histogram_data_[histogram_data_index_]]; + histogram_data_[histogram_data_index_] = lag; + RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]); + RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]); + ++histogram_[histogram_data_[histogram_data_index_]]; + histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size(); + candidate_ = + std::distance(histogram_.begin(), + std::max_element(histogram_.begin(), histogram_.end())); +} + +MatchedFilterLagAggregator::PreEchoLagAggregator::PreEchoLagAggregator( + size_t max_filter_lag, + size_t down_sampling_factor) + : block_size_log2_(GetDownSamplingBlockSizeLog2(down_sampling_factor)), + histogram_( + ((max_filter_lag + 1) * down_sampling_factor) >> kBlockSizeLog2, + 0) { + Reset(); +} + +void MatchedFilterLagAggregator::PreEchoLagAggregator::Reset() { + std::fill(histogram_.begin(), histogram_.end(), 0); + histogram_data_.fill(kPreEchoHistogramDataNotUpdated); + histogram_data_index_ = 0; + pre_echo_candidate_ = 0; +} + +void MatchedFilterLagAggregator::PreEchoLagAggregator::Aggregate( + int pre_echo_lag) { + int pre_echo_block_size = pre_echo_lag >> block_size_log2_; + RTC_DCHECK(pre_echo_block_size >= 0 && + pre_echo_block_size < static_cast<int>(histogram_.size())); + pre_echo_block_size = + rtc::SafeClamp(pre_echo_block_size, 0, histogram_.size() - 1); + // Remove the oldest point from the `histogram_`, it ignores the initial + // points where no updates have been done to the `histogram_data_` array. + if (histogram_data_[histogram_data_index_] != + kPreEchoHistogramDataNotUpdated) { + --histogram_[histogram_data_[histogram_data_index_]]; + } + histogram_data_[histogram_data_index_] = pre_echo_block_size; + ++histogram_[histogram_data_[histogram_data_index_]]; + histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size(); + int pre_echo_candidate_block_size = + std::distance(histogram_.begin(), + std::max_element(histogram_.begin(), histogram_.end())); + pre_echo_candidate_ = (pre_echo_candidate_block_size << block_size_log2_); +} + +void MatchedFilterLagAggregator::PreEchoLagAggregator::Dump( + ApmDataDumper* const data_dumper) { + data_dumper->DumpRaw("aec3_pre_echo_delay_candidate", pre_echo_candidate_); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h new file mode 100644 index 0000000000..c0598bf226 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/matched_filter.h" + +namespace webrtc { + +class ApmDataDumper; + +// Aggregates lag estimates produced by the MatchedFilter class into a single +// reliable combined lag estimate. +class MatchedFilterLagAggregator { + public: + MatchedFilterLagAggregator(ApmDataDumper* data_dumper, + size_t max_filter_lag, + const EchoCanceller3Config::Delay& delay_config); + + MatchedFilterLagAggregator() = delete; + MatchedFilterLagAggregator(const MatchedFilterLagAggregator&) = delete; + MatchedFilterLagAggregator& operator=(const MatchedFilterLagAggregator&) = + delete; + + ~MatchedFilterLagAggregator(); + + // Resets the aggregator. + void Reset(bool hard_reset); + + // Aggregates the provided lag estimates. + absl::optional<DelayEstimate> Aggregate( + const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate); + + // Returns whether a reliable delay estimate has been found. + bool ReliableDelayFound() const { return significant_candidate_found_; } + + // Returns the delay candidate that is computed by looking at the highest peak + // on the matched filters. + int GetDelayAtHighestPeak() const { + return highest_peak_aggregator_.candidate(); + } + + private: + class PreEchoLagAggregator { + public: + PreEchoLagAggregator(size_t max_filter_lag, size_t down_sampling_factor); + void Reset(); + void Aggregate(int pre_echo_lag); + int pre_echo_candidate() const { return pre_echo_candidate_; } + void Dump(ApmDataDumper* const data_dumper); + + private: + const int block_size_log2_; + std::array<int, 250> histogram_data_; + std::vector<int> histogram_; + int histogram_data_index_ = 0; + int pre_echo_candidate_ = 0; + }; + + class HighestPeakAggregator { + public: + explicit HighestPeakAggregator(size_t max_filter_lag); + void Reset(); + void Aggregate(int lag); + int candidate() const { return candidate_; } + rtc::ArrayView<const int> histogram() const { return histogram_; } + + private: + std::vector<int> histogram_; + std::array<int, 250> histogram_data_; + int histogram_data_index_ = 0; + int candidate_ = -1; + }; + + ApmDataDumper* const data_dumper_; + bool significant_candidate_found_ = false; + const EchoCanceller3Config::Delay::DelaySelectionThresholds thresholds_; + const int headroom_; + HighestPeakAggregator highest_peak_aggregator_; + std::unique_ptr<PreEchoLagAggregator> pre_echo_lag_aggregator_; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_LAG_AGGREGATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc new file mode 100644 index 0000000000..6804102584 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_lag_aggregator_unittest.cc @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/matched_filter_lag_aggregator.h" + +#include <sstream> +#include <string> +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +constexpr size_t kNumLagsBeforeDetection = 26; + +} // namespace + +// Verifies that varying lag estimates causes lag estimates to not be deemed +// reliable. +TEST(MatchedFilterLagAggregator, + LagEstimateInvarianceRequiredForAggregatedLag) { + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/100, + config.delay); + + absl::optional<DelayEstimate> aggregated_lag; + for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) { + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/10, /*pre_echo_lag=*/10)); + } + EXPECT_TRUE(aggregated_lag); + + for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) { + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100)); + } + EXPECT_FALSE(aggregated_lag); + + for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) { + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100)); + EXPECT_FALSE(aggregated_lag); + } +} + +// Verifies that lag estimate updates are required to produce an updated lag +// aggregate. +TEST(MatchedFilterLagAggregator, + DISABLED_LagEstimateUpdatesRequiredForAggregatedLag) { + constexpr size_t kLag = 5; + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/kLag, + config.delay); + for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) { + absl::optional<DelayEstimate> aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/kLag, /*pre_echo_lag=*/kLag)); + EXPECT_FALSE(aggregated_lag); + EXPECT_EQ(kLag, aggregated_lag->delay); + } +} + +// Verifies that an aggregated lag is persistent if the lag estimates do not +// change and that an aggregated lag is not produced without gaining lag +// estimate confidence. +TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) { + constexpr size_t kLag1 = 5; + constexpr size_t kLag2 = 10; + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + std::vector<MatchedFilter::LagEstimate> lag_estimates(1); + MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2), + config.delay); + absl::optional<DelayEstimate> aggregated_lag; + for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) { + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/kLag1, /*pre_echo_lag=*/kLag1)); + } + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kLag1, aggregated_lag->delay); + + for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) { + aggregated_lag = aggregator.Aggregate( + MatchedFilter::LagEstimate(/*lag=*/kLag2, /*pre_echo_lag=*/kLag2)); + EXPECT_TRUE(aggregated_lag); + EXPECT_EQ(kLag1, aggregated_lag->delay); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for non-null data dumper. +TEST(MatchedFilterLagAggregatorDeathTest, NullDataDumper) { + EchoCanceller3Config config; + EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10, config.delay), ""); +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_unittest.cc new file mode 100644 index 0000000000..0a04c7809c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/matched_filter_unittest.cc @@ -0,0 +1,612 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/matched_filter.h" + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <algorithm> +#include <string> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/decimator.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/field_trial.h" +#include "test/gtest.h" + +namespace webrtc { +namespace aec3 { +namespace { + +std::string ProduceDebugText(size_t delay, size_t down_sampling_factor) { + rtc::StringBuilder ss; + ss << "Delay: " << delay; + ss << ", Down sampling factor: " << down_sampling_factor; + return ss.Release(); +} + +constexpr size_t kNumMatchedFilters = 10; +constexpr size_t kDownSamplingFactors[] = {2, 4, 8}; +constexpr size_t kWindowSizeSubBlocks = 32; +constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4; + +} // namespace + +class MatchedFilterTest : public ::testing::TestWithParam<bool> {}; + +#if defined(WEBRTC_HAS_NEON) +// Verifies that the optimized methods for NEON are similar to their reference +// counterparts. +TEST_P(MatchedFilterTest, TestNeonOptimizations) { + Random random_generator(42U); + constexpr float kSmoothing = 0.7f; + const bool kComputeAccumulatederror = GetParam(); + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + std::vector<float> x(2000); + RandomizeSampleVector(&random_generator, x); + std::vector<float> y(sub_block_size); + std::vector<float> h_NEON(512); + std::vector<float> h(512); + std::vector<float> accumulated_error(512); + std::vector<float> accumulated_error_NEON(512); + std::vector<float> scratch_memory(512); + + int x_index = 0; + for (int k = 0; k < 1000; ++k) { + RandomizeSampleVector(&random_generator, y); + + bool filters_updated = false; + float error_sum = 0.f; + bool filters_updated_NEON = false; + float error_sum_NEON = 0.f; + + MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, kSmoothing, x, + y, h_NEON, &filters_updated_NEON, &error_sum_NEON, + kComputeAccumulatederror, accumulated_error_NEON, + scratch_memory); + + MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h, + &filters_updated, &error_sum, kComputeAccumulatederror, + accumulated_error); + + EXPECT_EQ(filters_updated, filters_updated_NEON); + EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f); + + for (size_t j = 0; j < h.size(); ++j) { + EXPECT_NEAR(h[j], h_NEON[j], 0.00001f); + } + + if (kComputeAccumulatederror) { + for (size_t j = 0; j < accumulated_error.size(); ++j) { + float difference = + std::abs(accumulated_error[j] - accumulated_error_NEON[j]); + float relative_difference = accumulated_error[j] > 0 + ? difference / accumulated_error[j] + : difference; + EXPECT_NEAR(relative_difference, 0.0f, 0.02f); + } + } + + x_index = (x_index + sub_block_size) % x.size(); + } + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized methods for SSE2 are bitexact to their reference +// counterparts. +TEST_P(MatchedFilterTest, TestSse2Optimizations) { + const bool kComputeAccumulatederror = GetParam(); + bool use_sse2 = (GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + Random random_generator(42U); + constexpr float kSmoothing = 0.7f; + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + std::vector<float> x(2000); + RandomizeSampleVector(&random_generator, x); + std::vector<float> y(sub_block_size); + std::vector<float> h_SSE2(512); + std::vector<float> h(512); + std::vector<float> accumulated_error(512 / 4); + std::vector<float> accumulated_error_SSE2(512 / 4); + std::vector<float> scratch_memory(512); + int x_index = 0; + for (int k = 0; k < 1000; ++k) { + RandomizeSampleVector(&random_generator, y); + + bool filters_updated = false; + float error_sum = 0.f; + bool filters_updated_SSE2 = false; + float error_sum_SSE2 = 0.f; + + MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, kSmoothing, x, + y, h_SSE2, &filters_updated_SSE2, + &error_sum_SSE2, kComputeAccumulatederror, + accumulated_error_SSE2, scratch_memory); + + MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, + h, &filters_updated, &error_sum, + kComputeAccumulatederror, accumulated_error); + + EXPECT_EQ(filters_updated, filters_updated_SSE2); + EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f); + + for (size_t j = 0; j < h.size(); ++j) { + EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f); + } + + for (size_t j = 0; j < accumulated_error.size(); ++j) { + float difference = + std::abs(accumulated_error[j] - accumulated_error_SSE2[j]); + float relative_difference = accumulated_error[j] > 0 + ? difference / accumulated_error[j] + : difference; + EXPECT_NEAR(relative_difference, 0.0f, 0.00001f); + } + + x_index = (x_index + sub_block_size) % x.size(); + } + } + } +} + +TEST_P(MatchedFilterTest, TestAvx2Optimizations) { + bool use_avx2 = (GetCPUInfo(kAVX2) != 0); + const bool kComputeAccumulatederror = GetParam(); + if (use_avx2) { + Random random_generator(42U); + constexpr float kSmoothing = 0.7f; + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + std::vector<float> x(2000); + RandomizeSampleVector(&random_generator, x); + std::vector<float> y(sub_block_size); + std::vector<float> h_AVX2(512); + std::vector<float> h(512); + std::vector<float> accumulated_error(512 / 4); + std::vector<float> accumulated_error_AVX2(512 / 4); + std::vector<float> scratch_memory(512); + int x_index = 0; + for (int k = 0; k < 1000; ++k) { + RandomizeSampleVector(&random_generator, y); + bool filters_updated = false; + float error_sum = 0.f; + bool filters_updated_AVX2 = false; + float error_sum_AVX2 = 0.f; + MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x, + y, h_AVX2, &filters_updated_AVX2, + &error_sum_AVX2, kComputeAccumulatederror, + accumulated_error_AVX2, scratch_memory); + MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, + h, &filters_updated, &error_sum, + kComputeAccumulatederror, accumulated_error); + EXPECT_EQ(filters_updated, filters_updated_AVX2); + EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f); + for (size_t j = 0; j < h.size(); ++j) { + EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f); + } + for (size_t j = 0; j < accumulated_error.size(); j += 4) { + float difference = + std::abs(accumulated_error[j] - accumulated_error_AVX2[j]); + float relative_difference = accumulated_error[j] > 0 + ? difference / accumulated_error[j] + : difference; + EXPECT_NEAR(relative_difference, 0.0f, 0.00001f); + } + x_index = (x_index + sub_block_size) % x.size(); + } + } + } +} + +#endif + +// Verifies that the (optimized) function MaxSquarePeakIndex() produces output +// equal to the corresponding std-functions. +TEST(MatchedFilter, MaxSquarePeakIndex) { + Random random_generator(42U); + constexpr int kMaxLength = 128; + constexpr int kNumIterationsPerLength = 256; + for (int length = 1; length < kMaxLength; ++length) { + std::vector<float> y(length); + for (int i = 0; i < kNumIterationsPerLength; ++i) { + RandomizeSampleVector(&random_generator, y); + + size_t lag_from_function = MaxSquarePeakIndex(y); + size_t lag_from_std = std::distance( + y.begin(), + std::max_element(y.begin(), y.end(), [](float a, float b) -> bool { + return a * a < b * b; + })); + EXPECT_EQ(lag_from_function, lag_from_std); + } + } +} + +// Verifies that the matched filter produces proper lag estimates for +// artificially delayed signals. +TEST_P(MatchedFilterTest, LagEstimation) { + const bool kDetectPreEcho = GetParam(); + Random random_generator(42U); + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + Block render(kNumBands, kNumChannels); + std::vector<std::vector<float>> capture( + 1, std::vector<float>(kBlockSize, 0.f)); + ApmDataDumper data_dumper(0); + for (size_t delay_samples : {5, 64, 150, 200, 800, 1000}) { + SCOPED_TRACE(ProduceDebugText(delay_samples, down_sampling_factor)); + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = kNumMatchedFilters; + Decimator capture_decimator(down_sampling_factor); + DelayBuffer<float> signal_delay_buffer(down_sampling_factor * + delay_samples); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, + 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); + + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < (600 + delay_samples / sub_block_size); ++k) { + for (size_t band = 0; band < kNumBands; ++band) { + for (size_t channel = 0; channel < kNumChannels; ++channel) { + RandomizeSampleVector(&random_generator, + render.View(band, channel)); + } + } + signal_delay_buffer.Delay(render.View(/*band=*/0, /*channel=*/0), + capture[0]); + render_delay_buffer->Insert(render); + + if (k == 0) { + render_delay_buffer->Reset(); + } + + render_delay_buffer->PrepareCaptureProcessing(); + std::array<float, kBlockSize> downsampled_capture_data; + rtc::ArrayView<float> downsampled_capture( + downsampled_capture_data.data(), sub_block_size); + capture_decimator.Decimate(capture[0], downsampled_capture); + filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), + downsampled_capture, /*use_slow_smoothing=*/false); + } + + // Obtain the lag estimates. + auto lag_estimate = filter.GetBestLagEstimate(); + EXPECT_TRUE(lag_estimate.has_value()); + + // Verify that the expected most accurate lag estimate is correct. + if (lag_estimate.has_value()) { + EXPECT_EQ(delay_samples, lag_estimate->lag); + EXPECT_EQ(delay_samples, lag_estimate->pre_echo_lag); + } + } + } +} + +// Test the pre echo estimation. +TEST_P(MatchedFilterTest, PreEchoEstimation) { + const bool kDetectPreEcho = GetParam(); + Random random_generator(42U); + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + Block render(kNumBands, kNumChannels); + std::vector<std::vector<float>> capture( + 1, std::vector<float>(kBlockSize, 0.f)); + std::vector<float> capture_with_pre_echo(kBlockSize, 0.f); + ApmDataDumper data_dumper(0); + // data_dumper.SetActivated(true); + size_t pre_echo_delay_samples = 20e-3 * 16000 / down_sampling_factor; + size_t echo_delay_samples = 50e-3 * 16000 / down_sampling_factor; + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = kNumMatchedFilters; + Decimator capture_decimator(down_sampling_factor); + DelayBuffer<float> signal_echo_delay_buffer(down_sampling_factor * + echo_delay_samples); + DelayBuffer<float> signal_pre_echo_delay_buffer(down_sampling_factor * + pre_echo_delay_samples); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); + // Analyze the correlation between render and capture. + for (size_t k = 0; k < (600 + echo_delay_samples / sub_block_size); ++k) { + for (size_t band = 0; band < kNumBands; ++band) { + for (size_t channel = 0; channel < kNumChannels; ++channel) { + RandomizeSampleVector(&random_generator, render.View(band, channel)); + } + } + signal_echo_delay_buffer.Delay(render.View(0, 0), capture[0]); + signal_pre_echo_delay_buffer.Delay(render.View(0, 0), + capture_with_pre_echo); + for (size_t k = 0; k < capture[0].size(); ++k) { + constexpr float gain_pre_echo = 0.8f; + capture[0][k] += gain_pre_echo * capture_with_pre_echo[k]; + } + render_delay_buffer->Insert(render); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + std::array<float, kBlockSize> downsampled_capture_data; + rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(), + sub_block_size); + capture_decimator.Decimate(capture[0], downsampled_capture); + filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), + downsampled_capture, /*use_slow_smoothing=*/false); + } + // Obtain the lag estimates. + auto lag_estimate = filter.GetBestLagEstimate(); + EXPECT_TRUE(lag_estimate.has_value()); + // Verify that the expected most accurate lag estimate is correct. + if (lag_estimate.has_value()) { + EXPECT_EQ(echo_delay_samples, lag_estimate->lag); + if (kDetectPreEcho) { + // The pre echo delay is estimated in a subsampled domain and a larger + // error is allowed. + EXPECT_NEAR(pre_echo_delay_samples, lag_estimate->pre_echo_lag, 4); + } else { + // The pre echo delay fallback to the highest mached filter peak when + // its detection is disabled. + EXPECT_EQ(echo_delay_samples, lag_estimate->pre_echo_lag); + } + } + } +} + +// Verifies that the matched filter does not produce reliable and accurate +// estimates for uncorrelated render and capture signals. +TEST_P(MatchedFilterTest, LagNotReliableForUncorrelatedRenderAndCapture) { + const bool kDetectPreEcho = GetParam(); + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + Random random_generator(42U); + for (auto down_sampling_factor : kDownSamplingFactors) { + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = kNumMatchedFilters; + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + Block render(kNumBands, kNumChannels); + std::array<float, kBlockSize> capture_data; + rtc::ArrayView<float> capture(capture_data.data(), sub_block_size); + std::fill(capture.begin(), capture.end(), 0.f); + ApmDataDumper data_dumper(0); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, + render.View(/*band=*/0, /*channel=*/0)); + RandomizeSampleVector(&random_generator, capture); + render_delay_buffer->Insert(render); + filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture, + false); + } + + // Obtain the best lag estimate and Verify that no lag estimates are + // reliable. + auto best_lag_estimates = filter.GetBestLagEstimate(); + EXPECT_FALSE(best_lag_estimates.has_value()); + } +} + +// Verifies that the matched filter does not produce updated lag estimates for +// render signals of low level. +TEST_P(MatchedFilterTest, LagNotUpdatedForLowLevelRender) { + const bool kDetectPreEcho = GetParam(); + Random random_generator(42U); + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + for (auto down_sampling_factor : kDownSamplingFactors) { + const size_t sub_block_size = kBlockSize / down_sampling_factor; + + Block render(kNumBands, kNumChannels); + std::vector<std::vector<float>> capture( + 1, std::vector<float>(kBlockSize, 0.f)); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilter filter( + &data_dumper, DetectOptimization(), sub_block_size, + kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, kDetectPreEcho); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + kNumChannels)); + Decimator capture_decimator(down_sampling_factor); + + // Analyze the correlation between render and capture. + for (size_t k = 0; k < 100; ++k) { + RandomizeSampleVector(&random_generator, render.View(0, 0)); + for (auto& render_k : render.View(0, 0)) { + render_k *= 149.f / 32767.f; + } + std::copy(render.begin(0, 0), render.end(0, 0), capture[0].begin()); + std::array<float, kBlockSize> downsampled_capture_data; + rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(), + sub_block_size); + capture_decimator.Decimate(capture[0], downsampled_capture); + filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), + downsampled_capture, false); + } + + // Verify that no lag estimate has been produced. + auto lag_estimate = filter.GetBestLagEstimate(); + EXPECT_FALSE(lag_estimate.has_value()); + } +} + +INSTANTIATE_TEST_SUITE_P(_, MatchedFilterTest, testing::Values(true, false)); + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +class MatchedFilterDeathTest : public ::testing::TestWithParam<bool> {}; + +// Verifies the check for non-zero windows size. +TEST_P(MatchedFilterDeathTest, ZeroWindowSize) { + const bool kDetectPreEcho = GetParam(); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1, + 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), + ""); +} + +// Verifies the check for non-null data dumper. +TEST_P(MatchedFilterDeathTest, NullDataDumper) { + const bool kDetectPreEcho = GetParam(); + EchoCanceller3Config config; + EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), + ""); +} + +// Verifies the check for that the sub block size is a multiple of 4. +// TODO(peah): Activate the unittest once the required code has been landed. +TEST_P(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) { + const bool kDetectPreEcho = GetParam(); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1, + 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), + ""); +} + +// Verifies the check for that there is an integer number of sub blocks that add +// up to a block size. +// TODO(peah): Activate the unittest once the required code has been landed. +TEST_P(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) { + const bool kDetectPreEcho = GetParam(); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1, + 150, config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + kDetectPreEcho), + ""); +} + +INSTANTIATE_TEST_SUITE_P(_, + MatchedFilterDeathTest, + testing::Values(true, false)); + +#endif + +} // namespace aec3 + +TEST(MatchedFilterFieldTrialTest, PreEchoConfigurationTest) { + float threshold_in = 0.1f; + int mode_in = 2; + rtc::StringBuilder field_trial_name; + field_trial_name << "WebRTC-Aec3PreEchoConfiguration/threshold:" + << threshold_in << ",mode:" << mode_in << "/"; + webrtc::test::ScopedFieldTrials field_trials(field_trial_name.str()); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilter matched_filter( + &data_dumper, DetectOptimization(), + kBlockSize / config.delay.down_sampling_factor, + aec3::kWindowSizeSubBlocks, aec3::kNumMatchedFilters, + aec3::kAlignmentShiftSubBlocks, + config.render_levels.poor_excitation_render_limit, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + config.delay.detect_pre_echo); + + auto& pre_echo_config = matched_filter.GetPreEchoConfiguration(); + EXPECT_EQ(pre_echo_config.threshold, threshold_in); + EXPECT_EQ(pre_echo_config.mode, mode_in); +} + +TEST(MatchedFilterFieldTrialTest, WrongPreEchoConfigurationTest) { + constexpr float kDefaultThreshold = 0.5f; + constexpr int kDefaultMode = 0; + float threshold_in = -0.1f; + int mode_in = 5; + rtc::StringBuilder field_trial_name; + field_trial_name << "WebRTC-Aec3PreEchoConfiguration/threshold:" + << threshold_in << ",mode:" << mode_in << "/"; + webrtc::test::ScopedFieldTrials field_trials(field_trial_name.str()); + ApmDataDumper data_dumper(0); + EchoCanceller3Config config; + MatchedFilter matched_filter( + &data_dumper, DetectOptimization(), + kBlockSize / config.delay.down_sampling_factor, + aec3::kWindowSizeSubBlocks, aec3::kNumMatchedFilters, + aec3::kAlignmentShiftSubBlocks, + config.render_levels.poor_excitation_render_limit, + config.delay.delay_estimate_smoothing, + config.delay.delay_estimate_smoothing_delay_found, + config.delay.delay_candidate_detection_threshold, + config.delay.detect_pre_echo); + + auto& pre_echo_config = matched_filter.GetPreEchoConfiguration(); + EXPECT_EQ(pre_echo_config.threshold, kDefaultThreshold); + EXPECT_EQ(pre_echo_config.mode, kDefaultMode); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_block_processor.cc b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_block_processor.cc new file mode 100644 index 0000000000..c5c33dbd68 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_block_processor.cc @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/mock/mock_block_processor.h" + +namespace webrtc { +namespace test { + +MockBlockProcessor::MockBlockProcessor() = default; +MockBlockProcessor::~MockBlockProcessor() = default; + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_block_processor.h b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_block_processor.h new file mode 100644 index 0000000000..c9ae38c4aa --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_block_processor.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_BLOCK_PROCESSOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_BLOCK_PROCESSOR_H_ + +#include <vector> + +#include "modules/audio_processing/aec3/block_processor.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { + +class MockBlockProcessor : public BlockProcessor { + public: + MockBlockProcessor(); + virtual ~MockBlockProcessor(); + + MOCK_METHOD(void, + ProcessCapture, + (bool level_change, + bool saturated_microphone_signal, + Block* linear_output, + Block* capture_block), + (override)); + MOCK_METHOD(void, BufferRender, (const Block& block), (override)); + MOCK_METHOD(void, + UpdateEchoLeakageStatus, + (bool leakage_detected), + (override)); + MOCK_METHOD(void, + GetMetrics, + (EchoControl::Metrics * metrics), + (const, override)); + MOCK_METHOD(void, SetAudioBufferDelay, (int delay_ms), (override)); + MOCK_METHOD(void, + SetCaptureOutputUsage, + (bool capture_output_used), + (override)); +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_BLOCK_PROCESSOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_echo_remover.cc b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_echo_remover.cc new file mode 100644 index 0000000000..b903bf0785 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_echo_remover.cc @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/mock/mock_echo_remover.h" + +namespace webrtc { +namespace test { + +MockEchoRemover::MockEchoRemover() = default; +MockEchoRemover::~MockEchoRemover() = default; + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_echo_remover.h b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_echo_remover.h new file mode 100644 index 0000000000..31f075ef0a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_echo_remover.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_ECHO_REMOVER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_ECHO_REMOVER_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/echo_remover.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { + +class MockEchoRemover : public EchoRemover { + public: + MockEchoRemover(); + virtual ~MockEchoRemover(); + + MOCK_METHOD(void, + ProcessCapture, + (EchoPathVariability echo_path_variability, + bool capture_signal_saturation, + const absl::optional<DelayEstimate>& delay_estimate, + RenderBuffer* render_buffer, + Block* linear_output, + Block* capture), + (override)); + MOCK_METHOD(void, + UpdateEchoLeakageStatus, + (bool leakage_detected), + (override)); + MOCK_METHOD(void, + GetMetrics, + (EchoControl::Metrics * metrics), + (const, override)); + MOCK_METHOD(void, + SetCaptureOutputUsage, + (bool capture_output_used), + (override)); +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_ECHO_REMOVER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.cc new file mode 100644 index 0000000000..d4ad09b4bc --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/mock/mock_render_delay_buffer.h" + +namespace webrtc { +namespace test { + +MockRenderDelayBuffer::MockRenderDelayBuffer(int sample_rate_hz, + size_t num_channels) + : block_buffer_(GetRenderDelayBufferSize(4, 4, 12), + NumBandsForRate(sample_rate_hz), + num_channels), + spectrum_buffer_(block_buffer_.buffer.size(), num_channels), + fft_buffer_(block_buffer_.buffer.size(), num_channels), + render_buffer_(&block_buffer_, &spectrum_buffer_, &fft_buffer_), + downsampled_render_buffer_(GetDownSampledBufferSize(4, 4)) { + ON_CALL(*this, GetRenderBuffer()) + .WillByDefault( + ::testing::Invoke(this, &MockRenderDelayBuffer::FakeGetRenderBuffer)); + ON_CALL(*this, GetDownsampledRenderBuffer()) + .WillByDefault(::testing::Invoke( + this, &MockRenderDelayBuffer::FakeGetDownsampledRenderBuffer)); +} + +MockRenderDelayBuffer::~MockRenderDelayBuffer() = default; + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h new file mode 100644 index 0000000000..c17fd62caa --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_buffer.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_RENDER_DELAY_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_RENDER_DELAY_BUFFER_H_ + +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { + +class MockRenderDelayBuffer : public RenderDelayBuffer { + public: + MockRenderDelayBuffer(int sample_rate_hz, size_t num_channels); + virtual ~MockRenderDelayBuffer(); + + MOCK_METHOD(void, Reset, (), (override)); + MOCK_METHOD(RenderDelayBuffer::BufferingEvent, + Insert, + (const Block& block), + (override)); + MOCK_METHOD(void, HandleSkippedCaptureProcessing, (), (override)); + MOCK_METHOD(RenderDelayBuffer::BufferingEvent, + PrepareCaptureProcessing, + (), + (override)); + MOCK_METHOD(bool, AlignFromDelay, (size_t delay), (override)); + MOCK_METHOD(void, AlignFromExternalDelay, (), (override)); + MOCK_METHOD(size_t, Delay, (), (const, override)); + MOCK_METHOD(size_t, MaxDelay, (), (const, override)); + MOCK_METHOD(RenderBuffer*, GetRenderBuffer, (), (override)); + MOCK_METHOD(const DownsampledRenderBuffer&, + GetDownsampledRenderBuffer, + (), + (const, override)); + MOCK_METHOD(void, SetAudioBufferDelay, (int delay_ms), (override)); + MOCK_METHOD(bool, HasReceivedBufferDelay, (), (override)); + + private: + RenderBuffer* FakeGetRenderBuffer() { return &render_buffer_; } + const DownsampledRenderBuffer& FakeGetDownsampledRenderBuffer() const { + return downsampled_render_buffer_; + } + BlockBuffer block_buffer_; + SpectrumBuffer spectrum_buffer_; + FftBuffer fft_buffer_; + RenderBuffer render_buffer_; + DownsampledRenderBuffer downsampled_render_buffer_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_RENDER_DELAY_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.cc b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.cc new file mode 100644 index 0000000000..4ae2af96bf --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.cc @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/mock/mock_render_delay_controller.h" + +namespace webrtc { +namespace test { + +MockRenderDelayController::MockRenderDelayController() = default; +MockRenderDelayController::~MockRenderDelayController() = default; + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.h b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.h new file mode 100644 index 0000000000..14d499dd28 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/mock/mock_render_delay_controller.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_RENDER_DELAY_CONTROLLER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_RENDER_DELAY_CONTROLLER_H_ + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/aec3/render_delay_controller.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { + +class MockRenderDelayController : public RenderDelayController { + public: + MockRenderDelayController(); + virtual ~MockRenderDelayController(); + + MOCK_METHOD(void, Reset, (bool reset_delay_statistics), (override)); + MOCK_METHOD(void, LogRenderCall, (), (override)); + MOCK_METHOD(absl::optional<DelayEstimate>, + GetDelay, + (const DownsampledRenderBuffer& render_buffer, + size_t render_delay_buffer_delay, + const Block& capture), + (override)); + MOCK_METHOD(bool, HasClockdrift, (), (const, override)); +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MOCK_MOCK_RENDER_DELAY_CONTROLLER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.cc b/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.cc new file mode 100644 index 0000000000..7a81ee89ea --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.cc @@ -0,0 +1,60 @@ + +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/moving_average.h" + +#include <algorithm> +#include <functional> + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace aec3 { + +MovingAverage::MovingAverage(size_t num_elem, size_t mem_len) + : num_elem_(num_elem), + mem_len_(mem_len - 1), + scaling_(1.0f / static_cast<float>(mem_len)), + memory_(num_elem * mem_len_, 0.f), + mem_index_(0) { + RTC_DCHECK(num_elem_ > 0); + RTC_DCHECK(mem_len > 0); +} + +MovingAverage::~MovingAverage() = default; + +void MovingAverage::Average(rtc::ArrayView<const float> input, + rtc::ArrayView<float> output) { + RTC_DCHECK(input.size() == num_elem_); + RTC_DCHECK(output.size() == num_elem_); + + // Sum all contributions. + std::copy(input.begin(), input.end(), output.begin()); + for (auto i = memory_.begin(); i < memory_.end(); i += num_elem_) { + std::transform(i, i + num_elem_, output.begin(), output.begin(), + std::plus<float>()); + } + + // Divide by mem_len_. + for (float& o : output) { + o *= scaling_; + } + + // Update memory. + if (mem_len_ > 0) { + std::copy(input.begin(), input.end(), + memory_.begin() + mem_index_ * num_elem_); + mem_index_ = (mem_index_ + 1) % mem_len_; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.h b/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.h new file mode 100644 index 0000000000..913d78519c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/moving_average.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MOVING_AVERAGE_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MOVING_AVERAGE_H_ + +#include <stddef.h> + +#include <vector> + +#include "api/array_view.h" + +namespace webrtc { +namespace aec3 { + +class MovingAverage { + public: + // Creates an instance of MovingAverage that accepts inputs of length num_elem + // and averages over mem_len inputs. + MovingAverage(size_t num_elem, size_t mem_len); + ~MovingAverage(); + + // Computes the average of input and mem_len-1 previous inputs and stores the + // result in output. + void Average(rtc::ArrayView<const float> input, rtc::ArrayView<float> output); + + private: + const size_t num_elem_; + const size_t mem_len_; + const float scaling_; + std::vector<float> memory_; + size_t mem_index_; +}; + +} // namespace aec3 +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MOVING_AVERAGE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/moving_average_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/moving_average_unittest.cc new file mode 100644 index 0000000000..84ba9cbc5b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/moving_average_unittest.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/moving_average.h" + +#include "test/gtest.h" + +namespace webrtc { + +TEST(MovingAverage, Average) { + constexpr size_t num_elem = 4; + constexpr size_t mem_len = 3; + constexpr float e = 1e-6f; + aec3::MovingAverage ma(num_elem, mem_len); + std::array<float, num_elem> data1 = {1, 2, 3, 4}; + std::array<float, num_elem> data2 = {5, 1, 9, 7}; + std::array<float, num_elem> data3 = {3, 3, 5, 6}; + std::array<float, num_elem> data4 = {8, 4, 2, 1}; + std::array<float, num_elem> output; + + ma.Average(data1, output); + EXPECT_NEAR(output[0], data1[0] / 3.0f, e); + EXPECT_NEAR(output[1], data1[1] / 3.0f, e); + EXPECT_NEAR(output[2], data1[2] / 3.0f, e); + EXPECT_NEAR(output[3], data1[3] / 3.0f, e); + + ma.Average(data2, output); + EXPECT_NEAR(output[0], (data1[0] + data2[0]) / 3.0f, e); + EXPECT_NEAR(output[1], (data1[1] + data2[1]) / 3.0f, e); + EXPECT_NEAR(output[2], (data1[2] + data2[2]) / 3.0f, e); + EXPECT_NEAR(output[3], (data1[3] + data2[3]) / 3.0f, e); + + ma.Average(data3, output); + EXPECT_NEAR(output[0], (data1[0] + data2[0] + data3[0]) / 3.0f, e); + EXPECT_NEAR(output[1], (data1[1] + data2[1] + data3[1]) / 3.0f, e); + EXPECT_NEAR(output[2], (data1[2] + data2[2] + data3[2]) / 3.0f, e); + EXPECT_NEAR(output[3], (data1[3] + data2[3] + data3[3]) / 3.0f, e); + + ma.Average(data4, output); + EXPECT_NEAR(output[0], (data2[0] + data3[0] + data4[0]) / 3.0f, e); + EXPECT_NEAR(output[1], (data2[1] + data3[1] + data4[1]) / 3.0f, e); + EXPECT_NEAR(output[2], (data2[2] + data3[2] + data4[2]) / 3.0f, e); + EXPECT_NEAR(output[3], (data2[3] + data3[3] + data4[3]) / 3.0f, e); +} + +TEST(MovingAverage, PassThrough) { + constexpr size_t num_elem = 4; + constexpr size_t mem_len = 1; + constexpr float e = 1e-6f; + aec3::MovingAverage ma(num_elem, mem_len); + std::array<float, num_elem> data1 = {1, 2, 3, 4}; + std::array<float, num_elem> data2 = {5, 1, 9, 7}; + std::array<float, num_elem> data3 = {3, 3, 5, 6}; + std::array<float, num_elem> data4 = {8, 4, 2, 1}; + std::array<float, num_elem> output; + + ma.Average(data1, output); + EXPECT_NEAR(output[0], data1[0], e); + EXPECT_NEAR(output[1], data1[1], e); + EXPECT_NEAR(output[2], data1[2], e); + EXPECT_NEAR(output[3], data1[3], e); + + ma.Average(data2, output); + EXPECT_NEAR(output[0], data2[0], e); + EXPECT_NEAR(output[1], data2[1], e); + EXPECT_NEAR(output[2], data2[2], e); + EXPECT_NEAR(output[3], data2[3], e); + + ma.Average(data3, output); + EXPECT_NEAR(output[0], data3[0], e); + EXPECT_NEAR(output[1], data3[1], e); + EXPECT_NEAR(output[2], data3[2], e); + EXPECT_NEAR(output[3], data3[3], e); + + ma.Average(data4, output); + EXPECT_NEAR(output[0], data4[0], e); + EXPECT_NEAR(output[1], data4[1], e); + EXPECT_NEAR(output[2], data4[2], e); + EXPECT_NEAR(output[3], data4[3], e); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.cc b/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.cc new file mode 100644 index 0000000000..98068964d9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/multi_channel_content_detector.h" + +#include <cmath> + +#include "rtc_base/checks.h" +#include "system_wrappers/include/metrics.h" + +namespace webrtc { + +namespace { + +constexpr int kNumFramesPerSecond = 100; + +// Compares the left and right channels in the render `frame` to determine +// whether the signal is a proper stereo signal. To allow for differences +// introduced by hardware drivers, a threshold `detection_threshold` is used for +// the detection. +bool HasStereoContent(const std::vector<std::vector<std::vector<float>>>& frame, + float detection_threshold) { + if (frame[0].size() < 2) { + return false; + } + + for (size_t band = 0; band < frame.size(); ++band) { + for (size_t k = 0; k < frame[band][0].size(); ++k) { + if (std::fabs(frame[band][0][k] - frame[band][1][k]) > + detection_threshold) { + return true; + } + } + } + return false; +} + +// In order to avoid logging metrics for very short lifetimes that are unlikely +// to reflect real calls and that may dilute the "real" data, logging is limited +// to lifetimes of at leats 5 seconds. +constexpr int kMinNumberOfFramesRequiredToLogMetrics = 500; + +// Continuous metrics are logged every 10 seconds. +constexpr int kFramesPer10Seconds = 1000; + +} // namespace + +MultiChannelContentDetector::MetricsLogger::MetricsLogger() {} + +MultiChannelContentDetector::MetricsLogger::~MetricsLogger() { + if (frame_counter_ < kMinNumberOfFramesRequiredToLogMetrics) + return; + + RTC_HISTOGRAM_BOOLEAN( + "WebRTC.Audio.EchoCanceller.PersistentMultichannelContentEverDetected", + any_multichannel_content_detected_ ? 1 : 0); +} + +void MultiChannelContentDetector::MetricsLogger::Update( + bool persistent_multichannel_content_detected) { + ++frame_counter_; + if (persistent_multichannel_content_detected) { + any_multichannel_content_detected_ = true; + ++persistent_multichannel_frame_counter_; + } + + if (frame_counter_ < kMinNumberOfFramesRequiredToLogMetrics) + return; + if (frame_counter_ % kFramesPer10Seconds != 0) + return; + const bool mostly_multichannel_last_10_seconds = + (persistent_multichannel_frame_counter_ >= kFramesPer10Seconds / 2); + RTC_HISTOGRAM_BOOLEAN( + "WebRTC.Audio.EchoCanceller.ProcessingPersistentMultichannelContent", + mostly_multichannel_last_10_seconds ? 1 : 0); + + persistent_multichannel_frame_counter_ = 0; +} + +MultiChannelContentDetector::MultiChannelContentDetector( + bool detect_stereo_content, + int num_render_input_channels, + float detection_threshold, + int stereo_detection_timeout_threshold_seconds, + float stereo_detection_hysteresis_seconds) + : detect_stereo_content_(detect_stereo_content), + detection_threshold_(detection_threshold), + detection_timeout_threshold_frames_( + stereo_detection_timeout_threshold_seconds > 0 + ? absl::make_optional(stereo_detection_timeout_threshold_seconds * + kNumFramesPerSecond) + : absl::nullopt), + stereo_detection_hysteresis_frames_(static_cast<int>( + stereo_detection_hysteresis_seconds * kNumFramesPerSecond)), + metrics_logger_((detect_stereo_content && num_render_input_channels > 1) + ? std::make_unique<MetricsLogger>() + : nullptr), + persistent_multichannel_content_detected_( + !detect_stereo_content && num_render_input_channels > 1) {} + +bool MultiChannelContentDetector::UpdateDetection( + const std::vector<std::vector<std::vector<float>>>& frame) { + if (!detect_stereo_content_) { + RTC_DCHECK_EQ(frame[0].size() > 1, + persistent_multichannel_content_detected_); + return false; + } + + const bool previous_persistent_multichannel_content_detected = + persistent_multichannel_content_detected_; + const bool stereo_detected_in_frame = + HasStereoContent(frame, detection_threshold_); + + consecutive_frames_with_stereo_ = + stereo_detected_in_frame ? consecutive_frames_with_stereo_ + 1 : 0; + frames_since_stereo_detected_last_ = + stereo_detected_in_frame ? 0 : frames_since_stereo_detected_last_ + 1; + + // Detect persistent multichannel content. + if (consecutive_frames_with_stereo_ > stereo_detection_hysteresis_frames_) { + persistent_multichannel_content_detected_ = true; + } + if (detection_timeout_threshold_frames_.has_value() && + frames_since_stereo_detected_last_ >= + *detection_timeout_threshold_frames_) { + persistent_multichannel_content_detected_ = false; + } + + // Detect temporary multichannel content. + temporary_multichannel_content_detected_ = + persistent_multichannel_content_detected_ ? false + : stereo_detected_in_frame; + + if (metrics_logger_) + metrics_logger_->Update(persistent_multichannel_content_detected_); + + return previous_persistent_multichannel_content_detected != + persistent_multichannel_content_detected_; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.h b/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.h new file mode 100644 index 0000000000..1742c5fc17 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_MULTI_CHANNEL_CONTENT_DETECTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_MULTI_CHANNEL_CONTENT_DETECTOR_H_ + +#include <memory> +#include <stddef.h> + +#include <memory> +#include <vector> + +#include "absl/types/optional.h" + +namespace webrtc { + +// Analyzes audio content to determine whether the contained audio is proper +// multichannel, or only upmixed mono. To allow for differences introduced by +// hardware drivers, a threshold `detection_threshold` is used for the +// detection. +// Logs metrics continously and upon destruction. +class MultiChannelContentDetector { + public: + // If |stereo_detection_timeout_threshold_seconds| <= 0, no timeout is + // applied: Once multichannel is detected, the detector remains in that state + // for its lifetime. + MultiChannelContentDetector(bool detect_stereo_content, + int num_render_input_channels, + float detection_threshold, + int stereo_detection_timeout_threshold_seconds, + float stereo_detection_hysteresis_seconds); + + // Compares the left and right channels in the render `frame` to determine + // whether the signal is a proper multichannel signal. Returns a bool + // indicating whether a change in the proper multichannel content was + // detected. + bool UpdateDetection( + const std::vector<std::vector<std::vector<float>>>& frame); + + bool IsProperMultiChannelContentDetected() const { + return persistent_multichannel_content_detected_; + } + + bool IsTemporaryMultiChannelContentDetected() const { + return temporary_multichannel_content_detected_; + } + + private: + // Tracks and logs metrics for the amount of multichannel content detected. + class MetricsLogger { + public: + MetricsLogger(); + + // The destructor logs call summary statistics. + ~MetricsLogger(); + + // Updates and logs metrics. + void Update(bool persistent_multichannel_content_detected); + + private: + int frame_counter_ = 0; + + // Counts the number of frames of persistent multichannel audio observed + // during the current metrics collection interval. + int persistent_multichannel_frame_counter_ = 0; + + // Indicates whether persistent multichannel content has ever been detected. + bool any_multichannel_content_detected_ = false; + }; + + const bool detect_stereo_content_; + const float detection_threshold_; + const absl::optional<int> detection_timeout_threshold_frames_; + const int stereo_detection_hysteresis_frames_; + + // Collects and reports metrics on the amount of multichannel content + // detected. Only created if |num_render_input_channels| > 1 and + // |detect_stereo_content_| is true. + const std::unique_ptr<MetricsLogger> metrics_logger_; + + bool persistent_multichannel_content_detected_; + bool temporary_multichannel_content_detected_ = false; + int64_t frames_since_stereo_detected_last_ = 0; + int64_t consecutive_frames_with_stereo_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_MULTI_CHANNEL_CONTENT_DETECTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector_unittest.cc new file mode 100644 index 0000000000..8d38dd0991 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/multi_channel_content_detector_unittest.cc @@ -0,0 +1,470 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/multi_channel_content_detector.h" + +#include "system_wrappers/include/metrics.h" +#include "test/gtest.h" + +namespace webrtc { + +TEST(MultiChannelContentDetector, HandlingOfMono) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/1, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); +} + +TEST(MultiChannelContentDetector, HandlingOfMonoAndDetectionOff) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/false, + /*num_render_input_channels=*/1, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); +} + +TEST(MultiChannelContentDetector, HandlingOfDetectionOff) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/false, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + + std::vector<std::vector<std::vector<float>>> frame( + 1, std::vector<std::vector<float>>(2, std::vector<float>(160, 0.0f))); + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 101.0f); + + EXPECT_FALSE(mc.UpdateDetection(frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); +} + +TEST(MultiChannelContentDetector, InitialDetectionOfStereo) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); +} + +TEST(MultiChannelContentDetector, DetectionWhenFakeStereo) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + std::vector<std::vector<std::vector<float>>> frame( + 1, std::vector<std::vector<float>>(2, std::vector<float>(160, 0.0f))); + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 100.0f); + EXPECT_FALSE(mc.UpdateDetection(frame)); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); +} + +TEST(MultiChannelContentDetector, DetectionWhenStereo) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + std::vector<std::vector<std::vector<float>>> frame( + 1, std::vector<std::vector<float>>(2, std::vector<float>(160, 0.0f))); + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 101.0f); + EXPECT_TRUE(mc.UpdateDetection(frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); +} + +TEST(MultiChannelContentDetector, DetectionWhenStereoAfterAWhile) { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + std::vector<std::vector<std::vector<float>>> frame( + 1, std::vector<std::vector<float>>(2, std::vector<float>(160, 0.0f))); + + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 100.0f); + EXPECT_FALSE(mc.UpdateDetection(frame)); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); + + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 101.0f); + + EXPECT_TRUE(mc.UpdateDetection(frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); +} + +TEST(MultiChannelContentDetector, DetectionWithStereoBelowThreshold) { + constexpr float kThreshold = 1.0f; + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/kThreshold, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + std::vector<std::vector<std::vector<float>>> frame( + 1, std::vector<std::vector<float>>(2, std::vector<float>(160, 0.0f))); + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 100.0f + kThreshold); + + EXPECT_FALSE(mc.UpdateDetection(frame)); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); +} + +TEST(MultiChannelContentDetector, DetectionWithStereoAboveThreshold) { + constexpr float kThreshold = 1.0f; + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/kThreshold, + /*stereo_detection_timeout_threshold_seconds=*/0, + /*stereo_detection_hysteresis_seconds=*/0.0f); + std::vector<std::vector<std::vector<float>>> frame( + 1, std::vector<std::vector<float>>(2, std::vector<float>(160, 0.0f))); + std::fill(frame[0][0].begin(), frame[0][0].end(), 100.0f); + std::fill(frame[0][1].begin(), frame[0][1].end(), 100.0f + kThreshold + 0.1f); + + EXPECT_TRUE(mc.UpdateDetection(frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + + EXPECT_FALSE(mc.UpdateDetection(frame)); +} + +class MultiChannelContentDetectorTimeoutBehavior + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<bool, int>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannelContentDetector, + MultiChannelContentDetectorTimeoutBehavior, + ::testing::Combine(::testing::Values(false, true), + ::testing::Values(0, 1, 10))); + +TEST_P(MultiChannelContentDetectorTimeoutBehavior, + TimeOutBehaviorForNonTrueStereo) { + constexpr int kNumFramesPerSecond = 100; + const bool detect_stereo_content = std::get<0>(GetParam()); + const int stereo_detection_timeout_threshold_seconds = + std::get<1>(GetParam()); + const int stereo_detection_timeout_threshold_frames = + stereo_detection_timeout_threshold_seconds * kNumFramesPerSecond; + + MultiChannelContentDetector mc(detect_stereo_content, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + stereo_detection_timeout_threshold_seconds, + /*stereo_detection_hysteresis_seconds=*/0.0f); + std::vector<std::vector<std::vector<float>>> true_stereo_frame = { + {std::vector<float>(160, 100.0f), std::vector<float>(160, 101.0f)}}; + + std::vector<std::vector<std::vector<float>>> fake_stereo_frame = { + {std::vector<float>(160, 100.0f), std::vector<float>(160, 100.0f)}}; + + // Pass fake stereo frames and verify the content detection. + for (int k = 0; k < 10; ++k) { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + if (detect_stereo_content) { + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + } else { + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + } + } + + // Pass a true stereo frame and verify that it is properly detected. + if (detect_stereo_content) { + EXPECT_TRUE(mc.UpdateDetection(true_stereo_frame)); + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + } + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + + // Pass fake stereo frames until any timeouts are about to occur. + for (int k = 0; k < stereo_detection_timeout_threshold_frames - 1; ++k) { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + } + + // Pass a fake stereo frame and verify that any timeouts properly occur. + if (detect_stereo_content && stereo_detection_timeout_threshold_frames > 0) { + EXPECT_TRUE(mc.UpdateDetection(fake_stereo_frame)); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + } else { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + } + + // Pass fake stereo frames and verify the behavior after any timeout. + for (int k = 0; k < 10; ++k) { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + if (detect_stereo_content && + stereo_detection_timeout_threshold_frames > 0) { + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + } else { + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + } + } +} + +class MultiChannelContentDetectorHysteresisBehavior + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<bool, float>> {}; + +INSTANTIATE_TEST_SUITE_P( + MultiChannelContentDetector, + MultiChannelContentDetectorHysteresisBehavior, + ::testing::Combine(::testing::Values(false, true), + ::testing::Values(0.0f, 0.1f, 0.2f))); + +TEST_P(MultiChannelContentDetectorHysteresisBehavior, + PeriodBeforeStereoDetectionIsTriggered) { + constexpr int kNumFramesPerSecond = 100; + const bool detect_stereo_content = std::get<0>(GetParam()); + const int stereo_detection_hysteresis_seconds = std::get<1>(GetParam()); + const int stereo_detection_hysteresis_frames = + stereo_detection_hysteresis_seconds * kNumFramesPerSecond; + + MultiChannelContentDetector mc( + detect_stereo_content, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/0, + stereo_detection_hysteresis_seconds); + std::vector<std::vector<std::vector<float>>> true_stereo_frame = { + {std::vector<float>(160, 100.0f), std::vector<float>(160, 101.0f)}}; + + std::vector<std::vector<std::vector<float>>> fake_stereo_frame = { + {std::vector<float>(160, 100.0f), std::vector<float>(160, 100.0f)}}; + + // Pass fake stereo frames and verify the content detection. + for (int k = 0; k < 10; ++k) { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + if (detect_stereo_content) { + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + } else { + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + } + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } + + // Pass a two true stereo frames and verify that they are properly detected. + ASSERT_TRUE(stereo_detection_hysteresis_frames > 2 || + stereo_detection_hysteresis_frames == 0); + for (int k = 0; k < 2; ++k) { + if (detect_stereo_content) { + if (stereo_detection_hysteresis_seconds == 0.0f) { + if (k == 0) { + EXPECT_TRUE(mc.UpdateDetection(true_stereo_frame)); + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + } + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + EXPECT_TRUE(mc.IsTemporaryMultiChannelContentDetected()); + } + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } + } + + if (stereo_detection_hysteresis_seconds == 0.0f) { + return; + } + + // Pass true stereo frames until any timeouts are about to occur. + for (int k = 0; k < stereo_detection_hysteresis_frames - 3; ++k) { + if (detect_stereo_content) { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_FALSE(mc.IsProperMultiChannelContentDetected()); + EXPECT_TRUE(mc.IsTemporaryMultiChannelContentDetected()); + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } + } + + // Pass a true stereo frame and verify that it is properly detected. + if (detect_stereo_content) { + EXPECT_TRUE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } + + // Pass an additional true stereo frame and verify that it is properly + // detected. + if (detect_stereo_content) { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } else { + EXPECT_FALSE(mc.UpdateDetection(true_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } + + // Pass a fake stereo frame and verify that it is properly detected. + if (detect_stereo_content) { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } else { + EXPECT_FALSE(mc.UpdateDetection(fake_stereo_frame)); + EXPECT_TRUE(mc.IsProperMultiChannelContentDetected()); + EXPECT_FALSE(mc.IsTemporaryMultiChannelContentDetected()); + } +} + +class MultiChannelContentDetectorMetricsDisabled + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<bool, int>> {}; + +INSTANTIATE_TEST_SUITE_P( + /*no prefix*/, + MultiChannelContentDetectorMetricsDisabled, + ::testing::Values(std::tuple<bool, int>(false, 2), + std::tuple<bool, int>(true, 1))); + +// Test that no metrics are logged when they are clearly uninteresting and would +// dilute relevant data: when the reference audio is single channel, or when +// dynamic detection is disabled. +TEST_P(MultiChannelContentDetectorMetricsDisabled, ReportsNoMetrics) { + metrics::Reset(); + constexpr int kNumFramesPerSecond = 100; + const bool detect_stereo_content = std::get<0>(GetParam()); + const int channel_count = std::get<1>(GetParam()); + std::vector<std::vector<std::vector<float>>> audio_frame = { + std::vector<std::vector<float>>(channel_count, + std::vector<float>(160, 100.0f))}; + { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/detect_stereo_content, + /*num_render_input_channels=*/channel_count, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/1, + /*stereo_detection_hysteresis_seconds=*/0.0f); + for (int k = 0; k < 20 * kNumFramesPerSecond; ++k) { + mc.UpdateDetection(audio_frame); + } + } + EXPECT_METRIC_EQ( + 0, metrics::NumSamples("WebRTC.Audio.EchoCanceller." + "ProcessingPersistentMultichannelContent")); + EXPECT_METRIC_EQ( + 0, metrics::NumSamples("WebRTC.Audio.EchoCanceller." + "PersistentMultichannelContentEverDetected")); +} + +// Tests that after 3 seconds, no metrics are reported. +TEST(MultiChannelContentDetectorMetrics, ReportsNoMetricsForShortLifetime) { + metrics::Reset(); + constexpr int kNumFramesPerSecond = 100; + constexpr int kTooFewFramesToLogMetrics = 3 * kNumFramesPerSecond; + std::vector<std::vector<std::vector<float>>> audio_frame = { + std::vector<std::vector<float>>(2, std::vector<float>(160, 100.0f))}; + { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/1, + /*stereo_detection_hysteresis_seconds=*/0.0f); + for (int k = 0; k < kTooFewFramesToLogMetrics; ++k) { + mc.UpdateDetection(audio_frame); + } + } + EXPECT_METRIC_EQ( + 0, metrics::NumSamples("WebRTC.Audio.EchoCanceller." + "ProcessingPersistentMultichannelContent")); + EXPECT_METRIC_EQ( + 0, metrics::NumSamples("WebRTC.Audio.EchoCanceller." + "PersistentMultichannelContentEverDetected")); +} + +// Tests that after 25 seconds, metrics are reported. +TEST(MultiChannelContentDetectorMetrics, ReportsMetrics) { + metrics::Reset(); + constexpr int kNumFramesPerSecond = 100; + std::vector<std::vector<std::vector<float>>> true_stereo_frame = { + {std::vector<float>(160, 100.0f), std::vector<float>(160, 101.0f)}}; + std::vector<std::vector<std::vector<float>>> fake_stereo_frame = { + {std::vector<float>(160, 100.0f), std::vector<float>(160, 100.0f)}}; + { + MultiChannelContentDetector mc( + /*detect_stereo_content=*/true, + /*num_render_input_channels=*/2, + /*detection_threshold=*/0.0f, + /*stereo_detection_timeout_threshold_seconds=*/1, + /*stereo_detection_hysteresis_seconds=*/0.0f); + for (int k = 0; k < 10 * kNumFramesPerSecond; ++k) { + mc.UpdateDetection(true_stereo_frame); + } + for (int k = 0; k < 15 * kNumFramesPerSecond; ++k) { + mc.UpdateDetection(fake_stereo_frame); + } + } + // After 10 seconds of true stereo and the remainder fake stereo, we expect + // one lifetime metric sample (multichannel detected) and two periodic samples + // (one multichannel, one mono). + + // Check lifetime metric. + EXPECT_METRIC_EQ( + 1, metrics::NumSamples("WebRTC.Audio.EchoCanceller." + "PersistentMultichannelContentEverDetected")); + EXPECT_METRIC_EQ( + 1, metrics::NumEvents("WebRTC.Audio.EchoCanceller." + "PersistentMultichannelContentEverDetected", 1)); + + // Check periodic metric. + EXPECT_METRIC_EQ( + 2, metrics::NumSamples("WebRTC.Audio.EchoCanceller." + "ProcessingPersistentMultichannelContent")); + EXPECT_METRIC_EQ( + 1, metrics::NumEvents("WebRTC.Audio.EchoCanceller." + "ProcessingPersistentMultichannelContent", 0)); + EXPECT_METRIC_EQ( + 1, metrics::NumEvents("WebRTC.Audio.EchoCanceller." + "ProcessingPersistentMultichannelContent", 1)); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/nearend_detector.h b/third_party/libwebrtc/modules/audio_processing/aec3/nearend_detector.h new file mode 100644 index 0000000000..0d8a06b2cd --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/nearend_detector.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_NEAREND_DETECTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_NEAREND_DETECTOR_H_ + +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { +// Class for selecting whether the suppressor is in the nearend or echo state. +class NearendDetector { + public: + virtual ~NearendDetector() {} + + // Returns whether the current state is the nearend state. + virtual bool IsNearendState() const = 0; + + // Updates the state selection based on latest spectral estimates. + virtual void Update( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + bool initial_state) = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_NEAREND_DETECTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.cc b/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.cc new file mode 100644 index 0000000000..8e391d6fa6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.cc @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/refined_filter_update_gain.h" + +#include <algorithm> +#include <functional> + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace { + +constexpr float kHErrorInitial = 10000.f; +constexpr int kPoorExcitationCounterInitial = 1000; + +} // namespace + +std::atomic<int> RefinedFilterUpdateGain::instance_count_(0); + +RefinedFilterUpdateGain::RefinedFilterUpdateGain( + const EchoCanceller3Config::Filter::RefinedConfiguration& config, + size_t config_change_duration_blocks) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + config_change_duration_blocks_( + static_cast<int>(config_change_duration_blocks)), + poor_excitation_counter_(kPoorExcitationCounterInitial) { + SetConfig(config, true); + H_error_.fill(kHErrorInitial); + RTC_DCHECK_LT(0, config_change_duration_blocks_); + one_by_config_change_duration_blocks_ = 1.f / config_change_duration_blocks_; +} + +RefinedFilterUpdateGain::~RefinedFilterUpdateGain() {} + +void RefinedFilterUpdateGain::HandleEchoPathChange( + const EchoPathVariability& echo_path_variability) { + if (echo_path_variability.gain_change) { + // TODO(bugs.webrtc.org/9526) Handle gain changes. + } + + if (echo_path_variability.delay_change != + EchoPathVariability::DelayAdjustment::kNone) { + H_error_.fill(kHErrorInitial); + } + + if (!echo_path_variability.gain_change) { + poor_excitation_counter_ = kPoorExcitationCounterInitial; + call_counter_ = 0; + } +} + +void RefinedFilterUpdateGain::Compute( + const std::array<float, kFftLengthBy2Plus1>& render_power, + const RenderSignalAnalyzer& render_signal_analyzer, + const SubtractorOutput& subtractor_output, + rtc::ArrayView<const float> erl, + size_t size_partitions, + bool saturated_capture_signal, + bool disallow_leakage_diverged, + FftData* gain_fft) { + RTC_DCHECK(gain_fft); + // Introducing shorter notation to improve readability. + const FftData& E_refined = subtractor_output.E_refined; + const auto& E2_refined = subtractor_output.E2_refined; + const auto& E2_coarse = subtractor_output.E2_coarse; + FftData* G = gain_fft; + const auto& X2 = render_power; + + ++call_counter_; + + UpdateCurrentConfig(); + + if (render_signal_analyzer.PoorSignalExcitation()) { + poor_excitation_counter_ = 0; + } + + // Do not update the filter if the render is not sufficiently excited. + if (++poor_excitation_counter_ < size_partitions || + saturated_capture_signal || call_counter_ <= size_partitions) { + G->re.fill(0.f); + G->im.fill(0.f); + } else { + // Corresponds to WGN of power -39 dBFS. + std::array<float, kFftLengthBy2Plus1> mu; + // mu = H_error / (0.5* H_error* X2 + n * E2). + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (X2[k] >= current_config_.noise_gate) { + mu[k] = H_error_[k] / + (0.5f * H_error_[k] * X2[k] + size_partitions * E2_refined[k]); + } else { + mu[k] = 0.f; + } + } + + // Avoid updating the filter close to narrow bands in the render signals. + render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu); + + // H_error = H_error - 0.5 * mu * X2 * H_error. + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k]; + } + + // G = mu * E. + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + G->re[k] = mu[k] * E_refined.re[k]; + G->im[k] = mu[k] * E_refined.im[k]; + } + } + + // H_error = H_error + factor * erl. + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (E2_refined[k] <= E2_coarse[k] || disallow_leakage_diverged) { + H_error_[k] += current_config_.leakage_converged * erl[k]; + } else { + H_error_[k] += current_config_.leakage_diverged * erl[k]; + } + + H_error_[k] = std::max(H_error_[k], current_config_.error_floor); + H_error_[k] = std::min(H_error_[k], current_config_.error_ceil); + } + + data_dumper_->DumpRaw("aec3_refined_gain_H_error", H_error_); +} + +void RefinedFilterUpdateGain::UpdateCurrentConfig() { + RTC_DCHECK_GE(config_change_duration_blocks_, config_change_counter_); + if (config_change_counter_ > 0) { + if (--config_change_counter_ > 0) { + auto average = [](float from, float to, float from_weight) { + return from * from_weight + to * (1.f - from_weight); + }; + + float change_factor = + config_change_counter_ * one_by_config_change_duration_blocks_; + + current_config_.leakage_converged = + average(old_target_config_.leakage_converged, + target_config_.leakage_converged, change_factor); + current_config_.leakage_diverged = + average(old_target_config_.leakage_diverged, + target_config_.leakage_diverged, change_factor); + current_config_.error_floor = + average(old_target_config_.error_floor, target_config_.error_floor, + change_factor); + current_config_.error_ceil = + average(old_target_config_.error_ceil, target_config_.error_ceil, + change_factor); + current_config_.noise_gate = + average(old_target_config_.noise_gate, target_config_.noise_gate, + change_factor); + } else { + current_config_ = old_target_config_ = target_config_; + } + } + RTC_DCHECK_LE(0, config_change_counter_); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.h b/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.h new file mode 100644 index 0000000000..1a68ebc296 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_REFINED_FILTER_UPDATE_GAIN_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_REFINED_FILTER_UPDATE_GAIN_H_ + +#include <stddef.h> + +#include <array> +#include <atomic> +#include <memory> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +class AdaptiveFirFilter; +class ApmDataDumper; +struct EchoPathVariability; +struct FftData; +class RenderSignalAnalyzer; +struct SubtractorOutput; + +// Provides functionality for computing the adaptive gain for the refined +// filter. +class RefinedFilterUpdateGain { + public: + RefinedFilterUpdateGain( + const EchoCanceller3Config::Filter::RefinedConfiguration& config, + size_t config_change_duration_blocks); + ~RefinedFilterUpdateGain(); + + RefinedFilterUpdateGain(const RefinedFilterUpdateGain&) = delete; + RefinedFilterUpdateGain& operator=(const RefinedFilterUpdateGain&) = delete; + + // Takes action in the case of a known echo path change. + void HandleEchoPathChange(const EchoPathVariability& echo_path_variability); + + // Computes the gain. + void Compute(const std::array<float, kFftLengthBy2Plus1>& render_power, + const RenderSignalAnalyzer& render_signal_analyzer, + const SubtractorOutput& subtractor_output, + rtc::ArrayView<const float> erl, + size_t size_partitions, + bool saturated_capture_signal, + bool disallow_leakage_diverged, + FftData* gain_fft); + + // Sets a new config. + void SetConfig( + const EchoCanceller3Config::Filter::RefinedConfiguration& config, + bool immediate_effect) { + if (immediate_effect) { + old_target_config_ = current_config_ = target_config_ = config; + config_change_counter_ = 0; + } else { + old_target_config_ = current_config_; + target_config_ = config; + config_change_counter_ = config_change_duration_blocks_; + } + } + + private: + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const int config_change_duration_blocks_; + float one_by_config_change_duration_blocks_; + EchoCanceller3Config::Filter::RefinedConfiguration current_config_; + EchoCanceller3Config::Filter::RefinedConfiguration target_config_; + EchoCanceller3Config::Filter::RefinedConfiguration old_target_config_; + std::array<float, kFftLengthBy2Plus1> H_error_; + size_t poor_excitation_counter_; + size_t call_counter_ = 0; + int config_change_counter_ = 0; + + // Updates the current config towards the target config. + void UpdateCurrentConfig(); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_REFINED_FILTER_UPDATE_GAIN_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain_unittest.cc new file mode 100644 index 0000000000..c77c5b53d5 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/refined_filter_update_gain_unittest.cc @@ -0,0 +1,392 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/refined_filter_update_gain.h" + +#include <algorithm> +#include <numeric> +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/coarse_filter_update_gain.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/numerics/safe_minmax.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +// Method for performing the simulations needed to test the refined filter +// update gain functionality. +void RunFilterUpdateTest(int num_blocks_to_process, + size_t delay_samples, + int filter_length_blocks, + const std::vector<int>& blocks_with_echo_path_changes, + const std::vector<int>& blocks_with_saturation, + bool use_silent_render_in_second_half, + std::array<float, kBlockSize>* e_last_block, + std::array<float, kBlockSize>* y_last_block, + FftData* G_last_block) { + ApmDataDumper data_dumper(42); + Aec3Optimization optimization = DetectOptimization(); + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + EchoCanceller3Config config; + config.filter.refined.length_blocks = filter_length_blocks; + config.filter.coarse.length_blocks = filter_length_blocks; + AdaptiveFirFilter refined_filter( + config.filter.refined.length_blocks, config.filter.refined.length_blocks, + config.filter.config_change_duration_blocks, kNumRenderChannels, + optimization, &data_dumper); + AdaptiveFirFilter coarse_filter( + config.filter.coarse.length_blocks, config.filter.coarse.length_blocks, + config.filter.config_change_duration_blocks, kNumRenderChannels, + optimization, &data_dumper); + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2( + kNumCaptureChannels, std::vector<std::array<float, kFftLengthBy2Plus1>>( + refined_filter.max_filter_size_partitions(), + std::array<float, kFftLengthBy2Plus1>())); + for (auto& H2_ch : H2) { + for (auto& H2_k : H2_ch) { + H2_k.fill(0.f); + } + } + std::vector<std::vector<float>> h( + kNumCaptureChannels, + std::vector<float>( + GetTimeDomainLength(refined_filter.max_filter_size_partitions()), + 0.f)); + + Aec3Fft fft; + std::array<float, kBlockSize> x_old; + x_old.fill(0.f); + CoarseFilterUpdateGain coarse_gain( + config.filter.coarse, config.filter.config_change_duration_blocks); + RefinedFilterUpdateGain refined_gain( + config.filter.refined, config.filter.config_change_duration_blocks); + Random random_generator(42U); + Block x(kNumBands, kNumRenderChannels); + std::vector<float> y(kBlockSize, 0.f); + config.delay.default_delay = 1; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); + AecState aec_state(config, kNumCaptureChannels); + RenderSignalAnalyzer render_signal_analyzer(config); + absl::optional<DelayEstimate> delay_estimate; + std::array<float, kFftLength> s_scratch; + std::array<float, kBlockSize> s; + FftData S; + FftData G; + std::vector<SubtractorOutput> output(kNumCaptureChannels); + for (auto& subtractor_output : output) { + subtractor_output.Reset(); + } + FftData& E_refined = output[0].E_refined; + FftData E_coarse; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels); + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( + kNumCaptureChannels); + std::array<float, kBlockSize>& e_refined = output[0].e_refined; + std::array<float, kBlockSize>& e_coarse = output[0].e_coarse; + for (auto& Y2_ch : Y2) { + Y2_ch.fill(0.f); + } + + constexpr float kScale = 1.0f / kFftLengthBy2; + + DelayBuffer<float> delay_buffer(delay_samples); + for (int k = 0; k < num_blocks_to_process; ++k) { + // Handle echo path changes. + if (std::find(blocks_with_echo_path_changes.begin(), + blocks_with_echo_path_changes.end(), + k) != blocks_with_echo_path_changes.end()) { + refined_filter.HandleEchoPathChange(); + } + + // Handle saturation. + const bool saturation = + std::find(blocks_with_saturation.begin(), blocks_with_saturation.end(), + k) != blocks_with_saturation.end(); + + // Create the render signal. + if (use_silent_render_in_second_half && k > num_blocks_to_process / 2) { + for (int band = 0; band < x.NumBands(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + std::fill(x.begin(band, channel), x.end(band, channel), 0.f); + } + } + } else { + for (int band = 0; band < x.NumChannels(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + RandomizeSampleVector(&random_generator, x.View(band, channel)); + } + } + } + delay_buffer.Delay(x.View(/*band=*/0, /*channel=*/0), y); + + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + + render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(), + aec_state.MinDirectPathFilterDelay()); + + // Apply the refined filter. + refined_filter.Filter(*render_delay_buffer->GetRenderBuffer(), &S); + fft.Ifft(S, &s_scratch); + std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2, + e_refined.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e_refined.begin(), e_refined.end(), + [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); + fft.ZeroPaddedFft(e_refined, Aec3Fft::Window::kRectangular, &E_refined); + for (size_t k = 0; k < kBlockSize; ++k) { + s[k] = kScale * s_scratch[k + kFftLengthBy2]; + } + + // Apply the coarse filter. + coarse_filter.Filter(*render_delay_buffer->GetRenderBuffer(), &S); + fft.Ifft(S, &s_scratch); + std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2, + e_coarse.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e_coarse.begin(), e_coarse.end(), + [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); + fft.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kRectangular, &E_coarse); + + // Compute spectra for future use. + E_refined.Spectrum(Aec3Optimization::kNone, output[0].E2_refined); + E_coarse.Spectrum(Aec3Optimization::kNone, output[0].E2_coarse); + + // Adapt the coarse filter. + std::array<float, kFftLengthBy2Plus1> render_power; + render_delay_buffer->GetRenderBuffer()->SpectralSum( + coarse_filter.SizePartitions(), &render_power); + coarse_gain.Compute(render_power, render_signal_analyzer, E_coarse, + coarse_filter.SizePartitions(), saturation, &G); + coarse_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G); + + // Adapt the refined filter + render_delay_buffer->GetRenderBuffer()->SpectralSum( + refined_filter.SizePartitions(), &render_power); + + std::array<float, kFftLengthBy2Plus1> erl; + ComputeErl(optimization, H2[0], erl); + refined_gain.Compute(render_power, render_signal_analyzer, output[0], erl, + refined_filter.SizePartitions(), saturation, false, + &G); + refined_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G, &h[0]); + + // Update the delay. + aec_state.HandleEchoPathChange(EchoPathVariability( + false, EchoPathVariability::DelayAdjustment::kNone, false)); + refined_filter.ComputeFrequencyResponse(&H2[0]); + std::copy(output[0].E2_refined.begin(), output[0].E2_refined.end(), + E2_refined[0].begin()); + aec_state.Update(delay_estimate, H2, h, + *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2, + output); + } + + std::copy(e_refined.begin(), e_refined.end(), e_last_block->begin()); + std::copy(y.begin(), y.end(), y_last_block->begin()); + std::copy(G.re.begin(), G.re.end(), G_last_block->re.begin()); + std::copy(G.im.begin(), G.im.end(), G_last_block->im.begin()); +} + +std::string ProduceDebugText(int filter_length_blocks) { + rtc::StringBuilder ss; + ss << "Length: " << filter_length_blocks; + return ss.Release(); +} + +std::string ProduceDebugText(size_t delay, int filter_length_blocks) { + rtc::StringBuilder ss; + ss << "Delay: " << delay << ", "; + ss << ProduceDebugText(filter_length_blocks); + return ss.Release(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gain parameter works. +TEST(RefinedFilterUpdateGainDeathTest, NullDataOutputGain) { + ApmDataDumper data_dumper(42); + EchoCanceller3Config config; + RenderSignalAnalyzer analyzer(config); + SubtractorOutput output; + RefinedFilterUpdateGain gain(config.filter.refined, + config.filter.config_change_duration_blocks); + std::array<float, kFftLengthBy2Plus1> render_power; + render_power.fill(0.f); + std::array<float, kFftLengthBy2Plus1> erl; + erl.fill(0.f); + EXPECT_DEATH( + gain.Compute(render_power, analyzer, output, erl, + config.filter.refined.length_blocks, false, false, nullptr), + ""); +} + +#endif + +// Verifies that the gain formed causes the filter using it to converge. +TEST(RefinedFilterUpdateGain, GainCausesFilterToConverge) { + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + for (size_t filter_length_blocks : {12, 20, 30}) { + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); + + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G; + + RunFilterUpdateTest(600, delay_samples, filter_length_blocks, + blocks_with_echo_path_changes, blocks_with_saturation, + false, &e, &y, &G); + + // Verify that the refined filter is able to perform well. + // Use different criteria to take overmodelling into account. + if (filter_length_blocks == 12) { + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } else { + EXPECT_LT(std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); + } + } + } +} + +// Verifies that the magnitude of the gain on average decreases for a +// persistently exciting signal. +TEST(RefinedFilterUpdateGain, DecreasingGain) { + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G_a; + FftData G_b; + FftData G_c; + std::array<float, kFftLengthBy2Plus1> G_a_power; + std::array<float, kFftLengthBy2Plus1> G_b_power; + std::array<float, kFftLengthBy2Plus1> G_c_power; + + RunFilterUpdateTest(250, 65, 12, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_a); + RunFilterUpdateTest(500, 65, 12, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_b); + RunFilterUpdateTest(750, 65, 12, blocks_with_echo_path_changes, + blocks_with_saturation, false, &e, &y, &G_c); + + G_a.Spectrum(Aec3Optimization::kNone, G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, G_b_power); + G_c.Spectrum(Aec3Optimization::kNone, G_c_power); + + EXPECT_GT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); + + EXPECT_GT(std::accumulate(G_b_power.begin(), G_b_power.end(), 0.), + std::accumulate(G_c_power.begin(), G_c_power.end(), 0.)); +} + +// Verifies that the gain is zero when there is saturation and that the internal +// error estimates cause the gain to increase after a period of saturation. +TEST(RefinedFilterUpdateGain, SaturationBehavior) { + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + for (int k = 99; k < 200; ++k) { + blocks_with_saturation.push_back(k); + } + + for (size_t filter_length_blocks : {12, 20, 30}) { + SCOPED_TRACE(ProduceDebugText(filter_length_blocks)); + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G_a; + FftData G_b; + FftData G_a_ref; + G_a_ref.re.fill(0.f); + G_a_ref.im.fill(0.f); + + std::array<float, kFftLengthBy2Plus1> G_a_power; + std::array<float, kFftLengthBy2Plus1> G_b_power; + + RunFilterUpdateTest(100, 65, filter_length_blocks, + blocks_with_echo_path_changes, blocks_with_saturation, + false, &e, &y, &G_a); + + EXPECT_EQ(G_a_ref.re, G_a.re); + EXPECT_EQ(G_a_ref.im, G_a.im); + + RunFilterUpdateTest(99, 65, filter_length_blocks, + blocks_with_echo_path_changes, blocks_with_saturation, + false, &e, &y, &G_a); + RunFilterUpdateTest(201, 65, filter_length_blocks, + blocks_with_echo_path_changes, blocks_with_saturation, + false, &e, &y, &G_b); + + G_a.Spectrum(Aec3Optimization::kNone, G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, G_b_power); + + EXPECT_LT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); + } +} + +// Verifies that the gain increases after an echo path change. +// TODO(peah): Correct and reactivate this test. +TEST(RefinedFilterUpdateGain, DISABLED_EchoPathChangeBehavior) { + for (size_t filter_length_blocks : {12, 20, 30}) { + SCOPED_TRACE(ProduceDebugText(filter_length_blocks)); + std::vector<int> blocks_with_echo_path_changes; + std::vector<int> blocks_with_saturation; + blocks_with_echo_path_changes.push_back(99); + + std::array<float, kBlockSize> e; + std::array<float, kBlockSize> y; + FftData G_a; + FftData G_b; + std::array<float, kFftLengthBy2Plus1> G_a_power; + std::array<float, kFftLengthBy2Plus1> G_b_power; + + RunFilterUpdateTest(100, 65, filter_length_blocks, + blocks_with_echo_path_changes, blocks_with_saturation, + false, &e, &y, &G_a); + RunFilterUpdateTest(101, 65, filter_length_blocks, + blocks_with_echo_path_changes, blocks_with_saturation, + false, &e, &y, &G_b); + + G_a.Spectrum(Aec3Optimization::kNone, G_a_power); + G_b.Spectrum(Aec3Optimization::kNone, G_b_power); + + EXPECT_LT(std::accumulate(G_a_power.begin(), G_a_power.end(), 0.), + std::accumulate(G_b_power.begin(), G_b_power.end(), 0.)); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.cc new file mode 100644 index 0000000000..aa511e2b6b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.cc @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_buffer.h" + +#include <algorithm> +#include <functional> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +RenderBuffer::RenderBuffer(BlockBuffer* block_buffer, + SpectrumBuffer* spectrum_buffer, + FftBuffer* fft_buffer) + : block_buffer_(block_buffer), + spectrum_buffer_(spectrum_buffer), + fft_buffer_(fft_buffer) { + RTC_DCHECK(block_buffer_); + RTC_DCHECK(spectrum_buffer_); + RTC_DCHECK(fft_buffer_); + RTC_DCHECK_EQ(block_buffer_->buffer.size(), fft_buffer_->buffer.size()); + RTC_DCHECK_EQ(spectrum_buffer_->buffer.size(), fft_buffer_->buffer.size()); + RTC_DCHECK_EQ(spectrum_buffer_->read, fft_buffer_->read); + RTC_DCHECK_EQ(spectrum_buffer_->write, fft_buffer_->write); +} + +RenderBuffer::~RenderBuffer() = default; + +void RenderBuffer::SpectralSum( + size_t num_spectra, + std::array<float, kFftLengthBy2Plus1>* X2) const { + X2->fill(0.f); + int position = spectrum_buffer_->read; + for (size_t j = 0; j < num_spectra; ++j) { + for (const auto& channel_spectrum : spectrum_buffer_->buffer[position]) { + for (size_t k = 0; k < X2->size(); ++k) { + (*X2)[k] += channel_spectrum[k]; + } + } + position = spectrum_buffer_->IncIndex(position); + } +} + +void RenderBuffer::SpectralSums( + size_t num_spectra_shorter, + size_t num_spectra_longer, + std::array<float, kFftLengthBy2Plus1>* X2_shorter, + std::array<float, kFftLengthBy2Plus1>* X2_longer) const { + RTC_DCHECK_LE(num_spectra_shorter, num_spectra_longer); + X2_shorter->fill(0.f); + int position = spectrum_buffer_->read; + size_t j = 0; + for (; j < num_spectra_shorter; ++j) { + for (const auto& channel_spectrum : spectrum_buffer_->buffer[position]) { + for (size_t k = 0; k < X2_shorter->size(); ++k) { + (*X2_shorter)[k] += channel_spectrum[k]; + } + } + position = spectrum_buffer_->IncIndex(position); + } + std::copy(X2_shorter->begin(), X2_shorter->end(), X2_longer->begin()); + for (; j < num_spectra_longer; ++j) { + for (const auto& channel_spectrum : spectrum_buffer_->buffer[position]) { + for (size_t k = 0; k < X2_longer->size(); ++k) { + (*X2_longer)[k] += channel_spectrum[k]; + } + } + position = spectrum_buffer_->IncIndex(position); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.h new file mode 100644 index 0000000000..8adc996087 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_RENDER_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_RENDER_BUFFER_H_ + +#include <stddef.h> + +#include <array> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block_buffer.h" +#include "modules/audio_processing/aec3/fft_buffer.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Provides a buffer of the render data for the echo remover. +class RenderBuffer { + public: + RenderBuffer(BlockBuffer* block_buffer, + SpectrumBuffer* spectrum_buffer, + FftBuffer* fft_buffer); + + RenderBuffer() = delete; + RenderBuffer(const RenderBuffer&) = delete; + RenderBuffer& operator=(const RenderBuffer&) = delete; + + ~RenderBuffer(); + + // Get a block. + const Block& GetBlock(int buffer_offset_blocks) const { + int position = + block_buffer_->OffsetIndex(block_buffer_->read, buffer_offset_blocks); + return block_buffer_->buffer[position]; + } + + // Get the spectrum from one of the FFTs in the buffer. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Spectrum( + int buffer_offset_ffts) const { + int position = spectrum_buffer_->OffsetIndex(spectrum_buffer_->read, + buffer_offset_ffts); + return spectrum_buffer_->buffer[position]; + } + + // Returns the circular fft buffer. + rtc::ArrayView<const std::vector<FftData>> GetFftBuffer() const { + return fft_buffer_->buffer; + } + + // Returns the current position in the circular buffer. + size_t Position() const { + RTC_DCHECK_EQ(spectrum_buffer_->read, fft_buffer_->read); + RTC_DCHECK_EQ(spectrum_buffer_->write, fft_buffer_->write); + return fft_buffer_->read; + } + + // Returns the sum of the spectrums for a certain number of FFTs. + void SpectralSum(size_t num_spectra, + std::array<float, kFftLengthBy2Plus1>* X2) const; + + // Returns the sums of the spectrums for two numbers of FFTs. + void SpectralSums(size_t num_spectra_shorter, + size_t num_spectra_longer, + std::array<float, kFftLengthBy2Plus1>* X2_shorter, + std::array<float, kFftLengthBy2Plus1>* X2_longer) const; + + // Gets the recent activity seen in the render signal. + bool GetRenderActivity() const { return render_activity_; } + + // Specifies the recent activity seen in the render signal. + void SetRenderActivity(bool activity) { render_activity_ = activity; } + + // Returns the headroom between the write and the read positions in the + // buffer. + int Headroom() const { + // The write and read indices are decreased over time. + int headroom = + fft_buffer_->write < fft_buffer_->read + ? fft_buffer_->read - fft_buffer_->write + : fft_buffer_->size - fft_buffer_->write + fft_buffer_->read; + + RTC_DCHECK_LE(0, headroom); + RTC_DCHECK_GE(fft_buffer_->size, headroom); + + return headroom; + } + + // Returns a reference to the spectrum buffer. + const SpectrumBuffer& GetSpectrumBuffer() const { return *spectrum_buffer_; } + + // Returns a reference to the block buffer. + const BlockBuffer& GetBlockBuffer() const { return *block_buffer_; } + + private: + const BlockBuffer* const block_buffer_; + const SpectrumBuffer* const spectrum_buffer_; + const FftBuffer* const fft_buffer_; + bool render_activity_ = false; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_RENDER_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer_gn/moz.build new file mode 100644 index 0000000000..b7a10f5d7c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer_gn/moz.build @@ -0,0 +1,205 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("render_buffer_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer_unittest.cc new file mode 100644 index 0000000000..5d9d646e76 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_buffer_unittest.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_buffer.h" + +#include <algorithm> +#include <functional> +#include <vector> + +#include "test/gtest.h" + +namespace webrtc { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for non-null fft buffer. +TEST(RenderBufferDeathTest, NullExternalFftBuffer) { + BlockBuffer block_buffer(10, 3, 1); + SpectrumBuffer spectrum_buffer(10, 1); + EXPECT_DEATH(RenderBuffer(&block_buffer, &spectrum_buffer, nullptr), ""); +} + +// Verifies the check for non-null spectrum buffer. +TEST(RenderBufferDeathTest, NullExternalSpectrumBuffer) { + FftBuffer fft_buffer(10, 1); + BlockBuffer block_buffer(10, 3, 1); + EXPECT_DEATH(RenderBuffer(&block_buffer, nullptr, &fft_buffer), ""); +} + +// Verifies the check for non-null block buffer. +TEST(RenderBufferDeathTest, NullExternalBlockBuffer) { + FftBuffer fft_buffer(10, 1); + SpectrumBuffer spectrum_buffer(10, 1); + EXPECT_DEATH(RenderBuffer(nullptr, &spectrum_buffer, &fft_buffer), ""); +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.cc new file mode 100644 index 0000000000..ec5d35507e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.cc @@ -0,0 +1,519 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_delay_buffer.h" + +#include <string.h> + +#include <algorithm> +#include <atomic> +#include <cmath> +#include <memory> +#include <numeric> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/alignment_mixer.h" +#include "modules/audio_processing/aec3/block_buffer.h" +#include "modules/audio_processing/aec3/decimator.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/aec3/fft_buffer.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { +namespace { + +bool UpdateCaptureCallCounterOnSkippedBlocks() { + return !field_trial::IsEnabled( + "WebRTC-Aec3RenderBufferCallCounterUpdateKillSwitch"); +} + +class RenderDelayBufferImpl final : public RenderDelayBuffer { + public: + RenderDelayBufferImpl(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels); + RenderDelayBufferImpl() = delete; + ~RenderDelayBufferImpl() override; + + void Reset() override; + BufferingEvent Insert(const Block& block) override; + BufferingEvent PrepareCaptureProcessing() override; + void HandleSkippedCaptureProcessing() override; + bool AlignFromDelay(size_t delay) override; + void AlignFromExternalDelay() override; + size_t Delay() const override { return ComputeDelay(); } + size_t MaxDelay() const override { + return blocks_.buffer.size() - 1 - buffer_headroom_; + } + RenderBuffer* GetRenderBuffer() override { return &echo_remover_buffer_; } + + const DownsampledRenderBuffer& GetDownsampledRenderBuffer() const override { + return low_rate_; + } + + int BufferLatency() const; + void SetAudioBufferDelay(int delay_ms) override; + bool HasReceivedBufferDelay() override; + + private: + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const Aec3Optimization optimization_; + const EchoCanceller3Config config_; + const bool update_capture_call_counter_on_skipped_blocks_; + const float render_linear_amplitude_gain_; + const rtc::LoggingSeverity delay_log_level_; + size_t down_sampling_factor_; + const int sub_block_size_; + BlockBuffer blocks_; + SpectrumBuffer spectra_; + FftBuffer ffts_; + absl::optional<size_t> delay_; + RenderBuffer echo_remover_buffer_; + DownsampledRenderBuffer low_rate_; + AlignmentMixer render_mixer_; + Decimator render_decimator_; + const Aec3Fft fft_; + std::vector<float> render_ds_; + const int buffer_headroom_; + bool last_call_was_render_ = false; + int num_api_calls_in_a_row_ = 0; + int max_observed_jitter_ = 1; + int64_t capture_call_counter_ = 0; + int64_t render_call_counter_ = 0; + bool render_activity_ = false; + size_t render_activity_counter_ = 0; + absl::optional<int> external_audio_buffer_delay_; + bool external_audio_buffer_delay_verified_after_reset_ = false; + size_t min_latency_blocks_ = 0; + size_t excess_render_detection_counter_ = 0; + + int MapDelayToTotalDelay(size_t delay) const; + int ComputeDelay() const; + void ApplyTotalDelay(int delay); + void InsertBlock(const Block& block, int previous_write); + bool DetectActiveRender(rtc::ArrayView<const float> x) const; + bool DetectExcessRenderBlocks(); + void IncrementWriteIndices(); + void IncrementLowRateReadIndices(); + void IncrementReadIndices(); + bool RenderOverrun(); + bool RenderUnderrun(); +}; + +std::atomic<int> RenderDelayBufferImpl::instance_count_ = 0; + +RenderDelayBufferImpl::RenderDelayBufferImpl(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + optimization_(DetectOptimization()), + config_(config), + update_capture_call_counter_on_skipped_blocks_( + UpdateCaptureCallCounterOnSkippedBlocks()), + render_linear_amplitude_gain_( + std::pow(10.0f, config_.render_levels.render_power_gain_db / 20.f)), + delay_log_level_(config_.delay.log_warning_on_delay_changes + ? rtc::LS_WARNING + : rtc::LS_VERBOSE), + down_sampling_factor_(config.delay.down_sampling_factor), + sub_block_size_(static_cast<int>(down_sampling_factor_ > 0 + ? kBlockSize / down_sampling_factor_ + : kBlockSize)), + blocks_(GetRenderDelayBufferSize(down_sampling_factor_, + config.delay.num_filters, + config.filter.refined.length_blocks), + NumBandsForRate(sample_rate_hz), + num_render_channels), + spectra_(blocks_.buffer.size(), num_render_channels), + ffts_(blocks_.buffer.size(), num_render_channels), + delay_(config_.delay.default_delay), + echo_remover_buffer_(&blocks_, &spectra_, &ffts_), + low_rate_(GetDownSampledBufferSize(down_sampling_factor_, + config.delay.num_filters)), + render_mixer_(num_render_channels, config.delay.render_alignment_mixing), + render_decimator_(down_sampling_factor_), + fft_(), + render_ds_(sub_block_size_, 0.f), + buffer_headroom_(config.filter.refined.length_blocks) { + RTC_DCHECK_EQ(blocks_.buffer.size(), ffts_.buffer.size()); + RTC_DCHECK_EQ(spectra_.buffer.size(), ffts_.buffer.size()); + for (size_t i = 0; i < blocks_.buffer.size(); ++i) { + RTC_DCHECK_EQ(blocks_.buffer[i].NumChannels(), ffts_.buffer[i].size()); + RTC_DCHECK_EQ(spectra_.buffer[i].size(), ffts_.buffer[i].size()); + } + + Reset(); +} + +RenderDelayBufferImpl::~RenderDelayBufferImpl() = default; + +// Resets the buffer delays and clears the reported delays. +void RenderDelayBufferImpl::Reset() { + last_call_was_render_ = false; + num_api_calls_in_a_row_ = 1; + min_latency_blocks_ = 0; + excess_render_detection_counter_ = 0; + + // Initialize the read index to one sub-block before the write index. + low_rate_.read = low_rate_.OffsetIndex(low_rate_.write, sub_block_size_); + + // Check for any external audio buffer delay and whether it is feasible. + if (external_audio_buffer_delay_) { + const int headroom = 2; + size_t audio_buffer_delay_to_set; + // Minimum delay is 1 (like the low-rate render buffer). + if (*external_audio_buffer_delay_ <= headroom) { + audio_buffer_delay_to_set = 1; + } else { + audio_buffer_delay_to_set = *external_audio_buffer_delay_ - headroom; + } + + audio_buffer_delay_to_set = std::min(audio_buffer_delay_to_set, MaxDelay()); + + // When an external delay estimate is available, use that delay as the + // initial render buffer delay. + ApplyTotalDelay(audio_buffer_delay_to_set); + delay_ = ComputeDelay(); + + external_audio_buffer_delay_verified_after_reset_ = false; + } else { + // If an external delay estimate is not available, use that delay as the + // initial delay. Set the render buffer delays to the default delay. + ApplyTotalDelay(config_.delay.default_delay); + + // Unset the delays which are set by AlignFromDelay. + delay_ = absl::nullopt; + } +} + +// Inserts a new block into the render buffers. +RenderDelayBuffer::BufferingEvent RenderDelayBufferImpl::Insert( + const Block& block) { + ++render_call_counter_; + if (delay_) { + if (!last_call_was_render_) { + last_call_was_render_ = true; + num_api_calls_in_a_row_ = 1; + } else { + if (++num_api_calls_in_a_row_ > max_observed_jitter_) { + max_observed_jitter_ = num_api_calls_in_a_row_; + RTC_LOG_V(delay_log_level_) + << "New max number api jitter observed at render block " + << render_call_counter_ << ": " << num_api_calls_in_a_row_ + << " blocks"; + } + } + } + + // Increase the write indices to where the new blocks should be written. + const int previous_write = blocks_.write; + IncrementWriteIndices(); + + // Allow overrun and do a reset when render overrun occurrs due to more render + // data being inserted than capture data is received. + BufferingEvent event = + RenderOverrun() ? BufferingEvent::kRenderOverrun : BufferingEvent::kNone; + + // Detect and update render activity. + if (!render_activity_) { + render_activity_counter_ += + DetectActiveRender(block.View(/*band=*/0, /*channel=*/0)) ? 1 : 0; + render_activity_ = render_activity_counter_ >= 20; + } + + // Insert the new render block into the specified position. + InsertBlock(block, previous_write); + + if (event != BufferingEvent::kNone) { + Reset(); + } + + return event; +} + +void RenderDelayBufferImpl::HandleSkippedCaptureProcessing() { + if (update_capture_call_counter_on_skipped_blocks_) { + ++capture_call_counter_; + } +} + +// Prepares the render buffers for processing another capture block. +RenderDelayBuffer::BufferingEvent +RenderDelayBufferImpl::PrepareCaptureProcessing() { + RenderDelayBuffer::BufferingEvent event = BufferingEvent::kNone; + ++capture_call_counter_; + + if (delay_) { + if (last_call_was_render_) { + last_call_was_render_ = false; + num_api_calls_in_a_row_ = 1; + } else { + if (++num_api_calls_in_a_row_ > max_observed_jitter_) { + max_observed_jitter_ = num_api_calls_in_a_row_; + RTC_LOG_V(delay_log_level_) + << "New max number api jitter observed at capture block " + << capture_call_counter_ << ": " << num_api_calls_in_a_row_ + << " blocks"; + } + } + } + + if (DetectExcessRenderBlocks()) { + // Too many render blocks compared to capture blocks. Risk of delay ending + // up before the filter used by the delay estimator. + RTC_LOG_V(delay_log_level_) + << "Excess render blocks detected at block " << capture_call_counter_; + Reset(); + event = BufferingEvent::kRenderOverrun; + } else if (RenderUnderrun()) { + // Don't increment the read indices of the low rate buffer if there is a + // render underrun. + RTC_LOG_V(delay_log_level_) + << "Render buffer underrun detected at block " << capture_call_counter_; + IncrementReadIndices(); + // Incrementing the buffer index without increasing the low rate buffer + // index means that the delay is reduced by one. + if (delay_ && *delay_ > 0) + delay_ = *delay_ - 1; + event = BufferingEvent::kRenderUnderrun; + } else { + // Increment the read indices in the render buffers to point to the most + // recent block to use in the capture processing. + IncrementLowRateReadIndices(); + IncrementReadIndices(); + } + + echo_remover_buffer_.SetRenderActivity(render_activity_); + if (render_activity_) { + render_activity_counter_ = 0; + render_activity_ = false; + } + + return event; +} + +// Sets the delay and returns a bool indicating whether the delay was changed. +bool RenderDelayBufferImpl::AlignFromDelay(size_t delay) { + RTC_DCHECK(!config_.delay.use_external_delay_estimator); + if (!external_audio_buffer_delay_verified_after_reset_ && + external_audio_buffer_delay_ && delay_) { + int difference = static_cast<int>(delay) - static_cast<int>(*delay_); + RTC_LOG_V(delay_log_level_) + << "Mismatch between first estimated delay after reset " + "and externally reported audio buffer delay: " + << difference << " blocks"; + external_audio_buffer_delay_verified_after_reset_ = true; + } + if (delay_ && *delay_ == delay) { + return false; + } + delay_ = delay; + + // Compute the total delay and limit the delay to the allowed range. + int total_delay = MapDelayToTotalDelay(*delay_); + total_delay = + std::min(MaxDelay(), static_cast<size_t>(std::max(total_delay, 0))); + + // Apply the delay to the buffers. + ApplyTotalDelay(total_delay); + return true; +} + +void RenderDelayBufferImpl::SetAudioBufferDelay(int delay_ms) { + if (!external_audio_buffer_delay_) { + RTC_LOG_V(delay_log_level_) + << "Receiving a first externally reported audio buffer delay of " + << delay_ms << " ms."; + } + + // Convert delay from milliseconds to blocks (rounded down). + external_audio_buffer_delay_ = delay_ms / 4; +} + +bool RenderDelayBufferImpl::HasReceivedBufferDelay() { + return external_audio_buffer_delay_.has_value(); +} + +// Maps the externally computed delay to the delay used internally. +int RenderDelayBufferImpl::MapDelayToTotalDelay( + size_t external_delay_blocks) const { + const int latency_blocks = BufferLatency(); + return latency_blocks + static_cast<int>(external_delay_blocks); +} + +// Returns the delay (not including call jitter). +int RenderDelayBufferImpl::ComputeDelay() const { + const int latency_blocks = BufferLatency(); + int internal_delay = spectra_.read >= spectra_.write + ? spectra_.read - spectra_.write + : spectra_.size + spectra_.read - spectra_.write; + + return internal_delay - latency_blocks; +} + +// Set the read indices according to the delay. +void RenderDelayBufferImpl::ApplyTotalDelay(int delay) { + RTC_LOG_V(delay_log_level_) + << "Applying total delay of " << delay << " blocks."; + blocks_.read = blocks_.OffsetIndex(blocks_.write, -delay); + spectra_.read = spectra_.OffsetIndex(spectra_.write, delay); + ffts_.read = ffts_.OffsetIndex(ffts_.write, delay); +} + +void RenderDelayBufferImpl::AlignFromExternalDelay() { + RTC_DCHECK(config_.delay.use_external_delay_estimator); + if (external_audio_buffer_delay_) { + const int64_t delay = render_call_counter_ - capture_call_counter_ + + *external_audio_buffer_delay_; + const int64_t delay_with_headroom = + delay - config_.delay.delay_headroom_samples / kBlockSize; + ApplyTotalDelay(delay_with_headroom); + } +} + +// Inserts a block into the render buffers. +void RenderDelayBufferImpl::InsertBlock(const Block& block, + int previous_write) { + auto& b = blocks_; + auto& lr = low_rate_; + auto& ds = render_ds_; + auto& f = ffts_; + auto& s = spectra_; + const size_t num_bands = b.buffer[b.write].NumBands(); + const size_t num_render_channels = b.buffer[b.write].NumChannels(); + RTC_DCHECK_EQ(block.NumBands(), num_bands); + RTC_DCHECK_EQ(block.NumChannels(), num_render_channels); + for (size_t band = 0; band < num_bands; ++band) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + std::copy(block.begin(band, ch), block.end(band, ch), + b.buffer[b.write].begin(band, ch)); + } + } + + if (render_linear_amplitude_gain_ != 1.f) { + for (size_t band = 0; band < num_bands; ++band) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + rtc::ArrayView<float, kBlockSize> b_view = + b.buffer[b.write].View(band, ch); + for (float& sample : b_view) { + sample *= render_linear_amplitude_gain_; + } + } + } + } + + std::array<float, kBlockSize> downmixed_render; + render_mixer_.ProduceOutput(b.buffer[b.write], downmixed_render); + render_decimator_.Decimate(downmixed_render, ds); + data_dumper_->DumpWav("aec3_render_decimator_output", ds.size(), ds.data(), + 16000 / down_sampling_factor_, 1); + std::copy(ds.rbegin(), ds.rend(), lr.buffer.begin() + lr.write); + for (int channel = 0; channel < b.buffer[b.write].NumChannels(); ++channel) { + fft_.PaddedFft(b.buffer[b.write].View(/*band=*/0, channel), + b.buffer[previous_write].View(/*band=*/0, channel), + &f.buffer[f.write][channel]); + f.buffer[f.write][channel].Spectrum(optimization_, + s.buffer[s.write][channel]); + } +} + +bool RenderDelayBufferImpl::DetectActiveRender( + rtc::ArrayView<const float> x) const { + const float x_energy = std::inner_product(x.begin(), x.end(), x.begin(), 0.f); + return x_energy > (config_.render_levels.active_render_limit * + config_.render_levels.active_render_limit) * + kFftLengthBy2; +} + +bool RenderDelayBufferImpl::DetectExcessRenderBlocks() { + bool excess_render_detected = false; + const size_t latency_blocks = static_cast<size_t>(BufferLatency()); + // The recently seen minimum latency in blocks. Should be close to 0. + min_latency_blocks_ = std::min(min_latency_blocks_, latency_blocks); + // After processing a configurable number of blocks the minimum latency is + // checked. + if (++excess_render_detection_counter_ >= + config_.buffering.excess_render_detection_interval_blocks) { + // If the minimum latency is not lower than the threshold there have been + // more render than capture frames. + excess_render_detected = min_latency_blocks_ > + config_.buffering.max_allowed_excess_render_blocks; + // Reset the counter and let the minimum latency be the current latency. + min_latency_blocks_ = latency_blocks; + excess_render_detection_counter_ = 0; + } + + data_dumper_->DumpRaw("aec3_latency_blocks", latency_blocks); + data_dumper_->DumpRaw("aec3_min_latency_blocks", min_latency_blocks_); + data_dumper_->DumpRaw("aec3_excess_render_detected", excess_render_detected); + return excess_render_detected; +} + +// Computes the latency in the buffer (the number of unread sub-blocks). +int RenderDelayBufferImpl::BufferLatency() const { + const DownsampledRenderBuffer& l = low_rate_; + int latency_samples = (l.buffer.size() + l.read - l.write) % l.buffer.size(); + int latency_blocks = latency_samples / sub_block_size_; + return latency_blocks; +} + +// Increments the write indices for the render buffers. +void RenderDelayBufferImpl::IncrementWriteIndices() { + low_rate_.UpdateWriteIndex(-sub_block_size_); + blocks_.IncWriteIndex(); + spectra_.DecWriteIndex(); + ffts_.DecWriteIndex(); +} + +// Increments the read indices of the low rate render buffers. +void RenderDelayBufferImpl::IncrementLowRateReadIndices() { + low_rate_.UpdateReadIndex(-sub_block_size_); +} + +// Increments the read indices for the render buffers. +void RenderDelayBufferImpl::IncrementReadIndices() { + if (blocks_.read != blocks_.write) { + blocks_.IncReadIndex(); + spectra_.DecReadIndex(); + ffts_.DecReadIndex(); + } +} + +// Checks for a render buffer overrun. +bool RenderDelayBufferImpl::RenderOverrun() { + return low_rate_.read == low_rate_.write || blocks_.read == blocks_.write; +} + +// Checks for a render buffer underrun. +bool RenderDelayBufferImpl::RenderUnderrun() { + return low_rate_.read == low_rate_.write; +} + +} // namespace + +RenderDelayBuffer* RenderDelayBuffer::Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels) { + return new RenderDelayBufferImpl(config, sample_rate_hz, num_render_channels); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.h new file mode 100644 index 0000000000..6dc1aefb85 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_BUFFER_H_ + +#include <stddef.h> + +#include <vector> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/aec3/render_buffer.h" + +namespace webrtc { + +// Class for buffering the incoming render blocks such that these may be +// extracted with a specified delay. +class RenderDelayBuffer { + public: + enum class BufferingEvent { + kNone, + kRenderUnderrun, + kRenderOverrun, + kApiCallSkew + }; + + static RenderDelayBuffer* Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_render_channels); + virtual ~RenderDelayBuffer() = default; + + // Resets the buffer alignment. + virtual void Reset() = 0; + + // Inserts a block into the buffer. + virtual BufferingEvent Insert(const Block& block) = 0; + + // Updates the buffers one step based on the specified buffer delay. Returns + // an enum indicating whether there was a special event that occurred. + virtual BufferingEvent PrepareCaptureProcessing() = 0; + + // Called on capture blocks where PrepareCaptureProcessing is not called. + virtual void HandleSkippedCaptureProcessing() = 0; + + // Sets the buffer delay and returns a bool indicating whether the delay + // changed. + virtual bool AlignFromDelay(size_t delay) = 0; + + // Sets the buffer delay from the most recently reported external delay. + virtual void AlignFromExternalDelay() = 0; + + // Gets the buffer delay. + virtual size_t Delay() const = 0; + + // Gets the buffer delay. + virtual size_t MaxDelay() const = 0; + + // Returns the render buffer for the echo remover. + virtual RenderBuffer* GetRenderBuffer() = 0; + + // Returns the downsampled render buffer. + virtual const DownsampledRenderBuffer& GetDownsampledRenderBuffer() const = 0; + + // Returns the maximum non calusal offset that can occur in the delay buffer. + static int DelayEstimatorOffset(const EchoCanceller3Config& config); + + // Provides an optional external estimate of the audio buffer delay. + virtual void SetAudioBufferDelay(int delay_ms) = 0; + + // Returns whether an external delay estimate has been reported via + // SetAudioBufferDelay. + virtual bool HasReceivedBufferDelay() = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc new file mode 100644 index 0000000000..d51e06a1ac --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_buffer_unittest.cc @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_delay_buffer.h" + +#include <memory> +#include <string> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +std::string ProduceDebugText(int sample_rate_hz) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.Release(); +} + +} // namespace + +// Verifies that the buffer overflow is correctly reported. +TEST(RenderDelayBuffer, BufferOverflow) { + const EchoCanceller3Config config; + for (auto num_channels : {1, 2, 8}) { + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<RenderDelayBuffer> delay_buffer( + RenderDelayBuffer::Create(config, rate, num_channels)); + Block block_to_insert(NumBandsForRate(rate), num_channels); + for (size_t k = 0; k < 10; ++k) { + EXPECT_EQ(RenderDelayBuffer::BufferingEvent::kNone, + delay_buffer->Insert(block_to_insert)); + } + bool overrun_occurred = false; + for (size_t k = 0; k < 1000; ++k) { + RenderDelayBuffer::BufferingEvent event = + delay_buffer->Insert(block_to_insert); + overrun_occurred = + overrun_occurred || + RenderDelayBuffer::BufferingEvent::kRenderOverrun == event; + } + + EXPECT_TRUE(overrun_occurred); + } + } +} + +// Verifies that the check for available block works. +TEST(RenderDelayBuffer, AvailableBlock) { + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + std::unique_ptr<RenderDelayBuffer> delay_buffer(RenderDelayBuffer::Create( + EchoCanceller3Config(), kSampleRateHz, kNumChannels)); + Block input_block(kNumBands, kNumChannels, 1.0f); + EXPECT_EQ(RenderDelayBuffer::BufferingEvent::kNone, + delay_buffer->Insert(input_block)); + delay_buffer->PrepareCaptureProcessing(); +} + +// Verifies the AlignFromDelay method. +TEST(RenderDelayBuffer, AlignFromDelay) { + EchoCanceller3Config config; + std::unique_ptr<RenderDelayBuffer> delay_buffer( + RenderDelayBuffer::Create(config, 16000, 1)); + ASSERT_TRUE(delay_buffer->Delay()); + delay_buffer->Reset(); + size_t initial_internal_delay = 0; + for (size_t delay = initial_internal_delay; + delay < initial_internal_delay + 20; ++delay) { + ASSERT_TRUE(delay_buffer->AlignFromDelay(delay)); + EXPECT_EQ(delay, delay_buffer->Delay()); + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for feasible delay. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(RenderDelayBufferDeathTest, DISABLED_WrongDelay) { + std::unique_ptr<RenderDelayBuffer> delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), 48000, 1)); + EXPECT_DEATH(delay_buffer->AlignFromDelay(21), ""); +} + +// Verifies the check for the number of bands in the inserted blocks. +TEST(RenderDelayBufferDeathTest, WrongNumberOfBands) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<RenderDelayBuffer> delay_buffer(RenderDelayBuffer::Create( + EchoCanceller3Config(), rate, num_channels)); + Block block_to_insert( + NumBandsForRate(rate < 48000 ? rate + 16000 : 16000), num_channels); + EXPECT_DEATH(delay_buffer->Insert(block_to_insert), ""); + } + } +} + +// Verifies the check for the number of channels in the inserted blocks. +TEST(RenderDelayBufferDeathTest, WrongNumberOfChannels) { + for (auto rate : {16000, 32000, 48000}) { + for (size_t num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<RenderDelayBuffer> delay_buffer(RenderDelayBuffer::Create( + EchoCanceller3Config(), rate, num_channels)); + Block block_to_insert(NumBandsForRate(rate), num_channels + 1); + EXPECT_DEATH(delay_buffer->Insert(block_to_insert), ""); + } + } +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.cc new file mode 100644 index 0000000000..465e77fb7c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.cc @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "modules/audio_processing/aec3/render_delay_controller.h" + +#include <stddef.h> + +#include <algorithm> +#include <atomic> +#include <memory> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/aec3/echo_path_delay_estimator.h" +#include "modules/audio_processing/aec3/render_delay_controller_metrics.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +class RenderDelayControllerImpl final : public RenderDelayController { + public: + RenderDelayControllerImpl(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_capture_channels); + + RenderDelayControllerImpl() = delete; + RenderDelayControllerImpl(const RenderDelayControllerImpl&) = delete; + RenderDelayControllerImpl& operator=(const RenderDelayControllerImpl&) = + delete; + + ~RenderDelayControllerImpl() override; + void Reset(bool reset_delay_confidence) override; + void LogRenderCall() override; + absl::optional<DelayEstimate> GetDelay( + const DownsampledRenderBuffer& render_buffer, + size_t render_delay_buffer_delay, + const Block& capture) override; + bool HasClockdrift() const override; + + private: + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const int hysteresis_limit_blocks_; + absl::optional<DelayEstimate> delay_; + EchoPathDelayEstimator delay_estimator_; + RenderDelayControllerMetrics metrics_; + absl::optional<DelayEstimate> delay_samples_; + size_t capture_call_counter_ = 0; + int delay_change_counter_ = 0; + DelayEstimate::Quality last_delay_estimate_quality_; +}; + +DelayEstimate ComputeBufferDelay( + const absl::optional<DelayEstimate>& current_delay, + int hysteresis_limit_blocks, + DelayEstimate estimated_delay) { + // Compute the buffer delay increase required to achieve the desired latency. + size_t new_delay_blocks = estimated_delay.delay >> kBlockSizeLog2; + // Add hysteresis. + if (current_delay) { + size_t current_delay_blocks = current_delay->delay; + if (new_delay_blocks > current_delay_blocks && + new_delay_blocks <= current_delay_blocks + hysteresis_limit_blocks) { + new_delay_blocks = current_delay_blocks; + } + } + DelayEstimate new_delay = estimated_delay; + new_delay.delay = new_delay_blocks; + return new_delay; +} + +std::atomic<int> RenderDelayControllerImpl::instance_count_(0); + +RenderDelayControllerImpl::RenderDelayControllerImpl( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_capture_channels) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + hysteresis_limit_blocks_( + static_cast<int>(config.delay.hysteresis_limit_blocks)), + delay_estimator_(data_dumper_.get(), config, num_capture_channels), + last_delay_estimate_quality_(DelayEstimate::Quality::kCoarse) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz)); + delay_estimator_.LogDelayEstimationProperties(sample_rate_hz, 0); +} + +RenderDelayControllerImpl::~RenderDelayControllerImpl() = default; + +void RenderDelayControllerImpl::Reset(bool reset_delay_confidence) { + delay_ = absl::nullopt; + delay_samples_ = absl::nullopt; + delay_estimator_.Reset(reset_delay_confidence); + delay_change_counter_ = 0; + if (reset_delay_confidence) { + last_delay_estimate_quality_ = DelayEstimate::Quality::kCoarse; + } +} + +void RenderDelayControllerImpl::LogRenderCall() {} + +absl::optional<DelayEstimate> RenderDelayControllerImpl::GetDelay( + const DownsampledRenderBuffer& render_buffer, + size_t render_delay_buffer_delay, + const Block& capture) { + ++capture_call_counter_; + + auto delay_samples = delay_estimator_.EstimateDelay(render_buffer, capture); + + if (delay_samples) { + if (!delay_samples_ || delay_samples->delay != delay_samples_->delay) { + delay_change_counter_ = 0; + } + if (delay_samples_) { + delay_samples_->blocks_since_last_change = + delay_samples_->delay == delay_samples->delay + ? delay_samples_->blocks_since_last_change + 1 + : 0; + delay_samples_->blocks_since_last_update = 0; + delay_samples_->delay = delay_samples->delay; + delay_samples_->quality = delay_samples->quality; + } else { + delay_samples_ = delay_samples; + } + } else { + if (delay_samples_) { + ++delay_samples_->blocks_since_last_change; + ++delay_samples_->blocks_since_last_update; + } + } + + if (delay_change_counter_ < 2 * kNumBlocksPerSecond) { + ++delay_change_counter_; + } + + if (delay_samples_) { + // Compute the render delay buffer delay. + const bool use_hysteresis = + last_delay_estimate_quality_ == DelayEstimate::Quality::kRefined && + delay_samples_->quality == DelayEstimate::Quality::kRefined; + delay_ = ComputeBufferDelay( + delay_, use_hysteresis ? hysteresis_limit_blocks_ : 0, *delay_samples_); + last_delay_estimate_quality_ = delay_samples_->quality; + } + + metrics_.Update( + delay_samples_ ? absl::optional<size_t>(delay_samples_->delay) + : absl::nullopt, + delay_ ? absl::optional<size_t>(delay_->delay) : absl::nullopt, + delay_estimator_.Clockdrift()); + + data_dumper_->DumpRaw("aec3_render_delay_controller_delay", + delay_samples ? delay_samples->delay : 0); + data_dumper_->DumpRaw("aec3_render_delay_controller_buffer_delay", + delay_ ? delay_->delay : 0); + + return delay_; +} + +bool RenderDelayControllerImpl::HasClockdrift() const { + return delay_estimator_.Clockdrift() != ClockdriftDetector::Level::kNone; +} + +} // namespace + +RenderDelayController* RenderDelayController::Create( + const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_capture_channels) { + return new RenderDelayControllerImpl(config, sample_rate_hz, + num_capture_channels); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.h b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.h new file mode 100644 index 0000000000..4a18a11e36 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_CONTROLLER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_CONTROLLER_H_ + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/delay_estimate.h" +#include "modules/audio_processing/aec3/downsampled_render_buffer.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +// Class for aligning the render and capture signal using a RenderDelayBuffer. +class RenderDelayController { + public: + static RenderDelayController* Create(const EchoCanceller3Config& config, + int sample_rate_hz, + size_t num_capture_channels); + virtual ~RenderDelayController() = default; + + // Resets the delay controller. If the delay confidence is reset, the reset + // behavior is as if the call is restarted. + virtual void Reset(bool reset_delay_confidence) = 0; + + // Logs a render call. + virtual void LogRenderCall() = 0; + + // Aligns the render buffer content with the capture signal. + virtual absl::optional<DelayEstimate> GetDelay( + const DownsampledRenderBuffer& render_buffer, + size_t render_delay_buffer_delay, + const Block& capture) = 0; + + // Returns true if clockdrift has been detected. + virtual bool HasClockdrift() const = 0; +}; +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_CONTROLLER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.cc new file mode 100644 index 0000000000..1e0a0f443e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_delay_controller_metrics.h" + +#include <algorithm> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/metrics.h" + +namespace webrtc { + +namespace { + +enum class DelayReliabilityCategory { + kNone, + kPoor, + kMedium, + kGood, + kExcellent, + kNumCategories +}; +enum class DelayChangesCategory { + kNone, + kFew, + kSeveral, + kMany, + kConstant, + kNumCategories +}; + +} // namespace + +RenderDelayControllerMetrics::RenderDelayControllerMetrics() = default; + +void RenderDelayControllerMetrics::Update( + absl::optional<size_t> delay_samples, + absl::optional<size_t> buffer_delay_blocks, + ClockdriftDetector::Level clockdrift) { + ++call_counter_; + + if (!initial_update) { + size_t delay_blocks; + if (delay_samples) { + ++reliable_delay_estimate_counter_; + // Add an offset by 1 (metric is halved before reporting) to reserve 0 for + // absent delay. + delay_blocks = (*delay_samples) / kBlockSize + 2; + } else { + delay_blocks = 0; + } + + if (delay_blocks != delay_blocks_) { + ++delay_change_counter_; + delay_blocks_ = delay_blocks; + } + + } else if (++initial_call_counter_ == 5 * kNumBlocksPerSecond) { + initial_update = false; + } + + if (call_counter_ == kMetricsReportingIntervalBlocks) { + int value_to_report = static_cast<int>(delay_blocks_); + // Divide by 2 to compress metric range. + value_to_report = std::min(124, value_to_report >> 1); + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.EchoCanceller.EchoPathDelay", + value_to_report, 0, 124, 125); + + // Divide by 2 to compress metric range. + // Offset by 1 to reserve 0 for absent delay. + value_to_report = buffer_delay_blocks ? (*buffer_delay_blocks + 2) >> 1 : 0; + value_to_report = std::min(124, value_to_report); + RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.EchoCanceller.BufferDelay", + value_to_report, 0, 124, 125); + + DelayReliabilityCategory delay_reliability; + if (reliable_delay_estimate_counter_ == 0) { + delay_reliability = DelayReliabilityCategory::kNone; + } else if (reliable_delay_estimate_counter_ > (call_counter_ >> 1)) { + delay_reliability = DelayReliabilityCategory::kExcellent; + } else if (reliable_delay_estimate_counter_ > 100) { + delay_reliability = DelayReliabilityCategory::kGood; + } else if (reliable_delay_estimate_counter_ > 10) { + delay_reliability = DelayReliabilityCategory::kMedium; + } else { + delay_reliability = DelayReliabilityCategory::kPoor; + } + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.Audio.EchoCanceller.ReliableDelayEstimates", + static_cast<int>(delay_reliability), + static_cast<int>(DelayReliabilityCategory::kNumCategories)); + + DelayChangesCategory delay_changes; + if (delay_change_counter_ == 0) { + delay_changes = DelayChangesCategory::kNone; + } else if (delay_change_counter_ > 10) { + delay_changes = DelayChangesCategory::kConstant; + } else if (delay_change_counter_ > 5) { + delay_changes = DelayChangesCategory::kMany; + } else if (delay_change_counter_ > 2) { + delay_changes = DelayChangesCategory::kSeveral; + } else { + delay_changes = DelayChangesCategory::kFew; + } + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.Audio.EchoCanceller.DelayChanges", + static_cast<int>(delay_changes), + static_cast<int>(DelayChangesCategory::kNumCategories)); + + RTC_HISTOGRAM_ENUMERATION( + "WebRTC.Audio.EchoCanceller.Clockdrift", static_cast<int>(clockdrift), + static_cast<int>(ClockdriftDetector::Level::kNumCategories)); + + call_counter_ = 0; + ResetMetrics(); + } +} + +void RenderDelayControllerMetrics::ResetMetrics() { + delay_change_counter_ = 0; + reliable_delay_estimate_counter_ = 0; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.h b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.h new file mode 100644 index 0000000000..b81833b43f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_CONTROLLER_METRICS_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_CONTROLLER_METRICS_H_ + +#include <stddef.h> + +#include "absl/types/optional.h" +#include "modules/audio_processing/aec3/clockdrift_detector.h" + +namespace webrtc { + +// Handles the reporting of metrics for the render delay controller. +class RenderDelayControllerMetrics { + public: + RenderDelayControllerMetrics(); + + RenderDelayControllerMetrics(const RenderDelayControllerMetrics&) = delete; + RenderDelayControllerMetrics& operator=(const RenderDelayControllerMetrics&) = + delete; + + // Updates the metric with new data. + void Update(absl::optional<size_t> delay_samples, + absl::optional<size_t> buffer_delay_blocks, + ClockdriftDetector::Level clockdrift); + + private: + // Resets the metrics. + void ResetMetrics(); + + size_t delay_blocks_ = 0; + int reliable_delay_estimate_counter_ = 0; + int delay_change_counter_ = 0; + int call_counter_ = 0; + int initial_call_counter_ = 0; + bool initial_update = true; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_RENDER_DELAY_CONTROLLER_METRICS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics_unittest.cc new file mode 100644 index 0000000000..cf9df6b297 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_metrics_unittest.cc @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_delay_controller_metrics.h" + +#include "absl/types/optional.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "system_wrappers/include/metrics.h" +#include "test/gtest.h" + +namespace webrtc { + +// Verify the general functionality of RenderDelayControllerMetrics. +TEST(RenderDelayControllerMetrics, NormalUsage) { + metrics::Reset(); + + RenderDelayControllerMetrics metrics; + + int expected_num_metric_reports = 0; + + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < kMetricsReportingIntervalBlocks - 1; ++k) { + metrics.Update(absl::nullopt, absl::nullopt, + ClockdriftDetector::Level::kNone); + } + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.EchoPathDelay"), + expected_num_metric_reports); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.BufferDelay"), + expected_num_metric_reports); + EXPECT_METRIC_EQ(metrics::NumSamples( + "WebRTC.Audio.EchoCanceller.ReliableDelayEstimates"), + expected_num_metric_reports); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.DelayChanges"), + expected_num_metric_reports); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.Clockdrift"), + expected_num_metric_reports); + + // We expect metric reports every kMetricsReportingIntervalBlocks blocks. + ++expected_num_metric_reports; + + metrics.Update(absl::nullopt, absl::nullopt, + ClockdriftDetector::Level::kNone); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.EchoPathDelay"), + expected_num_metric_reports); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.BufferDelay"), + expected_num_metric_reports); + EXPECT_METRIC_EQ(metrics::NumSamples( + "WebRTC.Audio.EchoCanceller.ReliableDelayEstimates"), + expected_num_metric_reports); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.DelayChanges"), + expected_num_metric_reports); + EXPECT_METRIC_EQ( + metrics::NumSamples("WebRTC.Audio.EchoCanceller.Clockdrift"), + expected_num_metric_reports); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc new file mode 100644 index 0000000000..e1a54fca9e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_delay_controller_unittest.cc @@ -0,0 +1,334 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_delay_controller.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/block_processor.h" +#include "modules/audio_processing/aec3/decimator.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +std::string ProduceDebugText(int sample_rate_hz) { + rtc::StringBuilder ss; + ss << "Sample rate: " << sample_rate_hz; + return ss.Release(); +} + +std::string ProduceDebugText(int sample_rate_hz, + size_t delay, + size_t num_render_channels, + size_t num_capture_channels) { + rtc::StringBuilder ss; + ss << ProduceDebugText(sample_rate_hz) << ", Delay: " << delay + << ", Num render channels: " << num_render_channels + << ", Num capture channels: " << num_capture_channels; + return ss.Release(); +} + +constexpr size_t kDownSamplingFactors[] = {2, 4, 8}; + +} // namespace + +// Verifies the output of GetDelay when there are no AnalyzeRender calls. +// TODO(bugs.webrtc.org/11161): Re-enable tests. +TEST(RenderDelayController, DISABLED_NoRenderSignal) { + for (size_t num_render_channels : {1, 2, 8}) { + Block block(/*num_bands=1*/ 1, /*num_channels=*/1); + EchoCanceller3Config config; + for (size_t num_matched_filters = 4; num_matched_filters <= 10; + num_matched_filters++) { + for (auto down_sampling_factor : kDownSamplingFactors) { + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = num_matched_filters; + for (auto rate : {16000, 32000, 48000}) { + SCOPED_TRACE(ProduceDebugText(rate)); + std::unique_ptr<RenderDelayBuffer> delay_buffer( + RenderDelayBuffer::Create(config, rate, num_render_channels)); + std::unique_ptr<RenderDelayController> delay_controller( + RenderDelayController::Create(config, rate, + /*num_capture_channels*/ 1)); + for (size_t k = 0; k < 100; ++k) { + auto delay = delay_controller->GetDelay( + delay_buffer->GetDownsampledRenderBuffer(), + delay_buffer->Delay(), block); + EXPECT_FALSE(delay->delay); + } + } + } + } + } +} + +// Verifies the basic API call sequence. +// TODO(bugs.webrtc.org/11161): Re-enable tests. +TEST(RenderDelayController, DISABLED_BasicApiCalls) { + for (size_t num_capture_channels : {1, 2, 4}) { + for (size_t num_render_channels : {1, 2, 8}) { + Block capture_block(/*num_bands=*/1, num_capture_channels); + absl::optional<DelayEstimate> delay_blocks; + for (size_t num_matched_filters = 4; num_matched_filters <= 10; + num_matched_filters++) { + for (auto down_sampling_factor : kDownSamplingFactors) { + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = num_matched_filters; + config.delay.capture_alignment_mixing.downmix = false; + config.delay.capture_alignment_mixing.adaptive_selection = false; + + for (auto rate : {16000, 32000, 48000}) { + Block render_block(NumBandsForRate(rate), num_render_channels); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, rate, num_render_channels)); + std::unique_ptr<RenderDelayController> delay_controller( + RenderDelayController::Create(EchoCanceller3Config(), rate, + num_capture_channels)); + for (size_t k = 0; k < 10; ++k) { + render_delay_buffer->Insert(render_block); + render_delay_buffer->PrepareCaptureProcessing(); + + delay_blocks = delay_controller->GetDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), + render_delay_buffer->Delay(), capture_block); + } + EXPECT_TRUE(delay_blocks); + EXPECT_FALSE(delay_blocks->delay); + } + } + } + } + } +} + +// Verifies that the RenderDelayController is able to align the signals for +// simple timeshifts between the signals. +// TODO(bugs.webrtc.org/11161): Re-enable tests. +TEST(RenderDelayController, DISABLED_Alignment) { + Random random_generator(42U); + for (size_t num_capture_channels : {1, 2, 4}) { + Block capture_block(/*num_bands=*/1, num_capture_channels); + for (size_t num_matched_filters = 4; num_matched_filters <= 10; + num_matched_filters++) { + for (auto down_sampling_factor : kDownSamplingFactors) { + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = num_matched_filters; + config.delay.capture_alignment_mixing.downmix = false; + config.delay.capture_alignment_mixing.adaptive_selection = false; + + for (size_t num_render_channels : {1, 2, 8}) { + for (auto rate : {16000, 32000, 48000}) { + Block render_block(NumBandsForRate(rate), num_render_channels); + + for (size_t delay_samples : {15, 50, 150, 200, 800, 4000}) { + absl::optional<DelayEstimate> delay_blocks; + SCOPED_TRACE(ProduceDebugText(rate, delay_samples, + num_render_channels, + num_capture_channels)); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, rate, num_render_channels)); + std::unique_ptr<RenderDelayController> delay_controller( + RenderDelayController::Create(config, rate, + num_capture_channels)); + DelayBuffer<float> signal_delay_buffer(delay_samples); + for (size_t k = 0; k < (400 + delay_samples / kBlockSize); ++k) { + for (int band = 0; band < render_block.NumBands(); ++band) { + for (int channel = 0; channel < render_block.NumChannels(); + ++channel) { + RandomizeSampleVector(&random_generator, + render_block.View(band, channel)); + } + } + signal_delay_buffer.Delay( + render_block.View(/*band=*/0, /*channel=*/0), + capture_block.View(/*band=*/0, /*channel=*/0)); + render_delay_buffer->Insert(render_block); + render_delay_buffer->PrepareCaptureProcessing(); + delay_blocks = delay_controller->GetDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), + render_delay_buffer->Delay(), capture_block); + } + ASSERT_TRUE(!!delay_blocks); + + constexpr int kDelayHeadroomBlocks = 1; + size_t expected_delay_blocks = + std::max(0, static_cast<int>(delay_samples / kBlockSize) - + kDelayHeadroomBlocks); + + EXPECT_EQ(expected_delay_blocks, delay_blocks->delay); + } + } + } + } + } + } +} + +// Verifies that the RenderDelayController is able to properly handle noncausal +// delays. +// TODO(bugs.webrtc.org/11161): Re-enable tests. +TEST(RenderDelayController, DISABLED_NonCausalAlignment) { + Random random_generator(42U); + for (size_t num_capture_channels : {1, 2, 4}) { + for (size_t num_render_channels : {1, 2, 8}) { + for (size_t num_matched_filters = 4; num_matched_filters <= 10; + num_matched_filters++) { + for (auto down_sampling_factor : kDownSamplingFactors) { + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = num_matched_filters; + config.delay.capture_alignment_mixing.downmix = false; + config.delay.capture_alignment_mixing.adaptive_selection = false; + for (auto rate : {16000, 32000, 48000}) { + Block render_block(NumBandsForRate(rate), num_render_channels); + Block capture_block(NumBandsForRate(rate), num_capture_channels); + + for (int delay_samples : {-15, -50, -150, -200}) { + absl::optional<DelayEstimate> delay_blocks; + SCOPED_TRACE(ProduceDebugText(rate, -delay_samples, + num_render_channels, + num_capture_channels)); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, rate, num_render_channels)); + std::unique_ptr<RenderDelayController> delay_controller( + RenderDelayController::Create(EchoCanceller3Config(), rate, + num_capture_channels)); + DelayBuffer<float> signal_delay_buffer(-delay_samples); + for (int k = 0; + k < (400 - delay_samples / static_cast<int>(kBlockSize)); + ++k) { + RandomizeSampleVector( + &random_generator, + capture_block.View(/*band=*/0, /*channel=*/0)); + signal_delay_buffer.Delay( + capture_block.View(/*band=*/0, /*channel=*/0), + render_block.View(/*band=*/0, /*channel=*/0)); + render_delay_buffer->Insert(render_block); + render_delay_buffer->PrepareCaptureProcessing(); + delay_blocks = delay_controller->GetDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), + render_delay_buffer->Delay(), capture_block); + } + + ASSERT_FALSE(delay_blocks); + } + } + } + } + } + } +} + +// Verifies that the RenderDelayController is able to align the signals for +// simple timeshifts between the signals when there is jitter in the API calls. +// TODO(bugs.webrtc.org/11161): Re-enable tests. +TEST(RenderDelayController, DISABLED_AlignmentWithJitter) { + Random random_generator(42U); + for (size_t num_capture_channels : {1, 2, 4}) { + for (size_t num_render_channels : {1, 2, 8}) { + Block capture_block( + /*num_bands=*/1, num_capture_channels); + for (size_t num_matched_filters = 4; num_matched_filters <= 10; + num_matched_filters++) { + for (auto down_sampling_factor : kDownSamplingFactors) { + EchoCanceller3Config config; + config.delay.down_sampling_factor = down_sampling_factor; + config.delay.num_filters = num_matched_filters; + config.delay.capture_alignment_mixing.downmix = false; + config.delay.capture_alignment_mixing.adaptive_selection = false; + + for (auto rate : {16000, 32000, 48000}) { + Block render_block(NumBandsForRate(rate), num_render_channels); + for (size_t delay_samples : {15, 50, 300, 800}) { + absl::optional<DelayEstimate> delay_blocks; + SCOPED_TRACE(ProduceDebugText(rate, delay_samples, + num_render_channels, + num_capture_channels)); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, rate, num_render_channels)); + std::unique_ptr<RenderDelayController> delay_controller( + RenderDelayController::Create(config, rate, + num_capture_channels)); + DelayBuffer<float> signal_delay_buffer(delay_samples); + constexpr size_t kMaxTestJitterBlocks = 26; + for (size_t j = 0; j < (1000 + delay_samples / kBlockSize) / + kMaxTestJitterBlocks + + 1; + ++j) { + std::vector<Block> capture_block_buffer; + for (size_t k = 0; k < (kMaxTestJitterBlocks - 1); ++k) { + RandomizeSampleVector( + &random_generator, + render_block.View(/*band=*/0, /*channel=*/0)); + signal_delay_buffer.Delay( + render_block.View(/*band=*/0, /*channel=*/0), + capture_block.View(/*band=*/0, /*channel=*/0)); + capture_block_buffer.push_back(capture_block); + render_delay_buffer->Insert(render_block); + } + for (size_t k = 0; k < (kMaxTestJitterBlocks - 1); ++k) { + render_delay_buffer->PrepareCaptureProcessing(); + delay_blocks = delay_controller->GetDelay( + render_delay_buffer->GetDownsampledRenderBuffer(), + render_delay_buffer->Delay(), capture_block_buffer[k]); + } + } + + constexpr int kDelayHeadroomBlocks = 1; + size_t expected_delay_blocks = + std::max(0, static_cast<int>(delay_samples / kBlockSize) - + kDelayHeadroomBlocks); + if (expected_delay_blocks < 2) { + expected_delay_blocks = 0; + } + + ASSERT_TRUE(delay_blocks); + EXPECT_EQ(expected_delay_blocks, delay_blocks->delay); + } + } + } + } + } + } +} + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for correct sample rate. +// TODO(peah): Re-enable the test once the issue with memory leaks during DEATH +// tests on test bots has been fixed. +TEST(RenderDelayControllerDeathTest, DISABLED_WrongSampleRate) { + for (auto rate : {-1, 0, 8001, 16001}) { + SCOPED_TRACE(ProduceDebugText(rate)); + EchoCanceller3Config config; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, rate, 1)); + EXPECT_DEATH( + std::unique_ptr<RenderDelayController>( + RenderDelayController::Create(EchoCanceller3Config(), rate, 1)), + ""); + } +} + +#endif + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.cc new file mode 100644 index 0000000000..bfbeb0ec2e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.cc @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_signal_analyzer.h" + +#include <math.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { +constexpr size_t kCounterThreshold = 5; + +// Identifies local bands with narrow characteristics. +void IdentifySmallNarrowBandRegions( + const RenderBuffer& render_buffer, + const absl::optional<size_t>& delay_partitions, + std::array<size_t, kFftLengthBy2 - 1>* narrow_band_counters) { + RTC_DCHECK(narrow_band_counters); + + if (!delay_partitions) { + narrow_band_counters->fill(0); + return; + } + + std::array<size_t, kFftLengthBy2 - 1> channel_counters; + channel_counters.fill(0); + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> X2 = + render_buffer.Spectrum(*delay_partitions); + for (size_t ch = 0; ch < X2.size(); ++ch) { + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (X2[ch][k] > 3 * std::max(X2[ch][k - 1], X2[ch][k + 1])) { + ++channel_counters[k - 1]; + } + } + } + for (size_t k = 1; k < kFftLengthBy2; ++k) { + (*narrow_band_counters)[k - 1] = + channel_counters[k - 1] > 0 ? (*narrow_band_counters)[k - 1] + 1 : 0; + } +} + +// Identifies whether the signal has a single strong narrow-band component. +void IdentifyStrongNarrowBandComponent(const RenderBuffer& render_buffer, + int strong_peak_freeze_duration, + absl::optional<int>* narrow_peak_band, + size_t* narrow_peak_counter) { + RTC_DCHECK(narrow_peak_band); + RTC_DCHECK(narrow_peak_counter); + if (*narrow_peak_band && + ++(*narrow_peak_counter) > + static_cast<size_t>(strong_peak_freeze_duration)) { + *narrow_peak_band = absl::nullopt; + } + + const Block& x_latest = render_buffer.GetBlock(0); + float max_peak_level = 0.f; + for (int channel = 0; channel < x_latest.NumChannels(); ++channel) { + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2_latest = + render_buffer.Spectrum(0)[channel]; + + // Identify the spectral peak. + const int peak_bin = + static_cast<int>(std::max_element(X2_latest.begin(), X2_latest.end()) - + X2_latest.begin()); + + // Compute the level around the peak. + float non_peak_power = 0.f; + for (int k = std::max(0, peak_bin - 14); k < peak_bin - 4; ++k) { + non_peak_power = std::max(X2_latest[k], non_peak_power); + } + for (int k = peak_bin + 5; + k < std::min(peak_bin + 15, static_cast<int>(kFftLengthBy2Plus1)); + ++k) { + non_peak_power = std::max(X2_latest[k], non_peak_power); + } + + // Assess the render signal strength. + auto result0 = std::minmax_element(x_latest.begin(/*band=*/0, channel), + x_latest.end(/*band=*/0, channel)); + float max_abs = std::max(fabs(*result0.first), fabs(*result0.second)); + + if (x_latest.NumBands() > 1) { + const auto result1 = + std::minmax_element(x_latest.begin(/*band=*/1, channel), + x_latest.end(/*band=*/1, channel)); + max_abs = + std::max(max_abs, static_cast<float>(std::max( + fabs(*result1.first), fabs(*result1.second)))); + } + + // Detect whether the spectral peak has as strong narrowband nature. + const float peak_level = X2_latest[peak_bin]; + if (peak_bin > 0 && max_abs > 100 && peak_level > 100 * non_peak_power) { + // Store the strongest peak across channels. + if (peak_level > max_peak_level) { + max_peak_level = peak_level; + *narrow_peak_band = peak_bin; + *narrow_peak_counter = 0; + } + } + } +} + +} // namespace + +RenderSignalAnalyzer::RenderSignalAnalyzer(const EchoCanceller3Config& config) + : strong_peak_freeze_duration_(config.filter.refined.length_blocks) { + narrow_band_counters_.fill(0); +} +RenderSignalAnalyzer::~RenderSignalAnalyzer() = default; + +void RenderSignalAnalyzer::Update( + const RenderBuffer& render_buffer, + const absl::optional<size_t>& delay_partitions) { + // Identify bands of narrow nature. + IdentifySmallNarrowBandRegions(render_buffer, delay_partitions, + &narrow_band_counters_); + + // Identify the presence of a strong narrow band. + IdentifyStrongNarrowBandComponent(render_buffer, strong_peak_freeze_duration_, + &narrow_peak_band_, &narrow_peak_counter_); +} + +void RenderSignalAnalyzer::MaskRegionsAroundNarrowBands( + std::array<float, kFftLengthBy2Plus1>* v) const { + RTC_DCHECK(v); + + // Set v to zero around narrow band signal regions. + if (narrow_band_counters_[0] > kCounterThreshold) { + (*v)[1] = (*v)[0] = 0.f; + } + for (size_t k = 2; k < kFftLengthBy2 - 1; ++k) { + if (narrow_band_counters_[k - 1] > kCounterThreshold) { + (*v)[k - 2] = (*v)[k - 1] = (*v)[k] = (*v)[k + 1] = (*v)[k + 2] = 0.f; + } + } + if (narrow_band_counters_[kFftLengthBy2 - 2] > kCounterThreshold) { + (*v)[kFftLengthBy2] = (*v)[kFftLengthBy2 - 1] = 0.f; + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.h b/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.h new file mode 100644 index 0000000000..2e4aaa4ba7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_RENDER_SIGNAL_ANALYZER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_RENDER_SIGNAL_ANALYZER_H_ + +#include <algorithm> +#include <array> +#include <cstddef> + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Provides functionality for analyzing the properties of the render signal. +class RenderSignalAnalyzer { + public: + explicit RenderSignalAnalyzer(const EchoCanceller3Config& config); + ~RenderSignalAnalyzer(); + + RenderSignalAnalyzer(const RenderSignalAnalyzer&) = delete; + RenderSignalAnalyzer& operator=(const RenderSignalAnalyzer&) = delete; + + // Updates the render signal analysis with the most recent render signal. + void Update(const RenderBuffer& render_buffer, + const absl::optional<size_t>& delay_partitions); + + // Returns true if the render signal is poorly exciting. + bool PoorSignalExcitation() const { + RTC_DCHECK_LT(2, narrow_band_counters_.size()); + return std::any_of(narrow_band_counters_.begin(), + narrow_band_counters_.end(), + [](size_t a) { return a > 10; }); + } + + // Zeros the array around regions with narrow bands signal characteristics. + void MaskRegionsAroundNarrowBands( + std::array<float, kFftLengthBy2Plus1>* v) const; + + absl::optional<int> NarrowPeakBand() const { return narrow_peak_band_; } + + private: + const int strong_peak_freeze_duration_; + std::array<size_t, kFftLengthBy2 - 1> narrow_band_counters_; + absl::optional<int> narrow_peak_band_; + size_t narrow_peak_counter_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_RENDER_SIGNAL_ANALYZER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc new file mode 100644 index 0000000000..16f6280cb6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/render_signal_analyzer_unittest.cc @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/render_signal_analyzer.h" + +#include <math.h> + +#include <array> +#include <cmath> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +constexpr float kPi = 3.141592f; + +void ProduceSinusoidInNoise(int sample_rate_hz, + size_t sinusoid_channel, + float sinusoidal_frequency_hz, + Random* random_generator, + size_t* sample_counter, + Block* x) { + // Fill x with low-amplitude noise. + for (int band = 0; band < x->NumBands(); ++band) { + for (int channel = 0; channel < x->NumChannels(); ++channel) { + RandomizeSampleVector(random_generator, x->View(band, channel), + /*amplitude=*/500.f); + } + } + // Produce a sinusoid of the specified frequency in the specified channel. + for (size_t k = *sample_counter, j = 0; k < (*sample_counter + kBlockSize); + ++k, ++j) { + x->View(/*band=*/0, sinusoid_channel)[j] += + 32000.f * + std::sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz); + } + *sample_counter = *sample_counter + kBlockSize; +} + +void RunNarrowBandDetectionTest(size_t num_channels) { + RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); + Random random_generator(42U); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + Block x(kNumBands, num_channels); + std::array<float, kBlockSize> x_old; + Aec3Fft fft; + EchoCanceller3Config config; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_channels)); + + std::array<float, kFftLengthBy2Plus1> mask; + x_old.fill(0.f); + constexpr int kSinusFrequencyBin = 32; + + auto generate_sinusoid_test = [&](bool known_delay) { + size_t sample_counter = 0; + for (size_t k = 0; k < 100; ++k) { + ProduceSinusoidInNoise(16000, num_channels - 1, + 16000 / 2 * kSinusFrequencyBin / kFftLengthBy2, + &random_generator, &sample_counter, &x); + + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + + analyzer.Update(*render_delay_buffer->GetRenderBuffer(), + known_delay ? absl::optional<size_t>(0) : absl::nullopt); + } + }; + + generate_sinusoid_test(true); + mask.fill(1.f); + analyzer.MaskRegionsAroundNarrowBands(&mask); + for (int k = 0; k < static_cast<int>(mask.size()); ++k) { + EXPECT_EQ(abs(k - kSinusFrequencyBin) <= 2 ? 0.f : 1.f, mask[k]); + } + EXPECT_TRUE(analyzer.PoorSignalExcitation()); + EXPECT_TRUE(static_cast<bool>(analyzer.NarrowPeakBand())); + EXPECT_EQ(*analyzer.NarrowPeakBand(), 32); + + // Verify that no bands are detected as narrow when the delay is unknown. + generate_sinusoid_test(false); + mask.fill(1.f); + analyzer.MaskRegionsAroundNarrowBands(&mask); + std::for_each(mask.begin(), mask.end(), [](float a) { EXPECT_EQ(1.f, a); }); + EXPECT_FALSE(analyzer.PoorSignalExcitation()); +} + +std::string ProduceDebugText(size_t num_channels) { + rtc::StringBuilder ss; + ss << "number of channels: " << num_channels; + return ss.Release(); +} +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) +// Verifies that the check for non-null output parameter works. +TEST(RenderSignalAnalyzerDeathTest, NullMaskOutput) { + RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); + EXPECT_DEATH(analyzer.MaskRegionsAroundNarrowBands(nullptr), ""); +} + +#endif + +// Verify that no narrow bands are detected in a Gaussian noise signal. +TEST(RenderSignalAnalyzer, NoFalseDetectionOfNarrowBands) { + for (auto num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(num_channels)); + RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); + Random random_generator(42U); + Block x(3, num_channels); + std::array<float, kBlockSize> x_old; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), 48000, num_channels)); + std::array<float, kFftLengthBy2Plus1> mask; + x_old.fill(0.f); + + for (int k = 0; k < 100; ++k) { + for (int band = 0; band < x.NumBands(); ++band) { + for (int channel = 0; channel < x.NumChannels(); ++channel) { + RandomizeSampleVector(&random_generator, x.View(band, channel)); + } + } + + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + + analyzer.Update(*render_delay_buffer->GetRenderBuffer(), + absl::optional<size_t>(0)); + } + + mask.fill(1.f); + analyzer.MaskRegionsAroundNarrowBands(&mask); + EXPECT_TRUE(std::all_of(mask.begin(), mask.end(), + [](float a) { return a == 1.f; })); + EXPECT_FALSE(analyzer.PoorSignalExcitation()); + EXPECT_FALSE(static_cast<bool>(analyzer.NarrowPeakBand())); + } +} + +// Verify that a sinusoid signal is detected as narrow bands. +TEST(RenderSignalAnalyzer, NarrowBandDetection) { + for (auto num_channels : {1, 2, 8}) { + SCOPED_TRACE(ProduceDebugText(num_channels)); + RunNarrowBandDetectionTest(num_channels); + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.cc new file mode 100644 index 0000000000..640a3e3cb9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.cc @@ -0,0 +1,379 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/residual_echo_estimator.h" + +#include <stddef.h> + +#include <algorithm> +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/reverb_model.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { +namespace { + +constexpr float kDefaultTransparentModeGain = 0.01f; + +float GetTransparentModeGain() { + return kDefaultTransparentModeGain; +} + +float GetEarlyReflectionsDefaultModeGain( + const EchoCanceller3Config::EpStrength& config) { + if (field_trial::IsEnabled("WebRTC-Aec3UseLowEarlyReflectionsDefaultGain")) { + return 0.1f; + } + return config.default_gain; +} + +float GetLateReflectionsDefaultModeGain( + const EchoCanceller3Config::EpStrength& config) { + if (field_trial::IsEnabled("WebRTC-Aec3UseLowLateReflectionsDefaultGain")) { + return 0.1f; + } + return config.default_gain; +} + +bool UseErleOnsetCompensationInDominantNearend( + const EchoCanceller3Config::EpStrength& config) { + return config.erle_onset_compensation_in_dominant_nearend || + field_trial::IsEnabled( + "WebRTC-Aec3UseErleOnsetCompensationInDominantNearend"); +} + +// Computes the indexes that will be used for computing spectral power over +// the blocks surrounding the delay. +void GetRenderIndexesToAnalyze( + const SpectrumBuffer& spectrum_buffer, + const EchoCanceller3Config::EchoModel& echo_model, + int filter_delay_blocks, + int* idx_start, + int* idx_stop) { + RTC_DCHECK(idx_start); + RTC_DCHECK(idx_stop); + size_t window_start; + size_t window_end; + window_start = + std::max(0, filter_delay_blocks - + static_cast<int>(echo_model.render_pre_window_size)); + window_end = filter_delay_blocks + + static_cast<int>(echo_model.render_post_window_size); + *idx_start = spectrum_buffer.OffsetIndex(spectrum_buffer.read, window_start); + *idx_stop = spectrum_buffer.OffsetIndex(spectrum_buffer.read, window_end + 1); +} + +// Estimates the residual echo power based on the echo return loss enhancement +// (ERLE) and the linear power estimate. +void LinearEstimate( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> erle, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) { + RTC_DCHECK_EQ(S2_linear.size(), erle.size()); + RTC_DCHECK_EQ(S2_linear.size(), R2.size()); + + const size_t num_capture_channels = R2.size(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + RTC_DCHECK_LT(0.f, erle[ch][k]); + R2[ch][k] = S2_linear[ch][k] / erle[ch][k]; + } + } +} + +// Estimates the residual echo power based on the estimate of the echo path +// gain. +void NonLinearEstimate( + float echo_path_gain, + const std::array<float, kFftLengthBy2Plus1>& X2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) { + const size_t num_capture_channels = R2.size(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + R2[ch][k] = X2[k] * echo_path_gain; + } + } +} + +// Applies a soft noise gate to the echo generating power. +void ApplyNoiseGate(const EchoCanceller3Config::EchoModel& config, + rtc::ArrayView<float, kFftLengthBy2Plus1> X2) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (config.noise_gate_power > X2[k]) { + X2[k] = std::max(0.f, X2[k] - config.noise_gate_slope * + (config.noise_gate_power - X2[k])); + } + } +} + +// Estimates the echo generating signal power as gated maximal power over a +// time window. +void EchoGeneratingPower(size_t num_render_channels, + const SpectrumBuffer& spectrum_buffer, + const EchoCanceller3Config::EchoModel& echo_model, + int filter_delay_blocks, + rtc::ArrayView<float, kFftLengthBy2Plus1> X2) { + int idx_stop; + int idx_start; + GetRenderIndexesToAnalyze(spectrum_buffer, echo_model, filter_delay_blocks, + &idx_start, &idx_stop); + + std::fill(X2.begin(), X2.end(), 0.f); + if (num_render_channels == 1) { + for (int k = idx_start; k != idx_stop; k = spectrum_buffer.IncIndex(k)) { + for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) { + X2[j] = std::max(X2[j], spectrum_buffer.buffer[k][/*channel=*/0][j]); + } + } + } else { + for (int k = idx_start; k != idx_stop; k = spectrum_buffer.IncIndex(k)) { + std::array<float, kFftLengthBy2Plus1> render_power; + render_power.fill(0.f); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const auto& channel_power = spectrum_buffer.buffer[k][ch]; + for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) { + render_power[j] += channel_power[j]; + } + } + for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) { + X2[j] = std::max(X2[j], render_power[j]); + } + } + } +} + +} // namespace + +ResidualEchoEstimator::ResidualEchoEstimator(const EchoCanceller3Config& config, + size_t num_render_channels) + : config_(config), + num_render_channels_(num_render_channels), + early_reflections_transparent_mode_gain_(GetTransparentModeGain()), + late_reflections_transparent_mode_gain_(GetTransparentModeGain()), + early_reflections_general_gain_( + GetEarlyReflectionsDefaultModeGain(config_.ep_strength)), + late_reflections_general_gain_( + GetLateReflectionsDefaultModeGain(config_.ep_strength)), + erle_onset_compensation_in_dominant_nearend_( + UseErleOnsetCompensationInDominantNearend(config_.ep_strength)) { + Reset(); +} + +ResidualEchoEstimator::~ResidualEchoEstimator() = default; + +void ResidualEchoEstimator::Estimate( + const AecState& aec_state, + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + bool dominant_nearend, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2_unbounded) { + RTC_DCHECK_EQ(R2.size(), Y2.size()); + RTC_DCHECK_EQ(R2.size(), S2_linear.size()); + + const size_t num_capture_channels = R2.size(); + + // Estimate the power of the stationary noise in the render signal. + UpdateRenderNoisePower(render_buffer); + + // Estimate the residual echo power. + if (aec_state.UsableLinearEstimate()) { + // When there is saturated echo, assume the same spectral content as is + // present in the microphone signal. + if (aec_state.SaturatedEcho()) { + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + std::copy(Y2[ch].begin(), Y2[ch].end(), R2[ch].begin()); + std::copy(Y2[ch].begin(), Y2[ch].end(), R2_unbounded[ch].begin()); + } + } else { + const bool onset_compensated = + erle_onset_compensation_in_dominant_nearend_ || !dominant_nearend; + LinearEstimate(S2_linear, aec_state.Erle(onset_compensated), R2); + LinearEstimate(S2_linear, aec_state.ErleUnbounded(), R2_unbounded); + } + + UpdateReverb(ReverbType::kLinear, aec_state, render_buffer, + dominant_nearend); + AddReverb(R2); + AddReverb(R2_unbounded); + } else { + const float echo_path_gain = + GetEchoPathGain(aec_state, /*gain_for_early_reflections=*/true); + + // When there is saturated echo, assume the same spectral content as is + // present in the microphone signal. + if (aec_state.SaturatedEcho()) { + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + std::copy(Y2[ch].begin(), Y2[ch].end(), R2[ch].begin()); + std::copy(Y2[ch].begin(), Y2[ch].end(), R2_unbounded[ch].begin()); + } + } else { + // Estimate the echo generating signal power. + std::array<float, kFftLengthBy2Plus1> X2; + EchoGeneratingPower(num_render_channels_, + render_buffer.GetSpectrumBuffer(), config_.echo_model, + aec_state.MinDirectPathFilterDelay(), X2); + if (!aec_state.UseStationarityProperties()) { + ApplyNoiseGate(config_.echo_model, X2); + } + + // Subtract the stationary noise power to avoid stationary noise causing + // excessive echo suppression. + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + X2[k] -= config_.echo_model.stationary_gate_slope * X2_noise_floor_[k]; + X2[k] = std::max(0.f, X2[k]); + } + + NonLinearEstimate(echo_path_gain, X2, R2); + NonLinearEstimate(echo_path_gain, X2, R2_unbounded); + } + + if (config_.echo_model.model_reverb_in_nonlinear_mode && + !aec_state.TransparentModeActive()) { + UpdateReverb(ReverbType::kNonLinear, aec_state, render_buffer, + dominant_nearend); + AddReverb(R2); + AddReverb(R2_unbounded); + } + } + + if (aec_state.UseStationarityProperties()) { + // Scale the echo according to echo audibility. + std::array<float, kFftLengthBy2Plus1> residual_scaling; + aec_state.GetResidualEchoScaling(residual_scaling); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + R2[ch][k] *= residual_scaling[k]; + R2_unbounded[ch][k] *= residual_scaling[k]; + } + } + } +} + +void ResidualEchoEstimator::Reset() { + echo_reverb_.Reset(); + X2_noise_floor_counter_.fill(config_.echo_model.noise_floor_hold); + X2_noise_floor_.fill(config_.echo_model.min_noise_floor_power); +} + +void ResidualEchoEstimator::UpdateRenderNoisePower( + const RenderBuffer& render_buffer) { + std::array<float, kFftLengthBy2Plus1> render_power_data; + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> X2 = + render_buffer.Spectrum(0); + rtc::ArrayView<const float, kFftLengthBy2Plus1> render_power = + X2[/*channel=*/0]; + if (num_render_channels_ > 1) { + render_power_data.fill(0.f); + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + const auto& channel_power = X2[ch]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + render_power_data[k] += channel_power[k]; + } + } + render_power = render_power_data; + } + + // Estimate the stationary noise power in a minimum statistics manner. + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + // Decrease rapidly. + if (render_power[k] < X2_noise_floor_[k]) { + X2_noise_floor_[k] = render_power[k]; + X2_noise_floor_counter_[k] = 0; + } else { + // Increase in a delayed, leaky manner. + if (X2_noise_floor_counter_[k] >= + static_cast<int>(config_.echo_model.noise_floor_hold)) { + X2_noise_floor_[k] = std::max(X2_noise_floor_[k] * 1.1f, + config_.echo_model.min_noise_floor_power); + } else { + ++X2_noise_floor_counter_[k]; + } + } + } +} + +// Updates the reverb estimation. +void ResidualEchoEstimator::UpdateReverb(ReverbType reverb_type, + const AecState& aec_state, + const RenderBuffer& render_buffer, + bool dominant_nearend) { + // Choose reverb partition based on what type of echo power model is used. + const size_t first_reverb_partition = + reverb_type == ReverbType::kLinear + ? aec_state.FilterLengthBlocks() + 1 + : aec_state.MinDirectPathFilterDelay() + 1; + + // Compute render power for the reverb. + std::array<float, kFftLengthBy2Plus1> render_power_data; + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> X2 = + render_buffer.Spectrum(first_reverb_partition); + rtc::ArrayView<const float, kFftLengthBy2Plus1> render_power = + X2[/*channel=*/0]; + if (num_render_channels_ > 1) { + render_power_data.fill(0.f); + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + const auto& channel_power = X2[ch]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + render_power_data[k] += channel_power[k]; + } + } + render_power = render_power_data; + } + + // Update the reverb estimate. + float reverb_decay = aec_state.ReverbDecay(/*mild=*/dominant_nearend); + if (reverb_type == ReverbType::kLinear) { + echo_reverb_.UpdateReverb( + render_power, aec_state.GetReverbFrequencyResponse(), reverb_decay); + } else { + const float echo_path_gain = + GetEchoPathGain(aec_state, /*gain_for_early_reflections=*/false); + echo_reverb_.UpdateReverbNoFreqShaping(render_power, echo_path_gain, + reverb_decay); + } +} +// Adds the estimated power of the reverb to the residual echo power. +void ResidualEchoEstimator::AddReverb( + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) const { + const size_t num_capture_channels = R2.size(); + + // Add the reverb power. + rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb_power = + echo_reverb_.reverb(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + R2[ch][k] += reverb_power[k]; + } + } +} + +// Chooses the echo path gain to use. +float ResidualEchoEstimator::GetEchoPathGain( + const AecState& aec_state, + bool gain_for_early_reflections) const { + float gain_amplitude; + if (aec_state.TransparentModeActive()) { + gain_amplitude = gain_for_early_reflections + ? early_reflections_transparent_mode_gain_ + : late_reflections_transparent_mode_gain_; + } else { + gain_amplitude = gain_for_early_reflections + ? early_reflections_general_gain_ + : late_reflections_general_gain_; + } + return gain_amplitude * gain_amplitude; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.h new file mode 100644 index 0000000000..c468764002 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_RESIDUAL_ECHO_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_RESIDUAL_ECHO_ESTIMATOR_H_ + +#include <array> +#include <memory> + +#include "absl/types/optional.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/reverb_model.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +class ResidualEchoEstimator { + public: + ResidualEchoEstimator(const EchoCanceller3Config& config, + size_t num_render_channels); + ~ResidualEchoEstimator(); + + ResidualEchoEstimator(const ResidualEchoEstimator&) = delete; + ResidualEchoEstimator& operator=(const ResidualEchoEstimator&) = delete; + + void Estimate( + const AecState& aec_state, + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2_linear, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + bool dominant_nearend, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2, + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2_unbounded); + + private: + enum class ReverbType { kLinear, kNonLinear }; + + // Resets the state. + void Reset(); + + // Updates estimate for the power of the stationary noise component in the + // render signal. + void UpdateRenderNoisePower(const RenderBuffer& render_buffer); + + // Updates the reverb estimation. + void UpdateReverb(ReverbType reverb_type, + const AecState& aec_state, + const RenderBuffer& render_buffer, + bool dominant_nearend); + + // Adds the estimated unmodelled echo power to the residual echo power + // estimate. + void AddReverb( + rtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2) const; + + // Gets the echo path gain to apply. + float GetEchoPathGain(const AecState& aec_state, + bool gain_for_early_reflections) const; + + const EchoCanceller3Config config_; + const size_t num_render_channels_; + const float early_reflections_transparent_mode_gain_; + const float late_reflections_transparent_mode_gain_; + const float early_reflections_general_gain_; + const float late_reflections_general_gain_; + const bool erle_onset_compensation_in_dominant_nearend_; + std::array<float, kFftLengthBy2Plus1> X2_noise_floor_; + std::array<int, kFftLengthBy2Plus1> X2_noise_floor_counter_; + ReverbModel echo_reverb_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_RESIDUAL_ECHO_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc new file mode 100644 index 0000000000..9a7bf0a89c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/residual_echo_estimator_unittest.cc @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/residual_echo_estimator.h" + +#include <numeric> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { +constexpr int kSampleRateHz = 48000; +constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); +constexpr float kEpsilon = 1e-4f; +} // namespace + +class ResidualEchoEstimatorTest { + public: + ResidualEchoEstimatorTest(size_t num_render_channels, + size_t num_capture_channels, + const EchoCanceller3Config& config) + : num_render_channels_(num_render_channels), + num_capture_channels_(num_capture_channels), + config_(config), + estimator_(config_, num_render_channels_), + aec_state_(config_, num_capture_channels_), + render_delay_buffer_(RenderDelayBuffer::Create(config_, + kSampleRateHz, + num_render_channels_)), + E2_refined_(num_capture_channels_), + S2_linear_(num_capture_channels_), + Y2_(num_capture_channels_), + R2_(num_capture_channels_), + R2_unbounded_(num_capture_channels_), + x_(kNumBands, num_render_channels_), + H2_(num_capture_channels_, + std::vector<std::array<float, kFftLengthBy2Plus1>>(10)), + h_(num_capture_channels_, + std::vector<float>( + GetTimeDomainLength(config_.filter.refined.length_blocks), + 0.0f)), + random_generator_(42U), + output_(num_capture_channels_) { + for (auto& H2_ch : H2_) { + for (auto& H2_k : H2_ch) { + H2_k.fill(0.01f); + } + H2_ch[2].fill(10.f); + H2_ch[2][0] = 0.1f; + } + + for (auto& subtractor_output : output_) { + subtractor_output.Reset(); + subtractor_output.s_refined.fill(100.f); + } + y_.fill(0.f); + + constexpr float kLevel = 10.f; + for (auto& E2_refined_ch : E2_refined_) { + E2_refined_ch.fill(kLevel); + } + S2_linear_[0].fill(kLevel); + for (auto& Y2_ch : Y2_) { + Y2_ch.fill(kLevel); + } + } + + void RunOneFrame(bool dominant_nearend) { + RandomizeSampleVector(&random_generator_, + x_.View(/*band=*/0, /*channel=*/0)); + render_delay_buffer_->Insert(x_); + if (first_frame_) { + render_delay_buffer_->Reset(); + first_frame_ = false; + } + render_delay_buffer_->PrepareCaptureProcessing(); + + aec_state_.Update(delay_estimate_, H2_, h_, + *render_delay_buffer_->GetRenderBuffer(), E2_refined_, + Y2_, output_); + + estimator_.Estimate(aec_state_, *render_delay_buffer_->GetRenderBuffer(), + S2_linear_, Y2_, dominant_nearend, R2_, R2_unbounded_); + } + + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> R2() const { + return R2_; + } + + private: + const size_t num_render_channels_; + const size_t num_capture_channels_; + const EchoCanceller3Config& config_; + ResidualEchoEstimator estimator_; + AecState aec_state_; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer_; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined_; + std::vector<std::array<float, kFftLengthBy2Plus1>> S2_linear_; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_; + std::vector<std::array<float, kFftLengthBy2Plus1>> R2_; + std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded_; + Block x_; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_; + std::vector<std::vector<float>> h_; + Random random_generator_; + std::vector<SubtractorOutput> output_; + std::array<float, kBlockSize> y_; + absl::optional<DelayEstimate> delay_estimate_; + bool first_frame_ = true; +}; + +class ResidualEchoEstimatorMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + ResidualEchoEstimatorMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 4), + ::testing::Values(1, 2, 4))); + +TEST_P(ResidualEchoEstimatorMultiChannel, BasicTest) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + + EchoCanceller3Config config; + ResidualEchoEstimatorTest residual_echo_estimator_test( + num_render_channels, num_capture_channels, config); + for (int k = 0; k < 1993; ++k) { + residual_echo_estimator_test.RunOneFrame(/*dominant_nearend=*/false); + } +} + +TEST(ResidualEchoEstimatorMultiChannel, ReverbTest) { + const size_t num_render_channels = 1; + const size_t num_capture_channels = 1; + const size_t nFrames = 100; + + EchoCanceller3Config reference_config; + reference_config.ep_strength.default_len = 0.95f; + reference_config.ep_strength.nearend_len = 0.95f; + EchoCanceller3Config config_use_nearend_len = reference_config; + config_use_nearend_len.ep_strength.default_len = 0.95f; + config_use_nearend_len.ep_strength.nearend_len = 0.83f; + + ResidualEchoEstimatorTest reference_residual_echo_estimator_test( + num_render_channels, num_capture_channels, reference_config); + ResidualEchoEstimatorTest use_nearend_len_residual_echo_estimator_test( + num_render_channels, num_capture_channels, config_use_nearend_len); + + std::vector<float> acum_energy_reference_R2(num_capture_channels, 0.0f); + std::vector<float> acum_energy_R2(num_capture_channels, 0.0f); + for (size_t frame = 0; frame < nFrames; ++frame) { + bool dominant_nearend = frame <= nFrames / 2 ? false : true; + reference_residual_echo_estimator_test.RunOneFrame(dominant_nearend); + use_nearend_len_residual_echo_estimator_test.RunOneFrame(dominant_nearend); + const auto& reference_R2 = reference_residual_echo_estimator_test.R2(); + const auto& R2 = use_nearend_len_residual_echo_estimator_test.R2(); + ASSERT_EQ(reference_R2.size(), R2.size()); + for (size_t ch = 0; ch < reference_R2.size(); ++ch) { + float energy_reference_R2 = std::accumulate( + reference_R2[ch].cbegin(), reference_R2[ch].cend(), 0.0f); + float energy_R2 = std::accumulate(R2[ch].cbegin(), R2[ch].cend(), 0.0f); + if (dominant_nearend) { + EXPECT_GE(energy_reference_R2, energy_R2); + } else { + EXPECT_NEAR(energy_reference_R2, energy_R2, kEpsilon); + } + acum_energy_reference_R2[ch] += energy_reference_R2; + acum_energy_R2[ch] += energy_R2; + } + if (frame == nFrames / 2 || frame == nFrames - 1) { + for (size_t ch = 0; ch < acum_energy_reference_R2.size(); ch++) { + if (dominant_nearend) { + EXPECT_GT(acum_energy_reference_R2[ch], acum_energy_R2[ch]); + } else { + EXPECT_NEAR(acum_energy_reference_R2[ch], acum_energy_R2[ch], + kEpsilon); + } + } + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.cc new file mode 100644 index 0000000000..2daf376911 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.cc @@ -0,0 +1,410 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_decay_estimator.h" + +#include <stddef.h> + +#include <algorithm> +#include <cmath> +#include <numeric> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +constexpr int kEarlyReverbMinSizeBlocks = 3; +constexpr int kBlocksPerSection = 6; +// Linear regression approach assumes symmetric index around 0. +constexpr float kEarlyReverbFirstPointAtLinearRegressors = + -0.5f * kBlocksPerSection * kFftLengthBy2 + 0.5f; + +// Averages the values in a block of size kFftLengthBy2; +float BlockAverage(rtc::ArrayView<const float> v, size_t block_index) { + constexpr float kOneByFftLengthBy2 = 1.f / kFftLengthBy2; + const int i = block_index * kFftLengthBy2; + RTC_DCHECK_GE(v.size(), i + kFftLengthBy2); + const float sum = + std::accumulate(v.begin() + i, v.begin() + i + kFftLengthBy2, 0.f); + return sum * kOneByFftLengthBy2; +} + +// Analyzes the gain in a block. +void AnalyzeBlockGain(const std::array<float, kFftLengthBy2>& h2, + float floor_gain, + float* previous_gain, + bool* block_adapting, + bool* decaying_gain) { + float gain = std::max(BlockAverage(h2, 0), 1e-32f); + *block_adapting = + *previous_gain > 1.1f * gain || *previous_gain < 0.9f * gain; + *decaying_gain = gain > floor_gain; + *previous_gain = gain; +} + +// Arithmetic sum of $2 \sum_{i=0.5}^{(N-1)/2}i^2$ calculated directly. +constexpr float SymmetricArithmetricSum(int N) { + return N * (N * N - 1.0f) * (1.f / 12.f); +} + +// Returns the peak energy of an impulse response. +float BlockEnergyPeak(rtc::ArrayView<const float> h, int peak_block) { + RTC_DCHECK_LE((peak_block + 1) * kFftLengthBy2, h.size()); + RTC_DCHECK_GE(peak_block, 0); + float peak_value = + *std::max_element(h.begin() + peak_block * kFftLengthBy2, + h.begin() + (peak_block + 1) * kFftLengthBy2, + [](float a, float b) { return a * a < b * b; }); + return peak_value * peak_value; +} + +// Returns the average energy of an impulse response block. +float BlockEnergyAverage(rtc::ArrayView<const float> h, int block_index) { + RTC_DCHECK_LE((block_index + 1) * kFftLengthBy2, h.size()); + RTC_DCHECK_GE(block_index, 0); + constexpr float kOneByFftLengthBy2 = 1.f / kFftLengthBy2; + const auto sum_of_squares = [](float a, float b) { return a + b * b; }; + return std::accumulate(h.begin() + block_index * kFftLengthBy2, + h.begin() + (block_index + 1) * kFftLengthBy2, 0.f, + sum_of_squares) * + kOneByFftLengthBy2; +} + +} // namespace + +ReverbDecayEstimator::ReverbDecayEstimator(const EchoCanceller3Config& config) + : filter_length_blocks_(config.filter.refined.length_blocks), + filter_length_coefficients_(GetTimeDomainLength(filter_length_blocks_)), + use_adaptive_echo_decay_(config.ep_strength.default_len < 0.f), + early_reverb_estimator_(config.filter.refined.length_blocks - + kEarlyReverbMinSizeBlocks), + late_reverb_start_(kEarlyReverbMinSizeBlocks), + late_reverb_end_(kEarlyReverbMinSizeBlocks), + previous_gains_(config.filter.refined.length_blocks, 0.f), + decay_(std::fabs(config.ep_strength.default_len)), + mild_decay_(std::fabs(config.ep_strength.nearend_len)) { + RTC_DCHECK_GT(config.filter.refined.length_blocks, + static_cast<size_t>(kEarlyReverbMinSizeBlocks)); +} + +ReverbDecayEstimator::~ReverbDecayEstimator() = default; + +void ReverbDecayEstimator::Update(rtc::ArrayView<const float> filter, + const absl::optional<float>& filter_quality, + int filter_delay_blocks, + bool usable_linear_filter, + bool stationary_signal) { + const int filter_size = static_cast<int>(filter.size()); + + if (stationary_signal) { + return; + } + + bool estimation_feasible = + filter_delay_blocks <= + filter_length_blocks_ - kEarlyReverbMinSizeBlocks - 1; + estimation_feasible = + estimation_feasible && filter_size == filter_length_coefficients_; + estimation_feasible = estimation_feasible && filter_delay_blocks > 0; + estimation_feasible = estimation_feasible && usable_linear_filter; + + if (!estimation_feasible) { + ResetDecayEstimation(); + return; + } + + if (!use_adaptive_echo_decay_) { + return; + } + + const float new_smoothing = filter_quality ? *filter_quality * 0.2f : 0.f; + smoothing_constant_ = std::max(new_smoothing, smoothing_constant_); + if (smoothing_constant_ == 0.f) { + return; + } + + if (block_to_analyze_ < filter_length_blocks_) { + // Analyze the filter and accumulate data for reverb estimation. + AnalyzeFilter(filter); + ++block_to_analyze_; + } else { + // When the filter is fully analyzed, estimate the reverb decay and reset + // the block_to_analyze_ counter. + EstimateDecay(filter, filter_delay_blocks); + } +} + +void ReverbDecayEstimator::ResetDecayEstimation() { + early_reverb_estimator_.Reset(); + late_reverb_decay_estimator_.Reset(0); + block_to_analyze_ = 0; + estimation_region_candidate_size_ = 0; + estimation_region_identified_ = false; + smoothing_constant_ = 0.f; + late_reverb_start_ = 0; + late_reverb_end_ = 0; +} + +void ReverbDecayEstimator::EstimateDecay(rtc::ArrayView<const float> filter, + int peak_block) { + auto& h = filter; + RTC_DCHECK_EQ(0, h.size() % kFftLengthBy2); + + // Reset the block analysis counter. + block_to_analyze_ = + std::min(peak_block + kEarlyReverbMinSizeBlocks, filter_length_blocks_); + + // To estimate the reverb decay, the energy of the first filter section must + // be substantially larger than the last. Also, the first filter section + // energy must not deviate too much from the max peak. + const float first_reverb_gain = BlockEnergyAverage(h, block_to_analyze_); + const size_t h_size_blocks = h.size() >> kFftLengthBy2Log2; + tail_gain_ = BlockEnergyAverage(h, h_size_blocks - 1); + float peak_energy = BlockEnergyPeak(h, peak_block); + const bool sufficient_reverb_decay = first_reverb_gain > 4.f * tail_gain_; + const bool valid_filter = + first_reverb_gain > 2.f * tail_gain_ && peak_energy < 100.f; + + // Estimate the size of the regions with early and late reflections. + const int size_early_reverb = early_reverb_estimator_.Estimate(); + const int size_late_reverb = + std::max(estimation_region_candidate_size_ - size_early_reverb, 0); + + // Only update the reverb decay estimate if the size of the identified late + // reverb is sufficiently large. + if (size_late_reverb >= 5) { + if (valid_filter && late_reverb_decay_estimator_.EstimateAvailable()) { + float decay = std::pow( + 2.0f, late_reverb_decay_estimator_.Estimate() * kFftLengthBy2); + constexpr float kMaxDecay = 0.95f; // ~1 sec min RT60. + constexpr float kMinDecay = 0.02f; // ~15 ms max RT60. + decay = std::max(.97f * decay_, decay); + decay = std::min(decay, kMaxDecay); + decay = std::max(decay, kMinDecay); + decay_ += smoothing_constant_ * (decay - decay_); + } + + // Update length of decay. Must have enough data (number of sections) in + // order to estimate decay rate. + late_reverb_decay_estimator_.Reset(size_late_reverb * kFftLengthBy2); + late_reverb_start_ = + peak_block + kEarlyReverbMinSizeBlocks + size_early_reverb; + late_reverb_end_ = + block_to_analyze_ + estimation_region_candidate_size_ - 1; + } else { + late_reverb_decay_estimator_.Reset(0); + late_reverb_start_ = 0; + late_reverb_end_ = 0; + } + + // Reset variables for the identification of the region for reverb decay + // estimation. + estimation_region_identified_ = !(valid_filter && sufficient_reverb_decay); + estimation_region_candidate_size_ = 0; + + // Stop estimation of the decay until another good filter is received. + smoothing_constant_ = 0.f; + + // Reset early reflections detector. + early_reverb_estimator_.Reset(); +} + +void ReverbDecayEstimator::AnalyzeFilter(rtc::ArrayView<const float> filter) { + auto h = rtc::ArrayView<const float>( + filter.begin() + block_to_analyze_ * kFftLengthBy2, kFftLengthBy2); + + // Compute squared filter coeffiecients for the block to analyze_; + std::array<float, kFftLengthBy2> h2; + std::transform(h.begin(), h.end(), h2.begin(), [](float a) { return a * a; }); + + // Map out the region for estimating the reverb decay. + bool adapting; + bool above_noise_floor; + AnalyzeBlockGain(h2, tail_gain_, &previous_gains_[block_to_analyze_], + &adapting, &above_noise_floor); + + // Count consecutive number of "good" filter sections, where "good" means: + // 1) energy is above noise floor. + // 2) energy of current section has not changed too much from last check. + estimation_region_identified_ = + estimation_region_identified_ || adapting || !above_noise_floor; + if (!estimation_region_identified_) { + ++estimation_region_candidate_size_; + } + + // Accumulate data for reverb decay estimation and for the estimation of early + // reflections. + if (block_to_analyze_ <= late_reverb_end_) { + if (block_to_analyze_ >= late_reverb_start_) { + for (float h2_k : h2) { + float h2_log2 = FastApproxLog2f(h2_k + 1e-10); + late_reverb_decay_estimator_.Accumulate(h2_log2); + early_reverb_estimator_.Accumulate(h2_log2, smoothing_constant_); + } + } else { + for (float h2_k : h2) { + float h2_log2 = FastApproxLog2f(h2_k + 1e-10); + early_reverb_estimator_.Accumulate(h2_log2, smoothing_constant_); + } + } + } +} + +void ReverbDecayEstimator::Dump(ApmDataDumper* data_dumper) const { + data_dumper->DumpRaw("aec3_reverb_decay", decay_); + data_dumper->DumpRaw("aec3_reverb_tail_energy", tail_gain_); + data_dumper->DumpRaw("aec3_reverb_alpha", smoothing_constant_); + data_dumper->DumpRaw("aec3_num_reverb_decay_blocks", + late_reverb_end_ - late_reverb_start_); + data_dumper->DumpRaw("aec3_late_reverb_start", late_reverb_start_); + data_dumper->DumpRaw("aec3_late_reverb_end", late_reverb_end_); + early_reverb_estimator_.Dump(data_dumper); +} + +void ReverbDecayEstimator::LateReverbLinearRegressor::Reset( + int num_data_points) { + RTC_DCHECK_LE(0, num_data_points); + RTC_DCHECK_EQ(0, num_data_points % 2); + const int N = num_data_points; + nz_ = 0.f; + // Arithmetic sum of $2 \sum_{i=0.5}^{(N-1)/2}i^2$ calculated directly. + nn_ = SymmetricArithmetricSum(N); + // The linear regression approach assumes symmetric index around 0. + count_ = N > 0 ? -N * 0.5f + 0.5f : 0.f; + N_ = N; + n_ = 0; +} + +void ReverbDecayEstimator::LateReverbLinearRegressor::Accumulate(float z) { + nz_ += count_ * z; + ++count_; + ++n_; +} + +float ReverbDecayEstimator::LateReverbLinearRegressor::Estimate() { + RTC_DCHECK(EstimateAvailable()); + if (nn_ == 0.f) { + RTC_DCHECK_NOTREACHED(); + return 0.f; + } + return nz_ / nn_; +} + +ReverbDecayEstimator::EarlyReverbLengthEstimator::EarlyReverbLengthEstimator( + int max_blocks) + : numerators_smooth_(max_blocks - kBlocksPerSection, 0.f), + numerators_(numerators_smooth_.size(), 0.f), + coefficients_counter_(0) { + RTC_DCHECK_LE(0, max_blocks); +} + +ReverbDecayEstimator::EarlyReverbLengthEstimator:: + ~EarlyReverbLengthEstimator() = default; + +void ReverbDecayEstimator::EarlyReverbLengthEstimator::Reset() { + coefficients_counter_ = 0; + std::fill(numerators_.begin(), numerators_.end(), 0.f); + block_counter_ = 0; +} + +void ReverbDecayEstimator::EarlyReverbLengthEstimator::Accumulate( + float value, + float smoothing) { + // Each section is composed by kBlocksPerSection blocks and each section + // overlaps with the next one in (kBlocksPerSection - 1) blocks. For example, + // the first section covers the blocks [0:5], the second covers the blocks + // [1:6] and so on. As a result, for each value, kBlocksPerSection sections + // need to be updated. + int first_section_index = std::max(block_counter_ - kBlocksPerSection + 1, 0); + int last_section_index = + std::min(block_counter_, static_cast<int>(numerators_.size() - 1)); + float x_value = static_cast<float>(coefficients_counter_) + + kEarlyReverbFirstPointAtLinearRegressors; + const float value_to_inc = kFftLengthBy2 * value; + float value_to_add = + x_value * value + (block_counter_ - last_section_index) * value_to_inc; + for (int section = last_section_index; section >= first_section_index; + --section, value_to_add += value_to_inc) { + numerators_[section] += value_to_add; + } + + // Check if this update was the last coefficient of the current block. In that + // case, check if we are at the end of one of the sections and update the + // numerator of the linear regressor that is computed in such section. + if (++coefficients_counter_ == kFftLengthBy2) { + if (block_counter_ >= (kBlocksPerSection - 1)) { + size_t section = block_counter_ - (kBlocksPerSection - 1); + RTC_DCHECK_GT(numerators_.size(), section); + RTC_DCHECK_GT(numerators_smooth_.size(), section); + numerators_smooth_[section] += + smoothing * (numerators_[section] - numerators_smooth_[section]); + n_sections_ = section + 1; + } + ++block_counter_; + coefficients_counter_ = 0; + } +} + +// Estimates the size in blocks of the early reverb. The estimation is done by +// comparing the tilt that is estimated in each section. As an optimization +// detail and due to the fact that all the linear regressors that are computed +// shared the same denominator, the comparison of the tilts is done by a +// comparison of the numerator of the linear regressors. +int ReverbDecayEstimator::EarlyReverbLengthEstimator::Estimate() { + constexpr float N = kBlocksPerSection * kFftLengthBy2; + constexpr float nn = SymmetricArithmetricSum(N); + // numerator_11 refers to the quantity that the linear regressor needs in the + // numerator for getting a decay equal to 1.1 (which is not a decay). + // log2(1.1) * nn / kFftLengthBy2. + constexpr float numerator_11 = 0.13750352374993502f * nn / kFftLengthBy2; + // log2(0.8) * nn / kFftLengthBy2. + constexpr float numerator_08 = -0.32192809488736229f * nn / kFftLengthBy2; + constexpr int kNumSectionsToAnalyze = 9; + + if (n_sections_ < kNumSectionsToAnalyze) { + return 0; + } + + // Estimation of the blocks that correspond to early reverberations. The + // estimation is done by analyzing the impulse response. The portions of the + // impulse response whose energy is not decreasing over its coefficients are + // considered to be part of the early reverberations. Furthermore, the blocks + // where the energy is decreasing faster than what it does at the end of the + // impulse response are also considered to be part of the early + // reverberations. The estimation is limited to the first + // kNumSectionsToAnalyze sections. + + RTC_DCHECK_LE(n_sections_, numerators_smooth_.size()); + const float min_numerator_tail = + *std::min_element(numerators_smooth_.begin() + kNumSectionsToAnalyze, + numerators_smooth_.begin() + n_sections_); + int early_reverb_size_minus_1 = 0; + for (int k = 0; k < kNumSectionsToAnalyze; ++k) { + if ((numerators_smooth_[k] > numerator_11) || + (numerators_smooth_[k] < numerator_08 && + numerators_smooth_[k] < 0.9f * min_numerator_tail)) { + early_reverb_size_minus_1 = k; + } + } + + return early_reverb_size_minus_1 == 0 ? 0 : early_reverb_size_minus_1 + 1; +} + +void ReverbDecayEstimator::EarlyReverbLengthEstimator::Dump( + ApmDataDumper* data_dumper) const { + data_dumper->DumpRaw("aec3_er_acum_numerator", numerators_smooth_); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.h new file mode 100644 index 0000000000..fee54210e6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_decay_estimator.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_REVERB_DECAY_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_REVERB_DECAY_ESTIMATOR_H_ + +#include <array> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" // kMaxAdaptiveFilter... + +namespace webrtc { + +class ApmDataDumper; +struct EchoCanceller3Config; + +// Class for estimating the decay of the late reverb. +class ReverbDecayEstimator { + public: + explicit ReverbDecayEstimator(const EchoCanceller3Config& config); + ~ReverbDecayEstimator(); + // Updates the decay estimate. + void Update(rtc::ArrayView<const float> filter, + const absl::optional<float>& filter_quality, + int filter_delay_blocks, + bool usable_linear_filter, + bool stationary_signal); + // Returns the decay for the exponential model. The parameter `mild` indicates + // which exponential decay to return, the default one or a milder one. + float Decay(bool mild) const { + if (use_adaptive_echo_decay_) { + return decay_; + } else { + return mild ? mild_decay_ : decay_; + } + } + // Dumps debug data. + void Dump(ApmDataDumper* data_dumper) const; + + private: + void EstimateDecay(rtc::ArrayView<const float> filter, int peak_block); + void AnalyzeFilter(rtc::ArrayView<const float> filter); + + void ResetDecayEstimation(); + + // Class for estimating the decay of the late reverb from the linear filter. + class LateReverbLinearRegressor { + public: + // Resets the estimator to receive a specified number of data points. + void Reset(int num_data_points); + // Accumulates estimation data. + void Accumulate(float z); + // Estimates the decay. + float Estimate(); + // Returns whether an estimate is available. + bool EstimateAvailable() const { return n_ == N_ && N_ != 0; } + + public: + float nz_ = 0.f; + float nn_ = 0.f; + float count_ = 0.f; + int N_ = 0; + int n_ = 0; + }; + + // Class for identifying the length of the early reverb from the linear + // filter. For identifying the early reverberations, the impulse response is + // divided in sections and the tilt of each section is computed by a linear + // regressor. + class EarlyReverbLengthEstimator { + public: + explicit EarlyReverbLengthEstimator(int max_blocks); + ~EarlyReverbLengthEstimator(); + + // Resets the estimator. + void Reset(); + // Accumulates estimation data. + void Accumulate(float value, float smoothing); + // Estimates the size in blocks of the early reverb. + int Estimate(); + // Dumps debug data. + void Dump(ApmDataDumper* data_dumper) const; + + private: + std::vector<float> numerators_smooth_; + std::vector<float> numerators_; + int coefficients_counter_; + int block_counter_ = 0; + int n_sections_ = 0; + }; + + const int filter_length_blocks_; + const int filter_length_coefficients_; + const bool use_adaptive_echo_decay_; + LateReverbLinearRegressor late_reverb_decay_estimator_; + EarlyReverbLengthEstimator early_reverb_estimator_; + int late_reverb_start_; + int late_reverb_end_; + int block_to_analyze_ = 0; + int estimation_region_candidate_size_ = 0; + bool estimation_region_identified_ = false; + std::vector<float> previous_gains_; + float decay_; + float mild_decay_; + float tail_gain_ = 0.f; + float smoothing_constant_ = 0.f; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_REVERB_DECAY_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.cc b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.cc new file mode 100644 index 0000000000..6e7282a1fc --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.cc @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_frequency_response.h" + +#include <stddef.h> + +#include <algorithm> +#include <array> +#include <numeric> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { + +// Computes the ratio of the energies between the direct path and the tail. The +// energy is computed in the power spectrum domain discarding the DC +// contributions. +float AverageDecayWithinFilter( + rtc::ArrayView<const float> freq_resp_direct_path, + rtc::ArrayView<const float> freq_resp_tail) { + // Skipping the DC for the ratio computation + constexpr size_t kSkipBins = 1; + RTC_CHECK_EQ(freq_resp_direct_path.size(), freq_resp_tail.size()); + + float direct_path_energy = + std::accumulate(freq_resp_direct_path.begin() + kSkipBins, + freq_resp_direct_path.end(), 0.f); + + if (direct_path_energy == 0.f) { + return 0.f; + } + + float tail_energy = std::accumulate(freq_resp_tail.begin() + kSkipBins, + freq_resp_tail.end(), 0.f); + return tail_energy / direct_path_energy; +} + +} // namespace + +ReverbFrequencyResponse::ReverbFrequencyResponse( + bool use_conservative_tail_frequency_response) + : use_conservative_tail_frequency_response_( + use_conservative_tail_frequency_response) { + tail_response_.fill(0.0f); +} + +ReverbFrequencyResponse::~ReverbFrequencyResponse() = default; + +void ReverbFrequencyResponse::Update( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& + frequency_response, + int filter_delay_blocks, + const absl::optional<float>& linear_filter_quality, + bool stationary_block) { + if (stationary_block || !linear_filter_quality) { + return; + } + + Update(frequency_response, filter_delay_blocks, *linear_filter_quality); +} + +void ReverbFrequencyResponse::Update( + const std::vector<std::array<float, kFftLengthBy2Plus1>>& + frequency_response, + int filter_delay_blocks, + float linear_filter_quality) { + rtc::ArrayView<const float> freq_resp_tail( + frequency_response[frequency_response.size() - 1]); + + rtc::ArrayView<const float> freq_resp_direct_path( + frequency_response[filter_delay_blocks]); + + float average_decay = + AverageDecayWithinFilter(freq_resp_direct_path, freq_resp_tail); + + const float smoothing = 0.2f * linear_filter_quality; + average_decay_ += smoothing * (average_decay - average_decay_); + + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + tail_response_[k] = freq_resp_direct_path[k] * average_decay_; + } + + if (use_conservative_tail_frequency_response_) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + tail_response_[k] = std::max(freq_resp_tail[k], tail_response_[k]); + } + } + + for (size_t k = 1; k < kFftLengthBy2; ++k) { + const float avg_neighbour = + 0.5f * (tail_response_[k - 1] + tail_response_[k + 1]); + tail_response_[k] = std::max(tail_response_[k], avg_neighbour); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.h b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.h new file mode 100644 index 0000000000..69b16b54d0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_frequency_response.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_REVERB_FREQUENCY_RESPONSE_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_REVERB_FREQUENCY_RESPONSE_H_ + +#include <array> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Class for updating the frequency response for the reverb. +class ReverbFrequencyResponse { + public: + explicit ReverbFrequencyResponse( + bool use_conservative_tail_frequency_response); + ~ReverbFrequencyResponse(); + + // Updates the frequency response estimate of the reverb. + void Update(const std::vector<std::array<float, kFftLengthBy2Plus1>>& + frequency_response, + int filter_delay_blocks, + const absl::optional<float>& linear_filter_quality, + bool stationary_block); + + // Returns the estimated frequency response for the reverb. + rtc::ArrayView<const float> FrequencyResponse() const { + return tail_response_; + } + + private: + void Update(const std::vector<std::array<float, kFftLengthBy2Plus1>>& + frequency_response, + int filter_delay_blocks, + float linear_filter_quality); + + const bool use_conservative_tail_frequency_response_; + float average_decay_ = 0.f; + std::array<float, kFftLengthBy2Plus1> tail_response_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_REVERB_FREQUENCY_RESPONSE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.cc b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.cc new file mode 100644 index 0000000000..e4f3507d31 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.cc @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_model.h" + +#include <stddef.h> + +#include <algorithm> +#include <functional> + +#include "api/array_view.h" + +namespace webrtc { + +ReverbModel::ReverbModel() { + Reset(); +} + +ReverbModel::~ReverbModel() = default; + +void ReverbModel::Reset() { + reverb_.fill(0.); +} + +void ReverbModel::UpdateReverbNoFreqShaping( + rtc::ArrayView<const float> power_spectrum, + float power_spectrum_scaling, + float reverb_decay) { + if (reverb_decay > 0) { + // Update the estimate of the reverberant power. + for (size_t k = 0; k < power_spectrum.size(); ++k) { + reverb_[k] = (reverb_[k] + power_spectrum[k] * power_spectrum_scaling) * + reverb_decay; + } + } +} + +void ReverbModel::UpdateReverb( + rtc::ArrayView<const float> power_spectrum, + rtc::ArrayView<const float> power_spectrum_scaling, + float reverb_decay) { + if (reverb_decay > 0) { + // Update the estimate of the reverberant power. + for (size_t k = 0; k < power_spectrum.size(); ++k) { + reverb_[k] = + (reverb_[k] + power_spectrum[k] * power_spectrum_scaling[k]) * + reverb_decay; + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.h b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.h new file mode 100644 index 0000000000..5ba54853da --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_H_ + +#include <array> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// The ReverbModel class describes an exponential reverberant model +// that can be applied over power spectrums. +class ReverbModel { + public: + ReverbModel(); + ~ReverbModel(); + + // Resets the state. + void Reset(); + + // Returns the reverb. + rtc::ArrayView<const float, kFftLengthBy2Plus1> reverb() const { + return reverb_; + } + + // The methods UpdateReverbNoFreqShaping and UpdateReverb update the + // estimate of the reverberation contribution to an input/output power + // spectrum. Before applying the exponential reverberant model, the input + // power spectrum is pre-scaled. Use the method UpdateReverb when a different + // scaling should be applied per frequency and UpdateReverb_no_freq_shape if + // the same scaling should be used for all the frequencies. + void UpdateReverbNoFreqShaping(rtc::ArrayView<const float> power_spectrum, + float power_spectrum_scaling, + float reverb_decay); + + // Update the reverb based on new data. + void UpdateReverb(rtc::ArrayView<const float> power_spectrum, + rtc::ArrayView<const float> power_spectrum_scaling, + float reverb_decay); + + private: + + std::array<float, kFftLengthBy2Plus1> reverb_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.cc new file mode 100644 index 0000000000..5cd7a7870d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_model_estimator.h" + +namespace webrtc { + +ReverbModelEstimator::ReverbModelEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels) + : reverb_decay_estimators_(num_capture_channels), + reverb_frequency_responses_( + num_capture_channels, + ReverbFrequencyResponse( + config.ep_strength.use_conservative_tail_frequency_response)) { + for (size_t ch = 0; ch < reverb_decay_estimators_.size(); ++ch) { + reverb_decay_estimators_[ch] = + std::make_unique<ReverbDecayEstimator>(config); + } +} + +ReverbModelEstimator::~ReverbModelEstimator() = default; + +void ReverbModelEstimator::Update( + rtc::ArrayView<const std::vector<float>> impulse_responses, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + frequency_responses, + rtc::ArrayView<const absl::optional<float>> linear_filter_qualities, + rtc::ArrayView<const int> filter_delays_blocks, + const std::vector<bool>& usable_linear_estimates, + bool stationary_block) { + const size_t num_capture_channels = reverb_decay_estimators_.size(); + RTC_DCHECK_EQ(num_capture_channels, impulse_responses.size()); + RTC_DCHECK_EQ(num_capture_channels, frequency_responses.size()); + RTC_DCHECK_EQ(num_capture_channels, usable_linear_estimates.size()); + + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + // Estimate the frequency response for the reverb. + reverb_frequency_responses_[ch].Update( + frequency_responses[ch], filter_delays_blocks[ch], + linear_filter_qualities[ch], stationary_block); + + // Estimate the reverb decay, + reverb_decay_estimators_[ch]->Update( + impulse_responses[ch], linear_filter_qualities[ch], + filter_delays_blocks[ch], usable_linear_estimates[ch], + stationary_block); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.h new file mode 100644 index 0000000000..63bade977f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_ESTIMATOR_H_ + +#include <array> +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" // kFftLengthBy2Plus1 +#include "modules/audio_processing/aec3/reverb_decay_estimator.h" +#include "modules/audio_processing/aec3/reverb_frequency_response.h" + +namespace webrtc { + +class ApmDataDumper; + +// Class for estimating the model parameters for the reverberant echo. +class ReverbModelEstimator { + public: + ReverbModelEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels); + ~ReverbModelEstimator(); + + // Updates the estimates based on new data. + void Update( + rtc::ArrayView<const std::vector<float>> impulse_responses, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + frequency_responses, + rtc::ArrayView<const absl::optional<float>> linear_filter_qualities, + rtc::ArrayView<const int> filter_delays_blocks, + const std::vector<bool>& usable_linear_estimates, + bool stationary_block); + + // Returns the exponential decay of the reverberant echo. The parameter `mild` + // indicates which exponential decay to return, the default one or a milder + // one. + // TODO(peah): Correct to properly support multiple channels. + float ReverbDecay(bool mild) const { + return reverb_decay_estimators_[0]->Decay(mild); + } + + // Return the frequency response of the reverberant echo. + // TODO(peah): Correct to properly support multiple channels. + rtc::ArrayView<const float> GetReverbFrequencyResponse() const { + return reverb_frequency_responses_[0].FrequencyResponse(); + } + + // Dumps debug data. + void Dump(ApmDataDumper* data_dumper) const { + reverb_decay_estimators_[0]->Dump(data_dumper); + } + + private: + std::vector<std::unique_ptr<ReverbDecayEstimator>> reverb_decay_estimators_; + std::vector<ReverbFrequencyResponse> reverb_frequency_responses_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_REVERB_MODEL_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc new file mode 100644 index 0000000000..fb7dcef37f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/reverb_model_estimator_unittest.cc @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/reverb_model_estimator.h" + +#include <algorithm> +#include <array> +#include <cmath> +#include <numeric> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "rtc_base/checks.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { + +EchoCanceller3Config CreateConfigForTest(float default_decay) { + EchoCanceller3Config cfg; + cfg.ep_strength.default_len = default_decay; + cfg.filter.refined.length_blocks = 40; + return cfg; +} + +constexpr int kFilterDelayBlocks = 2; + +} // namespace + +class ReverbModelEstimatorTest { + public: + ReverbModelEstimatorTest(float default_decay, size_t num_capture_channels) + : aec3_config_(CreateConfigForTest(default_decay)), + estimated_decay_(default_decay), + h_(num_capture_channels, + std::vector<float>( + aec3_config_.filter.refined.length_blocks * kBlockSize, + 0.f)), + H2_(num_capture_channels, + std::vector<std::array<float, kFftLengthBy2Plus1>>( + aec3_config_.filter.refined.length_blocks)), + quality_linear_(num_capture_channels, 1.0f) { + CreateImpulseResponseWithDecay(); + } + void RunEstimator(); + float GetDecay(bool mild) { + return mild ? mild_estimated_decay_ : estimated_decay_; + } + float GetTrueDecay() { return kTruePowerDecay; } + float GetPowerTailDb() { return 10.f * std::log10(estimated_power_tail_); } + float GetTruePowerTailDb() { return 10.f * std::log10(true_power_tail_); } + + private: + void CreateImpulseResponseWithDecay(); + static constexpr bool kStationaryBlock = false; + static constexpr float kTruePowerDecay = 0.5f; + const EchoCanceller3Config aec3_config_; + float estimated_decay_; + float mild_estimated_decay_; + float estimated_power_tail_ = 0.f; + float true_power_tail_ = 0.f; + std::vector<std::vector<float>> h_; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_; + std::vector<absl::optional<float>> quality_linear_; +}; + +void ReverbModelEstimatorTest::CreateImpulseResponseWithDecay() { + const Aec3Fft fft; + for (const auto& h_k : h_) { + RTC_DCHECK_EQ(h_k.size(), + aec3_config_.filter.refined.length_blocks * kBlockSize); + } + for (const auto& H2_k : H2_) { + RTC_DCHECK_EQ(H2_k.size(), aec3_config_.filter.refined.length_blocks); + } + RTC_DCHECK_EQ(kFilterDelayBlocks, 2); + + float decay_sample = std::sqrt(powf(kTruePowerDecay, 1.f / kBlockSize)); + const size_t filter_delay_coefficients = kFilterDelayBlocks * kBlockSize; + for (auto& h_i : h_) { + std::fill(h_i.begin(), h_i.end(), 0.f); + h_i[filter_delay_coefficients] = 1.f; + for (size_t k = filter_delay_coefficients + 1; k < h_i.size(); ++k) { + h_i[k] = h_i[k - 1] * decay_sample; + } + } + + for (size_t ch = 0; ch < H2_.size(); ++ch) { + for (size_t j = 0, k = 0; j < H2_[ch].size(); ++j, k += kBlockSize) { + std::array<float, kFftLength> fft_data; + fft_data.fill(0.f); + std::copy(h_[ch].begin() + k, h_[ch].begin() + k + kBlockSize, + fft_data.begin()); + FftData H_j; + fft.Fft(&fft_data, &H_j); + H_j.Spectrum(Aec3Optimization::kNone, H2_[ch][j]); + } + } + rtc::ArrayView<float> H2_tail(H2_[0][H2_[0].size() - 1]); + true_power_tail_ = std::accumulate(H2_tail.begin(), H2_tail.end(), 0.f); +} +void ReverbModelEstimatorTest::RunEstimator() { + const size_t num_capture_channels = H2_.size(); + constexpr bool kUsableLinearEstimate = true; + ReverbModelEstimator estimator(aec3_config_, num_capture_channels); + std::vector<bool> usable_linear_estimates(num_capture_channels, + kUsableLinearEstimate); + std::vector<int> filter_delay_blocks(num_capture_channels, + kFilterDelayBlocks); + for (size_t k = 0; k < 3000; ++k) { + estimator.Update(h_, H2_, quality_linear_, filter_delay_blocks, + usable_linear_estimates, kStationaryBlock); + } + estimated_decay_ = estimator.ReverbDecay(/*mild=*/false); + mild_estimated_decay_ = estimator.ReverbDecay(/*mild=*/true); + auto freq_resp_tail = estimator.GetReverbFrequencyResponse(); + estimated_power_tail_ = + std::accumulate(freq_resp_tail.begin(), freq_resp_tail.end(), 0.f); +} + +TEST(ReverbModelEstimatorTests, NotChangingDecay) { + constexpr float kDefaultDecay = 0.9f; + for (size_t num_capture_channels : {1, 2, 4, 8}) { + ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels); + test.RunEstimator(); + EXPECT_EQ(test.GetDecay(/*mild=*/false), kDefaultDecay); + EXPECT_EQ(test.GetDecay(/*mild=*/true), + EchoCanceller3Config().ep_strength.nearend_len); + EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); + } +} + +TEST(ReverbModelEstimatorTests, ChangingDecay) { + constexpr float kDefaultDecay = -0.9f; + for (size_t num_capture_channels : {1, 2, 4, 8}) { + ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels); + test.RunEstimator(); + EXPECT_NEAR(test.GetDecay(/*mild=*/false), test.GetTrueDecay(), 0.1f); + EXPECT_NEAR(test.GetDecay(/*mild=*/true), test.GetTrueDecay(), 0.1f); + EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc new file mode 100644 index 0000000000..a5e77092a6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.cc @@ -0,0 +1,416 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h" + +#include <algorithm> +#include <functional> +#include <numeric> + +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { + +namespace { + +constexpr std::array<size_t, SignalDependentErleEstimator::kSubbands + 1> + kBandBoundaries = {1, 8, 16, 24, 32, 48, kFftLengthBy2Plus1}; + +std::array<size_t, kFftLengthBy2Plus1> FormSubbandMap() { + std::array<size_t, kFftLengthBy2Plus1> map_band_to_subband; + size_t subband = 1; + for (size_t k = 0; k < map_band_to_subband.size(); ++k) { + RTC_DCHECK_LT(subband, kBandBoundaries.size()); + if (k >= kBandBoundaries[subband]) { + subband++; + RTC_DCHECK_LT(k, kBandBoundaries[subband]); + } + map_band_to_subband[k] = subband - 1; + } + return map_band_to_subband; +} + +// Defines the size in blocks of the sections that are used for dividing the +// linear filter. The sections are split in a non-linear manner so that lower +// sections that typically represent the direct path have a larger resolution +// than the higher sections which typically represent more reverberant acoustic +// paths. +std::vector<size_t> DefineFilterSectionSizes(size_t delay_headroom_blocks, + size_t num_blocks, + size_t num_sections) { + size_t filter_length_blocks = num_blocks - delay_headroom_blocks; + std::vector<size_t> section_sizes(num_sections); + size_t remaining_blocks = filter_length_blocks; + size_t remaining_sections = num_sections; + size_t estimator_size = 2; + size_t idx = 0; + while (remaining_sections > 1 && + remaining_blocks > estimator_size * remaining_sections) { + RTC_DCHECK_LT(idx, section_sizes.size()); + section_sizes[idx] = estimator_size; + remaining_blocks -= estimator_size; + remaining_sections--; + estimator_size *= 2; + idx++; + } + + size_t last_groups_size = remaining_blocks / remaining_sections; + for (; idx < num_sections; idx++) { + section_sizes[idx] = last_groups_size; + } + section_sizes[num_sections - 1] += + remaining_blocks - last_groups_size * remaining_sections; + return section_sizes; +} + +// Forms the limits in blocks for each filter section. Those sections +// are used for analyzing the echo estimates and investigating which +// linear filter sections contribute most to the echo estimate energy. +std::vector<size_t> SetSectionsBoundaries(size_t delay_headroom_blocks, + size_t num_blocks, + size_t num_sections) { + std::vector<size_t> estimator_boundaries_blocks(num_sections + 1); + if (estimator_boundaries_blocks.size() == 2) { + estimator_boundaries_blocks[0] = 0; + estimator_boundaries_blocks[1] = num_blocks; + return estimator_boundaries_blocks; + } + RTC_DCHECK_GT(estimator_boundaries_blocks.size(), 2); + const std::vector<size_t> section_sizes = + DefineFilterSectionSizes(delay_headroom_blocks, num_blocks, + estimator_boundaries_blocks.size() - 1); + + size_t idx = 0; + size_t current_size_block = 0; + RTC_DCHECK_EQ(section_sizes.size() + 1, estimator_boundaries_blocks.size()); + estimator_boundaries_blocks[0] = delay_headroom_blocks; + for (size_t k = delay_headroom_blocks; k < num_blocks; ++k) { + current_size_block++; + if (current_size_block >= section_sizes[idx]) { + idx = idx + 1; + if (idx == section_sizes.size()) { + break; + } + estimator_boundaries_blocks[idx] = k + 1; + current_size_block = 0; + } + } + estimator_boundaries_blocks[section_sizes.size()] = num_blocks; + return estimator_boundaries_blocks; +} + +std::array<float, SignalDependentErleEstimator::kSubbands> +SetMaxErleSubbands(float max_erle_l, float max_erle_h, size_t limit_subband_l) { + std::array<float, SignalDependentErleEstimator::kSubbands> max_erle; + std::fill(max_erle.begin(), max_erle.begin() + limit_subband_l, max_erle_l); + std::fill(max_erle.begin() + limit_subband_l, max_erle.end(), max_erle_h); + return max_erle; +} + +} // namespace + +SignalDependentErleEstimator::SignalDependentErleEstimator( + const EchoCanceller3Config& config, + size_t num_capture_channels) + : min_erle_(config.erle.min), + num_sections_(config.erle.num_sections), + num_blocks_(config.filter.refined.length_blocks), + delay_headroom_blocks_(config.delay.delay_headroom_samples / kBlockSize), + band_to_subband_(FormSubbandMap()), + max_erle_(SetMaxErleSubbands(config.erle.max_l, + config.erle.max_h, + band_to_subband_[kFftLengthBy2 / 2])), + section_boundaries_blocks_(SetSectionsBoundaries(delay_headroom_blocks_, + num_blocks_, + num_sections_)), + use_onset_detection_(config.erle.onset_detection), + erle_(num_capture_channels), + erle_onset_compensated_(num_capture_channels), + S2_section_accum_( + num_capture_channels, + std::vector<std::array<float, kFftLengthBy2Plus1>>(num_sections_)), + erle_estimators_( + num_capture_channels, + std::vector<std::array<float, kSubbands>>(num_sections_)), + erle_ref_(num_capture_channels), + correction_factors_( + num_capture_channels, + std::vector<std::array<float, kSubbands>>(num_sections_)), + num_updates_(num_capture_channels), + n_active_sections_(num_capture_channels) { + RTC_DCHECK_LE(num_sections_, num_blocks_); + RTC_DCHECK_GE(num_sections_, 1); + Reset(); +} + +SignalDependentErleEstimator::~SignalDependentErleEstimator() = default; + +void SignalDependentErleEstimator::Reset() { + for (size_t ch = 0; ch < erle_.size(); ++ch) { + erle_[ch].fill(min_erle_); + erle_onset_compensated_[ch].fill(min_erle_); + for (auto& erle_estimator : erle_estimators_[ch]) { + erle_estimator.fill(min_erle_); + } + erle_ref_[ch].fill(min_erle_); + for (auto& factor : correction_factors_[ch]) { + factor.fill(1.0f); + } + num_updates_[ch].fill(0); + n_active_sections_[ch].fill(0); + } +} + +// Updates the Erle estimate by analyzing the current input signals. It takes +// the render buffer and the filter frequency response in order to do an +// estimation of the number of sections of the linear filter that are needed +// for getting the majority of the energy in the echo estimate. Based on that +// number of sections, it updates the erle estimation by introducing a +// correction factor to the erle that is given as an input to this method. +void SignalDependentErleEstimator::Update( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses, + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + average_erle_onset_compensated, + const std::vector<bool>& converged_filters) { + RTC_DCHECK_GT(num_sections_, 1); + + // Gets the number of filter sections that are needed for achieving 90 % + // of the power spectrum energy of the echo estimate. + ComputeNumberOfActiveFilterSections(render_buffer, + filter_frequency_responses); + + // Updates the correction factors that is used for correcting the erle and + // adapt it to the particular characteristics of the input signal. + UpdateCorrectionFactors(X2, Y2, E2, converged_filters); + + // Applies the correction factor to the input erle for getting a more refined + // erle estimation for the current input signal. + for (size_t ch = 0; ch < erle_.size(); ++ch) { + for (size_t k = 0; k < kFftLengthBy2; ++k) { + RTC_DCHECK_GT(correction_factors_[ch].size(), n_active_sections_[ch][k]); + float correction_factor = + correction_factors_[ch][n_active_sections_[ch][k]] + [band_to_subband_[k]]; + erle_[ch][k] = rtc::SafeClamp(average_erle[ch][k] * correction_factor, + min_erle_, max_erle_[band_to_subband_[k]]); + if (use_onset_detection_) { + erle_onset_compensated_[ch][k] = rtc::SafeClamp( + average_erle_onset_compensated[ch][k] * correction_factor, + min_erle_, max_erle_[band_to_subband_[k]]); + } + } + } +} + +void SignalDependentErleEstimator::Dump( + const std::unique_ptr<ApmDataDumper>& data_dumper) const { + for (auto& erle : erle_estimators_[0]) { + data_dumper->DumpRaw("aec3_all_erle", erle); + } + data_dumper->DumpRaw("aec3_ref_erle", erle_ref_[0]); + for (auto& factor : correction_factors_[0]) { + data_dumper->DumpRaw("aec3_erle_correction_factor", factor); + } +} + +// Estimates for each band the smallest number of sections in the filter that +// together constitute 90% of the estimated echo energy. +void SignalDependentErleEstimator::ComputeNumberOfActiveFilterSections( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses) { + RTC_DCHECK_GT(num_sections_, 1); + // Computes an approximation of the power spectrum if the filter would have + // been limited to a certain number of filter sections. + ComputeEchoEstimatePerFilterSection(render_buffer, + filter_frequency_responses); + // For each band, computes the number of filter sections that are needed for + // achieving the 90 % energy in the echo estimate. + ComputeActiveFilterSections(); +} + +void SignalDependentErleEstimator::UpdateCorrectionFactors( + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters) { + for (size_t ch = 0; ch < converged_filters.size(); ++ch) { + if (converged_filters[ch]) { + constexpr float kX2BandEnergyThreshold = 44015068.0f; + constexpr float kSmthConstantDecreases = 0.1f; + constexpr float kSmthConstantIncreases = kSmthConstantDecreases / 2.f; + auto subband_powers = [](rtc::ArrayView<const float> power_spectrum, + rtc::ArrayView<float> power_spectrum_subbands) { + for (size_t subband = 0; subband < kSubbands; ++subband) { + RTC_DCHECK_LE(kBandBoundaries[subband + 1], power_spectrum.size()); + power_spectrum_subbands[subband] = std::accumulate( + power_spectrum.begin() + kBandBoundaries[subband], + power_spectrum.begin() + kBandBoundaries[subband + 1], 0.f); + } + }; + + std::array<float, kSubbands> X2_subbands, E2_subbands, Y2_subbands; + subband_powers(X2, X2_subbands); + subband_powers(E2[ch], E2_subbands); + subband_powers(Y2[ch], Y2_subbands); + std::array<size_t, kSubbands> idx_subbands; + for (size_t subband = 0; subband < kSubbands; ++subband) { + // When aggregating the number of active sections in the filter for + // different bands we choose to take the minimum of all of them. As an + // example, if for one of the bands it is the direct path its refined + // contributor to the final echo estimate, we consider the direct path + // is as well the refined contributor for the subband that contains that + // particular band. That aggregate number of sections will be later used + // as the identifier of the erle estimator that needs to be updated. + RTC_DCHECK_LE(kBandBoundaries[subband + 1], + n_active_sections_[ch].size()); + idx_subbands[subband] = *std::min_element( + n_active_sections_[ch].begin() + kBandBoundaries[subband], + n_active_sections_[ch].begin() + kBandBoundaries[subband + 1]); + } + + std::array<float, kSubbands> new_erle; + std::array<bool, kSubbands> is_erle_updated; + is_erle_updated.fill(false); + new_erle.fill(0.f); + for (size_t subband = 0; subband < kSubbands; ++subband) { + if (X2_subbands[subband] > kX2BandEnergyThreshold && + E2_subbands[subband] > 0) { + new_erle[subband] = Y2_subbands[subband] / E2_subbands[subband]; + RTC_DCHECK_GT(new_erle[subband], 0); + is_erle_updated[subband] = true; + ++num_updates_[ch][subband]; + } + } + + for (size_t subband = 0; subband < kSubbands; ++subband) { + const size_t idx = idx_subbands[subband]; + RTC_DCHECK_LT(idx, erle_estimators_[ch].size()); + float alpha = new_erle[subband] > erle_estimators_[ch][idx][subband] + ? kSmthConstantIncreases + : kSmthConstantDecreases; + alpha = static_cast<float>(is_erle_updated[subband]) * alpha; + erle_estimators_[ch][idx][subband] += + alpha * (new_erle[subband] - erle_estimators_[ch][idx][subband]); + erle_estimators_[ch][idx][subband] = rtc::SafeClamp( + erle_estimators_[ch][idx][subband], min_erle_, max_erle_[subband]); + } + + for (size_t subband = 0; subband < kSubbands; ++subband) { + float alpha = new_erle[subband] > erle_ref_[ch][subband] + ? kSmthConstantIncreases + : kSmthConstantDecreases; + alpha = static_cast<float>(is_erle_updated[subband]) * alpha; + erle_ref_[ch][subband] += + alpha * (new_erle[subband] - erle_ref_[ch][subband]); + erle_ref_[ch][subband] = rtc::SafeClamp(erle_ref_[ch][subband], + min_erle_, max_erle_[subband]); + } + + for (size_t subband = 0; subband < kSubbands; ++subband) { + constexpr int kNumUpdateThr = 50; + if (is_erle_updated[subband] && + num_updates_[ch][subband] > kNumUpdateThr) { + const size_t idx = idx_subbands[subband]; + RTC_DCHECK_GT(erle_ref_[ch][subband], 0.f); + // Computes the ratio between the erle that is updated using all the + // points and the erle that is updated only on signals that share the + // same number of active filter sections. + float new_correction_factor = + erle_estimators_[ch][idx][subband] / erle_ref_[ch][subband]; + + correction_factors_[ch][idx][subband] += + 0.1f * + (new_correction_factor - correction_factors_[ch][idx][subband]); + } + } + } + } +} + +void SignalDependentErleEstimator::ComputeEchoEstimatePerFilterSection( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses) { + const SpectrumBuffer& spectrum_render_buffer = + render_buffer.GetSpectrumBuffer(); + const size_t num_render_channels = spectrum_render_buffer.buffer[0].size(); + const size_t num_capture_channels = S2_section_accum_.size(); + const float one_by_num_render_channels = 1.f / num_render_channels; + + RTC_DCHECK_EQ(S2_section_accum_.size(), filter_frequency_responses.size()); + + for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) { + RTC_DCHECK_EQ(S2_section_accum_[capture_ch].size() + 1, + section_boundaries_blocks_.size()); + size_t idx_render = render_buffer.Position(); + idx_render = spectrum_render_buffer.OffsetIndex( + idx_render, section_boundaries_blocks_[0]); + + for (size_t section = 0; section < num_sections_; ++section) { + std::array<float, kFftLengthBy2Plus1> X2_section; + std::array<float, kFftLengthBy2Plus1> H2_section; + X2_section.fill(0.f); + H2_section.fill(0.f); + const size_t block_limit = + std::min(section_boundaries_blocks_[section + 1], + filter_frequency_responses[capture_ch].size()); + for (size_t block = section_boundaries_blocks_[section]; + block < block_limit; ++block) { + for (size_t render_ch = 0; + render_ch < spectrum_render_buffer.buffer[idx_render].size(); + ++render_ch) { + for (size_t k = 0; k < X2_section.size(); ++k) { + X2_section[k] += + spectrum_render_buffer.buffer[idx_render][render_ch][k] * + one_by_num_render_channels; + } + } + std::transform(H2_section.begin(), H2_section.end(), + filter_frequency_responses[capture_ch][block].begin(), + H2_section.begin(), std::plus<float>()); + idx_render = spectrum_render_buffer.IncIndex(idx_render); + } + + std::transform(X2_section.begin(), X2_section.end(), H2_section.begin(), + S2_section_accum_[capture_ch][section].begin(), + std::multiplies<float>()); + } + + for (size_t section = 1; section < num_sections_; ++section) { + std::transform(S2_section_accum_[capture_ch][section - 1].begin(), + S2_section_accum_[capture_ch][section - 1].end(), + S2_section_accum_[capture_ch][section].begin(), + S2_section_accum_[capture_ch][section].begin(), + std::plus<float>()); + } + } +} + +void SignalDependentErleEstimator::ComputeActiveFilterSections() { + for (size_t ch = 0; ch < n_active_sections_.size(); ++ch) { + std::fill(n_active_sections_[ch].begin(), n_active_sections_[ch].end(), 0); + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + size_t section = num_sections_; + float target = 0.9f * S2_section_accum_[ch][num_sections_ - 1][k]; + while (section > 0 && S2_section_accum_[ch][section - 1][k] >= target) { + n_active_sections_[ch][k] = --section; + } + } + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.h new file mode 100644 index 0000000000..6847c1ab13 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SIGNAL_DEPENDENT_ERLE_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SIGNAL_DEPENDENT_ERLE_ESTIMATOR_H_ + +#include <memory> +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +// This class estimates the dependency of the Erle to the input signal. By +// looking at the input signal, an estimation on whether the current echo +// estimate is due to the direct path or to a more reverberant one is performed. +// Once that estimation is done, it is possible to refine the average Erle that +// this class receive as an input. +class SignalDependentErleEstimator { + public: + SignalDependentErleEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels); + + ~SignalDependentErleEstimator(); + + void Reset(); + + // Returns the Erle per frequency subband. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle( + bool onset_compensated) const { + return onset_compensated && use_onset_detection_ ? erle_onset_compensated_ + : erle_; + } + + // Updates the Erle estimate. The Erle that is passed as an input is required + // to be an estimation of the average Erle achieved by the linear filter. + void Update( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_response, + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> average_erle, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + average_erle_onset_compensated, + const std::vector<bool>& converged_filters); + + void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const; + + static constexpr size_t kSubbands = 6; + + private: + void ComputeNumberOfActiveFilterSections( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses); + + void UpdateCorrectionFactors( + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters); + + void ComputeEchoEstimatePerFilterSection( + const RenderBuffer& render_buffer, + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + filter_frequency_responses); + + void ComputeActiveFilterSections(); + + const float min_erle_; + const size_t num_sections_; + const size_t num_blocks_; + const size_t delay_headroom_blocks_; + const std::array<size_t, kFftLengthBy2Plus1> band_to_subband_; + const std::array<float, kSubbands> max_erle_; + const std::vector<size_t> section_boundaries_blocks_; + const bool use_onset_detection_; + std::vector<std::array<float, kFftLengthBy2Plus1>> erle_; + std::vector<std::array<float, kFftLengthBy2Plus1>> erle_onset_compensated_; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + S2_section_accum_; + std::vector<std::vector<std::array<float, kSubbands>>> erle_estimators_; + std::vector<std::array<float, kSubbands>> erle_ref_; + std::vector<std::vector<std::array<float, kSubbands>>> correction_factors_; + std::vector<std::array<int, kSubbands>> num_updates_; + std::vector<std::array<size_t, kFftLengthBy2Plus1>> n_active_sections_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SIGNAL_DEPENDENT_ERLE_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc new file mode 100644 index 0000000000..67927a6c68 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/signal_dependent_erle_estimator_unittest.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/signal_dependent_erle_estimator.h" + +#include <algorithm> +#include <iostream> +#include <string> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { + +namespace { + +void GetActiveFrame(Block* x) { + const std::array<float, kBlockSize> frame = { + 7459.88, 17209.6, 17383, 20768.9, 16816.7, 18386.3, 4492.83, 9675.85, + 6665.52, 14808.6, 9342.3, 7483.28, 19261.7, 4145.98, 1622.18, 13475.2, + 7166.32, 6856.61, 21937, 7263.14, 9569.07, 14919, 8413.32, 7551.89, + 7848.65, 6011.27, 13080.6, 15865.2, 12656, 17459.6, 4263.93, 4503.03, + 9311.79, 21095.8, 12657.9, 13906.6, 19267.2, 11338.1, 16828.9, 11501.6, + 11405, 15031.4, 14541.6, 19765.5, 18346.3, 19350.2, 3157.47, 18095.8, + 1743.68, 21328.2, 19727.5, 7295.16, 10332.4, 11055.5, 20107.4, 14708.4, + 12416.2, 16434, 2454.69, 9840.8, 6867.23, 1615.75, 6059.9, 8394.19}; + for (int band = 0; band < x->NumBands(); ++band) { + for (int channel = 0; channel < x->NumChannels(); ++channel) { + RTC_DCHECK_GE(kBlockSize, frame.size()); + std::copy(frame.begin(), frame.end(), x->begin(band, channel)); + } + } +} + +class TestInputs { + public: + TestInputs(const EchoCanceller3Config& cfg, + size_t num_render_channels, + size_t num_capture_channels); + ~TestInputs(); + const RenderBuffer& GetRenderBuffer() { return *render_buffer_; } + rtc::ArrayView<const float, kFftLengthBy2Plus1> GetX2() { return X2_; } + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetY2() const { + return Y2_; + } + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> GetE2() const { + return E2_; + } + rtc::ArrayView<const std::vector<std::array<float, kFftLengthBy2Plus1>>> + GetH2() const { + return H2_; + } + const std::vector<bool>& GetConvergedFilters() const { + return converged_filters_; + } + void Update(); + + private: + void UpdateCurrentPowerSpectra(); + int n_ = 0; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer_; + RenderBuffer* render_buffer_; + std::array<float, kFftLengthBy2Plus1> X2_; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2_; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> H2_; + Block x_; + std::vector<bool> converged_filters_; +}; + +TestInputs::TestInputs(const EchoCanceller3Config& cfg, + size_t num_render_channels, + size_t num_capture_channels) + : render_delay_buffer_( + RenderDelayBuffer::Create(cfg, 16000, num_render_channels)), + Y2_(num_capture_channels), + E2_(num_capture_channels), + H2_(num_capture_channels, + std::vector<std::array<float, kFftLengthBy2Plus1>>( + cfg.filter.refined.length_blocks)), + x_(1, num_render_channels), + converged_filters_(num_capture_channels, true) { + render_delay_buffer_->AlignFromDelay(4); + render_buffer_ = render_delay_buffer_->GetRenderBuffer(); + for (auto& H2_ch : H2_) { + for (auto& H2_p : H2_ch) { + H2_p.fill(0.f); + } + } + for (auto& H2_p : H2_[0]) { + H2_p.fill(1.f); + } +} + +TestInputs::~TestInputs() = default; + +void TestInputs::Update() { + if (n_ % 2 == 0) { + std::fill(x_.begin(/*band=*/0, /*channel=*/0), + x_.end(/*band=*/0, /*channel=*/0), 0.f); + } else { + GetActiveFrame(&x_); + } + + render_delay_buffer_->Insert(x_); + render_delay_buffer_->PrepareCaptureProcessing(); + UpdateCurrentPowerSpectra(); + ++n_; +} + +void TestInputs::UpdateCurrentPowerSpectra() { + const SpectrumBuffer& spectrum_render_buffer = + render_buffer_->GetSpectrumBuffer(); + size_t idx = render_buffer_->Position(); + size_t prev_idx = spectrum_render_buffer.OffsetIndex(idx, 1); + auto& X2 = spectrum_render_buffer.buffer[idx][/*channel=*/0]; + auto& X2_prev = spectrum_render_buffer.buffer[prev_idx][/*channel=*/0]; + std::copy(X2.begin(), X2.end(), X2_.begin()); + for (size_t ch = 0; ch < Y2_.size(); ++ch) { + RTC_DCHECK_EQ(X2.size(), Y2_[ch].size()); + for (size_t k = 0; k < X2.size(); ++k) { + E2_[ch][k] = 0.01f * X2_prev[k]; + Y2_[ch][k] = X2[k] + E2_[ch][k]; + } + } +} + +} // namespace + +class SignalDependentErleEstimatorMultiChannel + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +INSTANTIATE_TEST_SUITE_P(MultiChannel, + SignalDependentErleEstimatorMultiChannel, + ::testing::Combine(::testing::Values(1, 2, 4), + ::testing::Values(1, 2, 4))); + +TEST_P(SignalDependentErleEstimatorMultiChannel, SweepSettings) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + EchoCanceller3Config cfg; + size_t max_length_blocks = 50; + for (size_t blocks = 1; blocks < max_length_blocks; blocks = blocks + 10) { + for (size_t delay_headroom = 0; delay_headroom < 5; ++delay_headroom) { + for (size_t num_sections = 2; num_sections < max_length_blocks; + ++num_sections) { + cfg.filter.refined.length_blocks = blocks; + cfg.filter.refined_initial.length_blocks = + std::min(cfg.filter.refined_initial.length_blocks, blocks); + cfg.delay.delay_headroom_samples = delay_headroom * kBlockSize; + cfg.erle.num_sections = num_sections; + if (EchoCanceller3Config::Validate(&cfg)) { + SignalDependentErleEstimator s(cfg, num_capture_channels); + std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle( + num_capture_channels); + for (auto& e : average_erle) { + e.fill(cfg.erle.max_l); + } + TestInputs inputs(cfg, num_render_channels, num_capture_channels); + for (size_t n = 0; n < 10; ++n) { + inputs.Update(); + s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), + inputs.GetY2(), inputs.GetE2(), average_erle, average_erle, + inputs.GetConvergedFilters()); + } + } + } + } + } +} + +TEST_P(SignalDependentErleEstimatorMultiChannel, LongerRun) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + EchoCanceller3Config cfg; + cfg.filter.refined.length_blocks = 2; + cfg.filter.refined_initial.length_blocks = 1; + cfg.delay.delay_headroom_samples = 0; + cfg.delay.hysteresis_limit_blocks = 0; + cfg.erle.num_sections = 2; + EXPECT_EQ(EchoCanceller3Config::Validate(&cfg), true); + std::vector<std::array<float, kFftLengthBy2Plus1>> average_erle( + num_capture_channels); + for (auto& e : average_erle) { + e.fill(cfg.erle.max_l); + } + SignalDependentErleEstimator s(cfg, num_capture_channels); + TestInputs inputs(cfg, num_render_channels, num_capture_channels); + for (size_t n = 0; n < 200; ++n) { + inputs.Update(); + s.Update(inputs.GetRenderBuffer(), inputs.GetH2(), inputs.GetX2(), + inputs.GetY2(), inputs.GetE2(), average_erle, average_erle, + inputs.GetConvergedFilters()); + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.cc new file mode 100644 index 0000000000..fe32ece09c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.cc @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/spectrum_buffer.h" + +#include <algorithm> + +namespace webrtc { + +SpectrumBuffer::SpectrumBuffer(size_t size, size_t num_channels) + : size(static_cast<int>(size)), + buffer(size, + std::vector<std::array<float, kFftLengthBy2Plus1>>(num_channels)) { + for (auto& channel : buffer) { + for (auto& c : channel) { + std::fill(c.begin(), c.end(), 0.f); + } + } +} + +SpectrumBuffer::~SpectrumBuffer() = default; + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.h b/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.h new file mode 100644 index 0000000000..51e1317f55 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/spectrum_buffer.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SPECTRUM_BUFFER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SPECTRUM_BUFFER_H_ + +#include <stddef.h> + +#include <array> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Struct for bundling a circular buffer of one dimensional vector objects +// together with the read and write indices. +struct SpectrumBuffer { + SpectrumBuffer(size_t size, size_t num_channels); + ~SpectrumBuffer(); + + int IncIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index < size - 1 ? index + 1 : 0; + } + + int DecIndex(int index) const { + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + return index > 0 ? index - 1 : size - 1; + } + + int OffsetIndex(int index, int offset) const { + RTC_DCHECK_GE(size, offset); + RTC_DCHECK_EQ(buffer.size(), static_cast<size_t>(size)); + RTC_DCHECK_GE(size + index + offset, 0); + return (size + index + offset) % size; + } + + void UpdateWriteIndex(int offset) { write = OffsetIndex(write, offset); } + void IncWriteIndex() { write = IncIndex(write); } + void DecWriteIndex() { write = DecIndex(write); } + void UpdateReadIndex(int offset) { read = OffsetIndex(read, offset); } + void IncReadIndex() { read = IncIndex(read); } + void DecReadIndex() { read = DecIndex(read); } + + const int size; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> buffer; + int write = 0; + int read = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SPECTRUM_BUFFER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.cc new file mode 100644 index 0000000000..4d364041b3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.cc @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/stationarity_estimator.h" + +#include <algorithm> +#include <array> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/spectrum_buffer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +namespace { +constexpr float kMinNoisePower = 10.f; +constexpr int kHangoverBlocks = kNumBlocksPerSecond / 20; +constexpr int kNBlocksAverageInitPhase = 20; +constexpr int kNBlocksInitialPhase = kNumBlocksPerSecond * 2.; +} // namespace + +StationarityEstimator::StationarityEstimator() + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)) { + Reset(); +} + +StationarityEstimator::~StationarityEstimator() = default; + +void StationarityEstimator::Reset() { + noise_.Reset(); + hangovers_.fill(0); + stationarity_flags_.fill(false); +} + +// Update just the noise estimator. Usefull until the delay is known +void StationarityEstimator::UpdateNoiseEstimator( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> spectrum) { + noise_.Update(spectrum); + data_dumper_->DumpRaw("aec3_stationarity_noise_spectrum", noise_.Spectrum()); + data_dumper_->DumpRaw("aec3_stationarity_is_block_stationary", + IsBlockStationary()); +} + +void StationarityEstimator::UpdateStationarityFlags( + const SpectrumBuffer& spectrum_buffer, + rtc::ArrayView<const float> render_reverb_contribution_spectrum, + int idx_current, + int num_lookahead) { + std::array<int, kWindowLength> indexes; + int num_lookahead_bounded = std::min(num_lookahead, kWindowLength - 1); + int idx = idx_current; + + if (num_lookahead_bounded < kWindowLength - 1) { + int num_lookback = (kWindowLength - 1) - num_lookahead_bounded; + idx = spectrum_buffer.OffsetIndex(idx_current, num_lookback); + } + // For estimating the stationarity properties of the current frame, the + // power for each band is accumulated for several consecutive spectra in the + // method EstimateBandStationarity. + // In order to avoid getting the indexes of the spectra for every band with + // its associated overhead, those indexes are stored in an array and then use + // when the estimation is done. + indexes[0] = idx; + for (size_t k = 1; k < indexes.size(); ++k) { + indexes[k] = spectrum_buffer.DecIndex(indexes[k - 1]); + } + RTC_DCHECK_EQ( + spectrum_buffer.DecIndex(indexes[kWindowLength - 1]), + spectrum_buffer.OffsetIndex(idx_current, -(num_lookahead_bounded + 1))); + + for (size_t k = 0; k < stationarity_flags_.size(); ++k) { + stationarity_flags_[k] = EstimateBandStationarity( + spectrum_buffer, render_reverb_contribution_spectrum, indexes, k); + } + UpdateHangover(); + SmoothStationaryPerFreq(); +} + +bool StationarityEstimator::IsBlockStationary() const { + float acum_stationarity = 0.f; + RTC_DCHECK_EQ(stationarity_flags_.size(), kFftLengthBy2Plus1); + for (size_t band = 0; band < stationarity_flags_.size(); ++band) { + bool st = IsBandStationary(band); + acum_stationarity += static_cast<float>(st); + } + return ((acum_stationarity * (1.f / kFftLengthBy2Plus1)) > 0.75f); +} + +bool StationarityEstimator::EstimateBandStationarity( + const SpectrumBuffer& spectrum_buffer, + rtc::ArrayView<const float> average_reverb, + const std::array<int, kWindowLength>& indexes, + size_t band) const { + constexpr float kThrStationarity = 10.f; + float acum_power = 0.f; + const int num_render_channels = + static_cast<int>(spectrum_buffer.buffer[0].size()); + const float one_by_num_channels = 1.f / num_render_channels; + for (auto idx : indexes) { + for (int ch = 0; ch < num_render_channels; ++ch) { + acum_power += spectrum_buffer.buffer[idx][ch][band] * one_by_num_channels; + } + } + acum_power += average_reverb[band]; + float noise = kWindowLength * GetStationarityPowerBand(band); + RTC_CHECK_LT(0.f, noise); + bool stationary = acum_power < kThrStationarity * noise; + data_dumper_->DumpRaw("aec3_stationarity_long_ratio", acum_power / noise); + return stationary; +} + +bool StationarityEstimator::AreAllBandsStationary() { + for (auto b : stationarity_flags_) { + if (!b) + return false; + } + return true; +} + +void StationarityEstimator::UpdateHangover() { + bool reduce_hangover = AreAllBandsStationary(); + for (size_t k = 0; k < stationarity_flags_.size(); ++k) { + if (!stationarity_flags_[k]) { + hangovers_[k] = kHangoverBlocks; + } else if (reduce_hangover) { + hangovers_[k] = std::max(hangovers_[k] - 1, 0); + } + } +} + +void StationarityEstimator::SmoothStationaryPerFreq() { + std::array<bool, kFftLengthBy2Plus1> all_ahead_stationary_smooth; + for (size_t k = 1; k < kFftLengthBy2Plus1 - 1; ++k) { + all_ahead_stationary_smooth[k] = stationarity_flags_[k - 1] && + stationarity_flags_[k] && + stationarity_flags_[k + 1]; + } + + all_ahead_stationary_smooth[0] = all_ahead_stationary_smooth[1]; + all_ahead_stationary_smooth[kFftLengthBy2Plus1 - 1] = + all_ahead_stationary_smooth[kFftLengthBy2Plus1 - 2]; + + stationarity_flags_ = all_ahead_stationary_smooth; +} + +std::atomic<int> StationarityEstimator::instance_count_(0); + +StationarityEstimator::NoiseSpectrum::NoiseSpectrum() { + Reset(); +} + +StationarityEstimator::NoiseSpectrum::~NoiseSpectrum() = default; + +void StationarityEstimator::NoiseSpectrum::Reset() { + block_counter_ = 0; + noise_spectrum_.fill(kMinNoisePower); +} + +void StationarityEstimator::NoiseSpectrum::Update( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> spectrum) { + RTC_DCHECK_LE(1, spectrum[0].size()); + const int num_render_channels = static_cast<int>(spectrum.size()); + + std::array<float, kFftLengthBy2Plus1> avg_spectrum_data; + rtc::ArrayView<const float> avg_spectrum; + if (num_render_channels == 1) { + avg_spectrum = spectrum[0]; + } else { + // For multiple channels, average the channel spectra before passing to the + // noise spectrum estimator. + avg_spectrum = avg_spectrum_data; + std::copy(spectrum[0].begin(), spectrum[0].end(), + avg_spectrum_data.begin()); + for (int ch = 1; ch < num_render_channels; ++ch) { + for (size_t k = 1; k < kFftLengthBy2Plus1; ++k) { + avg_spectrum_data[k] += spectrum[ch][k]; + } + } + + const float one_by_num_channels = 1.f / num_render_channels; + for (size_t k = 1; k < kFftLengthBy2Plus1; ++k) { + avg_spectrum_data[k] *= one_by_num_channels; + } + } + + ++block_counter_; + float alpha = GetAlpha(); + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (block_counter_ <= kNBlocksAverageInitPhase) { + noise_spectrum_[k] += (1.f / kNBlocksAverageInitPhase) * avg_spectrum[k]; + } else { + noise_spectrum_[k] = + UpdateBandBySmoothing(avg_spectrum[k], noise_spectrum_[k], alpha); + } + } +} + +float StationarityEstimator::NoiseSpectrum::GetAlpha() const { + constexpr float kAlpha = 0.004f; + constexpr float kAlphaInit = 0.04f; + constexpr float kTiltAlpha = (kAlphaInit - kAlpha) / kNBlocksInitialPhase; + + if (block_counter_ > (kNBlocksInitialPhase + kNBlocksAverageInitPhase)) { + return kAlpha; + } else { + return kAlphaInit - + kTiltAlpha * (block_counter_ - kNBlocksAverageInitPhase); + } +} + +float StationarityEstimator::NoiseSpectrum::UpdateBandBySmoothing( + float power_band, + float power_band_noise, + float alpha) const { + float power_band_noise_updated = power_band_noise; + if (power_band_noise < power_band) { + RTC_DCHECK_GT(power_band, 0.f); + float alpha_inc = alpha * (power_band_noise / power_band); + if (block_counter_ > kNBlocksInitialPhase) { + if (10.f * power_band_noise < power_band) { + alpha_inc *= 0.1f; + } + } + power_band_noise_updated += alpha_inc * (power_band - power_band_noise); + } else { + power_band_noise_updated += alpha * (power_band - power_band_noise); + power_band_noise_updated = + std::max(power_band_noise_updated, kMinNoisePower); + } + return power_band_noise_updated; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.h new file mode 100644 index 0000000000..8bcd3b789e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/stationarity_estimator.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_STATIONARITY_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_STATIONARITY_ESTIMATOR_H_ + +#include <stddef.h> + +#include <array> +#include <atomic> +#include <memory> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" // kFftLengthBy2Plus1... +#include "modules/audio_processing/aec3/reverb_model.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +class ApmDataDumper; +struct SpectrumBuffer; + +class StationarityEstimator { + public: + StationarityEstimator(); + ~StationarityEstimator(); + + // Reset the stationarity estimator. + void Reset(); + + // Update just the noise estimator. Usefull until the delay is known + void UpdateNoiseEstimator( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> spectrum); + + // Update the flag indicating whether this current frame is stationary. For + // getting a more robust estimation, it looks at future and/or past frames. + void UpdateStationarityFlags( + const SpectrumBuffer& spectrum_buffer, + rtc::ArrayView<const float> render_reverb_contribution_spectrum, + int idx_current, + int num_lookahead); + + // Returns true if the current band is stationary. + bool IsBandStationary(size_t band) const { + return stationarity_flags_[band] && (hangovers_[band] == 0); + } + + // Returns true if the current block is estimated as stationary. + bool IsBlockStationary() const; + + private: + static constexpr int kWindowLength = 13; + // Returns the power of the stationary noise spectrum at a band. + float GetStationarityPowerBand(size_t k) const { return noise_.Power(k); } + + // Get an estimation of the stationarity for the current band by looking + // at the past/present/future available data. + bool EstimateBandStationarity(const SpectrumBuffer& spectrum_buffer, + rtc::ArrayView<const float> average_reverb, + const std::array<int, kWindowLength>& indexes, + size_t band) const; + + // True if all bands at the current point are stationary. + bool AreAllBandsStationary(); + + // Update the hangover depending on the stationary status of the current + // frame. + void UpdateHangover(); + + // Smooth the stationarity detection by looking at neighbouring frequency + // bands. + void SmoothStationaryPerFreq(); + + class NoiseSpectrum { + public: + NoiseSpectrum(); + ~NoiseSpectrum(); + + // Reset the noise power spectrum estimate state. + void Reset(); + + // Update the noise power spectrum with a new frame. + void Update( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> spectrum); + + // Get the noise estimation power spectrum. + rtc::ArrayView<const float> Spectrum() const { return noise_spectrum_; } + + // Get the noise power spectrum at a certain band. + float Power(size_t band) const { + RTC_DCHECK_LT(band, noise_spectrum_.size()); + return noise_spectrum_[band]; + } + + private: + // Get the update coefficient to be used for the current frame. + float GetAlpha() const; + + // Update the noise power spectrum at a certain band with a new frame. + float UpdateBandBySmoothing(float power_band, + float power_band_noise, + float alpha) const; + std::array<float, kFftLengthBy2Plus1> noise_spectrum_; + size_t block_counter_; + }; + + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + NoiseSpectrum noise_; + std::array<int, kFftLengthBy2Plus1> hangovers_; + std::array<bool, kFftLengthBy2Plus1> stationarity_flags_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_STATIONARITY_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.cc b/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.cc new file mode 100644 index 0000000000..dc7f92fd99 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.cc @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subband_erle_estimator.h" + +#include <algorithm> +#include <functional> + +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { + +namespace { + +constexpr float kX2BandEnergyThreshold = 44015068.0f; +constexpr int kBlocksToHoldErle = 100; +constexpr int kBlocksForOnsetDetection = kBlocksToHoldErle + 150; +constexpr int kPointsToAccumulate = 6; + +std::array<float, kFftLengthBy2Plus1> SetMaxErleBands(float max_erle_l, + float max_erle_h) { + std::array<float, kFftLengthBy2Plus1> max_erle; + std::fill(max_erle.begin(), max_erle.begin() + kFftLengthBy2 / 2, max_erle_l); + std::fill(max_erle.begin() + kFftLengthBy2 / 2, max_erle.end(), max_erle_h); + return max_erle; +} + +bool EnableMinErleDuringOnsets() { + return !field_trial::IsEnabled("WebRTC-Aec3MinErleDuringOnsetsKillSwitch"); +} + +} // namespace + +SubbandErleEstimator::SubbandErleEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels) + : use_onset_detection_(config.erle.onset_detection), + min_erle_(config.erle.min), + max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)), + use_min_erle_during_onsets_(EnableMinErleDuringOnsets()), + accum_spectra_(num_capture_channels), + erle_(num_capture_channels), + erle_onset_compensated_(num_capture_channels), + erle_unbounded_(num_capture_channels), + erle_during_onsets_(num_capture_channels), + coming_onset_(num_capture_channels), + hold_counters_(num_capture_channels) { + Reset(); +} + +SubbandErleEstimator::~SubbandErleEstimator() = default; + +void SubbandErleEstimator::Reset() { + const size_t num_capture_channels = erle_.size(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + erle_[ch].fill(min_erle_); + erle_onset_compensated_[ch].fill(min_erle_); + erle_unbounded_[ch].fill(min_erle_); + erle_during_onsets_[ch].fill(min_erle_); + coming_onset_[ch].fill(true); + hold_counters_[ch].fill(0); + } + ResetAccumulatedSpectra(); +} + +void SubbandErleEstimator::Update( + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters) { + UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters); + UpdateBands(converged_filters); + + if (use_onset_detection_) { + DecreaseErlePerBandForLowRenderSignals(); + } + + const size_t num_capture_channels = erle_.size(); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + auto& erle = erle_[ch]; + erle[0] = erle[1]; + erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1]; + + auto& erle_oc = erle_onset_compensated_[ch]; + erle_oc[0] = erle_oc[1]; + erle_oc[kFftLengthBy2] = erle_oc[kFftLengthBy2 - 1]; + + auto& erle_u = erle_unbounded_[ch]; + erle_u[0] = erle_u[1]; + erle_u[kFftLengthBy2] = erle_u[kFftLengthBy2 - 1]; + } +} + +void SubbandErleEstimator::Dump( + const std::unique_ptr<ApmDataDumper>& data_dumper) const { + data_dumper->DumpRaw("aec3_erle_onset", ErleDuringOnsets()[0]); +} + +void SubbandErleEstimator::UpdateBands( + const std::vector<bool>& converged_filters) { + const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size()); + for (int ch = 0; ch < num_capture_channels; ++ch) { + // Note that the use of the converged_filter flag already imposed + // a minimum of the erle that can be estimated as that flag would + // be false if the filter is performing poorly. + if (!converged_filters[ch]) { + continue; + } + + if (accum_spectra_.num_points[ch] != kPointsToAccumulate) { + continue; + } + + std::array<float, kFftLengthBy2> new_erle; + std::array<bool, kFftLengthBy2> is_erle_updated; + is_erle_updated.fill(false); + + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (accum_spectra_.E2[ch][k] > 0.f) { + new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k]; + is_erle_updated[k] = true; + } + } + + if (use_onset_detection_) { + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) { + if (coming_onset_[ch][k]) { + coming_onset_[ch][k] = false; + if (!use_min_erle_during_onsets_) { + float alpha = + new_erle[k] < erle_during_onsets_[ch][k] ? 0.3f : 0.15f; + erle_during_onsets_[ch][k] = rtc::SafeClamp( + erle_during_onsets_[ch][k] + + alpha * (new_erle[k] - erle_during_onsets_[ch][k]), + min_erle_, max_erle_[k]); + } + } + hold_counters_[ch][k] = kBlocksForOnsetDetection; + } + } + } + + auto update_erle_band = [](float& erle, float new_erle, + bool low_render_energy, float min_erle, + float max_erle) { + float alpha = 0.05f; + if (new_erle < erle) { + alpha = low_render_energy ? 0.f : 0.1f; + } + erle = + rtc::SafeClamp(erle + alpha * (new_erle - erle), min_erle, max_erle); + }; + + for (size_t k = 1; k < kFftLengthBy2; ++k) { + if (is_erle_updated[k]) { + const bool low_render_energy = accum_spectra_.low_render_energy[ch][k]; + update_erle_band(erle_[ch][k], new_erle[k], low_render_energy, + min_erle_, max_erle_[k]); + if (use_onset_detection_) { + update_erle_band(erle_onset_compensated_[ch][k], new_erle[k], + low_render_energy, min_erle_, max_erle_[k]); + } + + // Virtually unbounded ERLE. + constexpr float kUnboundedErleMax = 100000.0f; + update_erle_band(erle_unbounded_[ch][k], new_erle[k], low_render_energy, + min_erle_, kUnboundedErleMax); + } + } + } +} + +void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() { + const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size()); + for (int ch = 0; ch < num_capture_channels; ++ch) { + for (size_t k = 1; k < kFftLengthBy2; ++k) { + --hold_counters_[ch][k]; + if (hold_counters_[ch][k] <= + (kBlocksForOnsetDetection - kBlocksToHoldErle)) { + if (erle_onset_compensated_[ch][k] > erle_during_onsets_[ch][k]) { + erle_onset_compensated_[ch][k] = + std::max(erle_during_onsets_[ch][k], + 0.97f * erle_onset_compensated_[ch][k]); + RTC_DCHECK_LE(min_erle_, erle_onset_compensated_[ch][k]); + } + if (hold_counters_[ch][k] <= 0) { + coming_onset_[ch][k] = true; + hold_counters_[ch][k] = 0; + } + } + } + } +} + +void SubbandErleEstimator::ResetAccumulatedSpectra() { + for (size_t ch = 0; ch < erle_during_onsets_.size(); ++ch) { + accum_spectra_.Y2[ch].fill(0.f); + accum_spectra_.E2[ch].fill(0.f); + accum_spectra_.num_points[ch] = 0; + accum_spectra_.low_render_energy[ch].fill(false); + } +} + +void SubbandErleEstimator::UpdateAccumulatedSpectra( + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters) { + auto& st = accum_spectra_; + RTC_DCHECK_EQ(st.E2.size(), E2.size()); + RTC_DCHECK_EQ(st.E2.size(), E2.size()); + const int num_capture_channels = static_cast<int>(Y2.size()); + for (int ch = 0; ch < num_capture_channels; ++ch) { + // Note that the use of the converged_filter flag already imposed + // a minimum of the erle that can be estimated as that flag would + // be false if the filter is performing poorly. + if (!converged_filters[ch]) { + continue; + } + + if (st.num_points[ch] == kPointsToAccumulate) { + st.num_points[ch] = 0; + st.Y2[ch].fill(0.f); + st.E2[ch].fill(0.f); + st.low_render_energy[ch].fill(false); + } + + std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(), + st.Y2[ch].begin(), std::plus<float>()); + std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(), + st.E2[ch].begin(), std::plus<float>()); + + for (size_t k = 0; k < X2.size(); ++k) { + st.low_render_energy[ch][k] = + st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold; + } + + ++st.num_points[ch]; + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.h b/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.h new file mode 100644 index 0000000000..8bf9c4d645 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subband_erle_estimator.h @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBBAND_ERLE_ESTIMATOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUBBAND_ERLE_ESTIMATOR_H_ + +#include <stddef.h> + +#include <array> +#include <memory> +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +// Estimates the echo return loss enhancement for each frequency subband. +class SubbandErleEstimator { + public: + SubbandErleEstimator(const EchoCanceller3Config& config, + size_t num_capture_channels); + ~SubbandErleEstimator(); + + // Resets the ERLE estimator. + void Reset(); + + // Updates the ERLE estimate. + void Update(rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters); + + // Returns the ERLE estimate. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Erle( + bool onset_compensated) const { + return onset_compensated && use_onset_detection_ ? erle_onset_compensated_ + : erle_; + } + + // Returns the non-capped ERLE estimate. + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleUnbounded() + const { + return erle_unbounded_; + } + + // Returns the ERLE estimate at onsets (only used for testing). + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> ErleDuringOnsets() + const { + return erle_during_onsets_; + } + + void Dump(const std::unique_ptr<ApmDataDumper>& data_dumper) const; + + private: + struct AccumulatedSpectra { + explicit AccumulatedSpectra(size_t num_capture_channels) + : Y2(num_capture_channels), + E2(num_capture_channels), + low_render_energy(num_capture_channels), + num_points(num_capture_channels) {} + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2; + std::vector<std::array<bool, kFftLengthBy2Plus1>> low_render_energy; + std::vector<int> num_points; + }; + + void UpdateAccumulatedSpectra( + rtc::ArrayView<const float, kFftLengthBy2Plus1> X2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, + const std::vector<bool>& converged_filters); + + void ResetAccumulatedSpectra(); + + void UpdateBands(const std::vector<bool>& converged_filters); + void DecreaseErlePerBandForLowRenderSignals(); + + const bool use_onset_detection_; + const float min_erle_; + const std::array<float, kFftLengthBy2Plus1> max_erle_; + const bool use_min_erle_during_onsets_; + AccumulatedSpectra accum_spectra_; + // ERLE without special handling of render onsets. + std::vector<std::array<float, kFftLengthBy2Plus1>> erle_; + // ERLE lowered during render onsets. + std::vector<std::array<float, kFftLengthBy2Plus1>> erle_onset_compensated_; + std::vector<std::array<float, kFftLengthBy2Plus1>> erle_unbounded_; + // Estimation of ERLE during render onsets. + std::vector<std::array<float, kFftLengthBy2Plus1>> erle_during_onsets_; + std::vector<std::array<bool, kFftLengthBy2Plus1>> coming_onset_; + std::vector<std::array<int, kFftLengthBy2Plus1>> hold_counters_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUBBAND_ERLE_ESTIMATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.cc b/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.cc new file mode 100644 index 0000000000..2aa400c3af --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.cc @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subband_nearend_detector.h" + +#include <numeric> + +namespace webrtc { +SubbandNearendDetector::SubbandNearendDetector( + const EchoCanceller3Config::Suppressor::SubbandNearendDetection& config, + size_t num_capture_channels) + : config_(config), + num_capture_channels_(num_capture_channels), + nearend_smoothers_(num_capture_channels_, + aec3::MovingAverage(kFftLengthBy2Plus1, + config_.nearend_average_blocks)), + one_over_subband_length1_( + 1.f / (config_.subband1.high - config_.subband1.low + 1)), + one_over_subband_length2_( + 1.f / (config_.subband2.high - config_.subband2.low + 1)) {} + +void SubbandNearendDetector::Update( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + bool initial_state) { + nearend_state_ = false; + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + const std::array<float, kFftLengthBy2Plus1>& noise = + comfort_noise_spectrum[ch]; + std::array<float, kFftLengthBy2Plus1> nearend; + nearend_smoothers_[ch].Average(nearend_spectrum[ch], nearend); + + // Noise power of the first region. + float noise_power = + std::accumulate(noise.begin() + config_.subband1.low, + noise.begin() + config_.subband1.high + 1, 0.f) * + one_over_subband_length1_; + + // Nearend power of the first region. + float nearend_power_subband1 = + std::accumulate(nearend.begin() + config_.subband1.low, + nearend.begin() + config_.subband1.high + 1, 0.f) * + one_over_subband_length1_; + + // Nearend power of the second region. + float nearend_power_subband2 = + std::accumulate(nearend.begin() + config_.subband2.low, + nearend.begin() + config_.subband2.high + 1, 0.f) * + one_over_subband_length2_; + + // One channel is sufficient to trigger nearend state. + nearend_state_ = + nearend_state_ || + (nearend_power_subband1 < + config_.nearend_threshold * nearend_power_subband2 && + (nearend_power_subband1 > config_.snr_threshold * noise_power)); + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.h b/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.h new file mode 100644 index 0000000000..8357edb65f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subband_nearend_detector.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBBAND_NEAREND_DETECTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUBBAND_NEAREND_DETECTOR_H_ + +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/moving_average.h" +#include "modules/audio_processing/aec3/nearend_detector.h" + +namespace webrtc { +// Class for selecting whether the suppressor is in the nearend or echo state. +class SubbandNearendDetector : public NearendDetector { + public: + SubbandNearendDetector( + const EchoCanceller3Config::Suppressor::SubbandNearendDetection& config, + size_t num_capture_channels); + + // Returns whether the current state is the nearend state. + bool IsNearendState() const override { return nearend_state_; } + + // Updates the state selection based on latest spectral estimates. + void Update(rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + bool initial_state) override; + + private: + const EchoCanceller3Config::Suppressor::SubbandNearendDetection config_; + const size_t num_capture_channels_; + std::vector<aec3::MovingAverage> nearend_smoothers_; + const float one_over_subband_length1_; + const float one_over_subband_length2_; + bool nearend_state_ = false; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUBBAND_NEAREND_DETECTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.cc b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.cc new file mode 100644 index 0000000000..aa36bb272a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.cc @@ -0,0 +1,364 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subtractor.h" + +#include <algorithm> +#include <utility> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { + +namespace { + +bool UseCoarseFilterResetHangover() { + return !field_trial::IsEnabled( + "WebRTC-Aec3CoarseFilterResetHangoverKillSwitch"); +} + +void PredictionError(const Aec3Fft& fft, + const FftData& S, + rtc::ArrayView<const float> y, + std::array<float, kBlockSize>* e, + std::array<float, kBlockSize>* s) { + std::array<float, kFftLength> tmp; + fft.Ifft(S, &tmp); + constexpr float kScale = 1.0f / kFftLengthBy2; + std::transform(y.begin(), y.end(), tmp.begin() + kFftLengthBy2, e->begin(), + [&](float a, float b) { return a - b * kScale; }); + + if (s) { + for (size_t k = 0; k < s->size(); ++k) { + (*s)[k] = kScale * tmp[k + kFftLengthBy2]; + } + } +} + +void ScaleFilterOutput(rtc::ArrayView<const float> y, + float factor, + rtc::ArrayView<float> e, + rtc::ArrayView<float> s) { + RTC_DCHECK_EQ(y.size(), e.size()); + RTC_DCHECK_EQ(y.size(), s.size()); + for (size_t k = 0; k < y.size(); ++k) { + s[k] *= factor; + e[k] = y[k] - s[k]; + } +} + +} // namespace + +Subtractor::Subtractor(const EchoCanceller3Config& config, + size_t num_render_channels, + size_t num_capture_channels, + ApmDataDumper* data_dumper, + Aec3Optimization optimization) + : fft_(), + data_dumper_(data_dumper), + optimization_(optimization), + config_(config), + num_capture_channels_(num_capture_channels), + use_coarse_filter_reset_hangover_(UseCoarseFilterResetHangover()), + refined_filters_(num_capture_channels_), + coarse_filter_(num_capture_channels_), + refined_gains_(num_capture_channels_), + coarse_gains_(num_capture_channels_), + filter_misadjustment_estimators_(num_capture_channels_), + poor_coarse_filter_counters_(num_capture_channels_, 0), + coarse_filter_reset_hangover_(num_capture_channels_, 0), + refined_frequency_responses_( + num_capture_channels_, + std::vector<std::array<float, kFftLengthBy2Plus1>>( + std::max(config_.filter.refined_initial.length_blocks, + config_.filter.refined.length_blocks), + std::array<float, kFftLengthBy2Plus1>())), + refined_impulse_responses_( + num_capture_channels_, + std::vector<float>(GetTimeDomainLength(std::max( + config_.filter.refined_initial.length_blocks, + config_.filter.refined.length_blocks)), + 0.f)), + coarse_impulse_responses_(0) { + // Set up the storing of coarse impulse responses if data dumping is + // available. + if (ApmDataDumper::IsAvailable()) { + coarse_impulse_responses_.resize(num_capture_channels_); + const size_t filter_size = GetTimeDomainLength( + std::max(config_.filter.coarse_initial.length_blocks, + config_.filter.coarse.length_blocks)); + for (std::vector<float>& impulse_response : coarse_impulse_responses_) { + impulse_response.resize(filter_size, 0.f); + } + } + + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + refined_filters_[ch] = std::make_unique<AdaptiveFirFilter>( + config_.filter.refined.length_blocks, + config_.filter.refined_initial.length_blocks, + config.filter.config_change_duration_blocks, num_render_channels, + optimization, data_dumper_); + + coarse_filter_[ch] = std::make_unique<AdaptiveFirFilter>( + config_.filter.coarse.length_blocks, + config_.filter.coarse_initial.length_blocks, + config.filter.config_change_duration_blocks, num_render_channels, + optimization, data_dumper_); + refined_gains_[ch] = std::make_unique<RefinedFilterUpdateGain>( + config_.filter.refined_initial, + config_.filter.config_change_duration_blocks); + coarse_gains_[ch] = std::make_unique<CoarseFilterUpdateGain>( + config_.filter.coarse_initial, + config.filter.config_change_duration_blocks); + } + + RTC_DCHECK(data_dumper_); + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + for (auto& H2_k : refined_frequency_responses_[ch]) { + H2_k.fill(0.f); + } + } +} + +Subtractor::~Subtractor() = default; + +void Subtractor::HandleEchoPathChange( + const EchoPathVariability& echo_path_variability) { + const auto full_reset = [&]() { + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + refined_filters_[ch]->HandleEchoPathChange(); + coarse_filter_[ch]->HandleEchoPathChange(); + refined_gains_[ch]->HandleEchoPathChange(echo_path_variability); + coarse_gains_[ch]->HandleEchoPathChange(); + refined_gains_[ch]->SetConfig(config_.filter.refined_initial, true); + coarse_gains_[ch]->SetConfig(config_.filter.coarse_initial, true); + refined_filters_[ch]->SetSizePartitions( + config_.filter.refined_initial.length_blocks, true); + coarse_filter_[ch]->SetSizePartitions( + config_.filter.coarse_initial.length_blocks, true); + } + }; + + if (echo_path_variability.delay_change != + EchoPathVariability::DelayAdjustment::kNone) { + full_reset(); + } + + if (echo_path_variability.gain_change) { + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + refined_gains_[ch]->HandleEchoPathChange(echo_path_variability); + } + } +} + +void Subtractor::ExitInitialState() { + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + refined_gains_[ch]->SetConfig(config_.filter.refined, false); + coarse_gains_[ch]->SetConfig(config_.filter.coarse, false); + refined_filters_[ch]->SetSizePartitions( + config_.filter.refined.length_blocks, false); + coarse_filter_[ch]->SetSizePartitions(config_.filter.coarse.length_blocks, + false); + } +} + +void Subtractor::Process(const RenderBuffer& render_buffer, + const Block& capture, + const RenderSignalAnalyzer& render_signal_analyzer, + const AecState& aec_state, + rtc::ArrayView<SubtractorOutput> outputs) { + RTC_DCHECK_EQ(num_capture_channels_, capture.NumChannels()); + + // Compute the render powers. + const bool same_filter_sizes = refined_filters_[0]->SizePartitions() == + coarse_filter_[0]->SizePartitions(); + std::array<float, kFftLengthBy2Plus1> X2_refined; + std::array<float, kFftLengthBy2Plus1> X2_coarse_data; + auto& X2_coarse = same_filter_sizes ? X2_refined : X2_coarse_data; + if (same_filter_sizes) { + render_buffer.SpectralSum(refined_filters_[0]->SizePartitions(), + &X2_refined); + } else if (refined_filters_[0]->SizePartitions() > + coarse_filter_[0]->SizePartitions()) { + render_buffer.SpectralSums(coarse_filter_[0]->SizePartitions(), + refined_filters_[0]->SizePartitions(), + &X2_coarse, &X2_refined); + } else { + render_buffer.SpectralSums(refined_filters_[0]->SizePartitions(), + coarse_filter_[0]->SizePartitions(), &X2_refined, + &X2_coarse); + } + + // Process all capture channels + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + SubtractorOutput& output = outputs[ch]; + rtc::ArrayView<const float> y = capture.View(/*band=*/0, ch); + FftData& E_refined = output.E_refined; + FftData E_coarse; + std::array<float, kBlockSize>& e_refined = output.e_refined; + std::array<float, kBlockSize>& e_coarse = output.e_coarse; + + FftData S; + FftData& G = S; + + // Form the outputs of the refined and coarse filters. + refined_filters_[ch]->Filter(render_buffer, &S); + PredictionError(fft_, S, y, &e_refined, &output.s_refined); + + coarse_filter_[ch]->Filter(render_buffer, &S); + PredictionError(fft_, S, y, &e_coarse, &output.s_coarse); + + // Compute the signal powers in the subtractor output. + output.ComputeMetrics(y); + + // Adjust the filter if needed. + bool refined_filters_adjusted = false; + filter_misadjustment_estimators_[ch].Update(output); + if (filter_misadjustment_estimators_[ch].IsAdjustmentNeeded()) { + float scale = filter_misadjustment_estimators_[ch].GetMisadjustment(); + refined_filters_[ch]->ScaleFilter(scale); + for (auto& h_k : refined_impulse_responses_[ch]) { + h_k *= scale; + } + ScaleFilterOutput(y, scale, e_refined, output.s_refined); + filter_misadjustment_estimators_[ch].Reset(); + refined_filters_adjusted = true; + } + + // Compute the FFts of the refined and coarse filter outputs. + fft_.ZeroPaddedFft(e_refined, Aec3Fft::Window::kHanning, &E_refined); + fft_.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kHanning, &E_coarse); + + // Compute spectra for future use. + E_coarse.Spectrum(optimization_, output.E2_coarse); + E_refined.Spectrum(optimization_, output.E2_refined); + + // Update the refined filter. + if (!refined_filters_adjusted) { + // Do not allow the performance of the coarse filter to affect the + // adaptation speed of the refined filter just after the coarse filter has + // been reset. + const bool disallow_leakage_diverged = + coarse_filter_reset_hangover_[ch] > 0 && + use_coarse_filter_reset_hangover_; + + std::array<float, kFftLengthBy2Plus1> erl; + ComputeErl(optimization_, refined_frequency_responses_[ch], erl); + refined_gains_[ch]->Compute(X2_refined, render_signal_analyzer, output, + erl, refined_filters_[ch]->SizePartitions(), + aec_state.SaturatedCapture(), + disallow_leakage_diverged, &G); + } else { + G.re.fill(0.f); + G.im.fill(0.f); + } + refined_filters_[ch]->Adapt(render_buffer, G, + &refined_impulse_responses_[ch]); + refined_filters_[ch]->ComputeFrequencyResponse( + &refined_frequency_responses_[ch]); + + if (ch == 0) { + data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.re); + data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.im); + } + + // Update the coarse filter. + poor_coarse_filter_counters_[ch] = + output.e2_refined < output.e2_coarse + ? poor_coarse_filter_counters_[ch] + 1 + : 0; + if (poor_coarse_filter_counters_[ch] < 5) { + coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_coarse, + coarse_filter_[ch]->SizePartitions(), + aec_state.SaturatedCapture(), &G); + coarse_filter_reset_hangover_[ch] = + std::max(coarse_filter_reset_hangover_[ch] - 1, 0); + } else { + poor_coarse_filter_counters_[ch] = 0; + coarse_filter_[ch]->SetFilter(refined_filters_[ch]->SizePartitions(), + refined_filters_[ch]->GetFilter()); + coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_refined, + coarse_filter_[ch]->SizePartitions(), + aec_state.SaturatedCapture(), &G); + coarse_filter_reset_hangover_[ch] = + config_.filter.coarse_reset_hangover_blocks; + } + + if (ApmDataDumper::IsAvailable()) { + RTC_DCHECK_LT(ch, coarse_impulse_responses_.size()); + coarse_filter_[ch]->Adapt(render_buffer, G, + &coarse_impulse_responses_[ch]); + } else { + coarse_filter_[ch]->Adapt(render_buffer, G); + } + + if (ch == 0) { + data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.re); + data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.im); + filter_misadjustment_estimators_[ch].Dump(data_dumper_); + DumpFilters(); + } + + std::for_each(e_refined.begin(), e_refined.end(), + [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); + + if (ch == 0) { + data_dumper_->DumpWav("aec3_refined_filters_output", kBlockSize, + &e_refined[0], 16000, 1); + data_dumper_->DumpWav("aec3_coarse_filter_output", kBlockSize, + &e_coarse[0], 16000, 1); + } + } +} + +void Subtractor::FilterMisadjustmentEstimator::Update( + const SubtractorOutput& output) { + e2_acum_ += output.e2_refined; + y2_acum_ += output.y2; + if (++n_blocks_acum_ == n_blocks_) { + if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) { + float update = (e2_acum_ / y2_acum_); + if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) { + // Duration equal to blockSizeMs * n_blocks_ * 4. + overhang_ = 4; + } else { + overhang_ = std::max(overhang_ - 1, 0); + } + + if ((update < inv_misadjustment_) || (overhang_ > 0)) { + inv_misadjustment_ += 0.1f * (update - inv_misadjustment_); + } + } + e2_acum_ = 0.f; + y2_acum_ = 0.f; + n_blocks_acum_ = 0; + } +} + +void Subtractor::FilterMisadjustmentEstimator::Reset() { + e2_acum_ = 0.f; + y2_acum_ = 0.f; + n_blocks_acum_ = 0; + inv_misadjustment_ = 0.f; + overhang_ = 0.f; +} + +void Subtractor::FilterMisadjustmentEstimator::Dump( + ApmDataDumper* data_dumper) const { + data_dumper->DumpRaw("aec3_inv_misadjustment_factor", inv_misadjustment_); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.h b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.h new file mode 100644 index 0000000000..86159a3442 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor.h @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_H_ + +#include <math.h> +#include <stddef.h> + +#include <array> +#include <vector> + +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/coarse_filter_update_gain.h" +#include "modules/audio_processing/aec3/echo_path_variability.h" +#include "modules/audio_processing/aec3/refined_filter_update_gain.h" +#include "modules/audio_processing/aec3/render_buffer.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +// Proves linear echo cancellation functionality +class Subtractor { + public: + Subtractor(const EchoCanceller3Config& config, + size_t num_render_channels, + size_t num_capture_channels, + ApmDataDumper* data_dumper, + Aec3Optimization optimization); + ~Subtractor(); + Subtractor(const Subtractor&) = delete; + Subtractor& operator=(const Subtractor&) = delete; + + // Performs the echo subtraction. + void Process(const RenderBuffer& render_buffer, + const Block& capture, + const RenderSignalAnalyzer& render_signal_analyzer, + const AecState& aec_state, + rtc::ArrayView<SubtractorOutput> outputs); + + void HandleEchoPathChange(const EchoPathVariability& echo_path_variability); + + // Exits the initial state. + void ExitInitialState(); + + // Returns the block-wise frequency responses for the refined adaptive + // filters. + const std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>>& + FilterFrequencyResponses() const { + return refined_frequency_responses_; + } + + // Returns the estimates of the impulse responses for the refined adaptive + // filters. + const std::vector<std::vector<float>>& FilterImpulseResponses() const { + return refined_impulse_responses_; + } + + void DumpFilters() { + data_dumper_->DumpRaw( + "aec3_subtractor_h_refined", + rtc::ArrayView<const float>( + refined_impulse_responses_[0].data(), + GetTimeDomainLength( + refined_filters_[0]->max_filter_size_partitions()))); + if (ApmDataDumper::IsAvailable()) { + RTC_DCHECK_GT(coarse_impulse_responses_.size(), 0); + data_dumper_->DumpRaw( + "aec3_subtractor_h_coarse", + rtc::ArrayView<const float>( + coarse_impulse_responses_[0].data(), + GetTimeDomainLength( + coarse_filter_[0]->max_filter_size_partitions()))); + } + + refined_filters_[0]->DumpFilter("aec3_subtractor_H_refined"); + coarse_filter_[0]->DumpFilter("aec3_subtractor_H_coarse"); + } + + private: + class FilterMisadjustmentEstimator { + public: + FilterMisadjustmentEstimator() = default; + ~FilterMisadjustmentEstimator() = default; + // Update the misadjustment estimator. + void Update(const SubtractorOutput& output); + // GetMisadjustment() Returns a recommended scale for the filter so the + // prediction error energy gets closer to the energy that is seen at the + // microphone input. + float GetMisadjustment() const { + RTC_DCHECK_GT(inv_misadjustment_, 0.0f); + // It is not aiming to adjust all the estimated mismatch. Instead, + // it adjusts half of that estimated mismatch. + return 2.f / sqrtf(inv_misadjustment_); + } + // Returns true if the prediciton error energy is significantly larger + // than the microphone signal energy and, therefore, an adjustment is + // recommended. + bool IsAdjustmentNeeded() const { return inv_misadjustment_ > 10.f; } + void Reset(); + void Dump(ApmDataDumper* data_dumper) const; + + private: + const int n_blocks_ = 4; + int n_blocks_acum_ = 0; + float e2_acum_ = 0.f; + float y2_acum_ = 0.f; + float inv_misadjustment_ = 0.f; + int overhang_ = 0.f; + }; + + const Aec3Fft fft_; + ApmDataDumper* data_dumper_; + const Aec3Optimization optimization_; + const EchoCanceller3Config config_; + const size_t num_capture_channels_; + const bool use_coarse_filter_reset_hangover_; + + std::vector<std::unique_ptr<AdaptiveFirFilter>> refined_filters_; + std::vector<std::unique_ptr<AdaptiveFirFilter>> coarse_filter_; + std::vector<std::unique_ptr<RefinedFilterUpdateGain>> refined_gains_; + std::vector<std::unique_ptr<CoarseFilterUpdateGain>> coarse_gains_; + std::vector<FilterMisadjustmentEstimator> filter_misadjustment_estimators_; + std::vector<size_t> poor_coarse_filter_counters_; + std::vector<int> coarse_filter_reset_hangover_; + std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> + refined_frequency_responses_; + std::vector<std::vector<float>> refined_impulse_responses_; + std::vector<std::vector<float>> coarse_impulse_responses_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.cc b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.cc new file mode 100644 index 0000000000..ed80101f06 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subtractor_output.h" + +#include <numeric> + +namespace webrtc { + +SubtractorOutput::SubtractorOutput() = default; +SubtractorOutput::~SubtractorOutput() = default; + +void SubtractorOutput::Reset() { + s_refined.fill(0.f); + s_coarse.fill(0.f); + e_refined.fill(0.f); + e_coarse.fill(0.f); + E_refined.re.fill(0.f); + E_refined.im.fill(0.f); + E2_refined.fill(0.f); + E2_coarse.fill(0.f); + e2_refined = 0.f; + e2_coarse = 0.f; + s2_refined = 0.f; + s2_coarse = 0.f; + y2 = 0.f; +} + +void SubtractorOutput::ComputeMetrics(rtc::ArrayView<const float> y) { + const auto sum_of_squares = [](float a, float b) { return a + b * b; }; + y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares); + e2_refined = + std::accumulate(e_refined.begin(), e_refined.end(), 0.f, sum_of_squares); + e2_coarse = + std::accumulate(e_coarse.begin(), e_coarse.end(), 0.f, sum_of_squares); + s2_refined = + std::accumulate(s_refined.begin(), s_refined.end(), 0.f, sum_of_squares); + s2_coarse = + std::accumulate(s_coarse.begin(), s_coarse.end(), 0.f, sum_of_squares); + + s_refined_max_abs = *std::max_element(s_refined.begin(), s_refined.end()); + s_refined_max_abs = + std::max(s_refined_max_abs, + -(*std::min_element(s_refined.begin(), s_refined.end()))); + + s_coarse_max_abs = *std::max_element(s_coarse.begin(), s_coarse.end()); + s_coarse_max_abs = std::max( + s_coarse_max_abs, -(*std::min_element(s_coarse.begin(), s_coarse.end()))); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.h b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.h new file mode 100644 index 0000000000..d2d12082c6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_H_ + +#include <array> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/fft_data.h" + +namespace webrtc { + +// Stores the values being returned from the echo subtractor for a single +// capture channel. +struct SubtractorOutput { + SubtractorOutput(); + ~SubtractorOutput(); + + std::array<float, kBlockSize> s_refined; + std::array<float, kBlockSize> s_coarse; + std::array<float, kBlockSize> e_refined; + std::array<float, kBlockSize> e_coarse; + FftData E_refined; + std::array<float, kFftLengthBy2Plus1> E2_refined; + std::array<float, kFftLengthBy2Plus1> E2_coarse; + float s2_refined = 0.f; + float s2_coarse = 0.f; + float e2_refined = 0.f; + float e2_coarse = 0.f; + float y2 = 0.f; + float s_refined_max_abs = 0.f; + float s_coarse_max_abs = 0.f; + + // Reset the struct content. + void Reset(); + + // Updates the powers of the signals. + void ComputeMetrics(rtc::ArrayView<const float> y); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.cc b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.cc new file mode 100644 index 0000000000..baf0600161 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subtractor_output_analyzer.h" + +#include <algorithm> + +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +SubtractorOutputAnalyzer::SubtractorOutputAnalyzer(size_t num_capture_channels) + : filters_converged_(num_capture_channels, false) {} + +void SubtractorOutputAnalyzer::Update( + rtc::ArrayView<const SubtractorOutput> subtractor_output, + bool* any_filter_converged, + bool* any_coarse_filter_converged, + bool* all_filters_diverged) { + RTC_DCHECK(any_filter_converged); + RTC_DCHECK(all_filters_diverged); + RTC_DCHECK_EQ(subtractor_output.size(), filters_converged_.size()); + + *any_filter_converged = false; + *any_coarse_filter_converged = false; + *all_filters_diverged = true; + + for (size_t ch = 0; ch < subtractor_output.size(); ++ch) { + const float y2 = subtractor_output[ch].y2; + const float e2_refined = subtractor_output[ch].e2_refined; + const float e2_coarse = subtractor_output[ch].e2_coarse; + + constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize; + constexpr float kConvergenceThresholdLowLevel = 20 * 20 * kBlockSize; + bool refined_filter_converged = + e2_refined < 0.5f * y2 && y2 > kConvergenceThreshold; + bool coarse_filter_converged_strict = + e2_coarse < 0.05f * y2 && y2 > kConvergenceThreshold; + bool coarse_filter_converged_relaxed = + e2_coarse < 0.2f * y2 && y2 > kConvergenceThresholdLowLevel; + float min_e2 = std::min(e2_refined, e2_coarse); + bool filter_diverged = min_e2 > 1.5f * y2 && y2 > 30.f * 30.f * kBlockSize; + filters_converged_[ch] = + refined_filter_converged || coarse_filter_converged_strict; + + *any_filter_converged = *any_filter_converged || filters_converged_[ch]; + *any_coarse_filter_converged = + *any_coarse_filter_converged || coarse_filter_converged_relaxed; + *all_filters_diverged = *all_filters_diverged && filter_diverged; + } +} + +void SubtractorOutputAnalyzer::HandleEchoPathChange() { + std::fill(filters_converged_.begin(), filters_converged_.end(), false); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.h b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.h new file mode 100644 index 0000000000..32707dbb19 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_output_analyzer.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ + +#include <vector> + +#include "modules/audio_processing/aec3/subtractor_output.h" + +namespace webrtc { + +// Class for analyzing the properties subtractor output. +class SubtractorOutputAnalyzer { + public: + explicit SubtractorOutputAnalyzer(size_t num_capture_channels); + ~SubtractorOutputAnalyzer() = default; + + // Analyses the subtractor output. + void Update(rtc::ArrayView<const SubtractorOutput> subtractor_output, + bool* any_filter_converged, + bool* any_coarse_filter_converged, + bool* all_filters_diverged); + + const std::vector<bool>& ConvergedFilters() const { + return filters_converged_; + } + + // Handle echo path change. + void HandleEchoPathChange(); + + private: + std::vector<bool> filters_converged_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUBTRACTOR_OUTPUT_ANALYZER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_unittest.cc new file mode 100644 index 0000000000..56b9cec9f1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/subtractor_unittest.cc @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/subtractor.h" + +#include <algorithm> +#include <memory> +#include <numeric> +#include <string> + +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "modules/audio_processing/utility/cascaded_biquad_filter.h" +#include "rtc_base/random.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace { + +std::vector<float> RunSubtractorTest( + size_t num_render_channels, + size_t num_capture_channels, + int num_blocks_to_process, + int delay_samples, + int refined_filter_length_blocks, + int coarse_filter_length_blocks, + bool uncorrelated_inputs, + const std::vector<int>& blocks_with_echo_path_changes) { + ApmDataDumper data_dumper(42); + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + EchoCanceller3Config config; + config.filter.refined.length_blocks = refined_filter_length_blocks; + config.filter.coarse.length_blocks = coarse_filter_length_blocks; + + Subtractor subtractor(config, num_render_channels, num_capture_channels, + &data_dumper, DetectOptimization()); + absl::optional<DelayEstimate> delay_estimate; + Block x(kNumBands, num_render_channels); + Block y(/*num_bands=*/1, num_capture_channels); + std::array<float, kBlockSize> x_old; + std::vector<SubtractorOutput> output(num_capture_channels); + config.delay.default_delay = 1; + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + RenderSignalAnalyzer render_signal_analyzer(config); + Random random_generator(42U); + Aec3Fft fft; + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); + std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( + num_capture_channels); + std::array<float, kFftLengthBy2Plus1> E2_coarse; + AecState aec_state(config, num_capture_channels); + x_old.fill(0.f); + for (auto& Y2_ch : Y2) { + Y2_ch.fill(0.f); + } + for (auto& E2_refined_ch : E2_refined) { + E2_refined_ch.fill(0.f); + } + E2_coarse.fill(0.f); + + std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer( + num_capture_channels); + for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) { + delay_buffer[capture_ch].resize(num_render_channels); + for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) { + delay_buffer[capture_ch][render_ch] = + std::make_unique<DelayBuffer<float>>(delay_samples); + } + } + + // [B,A] = butter(2,100/8000,'high') + constexpr CascadedBiQuadFilter::BiQuadCoefficients + kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f}, + {-1.94448f, 0.94598f}}; + std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter( + num_render_channels); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch] = + std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1); + } + std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter( + num_capture_channels); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + y_hp_filter[ch] = + std::make_unique<CascadedBiQuadFilter>(kHighPassFilterCoefficients, 1); + } + + for (int k = 0; k < num_blocks_to_process; ++k) { + for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) { + RandomizeSampleVector(&random_generator, x.View(/*band=*/0, render_ch)); + } + if (uncorrelated_inputs) { + for (size_t capture_ch = 0; capture_ch < num_capture_channels; + ++capture_ch) { + RandomizeSampleVector(&random_generator, + y.View(/*band=*/0, capture_ch)); + } + } else { + for (size_t capture_ch = 0; capture_ch < num_capture_channels; + ++capture_ch) { + rtc::ArrayView<float> y_view = y.View(/*band=*/0, capture_ch); + for (size_t render_ch = 0; render_ch < num_render_channels; + ++render_ch) { + std::array<float, kBlockSize> y_channel; + delay_buffer[capture_ch][render_ch]->Delay( + x.View(/*band=*/0, render_ch), y_channel); + for (size_t k = 0; k < kBlockSize; ++k) { + y_view[k] += y_channel[k] / num_render_channels; + } + } + } + } + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch]->Process(x.View(/*band=*/0, ch)); + } + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + y_hp_filter[ch]->Process(y.View(/*band=*/0, ch)); + } + + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(), + aec_state.MinDirectPathFilterDelay()); + + // Handle echo path changes. + if (std::find(blocks_with_echo_path_changes.begin(), + blocks_with_echo_path_changes.end(), + k) != blocks_with_echo_path_changes.end()) { + subtractor.HandleEchoPathChange(EchoPathVariability( + true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, + false)); + } + subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y, + render_signal_analyzer, aec_state, output); + + aec_state.HandleEchoPathChange(EchoPathVariability( + false, EchoPathVariability::DelayAdjustment::kNone, false)); + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), + subtractor.FilterImpulseResponses(), + *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2, + output); + } + + std::vector<float> results(num_capture_channels); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + const float output_power = std::inner_product( + output[ch].e_refined.begin(), output[ch].e_refined.end(), + output[ch].e_refined.begin(), 0.f); + const float y_power = + std::inner_product(y.begin(/*band=*/0, ch), y.end(/*band=*/0, ch), + y.begin(/*band=*/0, ch), 0.f); + if (y_power == 0.f) { + ADD_FAILURE(); + results[ch] = -1.f; + } + results[ch] = output_power / y_power; + } + return results; +} + +std::string ProduceDebugText(size_t num_render_channels, + size_t num_capture_channels, + size_t delay, + int filter_length_blocks) { + rtc::StringBuilder ss; + ss << "delay: " << delay << ", "; + ss << "filter_length_blocks:" << filter_length_blocks << ", "; + ss << "num_render_channels:" << num_render_channels << ", "; + ss << "num_capture_channels:" << num_capture_channels; + return ss.Release(); +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non data dumper works. +TEST(SubtractorDeathTest, NullDataDumper) { + EXPECT_DEATH( + Subtractor(EchoCanceller3Config(), 1, 1, nullptr, DetectOptimization()), + ""); +} + +#endif + +// Verifies that the subtractor is able to converge on correlated data. +TEST(Subtractor, Convergence) { + std::vector<int> blocks_with_echo_path_changes; + for (size_t filter_length_blocks : {12, 20, 30}) { + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks)); + std::vector<float> echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks, + false, blocks_with_echo_path_changes); + + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.1f, echo_to_nearend_power); + } + } + } +} + +// Verifies that the subtractor is able to handle the case when the refined +// filter is longer than the coarse filter. +TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) { + std::vector<int> blocks_with_echo_path_changes; + std::vector<float> echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.5f, echo_to_nearend_power); + } +} + +// Verifies that the subtractor is able to handle the case when the coarse +// filter is longer than the refined filter. +TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) { + std::vector<int> blocks_with_echo_path_changes; + std::vector<float> echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.5f, echo_to_nearend_power); + } +} + +// Verifies that the subtractor does not converge on uncorrelated signals. +TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) { + std::vector<int> blocks_with_echo_path_changes; + for (size_t filter_length_blocks : {12, 20, 30}) { + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks)); + + std::vector<float> echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks, + true, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1); + } + } + } +} + +class SubtractorMultiChannelUpToEightRender + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +#if defined(NDEBUG) +INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel, + SubtractorMultiChannelUpToEightRender, + ::testing::Combine(::testing::Values(1, 2, 8), + ::testing::Values(1, 2, 4))); +#else +INSTANTIATE_TEST_SUITE_P(DebugMultiChannel, + SubtractorMultiChannelUpToEightRender, + ::testing::Combine(::testing::Values(1, 2), + ::testing::Values(1, 2))); +#endif + +// Verifies that the subtractor is able to converge on correlated data. +TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + + std::vector<int> blocks_with_echo_path_changes; + size_t num_blocks_to_process = 2500 * num_render_channels; + std::vector<float> echo_to_nearend_powers = RunSubtractorTest( + num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20, + 20, false, blocks_with_echo_path_changes); + + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.1f, echo_to_nearend_power); + } +} + +class SubtractorMultiChannelUpToFourRender + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; + +#if defined(NDEBUG) +INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel, + SubtractorMultiChannelUpToFourRender, + ::testing::Combine(::testing::Values(1, 2, 4), + ::testing::Values(1, 2, 4))); +#else +INSTANTIATE_TEST_SUITE_P(DebugMultiChannel, + SubtractorMultiChannelUpToFourRender, + ::testing::Combine(::testing::Values(1, 2), + ::testing::Values(1, 2))); +#endif + +// Verifies that the subtractor does not converge on uncorrelated signals. +TEST_P(SubtractorMultiChannelUpToFourRender, + NonConvergenceOnUncorrelatedSignals) { + const size_t num_render_channels = std::get<0>(GetParam()); + const size_t num_capture_channels = std::get<1>(GetParam()); + + std::vector<int> blocks_with_echo_path_changes; + size_t num_blocks_to_process = 5000 * num_render_channels; + std::vector<float> echo_to_nearend_powers = RunSubtractorTest( + num_render_channels, num_capture_channels, num_blocks_to_process, 64, 20, + 20, true, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_LT(.8f, echo_to_nearend_power); + EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f); + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.cc b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.cc new file mode 100644 index 0000000000..83ded425d5 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.cc @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/suppression_filter.h" + +#include <algorithm> +#include <cmath> +#include <cstring> +#include <functional> +#include <iterator> + +#include "modules/audio_processing/aec3/vector_math.h" +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { +namespace { + +// Hanning window from Matlab command win = sqrt(hanning(128)). +const float kSqrtHanning[kFftLength] = { + 0.00000000000000f, 0.02454122852291f, 0.04906767432742f, 0.07356456359967f, + 0.09801714032956f, 0.12241067519922f, 0.14673047445536f, 0.17096188876030f, + 0.19509032201613f, 0.21910124015687f, 0.24298017990326f, 0.26671275747490f, + 0.29028467725446f, 0.31368174039889f, 0.33688985339222f, 0.35989503653499f, + 0.38268343236509f, 0.40524131400499f, 0.42755509343028f, 0.44961132965461f, + 0.47139673682600f, 0.49289819222978f, 0.51410274419322f, 0.53499761988710f, + 0.55557023301960f, 0.57580819141785f, 0.59569930449243f, 0.61523159058063f, + 0.63439328416365f, 0.65317284295378f, 0.67155895484702f, 0.68954054473707f, + 0.70710678118655f, 0.72424708295147f, 0.74095112535496f, 0.75720884650648f, + 0.77301045336274f, 0.78834642762661f, 0.80320753148064f, 0.81758481315158f, + 0.83146961230255f, 0.84485356524971f, 0.85772861000027f, 0.87008699110871f, + 0.88192126434835f, 0.89322430119552f, 0.90398929312344f, 0.91420975570353f, + 0.92387953251129f, 0.93299279883474f, 0.94154406518302f, 0.94952818059304f, + 0.95694033573221f, 0.96377606579544f, 0.97003125319454f, 0.97570213003853f, + 0.98078528040323f, 0.98527764238894f, 0.98917650996478f, 0.99247953459871f, + 0.99518472667220f, 0.99729045667869f, 0.99879545620517f, 0.99969881869620f, + 1.00000000000000f, 0.99969881869620f, 0.99879545620517f, 0.99729045667869f, + 0.99518472667220f, 0.99247953459871f, 0.98917650996478f, 0.98527764238894f, + 0.98078528040323f, 0.97570213003853f, 0.97003125319454f, 0.96377606579544f, + 0.95694033573221f, 0.94952818059304f, 0.94154406518302f, 0.93299279883474f, + 0.92387953251129f, 0.91420975570353f, 0.90398929312344f, 0.89322430119552f, + 0.88192126434835f, 0.87008699110871f, 0.85772861000027f, 0.84485356524971f, + 0.83146961230255f, 0.81758481315158f, 0.80320753148064f, 0.78834642762661f, + 0.77301045336274f, 0.75720884650648f, 0.74095112535496f, 0.72424708295147f, + 0.70710678118655f, 0.68954054473707f, 0.67155895484702f, 0.65317284295378f, + 0.63439328416365f, 0.61523159058063f, 0.59569930449243f, 0.57580819141785f, + 0.55557023301960f, 0.53499761988710f, 0.51410274419322f, 0.49289819222978f, + 0.47139673682600f, 0.44961132965461f, 0.42755509343028f, 0.40524131400499f, + 0.38268343236509f, 0.35989503653499f, 0.33688985339222f, 0.31368174039889f, + 0.29028467725446f, 0.26671275747490f, 0.24298017990326f, 0.21910124015687f, + 0.19509032201613f, 0.17096188876030f, 0.14673047445536f, 0.12241067519922f, + 0.09801714032956f, 0.07356456359967f, 0.04906767432742f, 0.02454122852291f}; + +} // namespace + +SuppressionFilter::SuppressionFilter(Aec3Optimization optimization, + int sample_rate_hz, + size_t num_capture_channels) + : optimization_(optimization), + sample_rate_hz_(sample_rate_hz), + num_capture_channels_(num_capture_channels), + fft_(), + e_output_old_(NumBandsForRate(sample_rate_hz_), + std::vector<std::array<float, kFftLengthBy2>>( + num_capture_channels_)) { + RTC_DCHECK(ValidFullBandRate(sample_rate_hz_)); + for (size_t b = 0; b < e_output_old_.size(); ++b) { + for (size_t ch = 0; ch < e_output_old_[b].size(); ++ch) { + e_output_old_[b][ch].fill(0.f); + } + } +} + +SuppressionFilter::~SuppressionFilter() = default; + +void SuppressionFilter::ApplyGain( + rtc::ArrayView<const FftData> comfort_noise, + rtc::ArrayView<const FftData> comfort_noise_high_band, + const std::array<float, kFftLengthBy2Plus1>& suppression_gain, + float high_bands_gain, + rtc::ArrayView<const FftData> E_lowest_band, + Block* e) { + RTC_DCHECK(e); + RTC_DCHECK_EQ(e->NumBands(), NumBandsForRate(sample_rate_hz_)); + + // Comfort noise gain is sqrt(1-g^2), where g is the suppression gain. + std::array<float, kFftLengthBy2Plus1> noise_gain; + for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) { + noise_gain[i] = 1.f - suppression_gain[i] * suppression_gain[i]; + } + aec3::VectorMath(optimization_).Sqrt(noise_gain); + + const float high_bands_noise_scaling = + 0.4f * std::sqrt(1.f - high_bands_gain * high_bands_gain); + + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + FftData E; + + // Analysis filterbank. + E.Assign(E_lowest_band[ch]); + + for (size_t i = 0; i < kFftLengthBy2Plus1; ++i) { + // Apply suppression gains. + float E_real = E.re[i] * suppression_gain[i]; + float E_imag = E.im[i] * suppression_gain[i]; + + // Scale and add the comfort noise. + E.re[i] = E_real + noise_gain[i] * comfort_noise[ch].re[i]; + E.im[i] = E_imag + noise_gain[i] * comfort_noise[ch].im[i]; + } + + // Synthesis filterbank. + std::array<float, kFftLength> e_extended; + constexpr float kIfftNormalization = 2.f / kFftLength; + fft_.Ifft(E, &e_extended); + + auto e0 = e->View(/*band=*/0, ch); + float* e0_old = e_output_old_[0][ch].data(); + + // Window and add the first half of e_extended with the second half of + // e_extended from the previous block. + for (size_t i = 0; i < kFftLengthBy2; ++i) { + float e0_i = e0_old[i] * kSqrtHanning[kFftLengthBy2 + i]; + e0_i += e_extended[i] * kSqrtHanning[i]; + e0[i] = e0_i * kIfftNormalization; + } + + // The second half of e_extended is stored for the succeeding frame. + std::copy(e_extended.begin() + kFftLengthBy2, + e_extended.begin() + kFftLength, + std::begin(e_output_old_[0][ch])); + + // Apply suppression gain to upper bands. + for (int b = 1; b < e->NumBands(); ++b) { + auto e_band = e->View(b, ch); + for (size_t i = 0; i < kFftLengthBy2; ++i) { + e_band[i] *= high_bands_gain; + } + } + + // Add comfort noise to band 1. + if (e->NumBands() > 1) { + E.Assign(comfort_noise_high_band[ch]); + std::array<float, kFftLength> time_domain_high_band_noise; + fft_.Ifft(E, &time_domain_high_band_noise); + + auto e1 = e->View(/*band=*/1, ch); + const float gain = high_bands_noise_scaling * kIfftNormalization; + for (size_t i = 0; i < kFftLengthBy2; ++i) { + e1[i] += time_domain_high_band_noise[i] * gain; + } + } + + // Delay upper bands to match the delay of the filter bank. + for (int b = 1; b < e->NumBands(); ++b) { + auto e_band = e->View(b, ch); + float* e_band_old = e_output_old_[b][ch].data(); + for (size_t i = 0; i < kFftLengthBy2; ++i) { + std::swap(e_band[i], e_band_old[i]); + } + } + + // Clamp output of all bands. + for (int b = 0; b < e->NumBands(); ++b) { + auto e_band = e->View(b, ch); + for (size_t i = 0; i < kFftLengthBy2; ++i) { + e_band[i] = rtc::SafeClamp(e_band[i], -32768.f, 32767.f); + } + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.h b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.h new file mode 100644 index 0000000000..c18b2334bf --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_FILTER_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_FILTER_H_ + +#include <array> +#include <vector> + +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec3_fft.h" +#include "modules/audio_processing/aec3/block.h" +#include "modules/audio_processing/aec3/fft_data.h" + +namespace webrtc { + +class SuppressionFilter { + public: + SuppressionFilter(Aec3Optimization optimization, + int sample_rate_hz, + size_t num_capture_channels_); + ~SuppressionFilter(); + + SuppressionFilter(const SuppressionFilter&) = delete; + SuppressionFilter& operator=(const SuppressionFilter&) = delete; + + void ApplyGain(rtc::ArrayView<const FftData> comfort_noise, + rtc::ArrayView<const FftData> comfort_noise_high_bands, + const std::array<float, kFftLengthBy2Plus1>& suppression_gain, + float high_bands_gain, + rtc::ArrayView<const FftData> E_lowest_band, + Block* e); + + private: + const Aec3Optimization optimization_; + const int sample_rate_hz_; + const size_t num_capture_channels_; + const Aec3Fft fft_; + std::vector<std::vector<std::array<float, kFftLengthBy2>>> e_output_old_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_FILTER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc new file mode 100644 index 0000000000..464f5cfed2 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_filter_unittest.cc @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/suppression_filter.h" + +#include <math.h> + +#include <algorithm> +#include <cmath> +#include <numeric> + +#include "test/gtest.h" + +namespace webrtc { +namespace { + +constexpr float kPi = 3.141592f; + +void ProduceSinusoid(int sample_rate_hz, + float sinusoidal_frequency_hz, + size_t* sample_counter, + Block* x) { + // Produce a sinusoid of the specified frequency. + for (size_t k = *sample_counter, j = 0; k < (*sample_counter + kBlockSize); + ++k, ++j) { + for (int channel = 0; channel < x->NumChannels(); ++channel) { + x->View(/*band=*/0, channel)[j] = + 32767.f * + std::sin(2.f * kPi * sinusoidal_frequency_hz * k / sample_rate_hz); + } + } + *sample_counter = *sample_counter + kBlockSize; + + for (int band = 1; band < x->NumBands(); ++band) { + for (int channel = 0; channel < x->NumChannels(); ++channel) { + std::fill(x->begin(band, channel), x->end(band, channel), 0.f); + } + } +} + +} // namespace + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies the check for null suppressor output. +TEST(SuppressionFilterDeathTest, NullOutput) { + std::vector<FftData> cn(1); + std::vector<FftData> cn_high_bands(1); + std::vector<FftData> E(1); + std::array<float, kFftLengthBy2Plus1> gain; + + EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16000, 1) + .ApplyGain(cn, cn_high_bands, gain, 1.0f, E, nullptr), + ""); +} + +// Verifies the check for allowed sample rate. +TEST(SuppressionFilterDeathTest, ProperSampleRate) { + EXPECT_DEATH(SuppressionFilter(Aec3Optimization::kNone, 16001, 1), ""); +} + +#endif + +// Verifies that no comfort noise is added when the gain is 1. +TEST(SuppressionFilter, ComfortNoiseInUnityGain) { + SuppressionFilter filter(Aec3Optimization::kNone, 48000, 1); + std::vector<FftData> cn(1); + std::vector<FftData> cn_high_bands(1); + std::array<float, kFftLengthBy2Plus1> gain; + std::array<float, kFftLengthBy2> e_old_; + Aec3Fft fft; + + e_old_.fill(0.f); + gain.fill(1.f); + cn[0].re.fill(1.f); + cn[0].im.fill(1.f); + cn_high_bands[0].re.fill(1.f); + cn_high_bands[0].im.fill(1.f); + + Block e(3, kBlockSize); + Block e_ref = e; + + std::vector<FftData> E(1); + fft.PaddedFft(e.View(/*band=*/0, /*channel=*/0), e_old_, + Aec3Fft::Window::kSqrtHanning, &E[0]); + std::copy(e.begin(/*band=*/0, /*channel=*/0), + e.end(/*band=*/0, /*channel=*/0), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); + + for (int band = 0; band < e.NumBands(); ++band) { + for (int channel = 0; channel < e.NumChannels(); ++channel) { + const auto e_view = e.View(band, channel); + const auto e_ref_view = e_ref.View(band, channel); + for (size_t sample = 0; sample < e_view.size(); ++sample) { + EXPECT_EQ(e_ref_view[sample], e_view[sample]); + } + } + } +} + +// Verifies that the suppressor is able to suppress a signal. +TEST(SuppressionFilter, SignalSuppression) { + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + constexpr size_t kNumChannels = 1; + + SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1); + std::vector<FftData> cn(1); + std::vector<FftData> cn_high_bands(1); + std::array<float, kFftLengthBy2> e_old_; + Aec3Fft fft; + std::array<float, kFftLengthBy2Plus1> gain; + Block e(kNumBands, kNumChannels); + e_old_.fill(0.f); + + gain.fill(1.f); + std::for_each(gain.begin() + 10, gain.end(), [](float& a) { a = 0.f; }); + + cn[0].re.fill(0.f); + cn[0].im.fill(0.f); + cn_high_bands[0].re.fill(0.f); + cn_high_bands[0].im.fill(0.f); + + size_t sample_counter = 0; + + float e0_input = 0.f; + float e0_output = 0.f; + for (size_t k = 0; k < 100; ++k) { + ProduceSinusoid(16000, 16000 * 40 / kFftLengthBy2 / 2, &sample_counter, &e); + e0_input = std::inner_product(e.begin(/*band=*/0, /*channel=*/0), + e.end(/*band=*/0, /*channel=*/0), + e.begin(/*band=*/0, /*channel=*/0), e0_input); + + std::vector<FftData> E(1); + fft.PaddedFft(e.View(/*band=*/0, /*channel=*/0), e_old_, + Aec3Fft::Window::kSqrtHanning, &E[0]); + std::copy(e.begin(/*band=*/0, /*channel=*/0), + e.end(/*band=*/0, /*channel=*/0), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); + e0_output = std::inner_product( + e.begin(/*band=*/0, /*channel=*/0), e.end(/*band=*/0, /*channel=*/0), + e.begin(/*band=*/0, /*channel=*/0), e0_output); + } + + EXPECT_LT(e0_output, e0_input / 1000.f); +} + +// Verifies that the suppressor is able to pass through a desired signal while +// applying suppressing for some frequencies. +TEST(SuppressionFilter, SignalTransparency) { + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1); + std::vector<FftData> cn(1); + std::array<float, kFftLengthBy2> e_old_; + Aec3Fft fft; + std::vector<FftData> cn_high_bands(1); + std::array<float, kFftLengthBy2Plus1> gain; + Block e(kNumBands, kNumChannels); + e_old_.fill(0.f); + gain.fill(1.f); + std::for_each(gain.begin() + 30, gain.end(), [](float& a) { a = 0.f; }); + + cn[0].re.fill(0.f); + cn[0].im.fill(0.f); + cn_high_bands[0].re.fill(0.f); + cn_high_bands[0].im.fill(0.f); + + size_t sample_counter = 0; + + float e0_input = 0.f; + float e0_output = 0.f; + for (size_t k = 0; k < 100; ++k) { + ProduceSinusoid(16000, 16000 * 10 / kFftLengthBy2 / 2, &sample_counter, &e); + e0_input = std::inner_product(e.begin(/*band=*/0, /*channel=*/0), + e.end(/*band=*/0, /*channel=*/0), + e.begin(/*band=*/0, /*channel=*/0), e0_input); + + std::vector<FftData> E(1); + fft.PaddedFft(e.View(/*band=*/0, /*channel=*/0), e_old_, + Aec3Fft::Window::kSqrtHanning, &E[0]); + std::copy(e.begin(/*band=*/0, /*channel=*/0), + e.end(/*band=*/0, /*channel=*/0), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); + e0_output = std::inner_product( + e.begin(/*band=*/0, /*channel=*/0), e.end(/*band=*/0, /*channel=*/0), + e.begin(/*band=*/0, /*channel=*/0), e0_output); + } + + EXPECT_LT(0.9f * e0_input, e0_output); +} + +// Verifies that the suppressor delay. +TEST(SuppressionFilter, Delay) { + constexpr size_t kNumChannels = 1; + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + + SuppressionFilter filter(Aec3Optimization::kNone, kSampleRateHz, 1); + std::vector<FftData> cn(1); + std::vector<FftData> cn_high_bands(1); + std::array<float, kFftLengthBy2> e_old_; + Aec3Fft fft; + std::array<float, kFftLengthBy2Plus1> gain; + Block e(kNumBands, kNumChannels); + + gain.fill(1.f); + + cn[0].re.fill(0.f); + cn[0].im.fill(0.f); + cn_high_bands[0].re.fill(0.f); + cn_high_bands[0].im.fill(0.f); + + for (size_t k = 0; k < 100; ++k) { + for (size_t band = 0; band < kNumBands; ++band) { + for (size_t channel = 0; channel < kNumChannels; ++channel) { + auto e_view = e.View(band, channel); + for (size_t sample = 0; sample < kBlockSize; ++sample) { + e_view[sample] = k * kBlockSize + sample + channel; + } + } + } + + std::vector<FftData> E(1); + fft.PaddedFft(e.View(/*band=*/0, /*channel=*/0), e_old_, + Aec3Fft::Window::kSqrtHanning, &E[0]); + std::copy(e.begin(/*band=*/0, /*channel=*/0), + e.end(/*band=*/0, /*channel=*/0), e_old_.begin()); + + filter.ApplyGain(cn, cn_high_bands, gain, 1.f, E, &e); + if (k > 2) { + for (size_t band = 0; band < kNumBands; ++band) { + for (size_t channel = 0; channel < kNumChannels; ++channel) { + const auto e_view = e.View(band, channel); + for (size_t sample = 0; sample < kBlockSize; ++sample) { + EXPECT_NEAR(k * kBlockSize + sample - kBlockSize + channel, + e_view[sample], 0.01); + } + } + } + } + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.cc b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.cc new file mode 100644 index 0000000000..037dabaabe --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.cc @@ -0,0 +1,465 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/suppression_gain.h" + +#include <math.h> +#include <stddef.h> + +#include <algorithm> +#include <numeric> + +#include "modules/audio_processing/aec3/dominant_nearend_detector.h" +#include "modules/audio_processing/aec3/moving_average.h" +#include "modules/audio_processing/aec3/subband_nearend_detector.h" +#include "modules/audio_processing/aec3/vector_math.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { +namespace { + +void LimitLowFrequencyGains(std::array<float, kFftLengthBy2Plus1>* gain) { + // Limit the low frequency gains to avoid the impact of the high-pass filter + // on the lower-frequency gain influencing the overall achieved gain. + (*gain)[0] = (*gain)[1] = std::min((*gain)[1], (*gain)[2]); +} + +void LimitHighFrequencyGains(bool conservative_hf_suppression, + std::array<float, kFftLengthBy2Plus1>* gain) { + // Limit the high frequency gains to avoid echo leakage due to an imperfect + // filter. + constexpr size_t kFirstBandToLimit = (64 * 2000) / 8000; + const float min_upper_gain = (*gain)[kFirstBandToLimit]; + std::for_each( + gain->begin() + kFirstBandToLimit + 1, gain->end(), + [min_upper_gain](float& a) { a = std::min(a, min_upper_gain); }); + (*gain)[kFftLengthBy2] = (*gain)[kFftLengthBy2Minus1]; + + if (conservative_hf_suppression) { + // Limits the gain in the frequencies for which the adaptive filter has not + // converged. + // TODO(peah): Make adaptive to take the actual filter error into account. + constexpr size_t kUpperAccurateBandPlus1 = 29; + + constexpr float oneByBandsInSum = + 1 / static_cast<float>(kUpperAccurateBandPlus1 - 20); + const float hf_gain_bound = + std::accumulate(gain->begin() + 20, + gain->begin() + kUpperAccurateBandPlus1, 0.f) * + oneByBandsInSum; + + std::for_each( + gain->begin() + kUpperAccurateBandPlus1, gain->end(), + [hf_gain_bound](float& a) { a = std::min(a, hf_gain_bound); }); + } +} + +// Scales the echo according to assessed audibility at the other end. +void WeightEchoForAudibility(const EchoCanceller3Config& config, + rtc::ArrayView<const float> echo, + rtc::ArrayView<float> weighted_echo) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, echo.size()); + RTC_DCHECK_EQ(kFftLengthBy2Plus1, weighted_echo.size()); + + auto weigh = [](float threshold, float normalizer, size_t begin, size_t end, + rtc::ArrayView<const float> echo, + rtc::ArrayView<float> weighted_echo) { + for (size_t k = begin; k < end; ++k) { + if (echo[k] < threshold) { + float tmp = (threshold - echo[k]) * normalizer; + weighted_echo[k] = echo[k] * std::max(0.f, 1.f - tmp * tmp); + } else { + weighted_echo[k] = echo[k]; + } + } + }; + + float threshold = config.echo_audibility.floor_power * + config.echo_audibility.audibility_threshold_lf; + float normalizer = 1.f / (threshold - config.echo_audibility.floor_power); + weigh(threshold, normalizer, 0, 3, echo, weighted_echo); + + threshold = config.echo_audibility.floor_power * + config.echo_audibility.audibility_threshold_mf; + normalizer = 1.f / (threshold - config.echo_audibility.floor_power); + weigh(threshold, normalizer, 3, 7, echo, weighted_echo); + + threshold = config.echo_audibility.floor_power * + config.echo_audibility.audibility_threshold_hf; + normalizer = 1.f / (threshold - config.echo_audibility.floor_power); + weigh(threshold, normalizer, 7, kFftLengthBy2Plus1, echo, weighted_echo); +} + +} // namespace + +std::atomic<int> SuppressionGain::instance_count_(0); + +float SuppressionGain::UpperBandsGain( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + const absl::optional<int>& narrow_peak_band, + bool saturated_echo, + const Block& render, + const std::array<float, kFftLengthBy2Plus1>& low_band_gain) const { + RTC_DCHECK_LT(0, render.NumBands()); + if (render.NumBands() == 1) { + return 1.f; + } + const int num_render_channels = render.NumChannels(); + + if (narrow_peak_band && + (*narrow_peak_band > static_cast<int>(kFftLengthBy2Plus1 - 10))) { + return 0.001f; + } + + constexpr size_t kLowBandGainLimit = kFftLengthBy2 / 2; + const float gain_below_8_khz = *std::min_element( + low_band_gain.begin() + kLowBandGainLimit, low_band_gain.end()); + + // Always attenuate the upper bands when there is saturated echo. + if (saturated_echo) { + return std::min(0.001f, gain_below_8_khz); + } + + // Compute the upper and lower band energies. + const auto sum_of_squares = [](float a, float b) { return a + b * b; }; + float low_band_energy = 0.f; + for (int ch = 0; ch < num_render_channels; ++ch) { + const float channel_energy = + std::accumulate(render.begin(/*band=*/0, ch), + render.end(/*band=*/0, ch), 0.0f, sum_of_squares); + low_band_energy = std::max(low_band_energy, channel_energy); + } + float high_band_energy = 0.f; + for (int k = 1; k < render.NumBands(); ++k) { + for (int ch = 0; ch < num_render_channels; ++ch) { + const float energy = std::accumulate( + render.begin(k, ch), render.end(k, ch), 0.f, sum_of_squares); + high_band_energy = std::max(high_band_energy, energy); + } + } + + // If there is more power in the lower frequencies than the upper frequencies, + // or if the power in upper frequencies is low, do not bound the gain in the + // upper bands. + float anti_howling_gain; + const float activation_threshold = + kBlockSize * config_.suppressor.high_bands_suppression + .anti_howling_activation_threshold; + if (high_band_energy < std::max(low_band_energy, activation_threshold)) { + anti_howling_gain = 1.f; + } else { + // In all other cases, bound the gain for upper frequencies. + RTC_DCHECK_LE(low_band_energy, high_band_energy); + RTC_DCHECK_NE(0.f, high_band_energy); + anti_howling_gain = + config_.suppressor.high_bands_suppression.anti_howling_gain * + sqrtf(low_band_energy / high_band_energy); + } + + float gain_bound = 1.f; + if (!dominant_nearend_detector_->IsNearendState()) { + // Bound the upper gain during significant echo activity. + const auto& cfg = config_.suppressor.high_bands_suppression; + auto low_frequency_energy = [](rtc::ArrayView<const float> spectrum) { + RTC_DCHECK_LE(16, spectrum.size()); + return std::accumulate(spectrum.begin() + 1, spectrum.begin() + 16, 0.f); + }; + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + const float echo_sum = low_frequency_energy(echo_spectrum[ch]); + const float noise_sum = low_frequency_energy(comfort_noise_spectrum[ch]); + if (echo_sum > cfg.enr_threshold * noise_sum) { + gain_bound = cfg.max_gain_during_echo; + break; + } + } + } + + // Choose the gain as the minimum of the lower and upper gains. + return std::min(std::min(gain_below_8_khz, anti_howling_gain), gain_bound); +} + +// Computes the gain to reduce the echo to a non audible level. +void SuppressionGain::GainToNoAudibleEcho( + const std::array<float, kFftLengthBy2Plus1>& nearend, + const std::array<float, kFftLengthBy2Plus1>& echo, + const std::array<float, kFftLengthBy2Plus1>& masker, + std::array<float, kFftLengthBy2Plus1>* gain) const { + const auto& p = dominant_nearend_detector_->IsNearendState() ? nearend_params_ + : normal_params_; + for (size_t k = 0; k < gain->size(); ++k) { + float enr = echo[k] / (nearend[k] + 1.f); // Echo-to-nearend ratio. + float emr = echo[k] / (masker[k] + 1.f); // Echo-to-masker (noise) ratio. + float g = 1.0f; + if (enr > p.enr_transparent_[k] && emr > p.emr_transparent_[k]) { + g = (p.enr_suppress_[k] - enr) / + (p.enr_suppress_[k] - p.enr_transparent_[k]); + g = std::max(g, p.emr_transparent_[k] / emr); + } + (*gain)[k] = g; + } +} + +// Compute the minimum gain as the attenuating gain to put the signal just +// above the zero sample values. +void SuppressionGain::GetMinGain( + rtc::ArrayView<const float> weighted_residual_echo, + rtc::ArrayView<const float> last_nearend, + rtc::ArrayView<const float> last_echo, + bool low_noise_render, + bool saturated_echo, + rtc::ArrayView<float> min_gain) const { + if (!saturated_echo) { + const float min_echo_power = + low_noise_render ? config_.echo_audibility.low_render_limit + : config_.echo_audibility.normal_render_limit; + + for (size_t k = 0; k < min_gain.size(); ++k) { + min_gain[k] = weighted_residual_echo[k] > 0.f + ? min_echo_power / weighted_residual_echo[k] + : 1.f; + min_gain[k] = std::min(min_gain[k], 1.f); + } + + if (!initial_state_ || + config_.suppressor.lf_smoothing_during_initial_phase) { + const float& dec = dominant_nearend_detector_->IsNearendState() + ? nearend_params_.max_dec_factor_lf + : normal_params_.max_dec_factor_lf; + + for (int k = 0; k <= config_.suppressor.last_lf_smoothing_band; ++k) { + // Make sure the gains of the low frequencies do not decrease too + // quickly after strong nearend. + if (last_nearend[k] > last_echo[k] || + k <= config_.suppressor.last_permanent_lf_smoothing_band) { + min_gain[k] = std::max(min_gain[k], last_gain_[k] * dec); + min_gain[k] = std::min(min_gain[k], 1.f); + } + } + } + } else { + std::fill(min_gain.begin(), min_gain.end(), 0.f); + } +} + +// Compute the maximum gain by limiting the gain increase from the previous +// gain. +void SuppressionGain::GetMaxGain(rtc::ArrayView<float> max_gain) const { + const auto& inc = dominant_nearend_detector_->IsNearendState() + ? nearend_params_.max_inc_factor + : normal_params_.max_inc_factor; + const auto& floor = config_.suppressor.floor_first_increase; + for (size_t k = 0; k < max_gain.size(); ++k) { + max_gain[k] = std::min(std::max(last_gain_[k] * inc, floor), 1.f); + } +} + +void SuppressionGain::LowerBandGain( + bool low_noise_render, + const AecState& aec_state, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + suppressor_input, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> residual_echo, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> comfort_noise, + bool clock_drift, + std::array<float, kFftLengthBy2Plus1>* gain) { + gain->fill(1.f); + const bool saturated_echo = aec_state.SaturatedEcho(); + std::array<float, kFftLengthBy2Plus1> max_gain; + GetMaxGain(max_gain); + + for (size_t ch = 0; ch < num_capture_channels_; ++ch) { + std::array<float, kFftLengthBy2Plus1> G; + std::array<float, kFftLengthBy2Plus1> nearend; + nearend_smoothers_[ch].Average(suppressor_input[ch], nearend); + + // Weight echo power in terms of audibility. + std::array<float, kFftLengthBy2Plus1> weighted_residual_echo; + WeightEchoForAudibility(config_, residual_echo[ch], weighted_residual_echo); + + std::array<float, kFftLengthBy2Plus1> min_gain; + GetMinGain(weighted_residual_echo, last_nearend_[ch], last_echo_[ch], + low_noise_render, saturated_echo, min_gain); + + GainToNoAudibleEcho(nearend, weighted_residual_echo, comfort_noise[0], &G); + + // Clamp gains. + for (size_t k = 0; k < gain->size(); ++k) { + G[k] = std::max(std::min(G[k], max_gain[k]), min_gain[k]); + (*gain)[k] = std::min((*gain)[k], G[k]); + } + + // Store data required for the gain computation of the next block. + std::copy(nearend.begin(), nearend.end(), last_nearend_[ch].begin()); + std::copy(weighted_residual_echo.begin(), weighted_residual_echo.end(), + last_echo_[ch].begin()); + } + + LimitLowFrequencyGains(gain); + // Use conservative high-frequency gains during clock-drift or when not in + // dominant nearend. + if (!dominant_nearend_detector_->IsNearendState() || clock_drift || + config_.suppressor.conservative_hf_suppression) { + LimitHighFrequencyGains(config_.suppressor.conservative_hf_suppression, + gain); + } + + // Store computed gains. + std::copy(gain->begin(), gain->end(), last_gain_.begin()); + + // Transform gains to amplitude domain. + aec3::VectorMath(optimization_).Sqrt(*gain); +} + +SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, + Aec3Optimization optimization, + int sample_rate_hz, + size_t num_capture_channels) + : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), + optimization_(optimization), + config_(config), + num_capture_channels_(num_capture_channels), + state_change_duration_blocks_( + static_cast<int>(config_.filter.config_change_duration_blocks)), + last_nearend_(num_capture_channels_, {0}), + last_echo_(num_capture_channels_, {0}), + nearend_smoothers_( + num_capture_channels_, + aec3::MovingAverage(kFftLengthBy2Plus1, + config.suppressor.nearend_average_blocks)), + nearend_params_(config_.suppressor.last_lf_band, + config_.suppressor.first_hf_band, + config_.suppressor.nearend_tuning), + normal_params_(config_.suppressor.last_lf_band, + config_.suppressor.first_hf_band, + config_.suppressor.normal_tuning), + use_unbounded_echo_spectrum_(config.suppressor.dominant_nearend_detection + .use_unbounded_echo_spectrum) { + RTC_DCHECK_LT(0, state_change_duration_blocks_); + last_gain_.fill(1.f); + if (config_.suppressor.use_subband_nearend_detection) { + dominant_nearend_detector_ = std::make_unique<SubbandNearendDetector>( + config_.suppressor.subband_nearend_detection, num_capture_channels_); + } else { + dominant_nearend_detector_ = std::make_unique<DominantNearendDetector>( + config_.suppressor.dominant_nearend_detection, num_capture_channels_); + } + RTC_DCHECK(dominant_nearend_detector_); +} + +SuppressionGain::~SuppressionGain() = default; + +void SuppressionGain::GetGain( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum_unbounded, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + const RenderSignalAnalyzer& render_signal_analyzer, + const AecState& aec_state, + const Block& render, + bool clock_drift, + float* high_bands_gain, + std::array<float, kFftLengthBy2Plus1>* low_band_gain) { + RTC_DCHECK(high_bands_gain); + RTC_DCHECK(low_band_gain); + + // Choose residual echo spectrum for dominant nearend detection. + const auto echo = use_unbounded_echo_spectrum_ + ? residual_echo_spectrum_unbounded + : residual_echo_spectrum; + + // Update the nearend state selection. + dominant_nearend_detector_->Update(nearend_spectrum, echo, + comfort_noise_spectrum, initial_state_); + + // Compute gain for the lower band. + bool low_noise_render = low_render_detector_.Detect(render); + LowerBandGain(low_noise_render, aec_state, nearend_spectrum, + residual_echo_spectrum, comfort_noise_spectrum, clock_drift, + low_band_gain); + + // Compute the gain for the upper bands. + const absl::optional<int> narrow_peak_band = + render_signal_analyzer.NarrowPeakBand(); + + *high_bands_gain = + UpperBandsGain(echo_spectrum, comfort_noise_spectrum, narrow_peak_band, + aec_state.SaturatedEcho(), render, *low_band_gain); + + data_dumper_->DumpRaw("aec3_dominant_nearend", + dominant_nearend_detector_->IsNearendState()); +} + +void SuppressionGain::SetInitialState(bool state) { + initial_state_ = state; + if (state) { + initial_state_change_counter_ = state_change_duration_blocks_; + } else { + initial_state_change_counter_ = 0; + } +} + +// Detects when the render signal can be considered to have low power and +// consist of stationary noise. +bool SuppressionGain::LowNoiseRenderDetector::Detect(const Block& render) { + float x2_sum = 0.f; + float x2_max = 0.f; + for (int ch = 0; ch < render.NumChannels(); ++ch) { + for (float x_k : render.View(/*band=*/0, ch)) { + const float x2 = x_k * x_k; + x2_sum += x2; + x2_max = std::max(x2_max, x2); + } + } + x2_sum = x2_sum / render.NumChannels(); + + constexpr float kThreshold = 50.f * 50.f * 64.f; + const bool low_noise_render = + average_power_ < kThreshold && x2_max < 3 * average_power_; + average_power_ = average_power_ * 0.9f + x2_sum * 0.1f; + return low_noise_render; +} + +SuppressionGain::GainParameters::GainParameters( + int last_lf_band, + int first_hf_band, + const EchoCanceller3Config::Suppressor::Tuning& tuning) + : max_inc_factor(tuning.max_inc_factor), + max_dec_factor_lf(tuning.max_dec_factor_lf) { + // Compute per-band masking thresholds. + RTC_DCHECK_LT(last_lf_band, first_hf_band); + auto& lf = tuning.mask_lf; + auto& hf = tuning.mask_hf; + RTC_DCHECK_LT(lf.enr_transparent, lf.enr_suppress); + RTC_DCHECK_LT(hf.enr_transparent, hf.enr_suppress); + for (int k = 0; k < static_cast<int>(kFftLengthBy2Plus1); k++) { + float a; + if (k <= last_lf_band) { + a = 0.f; + } else if (k < first_hf_band) { + a = (k - last_lf_band) / static_cast<float>(first_hf_band - last_lf_band); + } else { + a = 1.f; + } + enr_transparent_[k] = (1 - a) * lf.enr_transparent + a * hf.enr_transparent; + enr_suppress_[k] = (1 - a) * lf.enr_suppress + a * hf.enr_suppress; + emr_transparent_[k] = (1 - a) * lf.emr_transparent + a * hf.emr_transparent; + } +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.h b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.h new file mode 100644 index 0000000000..c19ddd7e30 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_ + +#include <array> +#include <atomic> +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/fft_data.h" +#include "modules/audio_processing/aec3/moving_average.h" +#include "modules/audio_processing/aec3/nearend_detector.h" +#include "modules/audio_processing/aec3/render_signal_analyzer.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" + +namespace webrtc { + +class SuppressionGain { + public: + SuppressionGain(const EchoCanceller3Config& config, + Aec3Optimization optimization, + int sample_rate_hz, + size_t num_capture_channels); + ~SuppressionGain(); + + SuppressionGain(const SuppressionGain&) = delete; + SuppressionGain& operator=(const SuppressionGain&) = delete; + + void GetGain( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + nearend_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + residual_echo_spectrum_unbounded, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + const RenderSignalAnalyzer& render_signal_analyzer, + const AecState& aec_state, + const Block& render, + bool clock_drift, + float* high_bands_gain, + std::array<float, kFftLengthBy2Plus1>* low_band_gain); + + bool IsDominantNearend() { + return dominant_nearend_detector_->IsNearendState(); + } + + // Toggles the usage of the initial state. + void SetInitialState(bool state); + + private: + // Computes the gain to apply for the bands beyond the first band. + float UpperBandsGain( + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + comfort_noise_spectrum, + const absl::optional<int>& narrow_peak_band, + bool saturated_echo, + const Block& render, + const std::array<float, kFftLengthBy2Plus1>& low_band_gain) const; + + void GainToNoAudibleEcho(const std::array<float, kFftLengthBy2Plus1>& nearend, + const std::array<float, kFftLengthBy2Plus1>& echo, + const std::array<float, kFftLengthBy2Plus1>& masker, + std::array<float, kFftLengthBy2Plus1>* gain) const; + + void LowerBandGain( + bool stationary_with_low_power, + const AecState& aec_state, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> + suppressor_input, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> residual_echo, + rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> comfort_noise, + bool clock_drift, + std::array<float, kFftLengthBy2Plus1>* gain); + + void GetMinGain(rtc::ArrayView<const float> weighted_residual_echo, + rtc::ArrayView<const float> last_nearend, + rtc::ArrayView<const float> last_echo, + bool low_noise_render, + bool saturated_echo, + rtc::ArrayView<float> min_gain) const; + + void GetMaxGain(rtc::ArrayView<float> max_gain) const; + + class LowNoiseRenderDetector { + public: + bool Detect(const Block& render); + + private: + float average_power_ = 32768.f * 32768.f; + }; + + struct GainParameters { + explicit GainParameters( + int last_lf_band, + int first_hf_band, + const EchoCanceller3Config::Suppressor::Tuning& tuning); + const float max_inc_factor; + const float max_dec_factor_lf; + std::array<float, kFftLengthBy2Plus1> enr_transparent_; + std::array<float, kFftLengthBy2Plus1> enr_suppress_; + std::array<float, kFftLengthBy2Plus1> emr_transparent_; + }; + + static std::atomic<int> instance_count_; + std::unique_ptr<ApmDataDumper> data_dumper_; + const Aec3Optimization optimization_; + const EchoCanceller3Config config_; + const size_t num_capture_channels_; + const int state_change_duration_blocks_; + std::array<float, kFftLengthBy2Plus1> last_gain_; + std::vector<std::array<float, kFftLengthBy2Plus1>> last_nearend_; + std::vector<std::array<float, kFftLengthBy2Plus1>> last_echo_; + LowNoiseRenderDetector low_render_detector_; + bool initial_state_ = true; + int initial_state_change_counter_ = 0; + std::vector<aec3::MovingAverage> nearend_smoothers_; + const GainParameters nearend_params_; + const GainParameters normal_params_; + // Determines if the dominant nearend detector uses the unbounded residual + // echo spectrum. + const bool use_unbounded_echo_spectrum_; + std::unique_ptr<NearendDetector> dominant_nearend_detector_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc new file mode 100644 index 0000000000..02de706c77 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/suppression_gain.h" + +#include "modules/audio_processing/aec3/aec_state.h" +#include "modules/audio_processing/aec3/render_delay_buffer.h" +#include "modules/audio_processing/aec3/subtractor.h" +#include "modules/audio_processing/aec3/subtractor_output.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "rtc_base/checks.h" +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { +namespace aec3 { + +#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) + +// Verifies that the check for non-null output gains works. +TEST(SuppressionGainDeathTest, NullOutputGains) { + std::vector<std::array<float, kFftLengthBy2Plus1>> E2(1, {0.0f}); + std::vector<std::array<float, kFftLengthBy2Plus1>> R2(1, {0.0f}); + std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded(1, {0.0f}); + std::vector<std::array<float, kFftLengthBy2Plus1>> S2(1); + std::vector<std::array<float, kFftLengthBy2Plus1>> N2(1, {0.0f}); + for (auto& S2_k : S2) { + S2_k.fill(0.1f); + } + FftData E; + FftData Y; + E.re.fill(0.0f); + E.im.fill(0.0f); + Y.re.fill(0.0f); + Y.im.fill(0.0f); + + float high_bands_gain; + AecState aec_state(EchoCanceller3Config{}, 1); + EXPECT_DEATH( + SuppressionGain(EchoCanceller3Config{}, DetectOptimization(), 16000, 1) + .GetGain(E2, S2, R2, R2_unbounded, N2, + RenderSignalAnalyzer((EchoCanceller3Config{})), aec_state, + Block(3, 1), false, &high_bands_gain, nullptr), + ""); +} + +#endif + +// Does a sanity check that the gains are correctly computed. +TEST(SuppressionGain, BasicGainComputation) { + constexpr size_t kNumRenderChannels = 1; + constexpr size_t kNumCaptureChannels = 2; + constexpr int kSampleRateHz = 16000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + SuppressionGain suppression_gain(EchoCanceller3Config(), DetectOptimization(), + kSampleRateHz, kNumCaptureChannels); + RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); + float high_bands_gain; + std::vector<std::array<float, kFftLengthBy2Plus1>> E2(kNumCaptureChannels); + std::vector<std::array<float, kFftLengthBy2Plus1>> S2(kNumCaptureChannels, + {0.0f}); + std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels); + std::vector<std::array<float, kFftLengthBy2Plus1>> R2(kNumCaptureChannels); + std::vector<std::array<float, kFftLengthBy2Plus1>> R2_unbounded( + kNumCaptureChannels); + std::vector<std::array<float, kFftLengthBy2Plus1>> N2(kNumCaptureChannels); + std::array<float, kFftLengthBy2Plus1> g; + std::vector<SubtractorOutput> output(kNumCaptureChannels); + Block x(kNumBands, kNumRenderChannels); + EchoCanceller3Config config; + AecState aec_state(config, kNumCaptureChannels); + ApmDataDumper data_dumper(42); + Subtractor subtractor(config, kNumRenderChannels, kNumCaptureChannels, + &data_dumper, DetectOptimization()); + std::unique_ptr<RenderDelayBuffer> render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); + absl::optional<DelayEstimate> delay_estimate; + + // Ensure that a strong noise is detected to mask any echoes. + for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) { + E2[ch].fill(10.f); + Y2[ch].fill(10.f); + R2[ch].fill(0.1f); + R2_unbounded[ch].fill(0.1f); + N2[ch].fill(100.0f); + } + for (auto& subtractor_output : output) { + subtractor_output.Reset(); + } + + // Ensure that the gain is no longer forced to zero. + for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) { + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), + subtractor.FilterImpulseResponses(), + *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); + } + + for (int k = 0; k < 100; ++k) { + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), + subtractor.FilterImpulseResponses(), + *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); + suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state, + x, false, &high_bands_gain, &g); + } + std::for_each(g.begin(), g.end(), + [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); }); + + // Ensure that a strong nearend is detected to mask any echoes. + for (size_t ch = 0; ch < kNumCaptureChannels; ++ch) { + E2[ch].fill(100.f); + Y2[ch].fill(100.f); + R2[ch].fill(0.1f); + R2_unbounded[ch].fill(0.1f); + S2[ch].fill(0.1f); + N2[ch].fill(0.f); + } + + for (int k = 0; k < 100; ++k) { + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), + subtractor.FilterImpulseResponses(), + *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); + suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state, + x, false, &high_bands_gain, &g); + } + std::for_each(g.begin(), g.end(), + [](float a) { EXPECT_NEAR(1.0f, a, 0.001f); }); + + // Add a strong echo to one of the channels and ensure that it is suppressed. + E2[1].fill(1000000000.0f); + R2[1].fill(10000000000000.0f); + R2_unbounded[1].fill(10000000000000.0f); + + for (int k = 0; k < 10; ++k) { + suppression_gain.GetGain(E2, S2, R2, R2_unbounded, N2, analyzer, aec_state, + x, false, &high_bands_gain, &g); + } + std::for_each(g.begin(), g.end(), + [](float a) { EXPECT_NEAR(0.0f, a, 0.001f); }); +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.cc b/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.cc new file mode 100644 index 0000000000..489f53f4f1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.cc @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/transparent_mode.h" + +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "system_wrappers/include/field_trial.h" + +namespace webrtc { +namespace { + +constexpr size_t kBlocksSinceConvergencedFilterInit = 10000; +constexpr size_t kBlocksSinceConsistentEstimateInit = 10000; + +bool DeactivateTransparentMode() { + return field_trial::IsEnabled("WebRTC-Aec3TransparentModeKillSwitch"); +} + +bool ActivateTransparentModeHmm() { + return field_trial::IsEnabled("WebRTC-Aec3TransparentModeHmm"); +} + +} // namespace + +// Classifier that toggles transparent mode which reduces echo suppression when +// headsets are used. +class TransparentModeImpl : public TransparentMode { + public: + bool Active() const override { return transparency_activated_; } + + void Reset() override { + // Determines if transparent mode is used. + transparency_activated_ = false; + + // The estimated probability of being transparent mode. + prob_transparent_state_ = 0.f; + } + + void Update(int filter_delay_blocks, + bool any_filter_consistent, + bool any_filter_converged, + bool any_coarse_filter_converged, + bool all_filters_diverged, + bool active_render, + bool saturated_capture) override { + // The classifier is implemented as a Hidden Markov Model (HMM) with two + // hidden states: "normal" and "transparent". The estimated probabilities of + // the two states are updated by observing filter convergence during active + // render. The filters are less likely to be reported as converged when + // there is no echo present in the microphone signal. + + // The constants have been obtained by observing active_render and + // any_coarse_filter_converged under varying call scenarios. They + // have further been hand tuned to prefer normal state during uncertain + // regions (to avoid echo leaks). + + // The model is only updated during active render. + if (!active_render) + return; + + // Probability of switching from one state to the other. + constexpr float kSwitch = 0.000001f; + + // Probability of observing converged filters in states "normal" and + // "transparent" during active render. + constexpr float kConvergedNormal = 0.01f; + constexpr float kConvergedTransparent = 0.001f; + + // Probability of transitioning to transparent state from normal state and + // transparent state respectively. + constexpr float kA[2] = {kSwitch, 1.f - kSwitch}; + + // Probability of the two observations (converged filter or not converged + // filter) in normal state and transparent state respectively. + constexpr float kB[2][2] = { + {1.f - kConvergedNormal, kConvergedNormal}, + {1.f - kConvergedTransparent, kConvergedTransparent}}; + + // Probability of the two states before the update. + const float prob_transparent = prob_transparent_state_; + const float prob_normal = 1.f - prob_transparent; + + // Probability of transitioning to transparent state. + const float prob_transition_transparent = + prob_normal * kA[0] + prob_transparent * kA[1]; + const float prob_transition_normal = 1.f - prob_transition_transparent; + + // Observed output. + const int out = static_cast<int>(any_coarse_filter_converged); + + // Joint probabilites of the observed output and respective states. + const float prob_joint_normal = prob_transition_normal * kB[0][out]; + const float prob_joint_transparent = + prob_transition_transparent * kB[1][out]; + + // Conditional probability of transparent state and the observed output. + RTC_DCHECK_GT(prob_joint_normal + prob_joint_transparent, 0.f); + prob_transparent_state_ = + prob_joint_transparent / (prob_joint_normal + prob_joint_transparent); + + // Transparent mode is only activated when its state probability is high. + // Dead zone between activation/deactivation thresholds to avoid switching + // back and forth. + if (prob_transparent_state_ > 0.95f) { + transparency_activated_ = true; + } else if (prob_transparent_state_ < 0.5f) { + transparency_activated_ = false; + } + } + + private: + bool transparency_activated_ = false; + float prob_transparent_state_ = 0.f; +}; + +// Legacy classifier for toggling transparent mode. +class LegacyTransparentModeImpl : public TransparentMode { + public: + explicit LegacyTransparentModeImpl(const EchoCanceller3Config& config) + : linear_and_stable_echo_path_( + config.echo_removal_control.linear_and_stable_echo_path), + active_blocks_since_sane_filter_(kBlocksSinceConsistentEstimateInit), + non_converged_sequence_size_(kBlocksSinceConvergencedFilterInit) {} + + bool Active() const override { return transparency_activated_; } + + void Reset() override { + non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit; + diverged_sequence_size_ = 0; + strong_not_saturated_render_blocks_ = 0; + if (linear_and_stable_echo_path_) { + recent_convergence_during_activity_ = false; + } + } + + void Update(int filter_delay_blocks, + bool any_filter_consistent, + bool any_filter_converged, + bool any_coarse_filter_converged, + bool all_filters_diverged, + bool active_render, + bool saturated_capture) override { + ++capture_block_counter_; + strong_not_saturated_render_blocks_ += + active_render && !saturated_capture ? 1 : 0; + + if (any_filter_consistent && filter_delay_blocks < 5) { + sane_filter_observed_ = true; + active_blocks_since_sane_filter_ = 0; + } else if (active_render) { + ++active_blocks_since_sane_filter_; + } + + bool sane_filter_recently_seen; + if (!sane_filter_observed_) { + sane_filter_recently_seen = + capture_block_counter_ <= 5 * kNumBlocksPerSecond; + } else { + sane_filter_recently_seen = + active_blocks_since_sane_filter_ <= 30 * kNumBlocksPerSecond; + } + + if (any_filter_converged) { + recent_convergence_during_activity_ = true; + active_non_converged_sequence_size_ = 0; + non_converged_sequence_size_ = 0; + ++num_converged_blocks_; + } else { + if (++non_converged_sequence_size_ > 20 * kNumBlocksPerSecond) { + num_converged_blocks_ = 0; + } + + if (active_render && + ++active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) { + recent_convergence_during_activity_ = false; + } + } + + if (!all_filters_diverged) { + diverged_sequence_size_ = 0; + } else if (++diverged_sequence_size_ >= 60) { + // TODO(peah): Change these lines to ensure proper triggering of usable + // filter. + non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit; + } + + if (active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) { + finite_erl_recently_detected_ = false; + } + if (num_converged_blocks_ > 50) { + finite_erl_recently_detected_ = true; + } + + if (finite_erl_recently_detected_) { + transparency_activated_ = false; + } else if (sane_filter_recently_seen && + recent_convergence_during_activity_) { + transparency_activated_ = false; + } else { + const bool filter_should_have_converged = + strong_not_saturated_render_blocks_ > 6 * kNumBlocksPerSecond; + transparency_activated_ = filter_should_have_converged; + } + } + + private: + const bool linear_and_stable_echo_path_; + size_t capture_block_counter_ = 0; + bool transparency_activated_ = false; + size_t active_blocks_since_sane_filter_; + bool sane_filter_observed_ = false; + bool finite_erl_recently_detected_ = false; + size_t non_converged_sequence_size_; + size_t diverged_sequence_size_ = 0; + size_t active_non_converged_sequence_size_ = 0; + size_t num_converged_blocks_ = 0; + bool recent_convergence_during_activity_ = false; + size_t strong_not_saturated_render_blocks_ = 0; +}; + +std::unique_ptr<TransparentMode> TransparentMode::Create( + const EchoCanceller3Config& config) { + if (config.ep_strength.bounded_erl || DeactivateTransparentMode()) { + RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: Disabled"; + return nullptr; + } + if (ActivateTransparentModeHmm()) { + RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: HMM"; + return std::make_unique<TransparentModeImpl>(); + } + RTC_LOG(LS_INFO) << "AEC3 Transparent Mode: Legacy"; + return std::make_unique<LegacyTransparentModeImpl>(config); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.h b/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.h new file mode 100644 index 0000000000..bc5dd0391b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/transparent_mode.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_TRANSPARENT_MODE_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_TRANSPARENT_MODE_H_ + +#include <memory> + +#include "api/audio/echo_canceller3_config.h" +#include "modules/audio_processing/aec3/aec3_common.h" + +namespace webrtc { + +// Class for detecting and toggling the transparent mode which causes the +// suppressor to apply less suppression. +class TransparentMode { + public: + static std::unique_ptr<TransparentMode> Create( + const EchoCanceller3Config& config); + + virtual ~TransparentMode() {} + + // Returns whether the transparent mode should be active. + virtual bool Active() const = 0; + + // Resets the state of the detector. + virtual void Reset() = 0; + + // Updates the detection decision based on new data. + virtual void Update(int filter_delay_blocks, + bool any_filter_consistent, + bool any_filter_converged, + bool any_coarse_filter_converged, + bool all_filters_diverged, + bool active_render, + bool saturated_capture) = 0; +}; + +} // namespace webrtc +#endif // MODULES_AUDIO_PROCESSING_AEC3_TRANSPARENT_MODE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/vector_math.h b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math.h new file mode 100644 index 0000000000..e4d1381ae1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math.h @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ + +// Defines WEBRTC_ARCH_X86_FAMILY, used below. +#include "rtc_base/system/arch.h" + +#if defined(WEBRTC_HAS_NEON) +#include <arm_neon.h> +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include <emmintrin.h> +#endif +#include <math.h> + +#include <algorithm> +#include <array> +#include <functional> + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace aec3 { + +// Provides optimizations for mathematical operations based on vectors. +class VectorMath { + public: + explicit VectorMath(Aec3Optimization optimization) + : optimization_(optimization) {} + + // Elementwise square root. + void SqrtAVX2(rtc::ArrayView<float> x); + void Sqrt(rtc::ArrayView<float> x) { + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + __m128 g = _mm_loadu_ps(&x[j]); + g = _mm_sqrt_ps(g); + _mm_storeu_ps(&x[j], g); + } + + for (; j < x_size; ++j) { + x[j] = sqrtf(x[j]); + } + } break; + case Aec3Optimization::kAvx2: + SqrtAVX2(x); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + float32x4_t g = vld1q_f32(&x[j]); +#if !defined(WEBRTC_ARCH_ARM64) + float32x4_t y = vrsqrteq_f32(g); + + // Code to handle sqrt(0). + // If the input to sqrtf() is zero, a zero will be returned. + // If the input to vrsqrteq_f32() is zero, positive infinity is + // returned. + const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000); + // check for divide by zero + const uint32x4_t div_by_zero = + vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y)); + // zero out the positive infinity results + y = vreinterpretq_f32_u32( + vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y))); + // from arm documentation + // The Newton-Raphson iteration: + // y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2) + // converges to (1/√d) if y0 is the result of VRSQRTE applied to d. + // + // Note: The precision did not improve after 2 iterations. + for (int i = 0; i < 2; i++) { + y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y); + } + // sqrt(g) = g * 1/sqrt(g) + g = vmulq_f32(g, y); +#else + g = vsqrtq_f32(g); +#endif + vst1q_f32(&x[j], g); + } + + for (; j < x_size; ++j) { + x[j] = sqrtf(x[j]); + } + } +#endif + break; + default: + std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); }); + } + } + + // Elementwise vector multiplication z = x * y. + void MultiplyAVX2(rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> z); + void Multiply(rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> z) { + RTC_DCHECK_EQ(z.size(), x.size()); + RTC_DCHECK_EQ(z.size(), y.size()); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const __m128 x_j = _mm_loadu_ps(&x[j]); + const __m128 y_j = _mm_loadu_ps(&y[j]); + const __m128 z_j = _mm_mul_ps(x_j, y_j); + _mm_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] = x[j] * y[j]; + } + } break; + case Aec3Optimization::kAvx2: + MultiplyAVX2(x, y, z); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const float32x4_t x_j = vld1q_f32(&x[j]); + const float32x4_t y_j = vld1q_f32(&y[j]); + const float32x4_t z_j = vmulq_f32(x_j, y_j); + vst1q_f32(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] = x[j] * y[j]; + } + } break; +#endif + default: + std::transform(x.begin(), x.end(), y.begin(), z.begin(), + std::multiplies<float>()); + } + } + + // Elementwise vector accumulation z += x. + void AccumulateAVX2(rtc::ArrayView<const float> x, rtc::ArrayView<float> z); + void Accumulate(rtc::ArrayView<const float> x, rtc::ArrayView<float> z) { + RTC_DCHECK_EQ(z.size(), x.size()); + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const __m128 x_j = _mm_loadu_ps(&x[j]); + __m128 z_j = _mm_loadu_ps(&z[j]); + z_j = _mm_add_ps(x_j, z_j); + _mm_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] += x[j]; + } + } break; + case Aec3Optimization::kAvx2: + AccumulateAVX2(x, z); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 2; + + int j = 0; + for (; j < vector_limit * 4; j += 4) { + const float32x4_t x_j = vld1q_f32(&x[j]); + float32x4_t z_j = vld1q_f32(&z[j]); + z_j = vaddq_f32(z_j, x_j); + vst1q_f32(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] += x[j]; + } + } break; +#endif + default: + std::transform(x.begin(), x.end(), z.begin(), z.begin(), + std::plus<float>()); + } + } + + private: + Aec3Optimization optimization_; +}; + +} // namespace aec3 + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_avx2.cc b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_avx2.cc new file mode 100644 index 0000000000..0b5f3c142e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_avx2.cc @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/vector_math.h" + +#include <immintrin.h> +#include <math.h> + +#include "api/array_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace aec3 { + +// Elementwise square root. +void VectorMath::SqrtAVX2(rtc::ArrayView<float> x) { + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 3; + + int j = 0; + for (; j < vector_limit * 8; j += 8) { + __m256 g = _mm256_loadu_ps(&x[j]); + g = _mm256_sqrt_ps(g); + _mm256_storeu_ps(&x[j], g); + } + + for (; j < x_size; ++j) { + x[j] = sqrtf(x[j]); + } +} + +// Elementwise vector multiplication z = x * y. +void VectorMath::MultiplyAVX2(rtc::ArrayView<const float> x, + rtc::ArrayView<const float> y, + rtc::ArrayView<float> z) { + RTC_DCHECK_EQ(z.size(), x.size()); + RTC_DCHECK_EQ(z.size(), y.size()); + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 3; + + int j = 0; + for (; j < vector_limit * 8; j += 8) { + const __m256 x_j = _mm256_loadu_ps(&x[j]); + const __m256 y_j = _mm256_loadu_ps(&y[j]); + const __m256 z_j = _mm256_mul_ps(x_j, y_j); + _mm256_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] = x[j] * y[j]; + } +} + +// Elementwise vector accumulation z += x. +void VectorMath::AccumulateAVX2(rtc::ArrayView<const float> x, + rtc::ArrayView<float> z) { + RTC_DCHECK_EQ(z.size(), x.size()); + const int x_size = static_cast<int>(x.size()); + const int vector_limit = x_size >> 3; + + int j = 0; + for (; j < vector_limit * 8; j += 8) { + const __m256 x_j = _mm256_loadu_ps(&x[j]); + __m256 z_j = _mm256_loadu_ps(&z[j]); + z_j = _mm256_add_ps(x_j, z_j); + _mm256_storeu_ps(&z[j], z_j); + } + + for (; j < x_size; ++j) { + z[j] += x[j]; + } +} + +} // namespace aec3 +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_gn/moz.build b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_gn/moz.build new file mode 100644 index 0000000000..89ee0b6a81 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_gn/moz.build @@ -0,0 +1,205 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + + ### This moz.build was AUTOMATICALLY GENERATED from a GN config, ### + ### DO NOT edit it by hand. ### + +COMPILE_FLAGS["OS_INCLUDES"] = [] +AllowCompilerWarnings() + +DEFINES["ABSL_ALLOCATOR_NOTHROW"] = "1" +DEFINES["RTC_DAV1D_IN_INTERNAL_DECODER_FACTORY"] = True +DEFINES["RTC_ENABLE_VP9"] = True +DEFINES["WEBRTC_ENABLE_PROTOBUF"] = "0" +DEFINES["WEBRTC_LIBRARY_IMPL"] = True +DEFINES["WEBRTC_MOZILLA_BUILD"] = True +DEFINES["WEBRTC_NON_STATIC_TRACE_EVENT_HANDLERS"] = "0" +DEFINES["WEBRTC_STRICT_FIELD_TRIALS"] = "0" + +FINAL_LIBRARY = "webrtc" + + +LOCAL_INCLUDES += [ + "!/ipc/ipdl/_ipdlheaders", + "!/third_party/libwebrtc/gen", + "/ipc/chromium/src", + "/third_party/libwebrtc/", + "/third_party/libwebrtc/third_party/abseil-cpp/", + "/tools/profiler/public" +] + +if not CONFIG["MOZ_DEBUG"]: + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "0" + DEFINES["NDEBUG"] = True + DEFINES["NVALGRIND"] = True + +if CONFIG["MOZ_DEBUG"] == "1": + + DEFINES["DYNAMIC_ANNOTATIONS_ENABLED"] = "1" + +if CONFIG["OS_TARGET"] == "Android": + + DEFINES["ANDROID"] = True + DEFINES["ANDROID_NDK_VERSION_ROLL"] = "r22_1" + DEFINES["HAVE_SYS_UIO_H"] = True + DEFINES["WEBRTC_ANDROID"] = True + DEFINES["WEBRTC_ANDROID_OPENSLES"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_GNU_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + + OS_LIBS += [ + "log" + ] + +if CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["WEBRTC_MAC"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_LIBCPP_HAS_NO_ALIGNED_ALLOCATION"] = True + DEFINES["__ASSERT_MACROS_DEFINE_VERSIONS_WITHOUT_UNDERSCORES"] = "0" + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_AURA"] = "1" + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_NSS_CERTS"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_UDEV"] = True + DEFINES["WEBRTC_LINUX"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["USE_GLIB"] = "1" + DEFINES["USE_OZONE"] = "1" + DEFINES["USE_X11"] = "1" + DEFINES["WEBRTC_BSD"] = True + DEFINES["WEBRTC_POSIX"] = True + DEFINES["_FILE_OFFSET_BITS"] = "64" + DEFINES["_LARGEFILE64_SOURCE"] = True + DEFINES["_LARGEFILE_SOURCE"] = True + DEFINES["__STDC_CONSTANT_MACROS"] = True + DEFINES["__STDC_FORMAT_MACROS"] = True + +if CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["CERT_CHAIN_PARA_HAS_EXTRA_FIELDS"] = True + DEFINES["NOMINMAX"] = True + DEFINES["NTDDI_VERSION"] = "0x0A000000" + DEFINES["PSAPI_VERSION"] = "2" + DEFINES["UNICODE"] = True + DEFINES["USE_AURA"] = "1" + DEFINES["WEBRTC_WIN"] = True + DEFINES["WIN32"] = True + DEFINES["WIN32_LEAN_AND_MEAN"] = True + DEFINES["WINAPI_FAMILY"] = "WINAPI_FAMILY_DESKTOP_APP" + DEFINES["WINVER"] = "0x0A00" + DEFINES["_ATL_NO_OPENGL"] = True + DEFINES["_CRT_RAND_S"] = True + DEFINES["_CRT_SECURE_NO_DEPRECATE"] = True + DEFINES["_ENABLE_EXTENDED_ALIGNED_STORAGE"] = True + DEFINES["_HAS_EXCEPTIONS"] = "0" + DEFINES["_HAS_NODISCARD"] = True + DEFINES["_SCL_SECURE_NO_DEPRECATE"] = True + DEFINES["_SECURE_ATL"] = True + DEFINES["_UNICODE"] = True + DEFINES["_WIN32_WINNT"] = "0x0A00" + DEFINES["_WINDOWS"] = True + DEFINES["__STD_C"] = True + +if CONFIG["CPU_ARCH"] == "aarch64": + + DEFINES["WEBRTC_ARCH_ARM64"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "arm": + + DEFINES["WEBRTC_ARCH_ARM"] = True + DEFINES["WEBRTC_ARCH_ARM_V7"] = True + DEFINES["WEBRTC_HAS_NEON"] = True + +if CONFIG["CPU_ARCH"] == "mips32": + + DEFINES["MIPS32_LE"] = True + DEFINES["MIPS_FPU_LE"] = True + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "mips64": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["CPU_ARCH"] == "x86_64": + + DEFINES["WEBRTC_ENABLE_AVX2"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Android": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Darwin": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "OpenBSD": + + DEFINES["_DEBUG"] = True + +if CONFIG["MOZ_DEBUG"] == "1" and CONFIG["OS_TARGET"] == "WINNT": + + DEFINES["_HAS_ITERATOR_DEBUGGING"] = "0" + +if CONFIG["MOZ_X11"] == "1" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["USE_X11"] = "1" + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support", + "unwind" + ] + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Android": + + OS_LIBS += [ + "android_support" + ] + +if CONFIG["CPU_ARCH"] == "aarch64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "arm" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +if CONFIG["CPU_ARCH"] == "x86_64" and CONFIG["OS_TARGET"] == "Linux": + + DEFINES["_GNU_SOURCE"] = True + +Library("vector_math_gn") diff --git a/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_unittest.cc b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_unittest.cc new file mode 100644 index 0000000000..a9c37e33cf --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/aec3/vector_math_unittest.cc @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/vector_math.h" + +#include <math.h> + +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { + +#if defined(WEBRTC_HAS_NEON) + +TEST(VectorMath, Sqrt) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_neon; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = (2.f / 3.f) * k; + } + + std::copy(x.begin(), x.end(), z.begin()); + aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z); + std::copy(x.begin(), x.end(), z_neon.begin()); + aec3::VectorMath(Aec3Optimization::kNeon).Sqrt(z_neon); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_NEAR(z[k], z_neon[k], 0.0001f); + EXPECT_NEAR(sqrtf(x[k]), z_neon[k], 0.0001f); + } +} + +TEST(VectorMath, Multiply) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> y; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_neon; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + y[k] = (2.f / 3.f) * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z); + aec3::VectorMath(Aec3Optimization::kNeon).Multiply(x, y, z_neon); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_neon[k]); + EXPECT_FLOAT_EQ(x[k] * y[k], z_neon[k]); + } +} + +TEST(VectorMath, Accumulate) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_neon; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + z[k] = z_neon[k] = 2.f * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z); + aec3::VectorMath(Aec3Optimization::kNeon).Accumulate(x, z_neon); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_neon[k]); + EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_neon[k]); + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) + +TEST(VectorMath, Sse2Sqrt) { + if (GetCPUInfo(kSSE2) != 0) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_sse2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = (2.f / 3.f) * k; + } + + std::copy(x.begin(), x.end(), z.begin()); + aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z); + std::copy(x.begin(), x.end(), z_sse2.begin()); + aec3::VectorMath(Aec3Optimization::kSse2).Sqrt(z_sse2); + EXPECT_EQ(z, z_sse2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_sse2[k]); + EXPECT_FLOAT_EQ(sqrtf(x[k]), z_sse2[k]); + } + } +} + +TEST(VectorMath, Avx2Sqrt) { + if (GetCPUInfo(kAVX2) != 0) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_avx2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = (2.f / 3.f) * k; + } + + std::copy(x.begin(), x.end(), z.begin()); + aec3::VectorMath(Aec3Optimization::kNone).Sqrt(z); + std::copy(x.begin(), x.end(), z_avx2.begin()); + aec3::VectorMath(Aec3Optimization::kAvx2).Sqrt(z_avx2); + EXPECT_EQ(z, z_avx2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_avx2[k]); + EXPECT_FLOAT_EQ(sqrtf(x[k]), z_avx2[k]); + } + } +} + +TEST(VectorMath, Sse2Multiply) { + if (GetCPUInfo(kSSE2) != 0) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> y; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_sse2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + y[k] = (2.f / 3.f) * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z); + aec3::VectorMath(Aec3Optimization::kSse2).Multiply(x, y, z_sse2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_sse2[k]); + EXPECT_FLOAT_EQ(x[k] * y[k], z_sse2[k]); + } + } +} + +TEST(VectorMath, Avx2Multiply) { + if (GetCPUInfo(kAVX2) != 0) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> y; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_avx2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + y[k] = (2.f / 3.f) * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Multiply(x, y, z); + aec3::VectorMath(Aec3Optimization::kAvx2).Multiply(x, y, z_avx2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_avx2[k]); + EXPECT_FLOAT_EQ(x[k] * y[k], z_avx2[k]); + } + } +} + +TEST(VectorMath, Sse2Accumulate) { + if (GetCPUInfo(kSSE2) != 0) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_sse2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + z[k] = z_sse2[k] = 2.f * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z); + aec3::VectorMath(Aec3Optimization::kSse2).Accumulate(x, z_sse2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_sse2[k]); + EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_sse2[k]); + } + } +} + +TEST(VectorMath, Avx2Accumulate) { + if (GetCPUInfo(kAVX2) != 0) { + std::array<float, kFftLengthBy2Plus1> x; + std::array<float, kFftLengthBy2Plus1> z; + std::array<float, kFftLengthBy2Plus1> z_avx2; + + for (size_t k = 0; k < x.size(); ++k) { + x[k] = k; + z[k] = z_avx2[k] = 2.f * k; + } + + aec3::VectorMath(Aec3Optimization::kNone).Accumulate(x, z); + aec3::VectorMath(Aec3Optimization::kAvx2).Accumulate(x, z_avx2); + for (size_t k = 0; k < z.size(); ++k) { + EXPECT_FLOAT_EQ(z[k], z_avx2[k]); + EXPECT_FLOAT_EQ(x[k] + 2.f * x[k], z_avx2[k]); + } + } +} +#endif + +} // namespace webrtc |