/* * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h" #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_conversions.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" namespace webrtc { namespace rnn_vad { namespace { constexpr int kNumGruGates = 3; // Update, reset, output. std::vector PreprocessGruTensor(rtc::ArrayView tensor_src, int output_size) { // Transpose, cast and scale. // `n` is the size of the first dimension of the 3-dim tensor `weights`. const int n = rtc::CheckedDivExact(rtc::dchecked_cast(tensor_src.size()), output_size * kNumGruGates); const int stride_src = kNumGruGates * output_size; const int stride_dst = n * output_size; std::vector tensor_dst(tensor_src.size()); for (int g = 0; g < kNumGruGates; ++g) { for (int o = 0; o < output_size; ++o) { for (int i = 0; i < n; ++i) { tensor_dst[g * stride_dst + o * n + i] = ::rnnoise::kWeightsScale * static_cast( tensor_src[i * stride_src + g * output_size + o]); } } } return tensor_dst; } // Computes the output for the update or the reset gate. // Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where // - `g`: output gate vector // - `W`: weights matrix // - `i`: input vector // - `R`: recurrent weights matrix // - `s`: state gate vector // - `b`: bias vector void ComputeUpdateResetGate(int input_size, int output_size, const VectorMath& vector_math, rtc::ArrayView input, rtc::ArrayView state, rtc::ArrayView bias, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, rtc::ArrayView gate) { RTC_DCHECK_EQ(input.size(), input_size); RTC_DCHECK_EQ(state.size(), output_size); RTC_DCHECK_EQ(bias.size(), output_size); RTC_DCHECK_EQ(weights.size(), input_size * output_size); RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated. for (int o = 0; o < output_size; ++o) { float x = bias[o]; x += vector_math.DotProduct(input, weights.subview(o * input_size, input_size)); x += vector_math.DotProduct( state, recurrent_weights.subview(o * output_size, output_size)); gate[o] = ::rnnoise::SigmoidApproximated(x); } } // Computes the output for the state gate. // Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where // - `s'`: output state gate vector // - `s`: previous state gate vector // - `u`: update gate vector // - `W`: weights matrix // - `i`: input vector // - `R`: recurrent weights matrix // - `r`: reset gate vector // - `b`: bias vector // - `.*` element-wise product void ComputeStateGate(int input_size, int output_size, const VectorMath& vector_math, rtc::ArrayView input, rtc::ArrayView update, rtc::ArrayView reset, rtc::ArrayView bias, rtc::ArrayView weights, rtc::ArrayView recurrent_weights, rtc::ArrayView state) { RTC_DCHECK_EQ(input.size(), input_size); RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated. RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated. RTC_DCHECK_EQ(bias.size(), output_size); RTC_DCHECK_EQ(weights.size(), input_size * output_size); RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); RTC_DCHECK_EQ(state.size(), output_size); std::array reset_x_state; for (int o = 0; o < output_size; ++o) { reset_x_state[o] = state[o] * reset[o]; } for (int o = 0; o < output_size; ++o) { float x = bias[o]; x += vector_math.DotProduct(input, weights.subview(o * input_size, input_size)); x += vector_math.DotProduct( {reset_x_state.data(), static_cast(output_size)}, recurrent_weights.subview(o * output_size, output_size)); state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x); } } } // namespace GatedRecurrentLayer::GatedRecurrentLayer( const int input_size, const int output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, const rtc::ArrayView recurrent_weights, const AvailableCpuFeatures& cpu_features, absl::string_view layer_name) : input_size_(input_size), output_size_(output_size), bias_(PreprocessGruTensor(bias, output_size)), weights_(PreprocessGruTensor(weights, output_size)), recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)), vector_math_(cpu_features) { RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits) << "Insufficient GRU layer over-allocation (" << layer_name << ")."; RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size()) << "Mismatching output size and bias terms array size (" << layer_name << ")."; RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size()) << "Mismatching input-output size and weight coefficients array size (" << layer_name << ")."; RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_, recurrent_weights_.size()) << "Mismatching input-output size and recurrent weight coefficients array" " size (" << layer_name << ")."; Reset(); } GatedRecurrentLayer::~GatedRecurrentLayer() = default; void GatedRecurrentLayer::Reset() { state_.fill(0.f); } void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { RTC_DCHECK_EQ(input.size(), input_size_); // The tensors below are organized as a sequence of flattened tensors for the // `update`, `reset` and `state` gates. rtc::ArrayView bias(bias_); rtc::ArrayView weights(weights_); rtc::ArrayView recurrent_weights(recurrent_weights_); // Strides to access to the flattened tensors for a specific gate. const int stride_weights = input_size_ * output_size_; const int stride_recurrent_weights = output_size_ * output_size_; rtc::ArrayView state(state_.data(), output_size_); // Update gate. std::array update; ComputeUpdateResetGate( input_size_, output_size_, vector_math_, input, state, bias.subview(0, output_size_), weights.subview(0, stride_weights), recurrent_weights.subview(0, stride_recurrent_weights), update); // Reset gate. std::array reset; ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state, bias.subview(output_size_, output_size_), weights.subview(stride_weights, stride_weights), recurrent_weights.subview(stride_recurrent_weights, stride_recurrent_weights), reset); // State gate. ComputeStateGate(input_size_, output_size_, vector_math_, input, update, reset, bias.subview(2 * output_size_, output_size_), weights.subview(2 * stride_weights, stride_weights), recurrent_weights.subview(2 * stride_recurrent_weights, stride_recurrent_weights), state); } } // namespace rnn_vad } // namespace webrtc