summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc
blob: af005833c1244772cebfa718c6930d58536a961c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"

#include <algorithm>
#include <array>

#include "test/gtest.h"

namespace webrtc {
namespace rnn_vad {
namespace {

template <typename T, int S, int N>
void TestSequenceBufferPushOp() {
  SCOPED_TRACE(S);
  SCOPED_TRACE(N);
  SequenceBuffer<T, S, N> seq_buf;
  auto seq_buf_view = seq_buf.GetBufferView();
  std::array<T, N> chunk;

  // Check that a chunk is fully gone after ceil(S / N) push ops.
  chunk.fill(1);
  seq_buf.Push(chunk);
  chunk.fill(0);
  constexpr int required_push_ops = (S % N) ? S / N + 1 : S / N;
  for (int i = 0; i < required_push_ops - 1; ++i) {
    SCOPED_TRACE(i);
    seq_buf.Push(chunk);
    // Still in the buffer.
    const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end());
    EXPECT_EQ(1, *m);
  }
  // Gone after another push.
  seq_buf.Push(chunk);
  const auto* m = std::max_element(seq_buf_view.begin(), seq_buf_view.end());
  EXPECT_EQ(0, *m);

  // Check that the last item moves left by N positions after a push op.
  if (S > N) {
    // Fill in with non-zero values.
    for (int i = 0; i < N; ++i)
      chunk[i] = static_cast<T>(i + 1);
    seq_buf.Push(chunk);
    // With the next Push(), `last` will be moved left by N positions.
    const T last = chunk[N - 1];
    for (int i = 0; i < N; ++i)
      chunk[i] = static_cast<T>(last + i + 1);
    seq_buf.Push(chunk);
    EXPECT_EQ(last, seq_buf_view[S - N - 1]);
  }
}

TEST(RnnVadTest, SequenceBufferGetters) {
  constexpr int buffer_size = 8;
  constexpr int chunk_size = 8;
  SequenceBuffer<int, buffer_size, chunk_size> seq_buf;
  EXPECT_EQ(buffer_size, seq_buf.size());
  EXPECT_EQ(chunk_size, seq_buf.chunks_size());
  // Test view.
  auto seq_buf_view = seq_buf.GetBufferView();
  EXPECT_EQ(0, seq_buf_view[0]);
  EXPECT_EQ(0, seq_buf_view[seq_buf_view.size() - 1]);
  constexpr std::array<int, chunk_size> chunk = {10, 20, 30, 40,
                                                 50, 60, 70, 80};
  seq_buf.Push(chunk);
  EXPECT_EQ(10, *seq_buf_view.begin());
  EXPECT_EQ(80, *(seq_buf_view.end() - 1));
}

TEST(RnnVadTest, SequenceBufferPushOpsUnsigned) {
  TestSequenceBufferPushOp<uint8_t, 32, 8>();   // Chunk size: 25%.
  TestSequenceBufferPushOp<uint8_t, 32, 16>();  // Chunk size: 50%.
  TestSequenceBufferPushOp<uint8_t, 32, 32>();  // Chunk size: 100%.
  TestSequenceBufferPushOp<uint8_t, 23, 7>();   // Non-integer ratio.
}

TEST(RnnVadTest, SequenceBufferPushOpsSigned) {
  TestSequenceBufferPushOp<int, 32, 8>();   // Chunk size: 25%.
  TestSequenceBufferPushOp<int, 32, 16>();  // Chunk size: 50%.
  TestSequenceBufferPushOp<int, 32, 32>();  // Chunk size: 100%.
  TestSequenceBufferPushOp<int, 23, 7>();   // Non-integer ratio.
}

TEST(RnnVadTest, SequenceBufferPushOpsFloating) {
  TestSequenceBufferPushOp<float, 32, 8>();   // Chunk size: 25%.
  TestSequenceBufferPushOp<float, 32, 16>();  // Chunk size: 50%.
  TestSequenceBufferPushOp<float, 32, 32>();  // Chunk size: 100%.
  TestSequenceBufferPushOp<float, 23, 7>();   // Non-integer ratio.
}

}  // namespace
}  // namespace rnn_vad
}  // namespace webrtc