/* * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" #include #include #include "modules/audio_processing/agc2/cpu_features.h" #include "rtc_base/numerics/safe_compare.h" #include "rtc_base/numerics/safe_conversions.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // #include "test/fpe_observer.h" #include "test/gtest.h" namespace webrtc { namespace rnn_vad { namespace { constexpr int ceil(int n, int m) { return (n + m - 1) / m; } // Number of 10 ms frames required to fill a pitch buffer having size // `kBufSize24kHz`. constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz); // Number of samples for the test data. constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz; // Verifies that the pitch in Hz is in the detectable range. bool PitchIsValid(float pitch_hz) { const int pitch_period = static_cast(kSampleRate24kHz) / pitch_hz; return kInitialMinPitch24kHz <= pitch_period && pitch_period <= kMaxPitch24kHz; } void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView dst) { for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) { dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz); } } // Feeds `features_extractor` with `samples` splitting it in 10 ms frames. // For every frame, the output is written into `feature_vector`. Returns true // if silence is detected in the last frame. bool FeedTestData(FeaturesExtractor& features_extractor, rtc::ArrayView samples, rtc::ArrayView feature_vector) { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; bool is_silence = true; const int num_frames = samples.size() / kFrameSize10ms24kHz; for (int i = 0; i < num_frames; ++i) { is_silence = features_extractor.CheckSilenceComputeFeatures( {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz}, feature_vector); } return is_silence; } // Extracts the features for two pure tones and verifies that the pitch field // values reflect the known tone frequencies. TEST(RnnVadTest, FeatureExtractionLowHighPitch) { constexpr float amplitude = 1000.f; constexpr float low_pitch_hz = 150.f; constexpr float high_pitch_hz = 250.f; ASSERT_TRUE(PitchIsValid(low_pitch_hz)); ASSERT_TRUE(PitchIsValid(high_pitch_hz)); const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); FeaturesExtractor features_extractor(cpu_features); std::vector samples(kNumTestDataSize); std::vector feature_vector(kFeatureVectorSize); ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast(feature_vector.size())); rtc::ArrayView feature_vector_view( feature_vector.data(), kFeatureVectorSize); // Extract the normalized scalar feature that is proportional to the estimated // pitch period. constexpr int pitch_feature_index = kFeatureVectorSize - 2; // Low frequency tone - i.e., high period. CreatePureTone(amplitude, low_pitch_hz, samples); ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view)); float high_pitch_period = feature_vector_view[pitch_feature_index]; // High frequency tone - i.e., low period. features_extractor.Reset(); CreatePureTone(amplitude, high_pitch_hz, samples); ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view)); float low_pitch_period = feature_vector_view[pitch_feature_index]; // Check. EXPECT_LT(low_pitch_period, high_pitch_period); } } // namespace } // namespace rnn_vad } // namespace webrtc