/* * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "modules/audio_processing/transient/transient_suppressor_impl.h" #include #include #include #include #include #include #include #include #include "common_audio/include/audio_util.h" #include "common_audio/signal_processing/include/signal_processing_library.h" #include "common_audio/third_party/ooura/fft_size_256/fft4g.h" #include "modules/audio_processing/transient/common.h" #include "modules/audio_processing/transient/transient_detector.h" #include "modules/audio_processing/transient/transient_suppressor.h" #include "modules/audio_processing/transient/windows_private.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" namespace webrtc { static const float kMeanIIRCoefficient = 0.5f; // TODO(aluebs): Check if these values work also for 48kHz. static const size_t kMinVoiceBin = 3; static const size_t kMaxVoiceBin = 60; namespace { float ComplexMagnitude(float a, float b) { return std::abs(a) + std::abs(b); } std::string GetVadModeLabel(TransientSuppressor::VadMode vad_mode) { switch (vad_mode) { case TransientSuppressor::VadMode::kDefault: return "default"; case TransientSuppressor::VadMode::kRnnVad: return "RNN VAD"; case TransientSuppressor::VadMode::kNoVad: return "no VAD"; } } } // namespace TransientSuppressorImpl::TransientSuppressorImpl(VadMode vad_mode, int sample_rate_hz, int detector_rate_hz, int num_channels) : vad_mode_(vad_mode), voice_probability_delay_unit_(/*delay_num_samples=*/0, sample_rate_hz), analyzed_audio_is_silent_(false), data_length_(0), detection_length_(0), analysis_length_(0), buffer_delay_(0), complex_analysis_length_(0), num_channels_(0), window_(NULL), detector_smoothed_(0.f), keypress_counter_(0), chunks_since_keypress_(0), detection_enabled_(false), suppression_enabled_(false), use_hard_restoration_(false), chunks_since_voice_change_(0), seed_(182), using_reference_(false) { RTC_LOG(LS_INFO) << "VAD mode: " << GetVadModeLabel(vad_mode_); Initialize(sample_rate_hz, detector_rate_hz, num_channels); } TransientSuppressorImpl::~TransientSuppressorImpl() {} void TransientSuppressorImpl::Initialize(int sample_rate_hz, int detection_rate_hz, int num_channels) { RTC_DCHECK(sample_rate_hz == ts::kSampleRate8kHz || sample_rate_hz == ts::kSampleRate16kHz || sample_rate_hz == ts::kSampleRate32kHz || sample_rate_hz == ts::kSampleRate48kHz); RTC_DCHECK(detection_rate_hz == ts::kSampleRate8kHz || detection_rate_hz == ts::kSampleRate16kHz || detection_rate_hz == ts::kSampleRate32kHz || detection_rate_hz == ts::kSampleRate48kHz); RTC_DCHECK_GT(num_channels, 0); switch (sample_rate_hz) { case ts::kSampleRate8kHz: analysis_length_ = 128u; window_ = kBlocks80w128; break; case ts::kSampleRate16kHz: analysis_length_ = 256u; window_ = kBlocks160w256; break; case ts::kSampleRate32kHz: analysis_length_ = 512u; window_ = kBlocks320w512; break; case ts::kSampleRate48kHz: analysis_length_ = 1024u; window_ = kBlocks480w1024; break; default: RTC_DCHECK_NOTREACHED(); return; } detector_.reset(new TransientDetector(detection_rate_hz)); data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000; RTC_DCHECK_LE(data_length_, analysis_length_); buffer_delay_ = analysis_length_ - data_length_; voice_probability_delay_unit_.Initialize(/*delay_num_samples=*/buffer_delay_, sample_rate_hz); complex_analysis_length_ = analysis_length_ / 2 + 1; RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin); num_channels_ = num_channels; in_buffer_.reset(new float[analysis_length_ * num_channels_]); memset(in_buffer_.get(), 0, analysis_length_ * num_channels_ * sizeof(in_buffer_[0])); detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000; detection_buffer_.reset(new float[detection_length_]); memset(detection_buffer_.get(), 0, detection_length_ * sizeof(detection_buffer_[0])); out_buffer_.reset(new float[analysis_length_ * num_channels_]); memset(out_buffer_.get(), 0, analysis_length_ * num_channels_ * sizeof(out_buffer_[0])); // ip[0] must be zero to trigger initialization using rdft(). size_t ip_length = 2 + sqrtf(analysis_length_); ip_.reset(new size_t[ip_length]()); memset(ip_.get(), 0, ip_length * sizeof(ip_[0])); wfft_.reset(new float[complex_analysis_length_ - 1]); memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0])); spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]); memset(spectral_mean_.get(), 0, complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0])); fft_buffer_.reset(new float[analysis_length_ + 2]); memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0])); magnitudes_.reset(new float[complex_analysis_length_]); memset(magnitudes_.get(), 0, complex_analysis_length_ * sizeof(magnitudes_[0])); mean_factor_.reset(new float[complex_analysis_length_]); static const float kFactorHeight = 10.f; static const float kLowSlope = 1.f; static const float kHighSlope = 0.3f; for (size_t i = 0; i < complex_analysis_length_; ++i) { mean_factor_[i] = kFactorHeight / (1.f + std::exp(kLowSlope * static_cast(i - kMinVoiceBin))) + kFactorHeight / (1.f + std::exp(kHighSlope * static_cast(kMaxVoiceBin - i))); } detector_smoothed_ = 0.f; keypress_counter_ = 0; chunks_since_keypress_ = 0; detection_enabled_ = false; suppression_enabled_ = false; use_hard_restoration_ = false; chunks_since_voice_change_ = 0; seed_ = 182; using_reference_ = false; } float TransientSuppressorImpl::Suppress(float* data, size_t data_length, int num_channels, const float* detection_data, size_t detection_length, const float* reference_data, size_t reference_length, float voice_probability, bool key_pressed) { if (!data || data_length != data_length_ || num_channels != num_channels_ || detection_length != detection_length_ || voice_probability < 0 || voice_probability > 1) { // The audio is not modified, so the voice probability is returned as is // (delay not applied). return voice_probability; } UpdateKeypress(key_pressed); UpdateBuffers(data); if (detection_enabled_) { UpdateRestoration(voice_probability); if (!detection_data) { // Use the input data of the first channel if special detection data is // not supplied. detection_data = &in_buffer_[buffer_delay_]; } float detector_result = detector_->Detect(detection_data, detection_length, reference_data, reference_length); if (detector_result < 0) { // The audio is not modified, so the voice probability is returned as is // (delay not applied). return voice_probability; } using_reference_ = detector_->using_reference(); // `detector_smoothed_` follows the `detector_result` when this last one is // increasing, but has an exponential decaying tail to be able to suppress // the ringing of keyclicks. float smooth_factor = using_reference_ ? 0.6 : 0.1; detector_smoothed_ = detector_result >= detector_smoothed_ ? detector_result : smooth_factor * detector_smoothed_ + (1 - smooth_factor) * detector_result; for (int i = 0; i < num_channels_; ++i) { Suppress(&in_buffer_[i * analysis_length_], &spectral_mean_[i * complex_analysis_length_], &out_buffer_[i * analysis_length_]); } } // If the suppression isn't enabled, we use the in buffer to delay the signal // appropriately. This also gives time for the out buffer to be refreshed with // new data between detection and suppression getting enabled. for (int i = 0; i < num_channels_; ++i) { memcpy(&data[i * data_length_], suppression_enabled_ ? &out_buffer_[i * analysis_length_] : &in_buffer_[i * analysis_length_], data_length_ * sizeof(*data)); } // The audio has been modified, return the delayed voice probability. return voice_probability_delay_unit_.Delay(voice_probability); } // This should only be called when detection is enabled. UpdateBuffers() must // have been called. At return, `out_buffer_` will be filled with the // processed output. void TransientSuppressorImpl::Suppress(float* in_ptr, float* spectral_mean, float* out_ptr) { // Go to frequency domain. for (size_t i = 0; i < analysis_length_; ++i) { // TODO(aluebs): Rename windows fft_buffer_[i] = in_ptr[i] * window_[i]; } WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get()); // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end // for convenience. fft_buffer_[analysis_length_] = fft_buffer_[1]; fft_buffer_[analysis_length_ + 1] = 0.f; fft_buffer_[1] = 0.f; for (size_t i = 0; i < complex_analysis_length_; ++i) { magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2], fft_buffer_[i * 2 + 1]); } // Restore audio if necessary. if (suppression_enabled_) { if (use_hard_restoration_) { HardRestoration(spectral_mean); } else { SoftRestoration(spectral_mean); } } // Update the spectral mean. for (size_t i = 0; i < complex_analysis_length_; ++i) { spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] + kMeanIIRCoefficient * magnitudes_[i]; } // Back to time domain. // Put R[n/2] back in fft_buffer_[1]. fft_buffer_[1] = fft_buffer_[analysis_length_]; WebRtc_rdft(analysis_length_, -1, fft_buffer_.get(), ip_.get(), wfft_.get()); const float fft_scaling = 2.f / analysis_length_; for (size_t i = 0; i < analysis_length_; ++i) { out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling; } } void TransientSuppressorImpl::UpdateKeypress(bool key_pressed) { const int kKeypressPenalty = 1000 / ts::kChunkSizeMs; const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs; const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs; // 4 seconds. if (key_pressed) { keypress_counter_ += kKeypressPenalty; chunks_since_keypress_ = 0; detection_enabled_ = true; } keypress_counter_ = std::max(0, keypress_counter_ - 1); if (keypress_counter_ > kIsTypingThreshold) { if (!suppression_enabled_) { RTC_LOG(LS_INFO) << "[ts] Transient suppression is now enabled."; } suppression_enabled_ = true; keypress_counter_ = 0; } if (detection_enabled_ && ++chunks_since_keypress_ > kChunksUntilNotTyping) { if (suppression_enabled_) { RTC_LOG(LS_INFO) << "[ts] Transient suppression is now disabled."; } detection_enabled_ = false; suppression_enabled_ = false; keypress_counter_ = 0; } } void TransientSuppressorImpl::UpdateRestoration(float voice_probability) { bool not_voiced; switch (vad_mode_) { case TransientSuppressor::VadMode::kDefault: { constexpr float kVoiceThreshold = 0.02f; not_voiced = voice_probability < kVoiceThreshold; break; } case TransientSuppressor::VadMode::kRnnVad: { constexpr float kVoiceThreshold = 0.7f; not_voiced = voice_probability < kVoiceThreshold; break; } case TransientSuppressor::VadMode::kNoVad: // Always assume that voice is detected. not_voiced = false; break; } if (not_voiced == use_hard_restoration_) { chunks_since_voice_change_ = 0; } else { ++chunks_since_voice_change_; // Number of 10 ms frames to wait to transition to and from hard // restoration. constexpr int kHardRestorationOffsetDelay = 3; constexpr int kHardRestorationOnsetDelay = 80; if ((use_hard_restoration_ && chunks_since_voice_change_ > kHardRestorationOffsetDelay) || (!use_hard_restoration_ && chunks_since_voice_change_ > kHardRestorationOnsetDelay)) { use_hard_restoration_ = not_voiced; chunks_since_voice_change_ = 0; } } } // Shift buffers to make way for new data. Must be called after // `detection_enabled_` is updated by UpdateKeypress(). void TransientSuppressorImpl::UpdateBuffers(float* data) { // TODO(aluebs): Change to ring buffer. memmove(in_buffer_.get(), &in_buffer_[data_length_], (buffer_delay_ + (num_channels_ - 1) * analysis_length_) * sizeof(in_buffer_[0])); // Copy new chunk to buffer. for (int i = 0; i < num_channels_; ++i) { memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_], &data[i * data_length_], data_length_ * sizeof(*data)); } if (detection_enabled_) { // Shift previous chunk in out buffer. memmove(out_buffer_.get(), &out_buffer_[data_length_], (buffer_delay_ + (num_channels_ - 1) * analysis_length_) * sizeof(out_buffer_[0])); // Initialize new chunk in out buffer. for (int i = 0; i < num_channels_; ++i) { memset(&out_buffer_[buffer_delay_ + i * analysis_length_], 0, data_length_ * sizeof(out_buffer_[0])); } } } // Restores the unvoiced signal if a click is present. // Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds // the spectral mean. The attenuation depends on `detector_smoothed_`. // If a restoration takes place, the `magnitudes_` are updated to the new value. void TransientSuppressorImpl::HardRestoration(float* spectral_mean) { const float detector_result = 1.f - std::pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f); // To restore, we get the peaks in the spectrum. If higher than the previous // spectral mean we adjust them. for (size_t i = 0; i < complex_analysis_length_; ++i) { if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) { // RandU() generates values on [0, int16::max()] const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) / std::numeric_limits::max(); const float scaled_mean = detector_result * spectral_mean[i]; fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] + scaled_mean * cosf(phase); fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] + scaled_mean * sinf(phase); magnitudes_[i] = magnitudes_[i] - detector_result * (magnitudes_[i] - spectral_mean[i]); } } } // Restores the voiced signal if a click is present. // Attenuates by a certain factor every peak in the `fft_buffer_` that exceeds // the spectral mean and that is lower than some function of the current block // frequency mean. The attenuation depends on `detector_smoothed_`. // If a restoration takes place, the `magnitudes_` are updated to the new value. void TransientSuppressorImpl::SoftRestoration(float* spectral_mean) { // Get the spectral magnitude mean of the current block. float block_frequency_mean = 0; for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) { block_frequency_mean += magnitudes_[i]; } block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin); // To restore, we get the peaks in the spectrum. If higher than the // previous spectral mean and lower than a factor of the block mean // we adjust them. The factor is a double sigmoid that has a minimum in the // voice frequency range (300Hz - 3kHz). for (size_t i = 0; i < complex_analysis_length_; ++i) { if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 && (using_reference_ || magnitudes_[i] < block_frequency_mean * mean_factor_[i])) { const float new_magnitude = magnitudes_[i] - detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]); const float magnitude_ratio = new_magnitude / magnitudes_[i]; fft_buffer_[i * 2] *= magnitude_ratio; fft_buffer_[i * 2 + 1] *= magnitude_ratio; magnitudes_[i] = new_magnitude; } } } } // namespace webrtc