diff options
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc')
-rw-r--r-- | third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc new file mode 100644 index 0000000000..ef37410caa --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/agc2/rnn_vad/rnn_gru.cc @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h" + +#include "rtc_base/checks.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "third_party/rnnoise/src/rnn_activations.h" +#include "third_party/rnnoise/src/rnn_vad_weights.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +constexpr int kNumGruGates = 3; // Update, reset, output. + +std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src, + int output_size) { + // Transpose, cast and scale. + // `n` is the size of the first dimension of the 3-dim tensor `weights`. + const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()), + output_size * kNumGruGates); + const int stride_src = kNumGruGates * output_size; + const int stride_dst = n * output_size; + std::vector<float> tensor_dst(tensor_src.size()); + for (int g = 0; g < kNumGruGates; ++g) { + for (int o = 0; o < output_size; ++o) { + for (int i = 0; i < n; ++i) { + tensor_dst[g * stride_dst + o * n + i] = + ::rnnoise::kWeightsScale * + static_cast<float>( + tensor_src[i * stride_src + g * output_size + o]); + } + } + } + return tensor_dst; +} + +// Computes the output for the update or the reset gate. +// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where +// - `g`: output gate vector +// - `W`: weights matrix +// - `i`: input vector +// - `R`: recurrent weights matrix +// - `s`: state gate vector +// - `b`: bias vector +void ComputeUpdateResetGate(int input_size, + int output_size, + const VectorMath& vector_math, + rtc::ArrayView<const float> input, + rtc::ArrayView<const float> state, + rtc::ArrayView<const float> bias, + rtc::ArrayView<const float> weights, + rtc::ArrayView<const float> recurrent_weights, + rtc::ArrayView<float> gate) { + RTC_DCHECK_EQ(input.size(), input_size); + RTC_DCHECK_EQ(state.size(), output_size); + RTC_DCHECK_EQ(bias.size(), output_size); + RTC_DCHECK_EQ(weights.size(), input_size * output_size); + RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); + RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated. + for (int o = 0; o < output_size; ++o) { + float x = bias[o]; + x += vector_math.DotProduct(input, + weights.subview(o * input_size, input_size)); + x += vector_math.DotProduct( + state, recurrent_weights.subview(o * output_size, output_size)); + gate[o] = ::rnnoise::SigmoidApproximated(x); + } +} + +// Computes the output for the state gate. +// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where +// - `s'`: output state gate vector +// - `s`: previous state gate vector +// - `u`: update gate vector +// - `W`: weights matrix +// - `i`: input vector +// - `R`: recurrent weights matrix +// - `r`: reset gate vector +// - `b`: bias vector +// - `.*` element-wise product +void ComputeStateGate(int input_size, + int output_size, + const VectorMath& vector_math, + rtc::ArrayView<const float> input, + rtc::ArrayView<const float> update, + rtc::ArrayView<const float> reset, + rtc::ArrayView<const float> bias, + rtc::ArrayView<const float> weights, + rtc::ArrayView<const float> recurrent_weights, + rtc::ArrayView<float> state) { + RTC_DCHECK_EQ(input.size(), input_size); + RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated. + RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated. + RTC_DCHECK_EQ(bias.size(), output_size); + RTC_DCHECK_EQ(weights.size(), input_size * output_size); + RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); + RTC_DCHECK_EQ(state.size(), output_size); + std::array<float, kGruLayerMaxUnits> reset_x_state; + for (int o = 0; o < output_size; ++o) { + reset_x_state[o] = state[o] * reset[o]; + } + for (int o = 0; o < output_size; ++o) { + float x = bias[o]; + x += vector_math.DotProduct(input, + weights.subview(o * input_size, input_size)); + x += vector_math.DotProduct( + {reset_x_state.data(), static_cast<size_t>(output_size)}, + recurrent_weights.subview(o * output_size, output_size)); + state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x); + } +} + +} // namespace + +GatedRecurrentLayer::GatedRecurrentLayer( + const int input_size, + const int output_size, + const rtc::ArrayView<const int8_t> bias, + const rtc::ArrayView<const int8_t> weights, + const rtc::ArrayView<const int8_t> recurrent_weights, + const AvailableCpuFeatures& cpu_features, + absl::string_view layer_name) + : input_size_(input_size), + output_size_(output_size), + bias_(PreprocessGruTensor(bias, output_size)), + weights_(PreprocessGruTensor(weights, output_size)), + recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)), + vector_math_(cpu_features) { + RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits) + << "Insufficient GRU layer over-allocation (" << layer_name << ")."; + RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size()) + << "Mismatching output size and bias terms array size (" << layer_name + << ")."; + RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size()) + << "Mismatching input-output size and weight coefficients array size (" + << layer_name << ")."; + RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_, + recurrent_weights_.size()) + << "Mismatching input-output size and recurrent weight coefficients array" + " size (" + << layer_name << ")."; + Reset(); +} + +GatedRecurrentLayer::~GatedRecurrentLayer() = default; + +void GatedRecurrentLayer::Reset() { + state_.fill(0.f); +} + +void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) { + RTC_DCHECK_EQ(input.size(), input_size_); + + // The tensors below are organized as a sequence of flattened tensors for the + // `update`, `reset` and `state` gates. + rtc::ArrayView<const float> bias(bias_); + rtc::ArrayView<const float> weights(weights_); + rtc::ArrayView<const float> recurrent_weights(recurrent_weights_); + // Strides to access to the flattened tensors for a specific gate. + const int stride_weights = input_size_ * output_size_; + const int stride_recurrent_weights = output_size_ * output_size_; + + rtc::ArrayView<float> state(state_.data(), output_size_); + + // Update gate. + std::array<float, kGruLayerMaxUnits> update; + ComputeUpdateResetGate( + input_size_, output_size_, vector_math_, input, state, + bias.subview(0, output_size_), weights.subview(0, stride_weights), + recurrent_weights.subview(0, stride_recurrent_weights), update); + // Reset gate. + std::array<float, kGruLayerMaxUnits> reset; + ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state, + bias.subview(output_size_, output_size_), + weights.subview(stride_weights, stride_weights), + recurrent_weights.subview(stride_recurrent_weights, + stride_recurrent_weights), + reset); + // State gate. + ComputeStateGate(input_size_, output_size_, vector_math_, input, update, + reset, bias.subview(2 * output_size_, output_size_), + weights.subview(2 * stride_weights, stride_weights), + recurrent_weights.subview(2 * stride_recurrent_weights, + stride_recurrent_weights), + state); +} + +} // namespace rnn_vad +} // namespace webrtc |