/* * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #ifndef MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ #define MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_ #include #include #include "absl/types/optional.h" #include "api/array_view.h" #include "modules/audio_processing/aec3/aec3_common.h" #include "rtc_base/gtest_prod_util.h" #include "rtc_base/system/arch.h" namespace webrtc { class ApmDataDumper; struct DownsampledRenderBuffer; namespace aec3 { #if defined(WEBRTC_HAS_NEON) // Filter core for the matched filter that is optimized for NEON. void MatchedFilterCore_NEON(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulation_error, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory); #endif #if defined(WEBRTC_ARCH_X86_FAMILY) // Filter core for the matched filter that is optimized for SSE2. void MatchedFilterCore_SSE2(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulated_error, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory); // Filter core for the matched filter that is optimized for AVX2. void MatchedFilterCore_AVX2(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulated_error, rtc::ArrayView accumulated_error, rtc::ArrayView scratch_memory); #endif // Filter core for the matched filter. void MatchedFilterCore(size_t x_start_index, float x2_sum_threshold, float smoothing, rtc::ArrayView x, rtc::ArrayView y, rtc::ArrayView h, bool* filters_updated, float* error_sum, bool compute_accumulation_error, rtc::ArrayView accumulated_error); // Find largest peak of squared values in array. size_t MaxSquarePeakIndex(rtc::ArrayView h); } // namespace aec3 // Produces recursively updated cross-correlation estimates for several signal // shifts where the intra-shift spacing is uniform. class MatchedFilter { public: // Stores properties for the lag estimate corresponding to a particular signal // shift. struct LagEstimate { LagEstimate() = default; LagEstimate(size_t lag, size_t pre_echo_lag) : lag(lag), pre_echo_lag(pre_echo_lag) {} size_t lag = 0; size_t pre_echo_lag = 0; }; struct PreEchoConfiguration { const float threshold; const int mode; }; MatchedFilter(ApmDataDumper* data_dumper, Aec3Optimization optimization, size_t sub_block_size, size_t window_size_sub_blocks, int num_matched_filters, size_t alignment_shift_sub_blocks, float excitation_limit, float smoothing_fast, float smoothing_slow, float matching_filter_threshold, bool detect_pre_echo); MatchedFilter() = delete; MatchedFilter(const MatchedFilter&) = delete; MatchedFilter& operator=(const MatchedFilter&) = delete; ~MatchedFilter(); // Updates the correlation with the values in the capture buffer. void Update(const DownsampledRenderBuffer& render_buffer, rtc::ArrayView capture, bool use_slow_smoothing); // Resets the matched filter. void Reset(bool full_reset); // Returns the current lag estimates. absl::optional GetBestLagEstimate() const { return reported_lag_estimate_; } // Returns the maximum filter lag. size_t GetMaxFilterLag() const { return filters_.size() * filter_intra_lag_shift_ + filters_[0].size(); } // Log matched filter properties. void LogFilterProperties(int sample_rate_hz, size_t shift, size_t downsampling_factor) const; private: FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, PreEchoConfigurationTest); FRIEND_TEST_ALL_PREFIXES(MatchedFilterFieldTrialTest, WrongPreEchoConfigurationTest); // Only for testing. Gets the pre echo detection configuration. const PreEchoConfiguration& GetPreEchoConfiguration() const { return pre_echo_config_; } void Dump(); ApmDataDumper* const data_dumper_; const Aec3Optimization optimization_; const size_t sub_block_size_; const size_t filter_intra_lag_shift_; std::vector> filters_; std::vector> accumulated_error_; std::vector instantaneous_accumulated_error_; std::vector scratch_memory_; absl::optional reported_lag_estimate_; absl::optional winner_lag_; int last_detected_best_lag_filter_ = -1; std::vector filters_offsets_; int number_pre_echo_updates_ = 0; const float excitation_limit_; const float smoothing_fast_; const float smoothing_slow_; const float matching_filter_threshold_; const bool detect_pre_echo_; const PreEchoConfiguration pre_echo_config_; }; } // namespace webrtc #endif // MODULES_AUDIO_PROCESSING_AEC3_MATCHED_FILTER_H_