/* * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include #include "rtc_base/checks.h" namespace webrtc { namespace rnn_vad { namespace { constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. static_assert(1 << kAutoCorrelationFftOrder > kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz, ""); } // namespace AutoCorrelationCalculator::AutoCorrelationCalculator() : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal), tmp_(fft_.CreateBuffer()), X_(fft_.CreateBuffer()), H_(fft_.CreateBuffer()) {} AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; // The auto-correlations coefficients are computed as follows: // |.........|...........| <- pitch buffer // [ x (fixed) ] // [ y_0 ] // [ y_{m-1} ] // x and y are sub-array of equal length; x is never moved, whereas y slides. // The cross-correlation between y_0 and x corresponds to the auto-correlation // for the maximum pitch period. Hence, the first value in `auto_corr` has an // inverted lag equal to 0 that corresponds to a lag equal to the maximum // pitch period. void AutoCorrelationCalculator::ComputeOnPitchBuffer( rtc::ArrayView pitch_buf, rtc::ArrayView auto_corr) { RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder; constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; static_assert(kConvolutionLength == kFrameSize20ms12kHz, "Mismatch between pitch buffer size, frame size and maximum " "pitch period."); static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength, "The FFT length is not sufficiently big to avoid cyclic " "convolution errors."); auto tmp = tmp_->GetView(); // Compute the FFT for the reversed reference frame - i.e., // pitch_buf[-kConvolutionLength:]. std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(), tmp.begin()); std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f); fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false); // Compute the FFT for the sliding frames chunk. The sliding frames are // defined as pitch_buf[i:i+kConvolutionLength] where i in // [0, kNumLags12kHz). The chunk includes all of them, hence it is // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength]. std::copy(pitch_buf.begin(), pitch_buf.begin() + kConvolutionLength + kNumLags12kHz, tmp.begin()); std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f); fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); // Convolve in the frequency domain. constexpr float kScalingFactor = 1.f / static_cast(kFftFrameSize); std::fill(tmp.begin(), tmp.end(), 0.f); fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor); fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false); // Extract the auto-correlation coefficients. std::copy(tmp.begin() + kConvolutionLength - 1, tmp.begin() + kConvolutionLength + kNumLags12kHz - 1, auto_corr.begin()); } } // namespace rnn_vad } // namespace webrtc