/* * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/transient/transient_suppressor.h" #include #include "absl/types/optional.h" #include "modules/audio_processing/transient/common.h" #include "modules/audio_processing/transient/transient_suppressor_impl.h" #include "test/gtest.h" namespace webrtc { namespace { constexpr int kMono = 1; // Returns the index of the first non-zero sample in `samples` or an unspecified // value if no value is zero. absl::optional FindFirstNonZeroSample(const std::vector& samples) { for (size_t i = 0; i < samples.size(); ++i) { if (samples[i] != 0.0f) { return i; } } return absl::nullopt; } } // namespace class TransientSuppressorVadModeParametrization : public ::testing::TestWithParam {}; TEST_P(TransientSuppressorVadModeParametrization, TypingDetectionLogicWorksAsExpectedForMono) { TransientSuppressorImpl ts(GetParam(), ts::kSampleRate16kHz, ts::kSampleRate16kHz, kMono); // Each key-press enables detection. EXPECT_FALSE(ts.detection_enabled_); ts.UpdateKeypress(true); EXPECT_TRUE(ts.detection_enabled_); // It takes four seconds without any key-press to disable the detection for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) { ts.UpdateKeypress(false); EXPECT_TRUE(ts.detection_enabled_); } ts.UpdateKeypress(false); EXPECT_FALSE(ts.detection_enabled_); // Key-presses that are more than a second apart from each other don't enable // suppression. for (int i = 0; i < 100; ++i) { EXPECT_FALSE(ts.suppression_enabled_); ts.UpdateKeypress(true); EXPECT_TRUE(ts.detection_enabled_); EXPECT_FALSE(ts.suppression_enabled_); for (int time_ms = 0; time_ms < 990; time_ms += ts::kChunkSizeMs) { ts.UpdateKeypress(false); EXPECT_TRUE(ts.detection_enabled_); EXPECT_FALSE(ts.suppression_enabled_); } ts.UpdateKeypress(false); } // Two consecutive key-presses is enough to enable the suppression. ts.UpdateKeypress(true); EXPECT_FALSE(ts.suppression_enabled_); ts.UpdateKeypress(true); EXPECT_TRUE(ts.suppression_enabled_); // Key-presses that are less than a second apart from each other don't disable // detection nor suppression. for (int i = 0; i < 100; ++i) { for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) { ts.UpdateKeypress(false); EXPECT_TRUE(ts.detection_enabled_); EXPECT_TRUE(ts.suppression_enabled_); } ts.UpdateKeypress(true); EXPECT_TRUE(ts.detection_enabled_); EXPECT_TRUE(ts.suppression_enabled_); } // It takes four seconds without any key-press to disable the detection and // suppression. for (int time_ms = 0; time_ms < 3990; time_ms += ts::kChunkSizeMs) { ts.UpdateKeypress(false); EXPECT_TRUE(ts.detection_enabled_); EXPECT_TRUE(ts.suppression_enabled_); } for (int time_ms = 0; time_ms < 1000; time_ms += ts::kChunkSizeMs) { ts.UpdateKeypress(false); EXPECT_FALSE(ts.detection_enabled_); EXPECT_FALSE(ts.suppression_enabled_); } } INSTANTIATE_TEST_SUITE_P( TransientSuppressorImplTest, TransientSuppressorVadModeParametrization, ::testing::Values(TransientSuppressor::VadMode::kDefault, TransientSuppressor::VadMode::kRnnVad, TransientSuppressor::VadMode::kNoVad)); class TransientSuppressorSampleRateParametrization : public ::testing::TestWithParam {}; // Checks that voice probability and processed audio data are temporally aligned // after `Suppress()` is called. TEST_P(TransientSuppressorSampleRateParametrization, CheckAudioAndVoiceProbabilityTemporallyAligned) { const int sample_rate_hz = GetParam(); TransientSuppressorImpl ts(TransientSuppressor::VadMode::kDefault, sample_rate_hz, /*detection_rate_hz=*/sample_rate_hz, kMono); const int frame_size = sample_rate_hz * ts::kChunkSizeMs / 1000; std::vector frame(frame_size); constexpr int kMaxAttempts = 3; for (int i = 0; i < kMaxAttempts; ++i) { SCOPED_TRACE(i); // Call `Suppress()` on frames of non-zero audio samples. std::fill(frame.begin(), frame.end(), 1000.0f); float delayed_voice_probability = ts.Suppress( frame.data(), frame.size(), kMono, /*detection_data=*/nullptr, /*detection_length=*/frame_size, /*reference_data=*/nullptr, /*reference_length=*/frame_size, /*voice_probability=*/1.0f, /*key_pressed=*/false); // Detect the algorithmic delay of `TransientSuppressorImpl`. absl::optional frame_delay = FindFirstNonZeroSample(frame); // Check that the delayed voice probability is delayed according to the // measured delay. if (frame_delay.has_value()) { if (*frame_delay == 0) { // When the delay is a multiple integer of the frame duration, // `Suppress()` returns a copy of a previously observed voice // probability value. EXPECT_EQ(delayed_voice_probability, 1.0f); } else { // Instead, when the delay is fractional, `Suppress()` returns an // interpolated value. Since the exact value depends on the // interpolation method, we only check that the delayed voice // probability is not zero as it must converge towards the previoulsy // observed value. EXPECT_GT(delayed_voice_probability, 0.0f); } break; } else { // The algorithmic delay is longer than the duration of a single frame. // Until the delay is detected, the delayed voice probability is zero. EXPECT_EQ(delayed_voice_probability, 0.0f); } } } INSTANTIATE_TEST_SUITE_P(TransientSuppressorImplTest, TransientSuppressorSampleRateParametrization, ::testing::Values(ts::kSampleRate8kHz, ts::kSampleRate16kHz, ts::kSampleRate32kHz, ts::kSampleRate48kHz)); } // namespace webrtc