summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc')
-rw-r--r--third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc198
1 files changed, 198 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
new file mode 100644
index 0000000000..ef37410caa
--- /dev/null
+++ b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
+
+#include "rtc_base/checks.h"
+#include "rtc_base/numerics/safe_conversions.h"
+#include "third_party/rnnoise/src/rnn_activations.h"
+#include "third_party/rnnoise/src/rnn_vad_weights.h"
+
+namespace webrtc {
+namespace rnn_vad {
+namespace {
+
+constexpr int kNumGruGates = 3; // Update, reset, output.
+
+std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
+ int output_size) {
+ // Transpose, cast and scale.
+ // `n` is the size of the first dimension of the 3-dim tensor `weights`.
+ const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
+ output_size * kNumGruGates);
+ const int stride_src = kNumGruGates * output_size;
+ const int stride_dst = n * output_size;
+ std::vector<float> tensor_dst(tensor_src.size());
+ for (int g = 0; g < kNumGruGates; ++g) {
+ for (int o = 0; o < output_size; ++o) {
+ for (int i = 0; i < n; ++i) {
+ tensor_dst[g * stride_dst + o * n + i] =
+ ::rnnoise::kWeightsScale *
+ static_cast<float>(
+ tensor_src[i * stride_src + g * output_size + o]);
+ }
+ }
+ }
+ return tensor_dst;
+}
+
+// Computes the output for the update or the reset gate.
+// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
+// - `g`: output gate vector
+// - `W`: weights matrix
+// - `i`: input vector
+// - `R`: recurrent weights matrix
+// - `s`: state gate vector
+// - `b`: bias vector
+void ComputeUpdateResetGate(int input_size,
+ int output_size,
+ const VectorMath& vector_math,
+ rtc::ArrayView<const float> input,
+ rtc::ArrayView<const float> state,
+ rtc::ArrayView<const float> bias,
+ rtc::ArrayView<const float> weights,
+ rtc::ArrayView<const float> recurrent_weights,
+ rtc::ArrayView<float> gate) {
+ RTC_DCHECK_EQ(input.size(), input_size);
+ RTC_DCHECK_EQ(state.size(), output_size);
+ RTC_DCHECK_EQ(bias.size(), output_size);
+ RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+ RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
+ RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
+ for (int o = 0; o < output_size; ++o) {
+ float x = bias[o];
+ x += vector_math.DotProduct(input,
+ weights.subview(o * input_size, input_size));
+ x += vector_math.DotProduct(
+ state, recurrent_weights.subview(o * output_size, output_size));
+ gate[o] = ::rnnoise::SigmoidApproximated(x);
+ }
+}
+
+// Computes the output for the state gate.
+// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
+// - `s'`: output state gate vector
+// - `s`: previous state gate vector
+// - `u`: update gate vector
+// - `W`: weights matrix
+// - `i`: input vector
+// - `R`: recurrent weights matrix
+// - `r`: reset gate vector
+// - `b`: bias vector
+// - `.*` element-wise product
+void ComputeStateGate(int input_size,
+ int output_size,
+ const VectorMath& vector_math,
+ rtc::ArrayView<const float> input,
+ rtc::ArrayView<const float> update,
+ rtc::ArrayView<const float> reset,
+ rtc::ArrayView<const float> bias,
+ rtc::ArrayView<const float> weights,
+ rtc::ArrayView<const float> recurrent_weights,
+ rtc::ArrayView<float> state) {
+ RTC_DCHECK_EQ(input.size(), input_size);
+ RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
+ RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
+ RTC_DCHECK_EQ(bias.size(), output_size);
+ RTC_DCHECK_EQ(weights.size(), input_size * output_size);
+ RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
+ RTC_DCHECK_EQ(state.size(), output_size);
+ std::array<float, kGruLayerMaxUnits> reset_x_state;
+ for (int o = 0; o < output_size; ++o) {
+ reset_x_state[o] = state[o] * reset[o];
+ }
+ for (int o = 0; o < output_size; ++o) {
+ float x = bias[o];
+ x += vector_math.DotProduct(input,
+ weights.subview(o * input_size, input_size));
+ x += vector_math.DotProduct(
+ {reset_x_state.data(), static_cast<size_t>(output_size)},
+ recurrent_weights.subview(o * output_size, output_size));
+ state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
+ }
+}
+
+} // namespace
+
+GatedRecurrentLayer::GatedRecurrentLayer(
+ const int input_size,
+ const int output_size,
+ const rtc::ArrayView<const int8_t> bias,
+ const rtc::ArrayView<const int8_t> weights,
+ const rtc::ArrayView<const int8_t> recurrent_weights,
+ const AvailableCpuFeatures& cpu_features,
+ absl::string_view layer_name)
+ : input_size_(input_size),
+ output_size_(output_size),
+ bias_(PreprocessGruTensor(bias, output_size)),
+ weights_(PreprocessGruTensor(weights, output_size)),
+ recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
+ vector_math_(cpu_features) {
+ RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
+ << "Insufficient GRU layer over-allocation (" << layer_name << ").";
+ RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
+ << "Mismatching output size and bias terms array size (" << layer_name
+ << ").";
+ RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
+ << "Mismatching input-output size and weight coefficients array size ("
+ << layer_name << ").";
+ RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
+ recurrent_weights_.size())
+ << "Mismatching input-output size and recurrent weight coefficients array"
+ " size ("
+ << layer_name << ").";
+ Reset();
+}
+
+GatedRecurrentLayer::~GatedRecurrentLayer() = default;
+
+void GatedRecurrentLayer::Reset() {
+ state_.fill(0.f);
+}
+
+void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
+ RTC_DCHECK_EQ(input.size(), input_size_);
+
+ // The tensors below are organized as a sequence of flattened tensors for the
+ // `update`, `reset` and `state` gates.
+ rtc::ArrayView<const float> bias(bias_);
+ rtc::ArrayView<const float> weights(weights_);
+ rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
+ // Strides to access to the flattened tensors for a specific gate.
+ const int stride_weights = input_size_ * output_size_;
+ const int stride_recurrent_weights = output_size_ * output_size_;
+
+ rtc::ArrayView<float> state(state_.data(), output_size_);
+
+ // Update gate.
+ std::array<float, kGruLayerMaxUnits> update;
+ ComputeUpdateResetGate(
+ input_size_, output_size_, vector_math_, input, state,
+ bias.subview(0, output_size_), weights.subview(0, stride_weights),
+ recurrent_weights.subview(0, stride_recurrent_weights), update);
+ // Reset gate.
+ std::array<float, kGruLayerMaxUnits> reset;
+ ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
+ bias.subview(output_size_, output_size_),
+ weights.subview(stride_weights, stride_weights),
+ recurrent_weights.subview(stride_recurrent_weights,
+ stride_recurrent_weights),
+ reset);
+ // State gate.
+ ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
+ reset, bias.subview(2 * output_size_, output_size_),
+ weights.subview(2 * stride_weights, stride_weights),
+ recurrent_weights.subview(2 * stride_recurrent_weights,
+ stride_recurrent_weights),
+ state);
+}
+
+} // namespace rnn_vad
+} // namespace webrtc