summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc')
-rw-r--r--third_party/libwebrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc175
1 files changed, 175 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc b/third_party/libwebrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc
new file mode 100644
index 0000000000..ab48504af6
--- /dev/null
+++ b/third_party/libwebrtc/modules/audio_processing/transient/transient_suppressor_unittest.cc
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/transient/transient_suppressor.h"
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "modules/audio_processing/transient/common.h"
+#include "modules/audio_processing/transient/transient_suppressor_impl.h"
+#include "test/gtest.h"
+
+namespace webrtc {
+namespace {
+constexpr int kMono = 1;
+
+// Returns the index of the first non-zero sample in `samples` or an unspecified
+// value if no value is zero.
+absl::optional<int> FindFirstNonZeroSample(const std::vector<float>& samples) {
+ for (size_t i = 0; i < samples.size(); ++i) {
+ if (samples[i] != 0.0f) {
+ return i;
+ }
+ }
+ return absl::nullopt;
+}
+
+} // namespace
+
+class TransientSuppressorVadModeParametrization
+ : public ::testing::TestWithParam<TransientSuppressor::VadMode> {};
+
+TEST_P(TransientSuppressorVadModeParametrization,
+ TypingDetectionLogicWorksAsExpectedForMono) {
+ TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz,
+ ts::kSampleRate16kHz, kMono);
+
+ // Each key-press enables detection.
+ EXPECT_FALSE(ts.detection_enabled_);
+ ts.UpdateKeypress(true);
+ EXPECT_TRUE(ts.detection_enabled_);
+
+ // It takes four seconds without any key-press to disable the detection
+ for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
+ ts.UpdateKeypress(false);
+ EXPECT_TRUE(ts.detection_enabled_);
+ }
+ ts.UpdateKeypress(false);
+ EXPECT_FALSE(ts.detection_enabled_);
+
+ // Key-presses that are more than a second apart from each other don't enable
+ // suppression.
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_FALSE(ts.suppression_enabled_);
+ ts.UpdateKeypress(true);
+ EXPECT_TRUE(ts.detection_enabled_);
+ EXPECT_FALSE(ts.suppression_enabled_);
+ for (int time_ms = 0; time_ms < 990; time_ms += ts::kChunkSizeMs) {
+ ts.UpdateKeypress(false);
+ EXPECT_TRUE(ts.detection_enabled_);
+ EXPECT_FALSE(ts.suppression_enabled_);
+ }
+ ts.UpdateKeypress(false);
+ }
+
+ // Two consecutive key-presses is enough to enable the suppression.
+ ts.UpdateKeypress(true);
+ EXPECT_FALSE(ts.suppression_enabled_);
+ ts.UpdateKeypress(true);
+ EXPECT_TRUE(ts.suppression_enabled_);
+
+ // Key-presses that are less than a second apart from each other don't disable
+ // detection nor suppression.
+ for (int i = 0; i < 100; ++i) {
+ for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
+ ts.UpdateKeypress(false);
+ EXPECT_TRUE(ts.detection_enabled_);
+ EXPECT_TRUE(ts.suppression_enabled_);
+ }
+ ts.UpdateKeypress(true);
+ EXPECT_TRUE(ts.detection_enabled_);
+ EXPECT_TRUE(ts.suppression_enabled_);
+ }
+
+ // It takes four seconds without any key-press to disable the detection and
+ // suppression.
+ for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) {
+ ts.UpdateKeypress(false);
+ EXPECT_TRUE(ts.detection_enabled_);
+ EXPECT_TRUE(ts.suppression_enabled_);
+ }
+ for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) {
+ ts.UpdateKeypress(false);
+ EXPECT_FALSE(ts.detection_enabled_);
+ EXPECT_FALSE(ts.suppression_enabled_);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ TransientSuppressorImplTest,
+ TransientSuppressorVadModeParametrization,
+ ::testing::Values(TransientSuppressor::VadMode::kDefault,
+ TransientSuppressor::VadMode::kRnnVad,
+ TransientSuppressor::VadMode::kNoVad));
+
+class TransientSuppressorSampleRateParametrization
+ : public ::testing::TestWithParam<int> {};
+
+// Checks that voice probability and processed audio data are temporally aligned
+// after `Suppress()` is called.
+TEST_P(TransientSuppressorSampleRateParametrization,
+ CheckAudioAndVoiceProbabilityTemporallyAligned) {
+ const int sample_rate_hz = GetParam();
+ TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault,
+ sample_rate_hz,
+ /*detection_rate_hz=*/sample_rate_hz, kMono);
+
+ const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000;
+ std::vector<float> frame(frame_size);
+
+ constexpr int kMaxAttempts = 3;
+ for (int i = 0; i < kMaxAttempts; ++i) {
+ SCOPED_TRACE(i);
+
+ // Call `Suppress()` on frames of non-zero audio samples.
+ std::fill(frame.begin(), frame.end(), 1000.0f);
+ float delayed_voice_probability = ts.Suppress(
+ frame.data(), frame.size(), kMono, /*detection_data=*/nullptr,
+ /*detection_length=*/frame_size, /*reference_data=*/nullptr,
+ /*reference_length=*/frame_size, /*voice_probability=*/1.0f,
+ /*key_pressed=*/false);
+
+ // Detect the algorithmic delay of `TransientSuppressorImpl`.
+ absl::optional<int> frame_delay = FindFirstNonZeroSample(frame);
+
+ // Check that the delayed voice probability is delayed according to the
+ // measured delay.
+ if (frame_delay.has_value()) {
+ if (*frame_delay == 0) {
+ // When the delay is a multiple integer of the frame duration,
+ // `Suppress()` returns a copy of a previously observed voice
+ // probability value.
+ EXPECT_EQ(delayed_voice_probability, 1.0f);
+ } else {
+ // Instead, when the delay is fractional, `Suppress()` returns an
+ // interpolated value. Since the exact value depends on the
+ // interpolation method, we only check that the delayed voice
+ // probability is not zero as it must converge towards the previoulsy
+ // observed value.
+ EXPECT_GT(delayed_voice_probability, 0.0f);
+ }
+ break;
+ } else {
+ // The algorithmic delay is longer than the duration of a single frame.
+ // Until the delay is detected, the delayed voice probability is zero.
+ EXPECT_EQ(delayed_voice_probability, 0.0f);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest,
+ TransientSuppressorSampleRateParametrization,
+ ::testing::Values(ts::kSampleRate8kHz,
+ ts::kSampleRate16kHz,
+ ts::kSampleRate32kHz,
+ ts::kSampleRate48kHz));
+
+} // namespace webrtc