summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc')
-rw-r--r--third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc91
1 files changed, 91 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
new file mode 100644
index 0000000000..3ddeec8dba
--- /dev/null
+++ b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
+
+#include <algorithm>
+
+#include "rtc_base/checks.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace {
+
+constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
+static_assert(1 << kAutoCorrelationFftOrder >
+ kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
+ "");
+
+} // namespace
+
+AutoCorrelationCalculator::AutoCorrelationCalculator()
+ : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
+ tmp_(fft_.CreateBuffer()),
+ X_(fft_.CreateBuffer()),
+ H_(fft_.CreateBuffer()) {}
+
+AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
+
+// The auto-correlations coefficients are computed as follows:
+// |.........|...........| <- pitch buffer
+// [ x (fixed) ]
+// [ y_0 ]
+// [ y_{m-1} ]
+// x and y are sub-array of equal length; x is never moved, whereas y slides.
+// The cross-correlation between y_0 and x corresponds to the auto-correlation
+// for the maximum pitch period. Hence, the first value in `auto_corr` has an
+// inverted lag equal to 0 that corresponds to a lag equal to the maximum
+// pitch period.
+void AutoCorrelationCalculator::ComputeOnPitchBuffer(
+ rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
+ rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
+ RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
+ RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
+ constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
+ constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
+ static_assert(kConvolutionLength == kFrameSize20ms12kHz,
+ "Mismatch between pitch buffer size, frame size and maximum "
+ "pitch period.");
+ static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
+ "The FFT length is not sufficiently big to avoid cyclic "
+ "convolution errors.");
+ auto tmp = tmp_->GetView();
+
+ // Compute the FFT for the reversed reference frame - i.e.,
+ // pitch_buf[-kConvolutionLength:].
+ std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
+ tmp.begin());
+ std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
+ fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
+
+ // Compute the FFT for the sliding frames chunk. The sliding frames are
+ // defined as pitch_buf[i:i+kConvolutionLength] where i in
+ // [0, kNumLags12kHz). The chunk includes all of them, hence it is
+ // defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
+ std::copy(pitch_buf.begin(),
+ pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
+ tmp.begin());
+ std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
+ fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
+
+ // Convolve in the frequency domain.
+ constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
+ std::fill(tmp.begin(), tmp.end(), 0.f);
+ fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
+ fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
+
+ // Extract the auto-correlation coefficients.
+ std::copy(tmp.begin() + kConvolutionLength - 1,
+ tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
+ auto_corr.begin());
+}
+
+} // namespace rnn_vad
+} // namespace webrtc