summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc
blob: 96f956adfeb398a72b6a651205f3fd7b97a1868c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/*
 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"

#include <cmath>
#include <vector>

#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"

namespace webrtc {
namespace rnn_vad {
namespace {

constexpr int ceil(int n, int m) {
  return (n + m - 1) / m;
}

// Number of 10 ms frames required to fill a pitch buffer having size
// `kBufSize24kHz`.
constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
// Number of samples for the test data.
constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;

// Verifies that the pitch in Hz is in the detectable range.
bool PitchIsValid(float pitch_hz) {
  const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
  return kInitialMinPitch24kHz <= pitch_period &&
         pitch_period <= kMaxPitch24kHz;
}

void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
  for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
    dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
  }
}

// Feeds `features_extractor` with `samples` splitting it in 10 ms frames.
// For every frame, the output is written into `feature_vector`. Returns true
// if silence is detected in the last frame.
bool FeedTestData(FeaturesExtractor& features_extractor,
                  rtc::ArrayView<const float> samples,
                  rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
  // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
  // FloatingPointExceptionObserver fpe_observer;
  bool is_silence = true;
  const int num_frames = samples.size() / kFrameSize10ms24kHz;
  for (int i = 0; i < num_frames; ++i) {
    is_silence = features_extractor.CheckSilenceComputeFeatures(
        {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
        feature_vector);
  }
  return is_silence;
}

// Extracts the features for two pure tones and verifies that the pitch field
// values reflect the known tone frequencies.
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
  constexpr float amplitude = 1000.f;
  constexpr float low_pitch_hz = 150.f;
  constexpr float high_pitch_hz = 250.f;
  ASSERT_TRUE(PitchIsValid(low_pitch_hz));
  ASSERT_TRUE(PitchIsValid(high_pitch_hz));

  const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
  FeaturesExtractor features_extractor(cpu_features);
  std::vector<float> samples(kNumTestDataSize);
  std::vector<float> feature_vector(kFeatureVectorSize);
  ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
  rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
      feature_vector.data(), kFeatureVectorSize);

  // Extract the normalized scalar feature that is proportional to the estimated
  // pitch period.
  constexpr int pitch_feature_index = kFeatureVectorSize - 2;
  // Low frequency tone - i.e., high period.
  CreatePureTone(amplitude, low_pitch_hz, samples);
  ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
  float high_pitch_period = feature_vector_view[pitch_feature_index];
  // High frequency tone - i.e., low period.
  features_extractor.Reset();
  CreatePureTone(amplitude, high_pitch_hz, samples);
  ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
  float low_pitch_period = feature_vector_view[pitch_feature_index];
  // Check.
  EXPECT_LT(low_pitch_period, high_pitch_period);
}

}  // namespace
}  // namespace rnn_vad
}  // namespace webrtc