/* * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/aec3/reverb_model_estimator.h" #include #include #include #include #include #include "absl/types/optional.h" #include "api/array_view.h" #include "api/audio/echo_canceller3_config.h" #include "modules/audio_processing/aec3/aec3_common.h" #include "modules/audio_processing/aec3/aec3_fft.h" #include "modules/audio_processing/aec3/fft_data.h" #include "rtc_base/checks.h" #include "test/gtest.h" namespace webrtc { namespace { EchoCanceller3Config CreateConfigForTest(float default_decay) { EchoCanceller3Config cfg; cfg.ep_strength.default_len = default_decay; cfg.filter.refined.length_blocks = 40; return cfg; } constexpr int kFilterDelayBlocks = 2; } // namespace class ReverbModelEstimatorTest { public: ReverbModelEstimatorTest(float default_decay, size_t num_capture_channels) : aec3_config_(CreateConfigForTest(default_decay)), estimated_decay_(default_decay), h_(num_capture_channels, std::vector( aec3_config_.filter.refined.length_blocks * kBlockSize, 0.f)), H2_(num_capture_channels, std::vector>( aec3_config_.filter.refined.length_blocks)), quality_linear_(num_capture_channels, 1.0f) { CreateImpulseResponseWithDecay(); } void RunEstimator(); float GetDecay(bool mild) { return mild ? mild_estimated_decay_ : estimated_decay_; } float GetTrueDecay() { return kTruePowerDecay; } float GetPowerTailDb() { return 10.f * std::log10(estimated_power_tail_); } float GetTruePowerTailDb() { return 10.f * std::log10(true_power_tail_); } private: void CreateImpulseResponseWithDecay(); static constexpr bool kStationaryBlock = false; static constexpr float kTruePowerDecay = 0.5f; const EchoCanceller3Config aec3_config_; float estimated_decay_; float mild_estimated_decay_; float estimated_power_tail_ = 0.f; float true_power_tail_ = 0.f; std::vector> h_; std::vector>> H2_; std::vector> quality_linear_; }; void ReverbModelEstimatorTest::CreateImpulseResponseWithDecay() { const Aec3Fft fft; for (const auto& h_k : h_) { RTC_DCHECK_EQ(h_k.size(), aec3_config_.filter.refined.length_blocks * kBlockSize); } for (const auto& H2_k : H2_) { RTC_DCHECK_EQ(H2_k.size(), aec3_config_.filter.refined.length_blocks); } RTC_DCHECK_EQ(kFilterDelayBlocks, 2); float decay_sample = std::sqrt(powf(kTruePowerDecay, 1.f / kBlockSize)); const size_t filter_delay_coefficients = kFilterDelayBlocks * kBlockSize; for (auto& h_i : h_) { std::fill(h_i.begin(), h_i.end(), 0.f); h_i[filter_delay_coefficients] = 1.f; for (size_t k = filter_delay_coefficients + 1; k < h_i.size(); ++k) { h_i[k] = h_i[k - 1] * decay_sample; } } for (size_t ch = 0; ch < H2_.size(); ++ch) { for (size_t j = 0, k = 0; j < H2_[ch].size(); ++j, k += kBlockSize) { std::array fft_data; fft_data.fill(0.f); std::copy(h_[ch].begin() + k, h_[ch].begin() + k + kBlockSize, fft_data.begin()); FftData H_j; fft.Fft(&fft_data, &H_j); H_j.Spectrum(Aec3Optimization::kNone, H2_[ch][j]); } } rtc::ArrayView H2_tail(H2_[0][H2_[0].size() - 1]); true_power_tail_ = std::accumulate(H2_tail.begin(), H2_tail.end(), 0.f); } void ReverbModelEstimatorTest::RunEstimator() { const size_t num_capture_channels = H2_.size(); constexpr bool kUsableLinearEstimate = true; ReverbModelEstimator estimator(aec3_config_, num_capture_channels); std::vector usable_linear_estimates(num_capture_channels, kUsableLinearEstimate); std::vector filter_delay_blocks(num_capture_channels, kFilterDelayBlocks); for (size_t k = 0; k < 3000; ++k) { estimator.Update(h_, H2_, quality_linear_, filter_delay_blocks, usable_linear_estimates, kStationaryBlock); } estimated_decay_ = estimator.ReverbDecay(/*mild=*/false); mild_estimated_decay_ = estimator.ReverbDecay(/*mild=*/true); auto freq_resp_tail = estimator.GetReverbFrequencyResponse(); estimated_power_tail_ = std::accumulate(freq_resp_tail.begin(), freq_resp_tail.end(), 0.f); } TEST(ReverbModelEstimatorTests, NotChangingDecay) { constexpr float kDefaultDecay = 0.9f; for (size_t num_capture_channels : {1, 2, 4, 8}) { ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels); test.RunEstimator(); EXPECT_EQ(test.GetDecay(/*mild=*/false), kDefaultDecay); EXPECT_EQ(test.GetDecay(/*mild=*/true), EchoCanceller3Config().ep_strength.nearend_len); EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); } } TEST(ReverbModelEstimatorTests, ChangingDecay) { constexpr float kDefaultDecay = -0.9f; for (size_t num_capture_channels : {1, 2, 4, 8}) { ReverbModelEstimatorTest test(kDefaultDecay, num_capture_channels); test.RunEstimator(); EXPECT_NEAR(test.GetDecay(/*mild=*/false), test.GetTrueDecay(), 0.1f); EXPECT_NEAR(test.GetDecay(/*mild=*/true), test.GetTrueDecay(), 0.1f); EXPECT_NEAR(test.GetPowerTailDb(), test.GetTruePowerTailDb(), 5.f); } } } // namespace webrtc