diff options
Diffstat (limited to '')
-rw-r--r-- | third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc new file mode 100644 index 0000000000..3ddeec8dba --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" + +#include <algorithm> + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. +static_assert(1 << kAutoCorrelationFftOrder > + kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz, + ""); + +} // namespace + +AutoCorrelationCalculator::AutoCorrelationCalculator() + : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal), + tmp_(fft_.CreateBuffer()), + X_(fft_.CreateBuffer()), + H_(fft_.CreateBuffer()) {} + +AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; + +// The auto-correlations coefficients are computed as follows: +// |.........|...........| <- pitch buffer +// [ x (fixed) ] +// [ y_0 ] +// [ y_{m-1} ] +// x and y are sub-array of equal length; x is never moved, whereas y slides. +// The cross-correlation between y_0 and x corresponds to the auto-correlation +// for the maximum pitch period. Hence, the first value in `auto_corr` has an +// inverted lag equal to 0 that corresponds to a lag equal to the maximum +// pitch period. +void AutoCorrelationCalculator::ComputeOnPitchBuffer( + rtc::ArrayView<const float, kBufSize12kHz> pitch_buf, + rtc::ArrayView<float, kNumLags12kHz> auto_corr) { + RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); + RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); + constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder; + constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; + static_assert(kConvolutionLength == kFrameSize20ms12kHz, + "Mismatch between pitch buffer size, frame size and maximum " + "pitch period."); + static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength, + "The FFT length is not sufficiently big to avoid cyclic " + "convolution errors."); + auto tmp = tmp_->GetView(); + + // Compute the FFT for the reversed reference frame - i.e., + // pitch_buf[-kConvolutionLength:]. + std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(), + tmp.begin()); + std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f); + fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false); + + // Compute the FFT for the sliding frames chunk. The sliding frames are + // defined as pitch_buf[i:i+kConvolutionLength] where i in + // [0, kNumLags12kHz). The chunk includes all of them, hence it is + // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength]. + std::copy(pitch_buf.begin(), + pitch_buf.begin() + kConvolutionLength + kNumLags12kHz, + tmp.begin()); + std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f); + fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); + + // Convolve in the frequency domain. + constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize); + std::fill(tmp.begin(), tmp.end(), 0.f); + fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor); + fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false); + + // Extract the auto-correlation coefficients. + std::copy(tmp.begin() + kConvolutionLength - 1, + tmp.begin() + kConvolutionLength + kNumLags12kHz - 1, + auto_corr.begin()); +} + +} // namespace rnn_vad +} // namespace webrtc |