diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
commit | 43a97878ce14b72f0981164f87f2e35e14151312 (patch) | |
tree | 620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/libwebrtc/modules/audio_processing/test | |
parent | Initial commit. (diff) | |
download | firefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip |
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/test')
110 files changed, 15293 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/test/aec_dump_based_simulator.cc b/third_party/libwebrtc/modules/audio_processing/test/aec_dump_based_simulator.cc new file mode 100644 index 0000000000..ec35dd345c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/aec_dump_based_simulator.cc @@ -0,0 +1,654 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/aec_dump_based_simulator.h" + +#include <iostream> +#include <memory> + +#include "modules/audio_processing/echo_control_mobile_impl.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/aec_dump_based_simulator.h" +#include "modules/audio_processing/test/protobuf_utils.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" + +namespace webrtc { +namespace test { +namespace { + +// Verify output bitexactness for the fixed interface. +// TODO(peah): Check whether it would make sense to add a threshold +// to use for checking the bitexactness in a soft manner. +bool VerifyFixedBitExactness(const webrtc::audioproc::Stream& msg, + const Int16Frame& frame) { + if (sizeof(frame.data[0]) * frame.data.size() != msg.output_data().size()) { + return false; + } else { + const int16_t* frame_data = frame.data.data(); + for (int k = 0; k < frame.num_channels * frame.samples_per_channel; ++k) { + if (msg.output_data().data()[k] != frame_data[k]) { + return false; + } + } + } + return true; +} + +// Verify output bitexactness for the float interface. +bool VerifyFloatBitExactness(const webrtc::audioproc::Stream& msg, + const StreamConfig& out_config, + const ChannelBuffer<float>& out_buf) { + if (static_cast<size_t>(msg.output_channel_size()) != + out_config.num_channels() || + msg.output_channel(0).size() != out_config.num_frames()) { + return false; + } else { + for (int ch = 0; ch < msg.output_channel_size(); ++ch) { + for (size_t sample = 0; sample < out_config.num_frames(); ++sample) { + if (msg.output_channel(ch).data()[sample] != + out_buf.channels()[ch][sample]) { + return false; + } + } + } + } + return true; +} + +// Selectively reads the next proto-buf message from dump-file or string input. +// Returns a bool indicating whether a new message was available. +bool ReadNextMessage(bool use_dump_file, + FILE* dump_input_file, + std::stringstream& input, + webrtc::audioproc::Event& event_msg) { + if (use_dump_file) { + return ReadMessageFromFile(dump_input_file, &event_msg); + } + return ReadMessageFromString(&input, &event_msg); +} + +} // namespace + +AecDumpBasedSimulator::AecDumpBasedSimulator( + const SimulationSettings& settings, + rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder) + : AudioProcessingSimulator(settings, + std::move(audio_processing), + std::move(ap_builder)) { + MaybeOpenCallOrderFile(); +} + +AecDumpBasedSimulator::~AecDumpBasedSimulator() = default; + +void AecDumpBasedSimulator::PrepareProcessStreamCall( + const webrtc::audioproc::Stream& msg) { + if (msg.has_input_data()) { + // Fixed interface processing. + // Verify interface invariance. + RTC_CHECK(interface_used_ == InterfaceType::kFixedInterface || + interface_used_ == InterfaceType::kNotSpecified); + interface_used_ = InterfaceType::kFixedInterface; + + // Populate input buffer. + RTC_CHECK_EQ(sizeof(fwd_frame_.data[0]) * fwd_frame_.data.size(), + msg.input_data().size()); + memcpy(fwd_frame_.data.data(), msg.input_data().data(), + msg.input_data().size()); + } else { + // Float interface processing. + // Verify interface invariance. + RTC_CHECK(interface_used_ == InterfaceType::kFloatInterface || + interface_used_ == InterfaceType::kNotSpecified); + interface_used_ = InterfaceType::kFloatInterface; + + RTC_CHECK_EQ(in_buf_->num_channels(), + static_cast<size_t>(msg.input_channel_size())); + + // Populate input buffer. + for (size_t i = 0; i < in_buf_->num_channels(); ++i) { + RTC_CHECK_EQ(in_buf_->num_frames() * sizeof(*in_buf_->channels()[i]), + msg.input_channel(i).size()); + std::memcpy(in_buf_->channels()[i], msg.input_channel(i).data(), + msg.input_channel(i).size()); + } + } + + if (artificial_nearend_buffer_reader_) { + if (artificial_nearend_buffer_reader_->Read( + artificial_nearend_buf_.get())) { + if (msg.has_input_data()) { + int16_t* fwd_frame_data = fwd_frame_.data.data(); + for (size_t k = 0; k < in_buf_->num_frames(); ++k) { + fwd_frame_data[k] = rtc::saturated_cast<int16_t>( + fwd_frame_data[k] + + static_cast<int16_t>(32767 * + artificial_nearend_buf_->channels()[0][k])); + } + } else { + for (int i = 0; i < msg.input_channel_size(); ++i) { + for (size_t k = 0; k < in_buf_->num_frames(); ++k) { + in_buf_->channels()[i][k] += + artificial_nearend_buf_->channels()[0][k]; + in_buf_->channels()[i][k] = std::min( + 32767.f, std::max(-32768.f, in_buf_->channels()[i][k])); + } + } + } + } else { + if (!artificial_nearend_eof_reported_) { + std::cout << "The artificial nearend file ended before the recording."; + artificial_nearend_eof_reported_ = true; + } + } + } + + if (!settings_.use_stream_delay || *settings_.use_stream_delay) { + if (!settings_.stream_delay) { + if (msg.has_delay()) { + RTC_CHECK_EQ(AudioProcessing::kNoError, + ap_->set_stream_delay_ms(msg.delay())); + } + } else { + RTC_CHECK_EQ(AudioProcessing::kNoError, + ap_->set_stream_delay_ms(*settings_.stream_delay)); + } + } + + if (settings_.override_key_pressed.has_value()) { + // Key pressed state overridden. + ap_->set_stream_key_pressed(*settings_.override_key_pressed); + } else { + // Set the recorded key pressed state. + if (msg.has_keypress()) { + ap_->set_stream_key_pressed(msg.keypress()); + } + } + + // Level is always logged in AEC dumps. + RTC_CHECK(msg.has_level()); + aec_dump_mic_level_ = msg.level(); +} + +void AecDumpBasedSimulator::VerifyProcessStreamBitExactness( + const webrtc::audioproc::Stream& msg) { + if (bitexact_output_) { + if (interface_used_ == InterfaceType::kFixedInterface) { + bitexact_output_ = VerifyFixedBitExactness(msg, fwd_frame_); + } else { + bitexact_output_ = VerifyFloatBitExactness(msg, out_config_, *out_buf_); + } + } +} + +void AecDumpBasedSimulator::PrepareReverseProcessStreamCall( + const webrtc::audioproc::ReverseStream& msg) { + if (msg.has_data()) { + // Fixed interface processing. + // Verify interface invariance. + RTC_CHECK(interface_used_ == InterfaceType::kFixedInterface || + interface_used_ == InterfaceType::kNotSpecified); + interface_used_ = InterfaceType::kFixedInterface; + + // Populate input buffer. + RTC_CHECK_EQ(sizeof(rev_frame_.data[0]) * rev_frame_.data.size(), + msg.data().size()); + memcpy(rev_frame_.data.data(), msg.data().data(), msg.data().size()); + } else { + // Float interface processing. + // Verify interface invariance. + RTC_CHECK(interface_used_ == InterfaceType::kFloatInterface || + interface_used_ == InterfaceType::kNotSpecified); + interface_used_ = InterfaceType::kFloatInterface; + + RTC_CHECK_EQ(reverse_in_buf_->num_channels(), + static_cast<size_t>(msg.channel_size())); + + // Populate input buffer. + for (int i = 0; i < msg.channel_size(); ++i) { + RTC_CHECK_EQ(reverse_in_buf_->num_frames() * + sizeof(*reverse_in_buf_->channels()[i]), + msg.channel(i).size()); + std::memcpy(reverse_in_buf_->channels()[i], msg.channel(i).data(), + msg.channel(i).size()); + } + } +} + +void AecDumpBasedSimulator::Process() { + ConfigureAudioProcessor(); + + if (settings_.artificial_nearend_filename) { + std::unique_ptr<WavReader> artificial_nearend_file( + new WavReader(settings_.artificial_nearend_filename->c_str())); + + RTC_CHECK_EQ(1, artificial_nearend_file->num_channels()) + << "Only mono files for the artificial nearend are supported, " + "reverted to not using the artificial nearend file"; + + const int sample_rate_hz = artificial_nearend_file->sample_rate(); + artificial_nearend_buffer_reader_.reset( + new ChannelBufferWavReader(std::move(artificial_nearend_file))); + artificial_nearend_buf_.reset(new ChannelBuffer<float>( + rtc::CheckedDivExact(sample_rate_hz, kChunksPerSecond), 1)); + } + + const bool use_dump_file = !settings_.aec_dump_input_string.has_value(); + std::stringstream input; + if (use_dump_file) { + dump_input_file_ = + OpenFile(settings_.aec_dump_input_filename->c_str(), "rb"); + } else { + input << settings_.aec_dump_input_string.value(); + } + + webrtc::audioproc::Event event_msg; + int capture_frames_since_init = 0; + int init_index = 0; + while (ReadNextMessage(use_dump_file, dump_input_file_, input, event_msg)) { + SelectivelyToggleDataDumping(init_index, capture_frames_since_init); + HandleEvent(event_msg, capture_frames_since_init, init_index); + + // Perfom an early exit if the init block to process has been fully + // processed + if (finished_processing_specified_init_block_) { + break; + } + RTC_CHECK(!settings_.init_to_process || + *settings_.init_to_process >= init_index); + } + + if (use_dump_file) { + fclose(dump_input_file_); + } + + DetachAecDump(); +} + +void AecDumpBasedSimulator::Analyze() { + const bool use_dump_file = !settings_.aec_dump_input_string.has_value(); + std::stringstream input; + if (use_dump_file) { + dump_input_file_ = + OpenFile(settings_.aec_dump_input_filename->c_str(), "rb"); + } else { + input << settings_.aec_dump_input_string.value(); + } + + webrtc::audioproc::Event event_msg; + int num_capture_frames = 0; + int num_render_frames = 0; + int init_index = 0; + while (ReadNextMessage(use_dump_file, dump_input_file_, input, event_msg)) { + if (event_msg.type() == webrtc::audioproc::Event::INIT) { + ++init_index; + constexpr float kNumFramesPerSecond = 100.f; + float capture_time_seconds = num_capture_frames / kNumFramesPerSecond; + float render_time_seconds = num_render_frames / kNumFramesPerSecond; + + std::cout << "Inits:" << std::endl; + std::cout << init_index << ": -->" << std::endl; + std::cout << " Time:" << std::endl; + std::cout << " Capture: " << capture_time_seconds << " s (" + << num_capture_frames << " frames) " << std::endl; + std::cout << " Render: " << render_time_seconds << " s (" + << num_render_frames << " frames) " << std::endl; + } else if (event_msg.type() == webrtc::audioproc::Event::STREAM) { + ++num_capture_frames; + } else if (event_msg.type() == webrtc::audioproc::Event::REVERSE_STREAM) { + ++num_render_frames; + } + } + + if (use_dump_file) { + fclose(dump_input_file_); + } +} + +void AecDumpBasedSimulator::HandleEvent( + const webrtc::audioproc::Event& event_msg, + int& capture_frames_since_init, + int& init_index) { + switch (event_msg.type()) { + case webrtc::audioproc::Event::INIT: + RTC_CHECK(event_msg.has_init()); + ++init_index; + capture_frames_since_init = 0; + HandleMessage(event_msg.init(), init_index); + break; + case webrtc::audioproc::Event::STREAM: + RTC_CHECK(event_msg.has_stream()); + ++capture_frames_since_init; + HandleMessage(event_msg.stream()); + break; + case webrtc::audioproc::Event::REVERSE_STREAM: + RTC_CHECK(event_msg.has_reverse_stream()); + HandleMessage(event_msg.reverse_stream()); + break; + case webrtc::audioproc::Event::CONFIG: + RTC_CHECK(event_msg.has_config()); + HandleMessage(event_msg.config()); + break; + case webrtc::audioproc::Event::RUNTIME_SETTING: + HandleMessage(event_msg.runtime_setting()); + break; + case webrtc::audioproc::Event::UNKNOWN_EVENT: + RTC_CHECK_NOTREACHED(); + } +} + +void AecDumpBasedSimulator::HandleMessage( + const webrtc::audioproc::Config& msg) { + if (settings_.use_verbose_logging) { + std::cout << "Config at frame:" << std::endl; + std::cout << " Forward: " << get_num_process_stream_calls() << std::endl; + std::cout << " Reverse: " << get_num_reverse_process_stream_calls() + << std::endl; + } + + if (!settings_.discard_all_settings_in_aecdump) { + if (settings_.use_verbose_logging) { + std::cout << "Setting used in config:" << std::endl; + } + AudioProcessing::Config apm_config = ap_->GetConfig(); + + if (msg.has_aec_enabled() || settings_.use_aec) { + bool enable = settings_.use_aec ? *settings_.use_aec : msg.aec_enabled(); + apm_config.echo_canceller.enabled = enable; + if (settings_.use_verbose_logging) { + std::cout << " aec_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (msg.has_aecm_enabled() || settings_.use_aecm) { + bool enable = + settings_.use_aecm ? *settings_.use_aecm : msg.aecm_enabled(); + apm_config.echo_canceller.enabled |= enable; + apm_config.echo_canceller.mobile_mode = enable; + if (settings_.use_verbose_logging) { + std::cout << " aecm_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (msg.has_aecm_comfort_noise_enabled() && + msg.aecm_comfort_noise_enabled()) { + RTC_LOG(LS_ERROR) << "Ignoring deprecated setting: AECM comfort noise"; + } + + if (msg.has_aecm_routing_mode() && + static_cast<webrtc::EchoControlMobileImpl::RoutingMode>( + msg.aecm_routing_mode()) != EchoControlMobileImpl::kSpeakerphone) { + RTC_LOG(LS_ERROR) << "Ignoring deprecated setting: AECM routing mode: " + << msg.aecm_routing_mode(); + } + + if (msg.has_agc_enabled() || settings_.use_agc) { + bool enable = settings_.use_agc ? *settings_.use_agc : msg.agc_enabled(); + apm_config.gain_controller1.enabled = enable; + if (settings_.use_verbose_logging) { + std::cout << " agc_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (msg.has_agc_mode() || settings_.agc_mode) { + int mode = settings_.agc_mode ? *settings_.agc_mode : msg.agc_mode(); + apm_config.gain_controller1.mode = + static_cast<webrtc::AudioProcessing::Config::GainController1::Mode>( + mode); + if (settings_.use_verbose_logging) { + std::cout << " agc_mode: " << mode << std::endl; + } + } + + if (msg.has_agc_limiter_enabled() || settings_.use_agc_limiter) { + bool enable = settings_.use_agc_limiter ? *settings_.use_agc_limiter + : msg.agc_limiter_enabled(); + apm_config.gain_controller1.enable_limiter = enable; + if (settings_.use_verbose_logging) { + std::cout << " agc_limiter_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (settings_.use_agc2) { + bool enable = *settings_.use_agc2; + apm_config.gain_controller2.enabled = enable; + if (settings_.agc2_fixed_gain_db) { + apm_config.gain_controller2.fixed_digital.gain_db = + *settings_.agc2_fixed_gain_db; + } + if (settings_.use_verbose_logging) { + std::cout << " agc2_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (msg.has_noise_robust_agc_enabled()) { + apm_config.gain_controller1.analog_gain_controller.enabled = + settings_.use_analog_agc ? *settings_.use_analog_agc + : msg.noise_robust_agc_enabled(); + if (settings_.use_verbose_logging) { + std::cout << " noise_robust_agc_enabled: " + << (msg.noise_robust_agc_enabled() ? "true" : "false") + << std::endl; + } + } + + if (msg.has_transient_suppression_enabled() || settings_.use_ts) { + bool enable = settings_.use_ts ? *settings_.use_ts + : msg.transient_suppression_enabled(); + apm_config.transient_suppression.enabled = enable; + if (settings_.use_verbose_logging) { + std::cout << " transient_suppression_enabled: " + << (enable ? "true" : "false") << std::endl; + } + } + + if (msg.has_hpf_enabled() || settings_.use_hpf) { + bool enable = settings_.use_hpf ? *settings_.use_hpf : msg.hpf_enabled(); + apm_config.high_pass_filter.enabled = enable; + if (settings_.use_verbose_logging) { + std::cout << " hpf_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (msg.has_ns_enabled() || settings_.use_ns) { + bool enable = settings_.use_ns ? *settings_.use_ns : msg.ns_enabled(); + apm_config.noise_suppression.enabled = enable; + if (settings_.use_verbose_logging) { + std::cout << " ns_enabled: " << (enable ? "true" : "false") + << std::endl; + } + } + + if (msg.has_ns_level() || settings_.ns_level) { + int level = settings_.ns_level ? *settings_.ns_level : msg.ns_level(); + apm_config.noise_suppression.level = + static_cast<AudioProcessing::Config::NoiseSuppression::Level>(level); + if (settings_.use_verbose_logging) { + std::cout << " ns_level: " << level << std::endl; + } + } + + if (msg.has_pre_amplifier_enabled() || settings_.use_pre_amplifier) { + const bool enable = settings_.use_pre_amplifier + ? *settings_.use_pre_amplifier + : msg.pre_amplifier_enabled(); + apm_config.pre_amplifier.enabled = enable; + } + + if (msg.has_pre_amplifier_fixed_gain_factor() || + settings_.pre_amplifier_gain_factor) { + const float gain = settings_.pre_amplifier_gain_factor + ? *settings_.pre_amplifier_gain_factor + : msg.pre_amplifier_fixed_gain_factor(); + apm_config.pre_amplifier.fixed_gain_factor = gain; + } + + if (settings_.use_verbose_logging && msg.has_experiments_description() && + !msg.experiments_description().empty()) { + std::cout << " experiments not included by default in the simulation: " + << msg.experiments_description() << std::endl; + } + + ap_->ApplyConfig(apm_config); + } +} + +void AecDumpBasedSimulator::HandleMessage(const webrtc::audioproc::Init& msg, + int init_index) { + RTC_CHECK(msg.has_sample_rate()); + RTC_CHECK(msg.has_num_input_channels()); + RTC_CHECK(msg.has_num_reverse_channels()); + RTC_CHECK(msg.has_reverse_sample_rate()); + + // Do not perform the init if the init block to process is fully processed + if (settings_.init_to_process && *settings_.init_to_process < init_index) { + finished_processing_specified_init_block_ = true; + } + + MaybeOpenCallOrderFile(); + + if (settings_.use_verbose_logging) { + std::cout << "Init at frame:" << std::endl; + std::cout << " Forward: " << get_num_process_stream_calls() << std::endl; + std::cout << " Reverse: " << get_num_reverse_process_stream_calls() + << std::endl; + } + + int num_output_channels; + if (settings_.output_num_channels) { + num_output_channels = *settings_.output_num_channels; + } else { + num_output_channels = msg.has_num_output_channels() + ? msg.num_output_channels() + : msg.num_input_channels(); + } + + int output_sample_rate; + if (settings_.output_sample_rate_hz) { + output_sample_rate = *settings_.output_sample_rate_hz; + } else { + output_sample_rate = msg.has_output_sample_rate() ? msg.output_sample_rate() + : msg.sample_rate(); + } + + int num_reverse_output_channels; + if (settings_.reverse_output_num_channels) { + num_reverse_output_channels = *settings_.reverse_output_num_channels; + } else { + num_reverse_output_channels = msg.has_num_reverse_output_channels() + ? msg.num_reverse_output_channels() + : msg.num_reverse_channels(); + } + + int reverse_output_sample_rate; + if (settings_.reverse_output_sample_rate_hz) { + reverse_output_sample_rate = *settings_.reverse_output_sample_rate_hz; + } else { + reverse_output_sample_rate = msg.has_reverse_output_sample_rate() + ? msg.reverse_output_sample_rate() + : msg.reverse_sample_rate(); + } + + SetupBuffersConfigsOutputs( + msg.sample_rate(), output_sample_rate, msg.reverse_sample_rate(), + reverse_output_sample_rate, msg.num_input_channels(), num_output_channels, + msg.num_reverse_channels(), num_reverse_output_channels); +} + +void AecDumpBasedSimulator::HandleMessage( + const webrtc::audioproc::Stream& msg) { + if (call_order_output_file_) { + *call_order_output_file_ << "c"; + } + PrepareProcessStreamCall(msg); + ProcessStream(interface_used_ == InterfaceType::kFixedInterface); + VerifyProcessStreamBitExactness(msg); +} + +void AecDumpBasedSimulator::HandleMessage( + const webrtc::audioproc::ReverseStream& msg) { + if (call_order_output_file_) { + *call_order_output_file_ << "r"; + } + PrepareReverseProcessStreamCall(msg); + ProcessReverseStream(interface_used_ == InterfaceType::kFixedInterface); +} + +void AecDumpBasedSimulator::HandleMessage( + const webrtc::audioproc::RuntimeSetting& msg) { + RTC_CHECK(ap_.get()); + if (msg.has_capture_pre_gain()) { + // Handle capture pre-gain runtime setting only if not overridden. + const bool pre_amplifier_overridden = + (!settings_.use_pre_amplifier || *settings_.use_pre_amplifier) && + !settings_.pre_amplifier_gain_factor; + const bool capture_level_adjustment_overridden = + (!settings_.use_capture_level_adjustment || + *settings_.use_capture_level_adjustment) && + !settings_.pre_gain_factor; + if (pre_amplifier_overridden || capture_level_adjustment_overridden) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain( + msg.capture_pre_gain())); + } + } else if (msg.has_capture_post_gain()) { + // Handle capture post-gain runtime setting only if not overridden. + if ((!settings_.use_capture_level_adjustment || + *settings_.use_capture_level_adjustment) && + !settings_.post_gain_factor) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain( + msg.capture_pre_gain())); + } + } else if (msg.has_capture_fixed_post_gain()) { + // Handle capture fixed-post-gain runtime setting only if not overridden. + if ((!settings_.use_agc2 || *settings_.use_agc2) && + !settings_.agc2_fixed_gain_db) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureFixedPostGain( + msg.capture_fixed_post_gain())); + } + } else if (msg.has_playout_volume_change()) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreatePlayoutVolumeChange( + msg.playout_volume_change())); + } else if (msg.has_playout_audio_device_change()) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreatePlayoutAudioDeviceChange( + {msg.playout_audio_device_change().id(), + msg.playout_audio_device_change().max_volume()})); + } else if (msg.has_capture_output_used()) { + ap_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + msg.capture_output_used())); + } +} + +void AecDumpBasedSimulator::MaybeOpenCallOrderFile() { + if (settings_.call_order_output_filename.has_value()) { + const std::string filename = settings_.store_intermediate_output + ? *settings_.call_order_output_filename + + "_" + + std::to_string(output_reset_counter_) + : *settings_.call_order_output_filename; + call_order_output_file_ = std::make_unique<std::ofstream>(filename); + } +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/aec_dump_based_simulator.h b/third_party/libwebrtc/modules/audio_processing/test/aec_dump_based_simulator.h new file mode 100644 index 0000000000..e2c1f3e4ba --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/aec_dump_based_simulator.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_AEC_DUMP_BASED_SIMULATOR_H_ +#define MODULES_AUDIO_PROCESSING_TEST_AEC_DUMP_BASED_SIMULATOR_H_ + +#include <fstream> +#include <string> + +#include "modules/audio_processing/test/audio_processing_simulator.h" +#include "rtc_base/ignore_wundef.h" + +RTC_PUSH_IGNORING_WUNDEF() +#ifdef WEBRTC_ANDROID_PLATFORM_BUILD +#include "external/webrtc/webrtc/modules/audio_processing/debug.pb.h" +#else +#include "modules/audio_processing/debug.pb.h" +#endif +RTC_POP_IGNORING_WUNDEF() + +namespace webrtc { +namespace test { + +// Used to perform an audio processing simulation from an aec dump. +class AecDumpBasedSimulator final : public AudioProcessingSimulator { + public: + AecDumpBasedSimulator(const SimulationSettings& settings, + rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder); + + AecDumpBasedSimulator() = delete; + AecDumpBasedSimulator(const AecDumpBasedSimulator&) = delete; + AecDumpBasedSimulator& operator=(const AecDumpBasedSimulator&) = delete; + + ~AecDumpBasedSimulator() override; + + // Processes the messages in the aecdump file. + void Process() override; + + // Analyzes the data in the aecdump file and reports the resulting statistics. + void Analyze() override; + + private: + void HandleEvent(const webrtc::audioproc::Event& event_msg, + int& num_forward_chunks_processed, + int& init_index); + void HandleMessage(const webrtc::audioproc::Init& msg, int init_index); + void HandleMessage(const webrtc::audioproc::Stream& msg); + void HandleMessage(const webrtc::audioproc::ReverseStream& msg); + void HandleMessage(const webrtc::audioproc::Config& msg); + void HandleMessage(const webrtc::audioproc::RuntimeSetting& msg); + void PrepareProcessStreamCall(const webrtc::audioproc::Stream& msg); + void PrepareReverseProcessStreamCall( + const webrtc::audioproc::ReverseStream& msg); + void VerifyProcessStreamBitExactness(const webrtc::audioproc::Stream& msg); + void MaybeOpenCallOrderFile(); + enum InterfaceType { + kFixedInterface, + kFloatInterface, + kNotSpecified, + }; + + FILE* dump_input_file_; + std::unique_ptr<ChannelBuffer<float>> artificial_nearend_buf_; + std::unique_ptr<ChannelBufferWavReader> artificial_nearend_buffer_reader_; + bool artificial_nearend_eof_reported_ = false; + InterfaceType interface_used_ = InterfaceType::kNotSpecified; + std::unique_ptr<std::ofstream> call_order_output_file_; + bool finished_processing_specified_init_block_ = false; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_AEC_DUMP_BASED_SIMULATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/AndroidManifest.xml b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/AndroidManifest.xml new file mode 100644 index 0000000000..c6063b3d76 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/AndroidManifest.xml @@ -0,0 +1,30 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- BEGIN_INCLUDE(manifest) --> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="com.example.native_activity" + android:versionCode="1" + android:versionName="1.0"> + + <!-- This is the platform API where NativeActivity was introduced. --> + <uses-sdk android:minSdkVersion="8" /> + + <!-- This .apk has no Java code itself, so set hasCode to false. --> + <application android:label="@string/app_name" android:hasCode="false" android:debuggable="true"> + + <!-- Our activity is the built-in NativeActivity framework class. + This will take care of integrating with our NDK code. --> + <activity android:name="android.app.NativeActivity" + android:label="@string/app_name" + android:configChanges="orientation|keyboardHidden"> + <!-- Tell NativeActivity the name of or .so --> + <meta-data android:name="android.app.lib_name" + android:value="apmtest-activity" /> + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> + </application> + +</manifest> +<!-- END_INCLUDE(manifest) --> diff --git a/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/default.properties b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/default.properties new file mode 100644 index 0000000000..9a2c9f6c88 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/default.properties @@ -0,0 +1,11 @@ +# This file is automatically generated by Android Tools. +# Do not modify this file -- YOUR CHANGES WILL BE ERASED! +# +# This file must be checked in Version Control Systems. +# +# To customize properties used by the Ant build system use, +# "build.properties", and override values to adapt the script to your +# project structure. + +# Project target. +target=android-9 diff --git a/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/jni/main.c b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/jni/main.c new file mode 100644 index 0000000000..2e19635683 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/jni/main.c @@ -0,0 +1,307 @@ +/* + * Copyright (C) 2010 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +//BEGIN_INCLUDE(all) +#include <jni.h> +#include <errno.h> + +#include <EGL/egl.h> +#include <GLES/gl.h> + +#include <android/sensor.h> +#include <android/log.h> +#include <android_native_app_glue.h> + +#define LOGI(...) ((void)__android_log_print(ANDROID_LOG_INFO, "native-activity", __VA_ARGS__)) +#define LOGW(...) ((void)__android_log_print(ANDROID_LOG_WARN, "native-activity", __VA_ARGS__)) + +/** + * Our saved state data. + */ +struct saved_state { + float angle; + int32_t x; + int32_t y; +}; + +/** + * Shared state for our app. + */ +struct engine { + struct android_app* app; + + ASensorManager* sensorManager; + const ASensor* accelerometerSensor; + ASensorEventQueue* sensorEventQueue; + + int animating; + EGLDisplay display; + EGLSurface surface; + EGLContext context; + int32_t width; + int32_t height; + struct saved_state state; +}; + +/** + * Initialize an EGL context for the current display. + */ +static int engine_init_display(struct engine* engine) { + // initialize OpenGL ES and EGL + + /* + * Here specify the attributes of the desired configuration. + * Below, we select an EGLConfig with at least 8 bits per color + * component compatible with on-screen windows + */ + const EGLint attribs[] = { + EGL_SURFACE_TYPE, EGL_WINDOW_BIT, + EGL_BLUE_SIZE, 8, + EGL_GREEN_SIZE, 8, + EGL_RED_SIZE, 8, + EGL_NONE + }; + EGLint w, h, dummy, format; + EGLint numConfigs; + EGLConfig config; + EGLSurface surface; + EGLContext context; + + EGLDisplay display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + + eglInitialize(display, 0, 0); + + /* Here, the application chooses the configuration it desires. In this + * sample, we have a very simplified selection process, where we pick + * the first EGLConfig that matches our criteria */ + eglChooseConfig(display, attribs, &config, 1, &numConfigs); + + /* EGL_NATIVE_VISUAL_ID is an attribute of the EGLConfig that is + * guaranteed to be accepted by ANativeWindow_setBuffersGeometry(). + * As soon as we picked a EGLConfig, we can safely reconfigure the + * ANativeWindow buffers to match, using EGL_NATIVE_VISUAL_ID. */ + eglGetConfigAttrib(display, config, EGL_NATIVE_VISUAL_ID, &format); + + ANativeWindow_setBuffersGeometry(engine->app->window, 0, 0, format); + + surface = eglCreateWindowSurface(display, config, engine->app->window, NULL); + context = eglCreateContext(display, config, NULL, NULL); + + if (eglMakeCurrent(display, surface, surface, context) == EGL_FALSE) { + LOGW("Unable to eglMakeCurrent"); + return -1; + } + + eglQuerySurface(display, surface, EGL_WIDTH, &w); + eglQuerySurface(display, surface, EGL_HEIGHT, &h); + + engine->display = display; + engine->context = context; + engine->surface = surface; + engine->width = w; + engine->height = h; + engine->state.angle = 0; + + // Initialize GL state. + glHint(GL_PERSPECTIVE_CORRECTION_HINT, GL_FASTEST); + glEnable(GL_CULL_FACE); + glShadeModel(GL_SMOOTH); + glDisable(GL_DEPTH_TEST); + + return 0; +} + +/** + * Just the current frame in the display. + */ +static void engine_draw_frame(struct engine* engine) { + if (engine->display == NULL) { + // No display. + return; + } + + // Just fill the screen with a color. + glClearColor(((float)engine->state.x)/engine->width, engine->state.angle, + ((float)engine->state.y)/engine->height, 1); + glClear(GL_COLOR_BUFFER_BIT); + + eglSwapBuffers(engine->display, engine->surface); +} + +/** + * Tear down the EGL context currently associated with the display. + */ +static void engine_term_display(struct engine* engine) { + if (engine->display != EGL_NO_DISPLAY) { + eglMakeCurrent(engine->display, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); + if (engine->context != EGL_NO_CONTEXT) { + eglDestroyContext(engine->display, engine->context); + } + if (engine->surface != EGL_NO_SURFACE) { + eglDestroySurface(engine->display, engine->surface); + } + eglTerminate(engine->display); + } + engine->animating = 0; + engine->display = EGL_NO_DISPLAY; + engine->context = EGL_NO_CONTEXT; + engine->surface = EGL_NO_SURFACE; +} + +/** + * Process the next input event. + */ +static int32_t engine_handle_input(struct android_app* app, AInputEvent* event) { + struct engine* engine = (struct engine*)app->userData; + if (AInputEvent_getType(event) == AINPUT_EVENT_TYPE_MOTION) { + engine->animating = 1; + engine->state.x = AMotionEvent_getX(event, 0); + engine->state.y = AMotionEvent_getY(event, 0); + return 1; + } + return 0; +} + +/** + * Process the next main command. + */ +static void engine_handle_cmd(struct android_app* app, int32_t cmd) { + struct engine* engine = (struct engine*)app->userData; + switch (cmd) { + case APP_CMD_SAVE_STATE: + // The system has asked us to save our current state. Do so. + engine->app->savedState = malloc(sizeof(struct saved_state)); + *((struct saved_state*)engine->app->savedState) = engine->state; + engine->app->savedStateSize = sizeof(struct saved_state); + break; + case APP_CMD_INIT_WINDOW: + // The window is being shown, get it ready. + if (engine->app->window != NULL) { + engine_init_display(engine); + engine_draw_frame(engine); + } + break; + case APP_CMD_TERM_WINDOW: + // The window is being hidden or closed, clean it up. + engine_term_display(engine); + break; + case APP_CMD_GAINED_FOCUS: + // When our app gains focus, we start monitoring the accelerometer. + if (engine->accelerometerSensor != NULL) { + ASensorEventQueue_enableSensor(engine->sensorEventQueue, + engine->accelerometerSensor); + // We'd like to get 60 events per second (in us). + ASensorEventQueue_setEventRate(engine->sensorEventQueue, + engine->accelerometerSensor, (1000L/60)*1000); + } + break; + case APP_CMD_LOST_FOCUS: + // When our app loses focus, we stop monitoring the accelerometer. + // This is to avoid consuming battery while not being used. + if (engine->accelerometerSensor != NULL) { + ASensorEventQueue_disableSensor(engine->sensorEventQueue, + engine->accelerometerSensor); + } + // Also stop animating. + engine->animating = 0; + engine_draw_frame(engine); + break; + } +} + +/** + * This is the main entry point of a native application that is using + * android_native_app_glue. It runs in its own thread, with its own + * event loop for receiving input events and doing other things. + */ +void android_main(struct android_app* state) { + struct engine engine; + + // Make sure glue isn't stripped. + app_dummy(); + + memset(&engine, 0, sizeof(engine)); + state->userData = &engine; + state->onAppCmd = engine_handle_cmd; + state->onInputEvent = engine_handle_input; + engine.app = state; + + // Prepare to monitor accelerometer + engine.sensorManager = ASensorManager_getInstance(); + engine.accelerometerSensor = ASensorManager_getDefaultSensor(engine.sensorManager, + ASENSOR_TYPE_ACCELEROMETER); + engine.sensorEventQueue = ASensorManager_createEventQueue(engine.sensorManager, + state->looper, LOOPER_ID_USER, NULL, NULL); + + if (state->savedState != NULL) { + // We are starting with a previous saved state; restore from it. + engine.state = *(struct saved_state*)state->savedState; + } + + // loop waiting for stuff to do. + + while (1) { + // Read all pending events. + int ident; + int events; + struct android_poll_source* source; + + // If not animating, we will block forever waiting for events. + // If animating, we loop until all events are read, then continue + // to draw the next frame of animation. + while ((ident=ALooper_pollAll(engine.animating ? 0 : -1, NULL, &events, + (void**)&source)) >= 0) { + + // Process this event. + if (source != NULL) { + source->process(state, source); + } + + // If a sensor has data, process it now. + if (ident == LOOPER_ID_USER) { + if (engine.accelerometerSensor != NULL) { + ASensorEvent event; + while (ASensorEventQueue_getEvents(engine.sensorEventQueue, + &event, 1) > 0) { + LOGI("accelerometer: x=%f y=%f z=%f", + event.acceleration.x, event.acceleration.y, + event.acceleration.z); + } + } + } + + // Check if we are exiting. + if (state->destroyRequested != 0) { + engine_term_display(&engine); + return; + } + } + + if (engine.animating) { + // Done with events; draw next animation frame. + engine.state.angle += .01f; + if (engine.state.angle > 1) { + engine.state.angle = 0; + } + + // Drawing is throttled to the screen update rate, so there + // is no need to do timing here. + engine_draw_frame(&engine); + } + } +} +//END_INCLUDE(all) diff --git a/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/res/values/strings.xml b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/res/values/strings.xml new file mode 100644 index 0000000000..d0bd0f3051 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/android/apmtest/res/values/strings.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="utf-8"?> +<resources> + <string name="app_name">apmtest</string> +</resources> diff --git a/third_party/libwebrtc/modules/audio_processing/test/api_call_statistics.cc b/third_party/libwebrtc/modules/audio_processing/test/api_call_statistics.cc new file mode 100644 index 0000000000..ee8a308596 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/api_call_statistics.cc @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/api_call_statistics.h" + +#include <algorithm> +#include <fstream> +#include <iostream> +#include <limits> +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "rtc_base/time_utils.h" + +namespace webrtc { +namespace test { + +void ApiCallStatistics::Add(int64_t duration_nanos, CallType call_type) { + calls_.push_back(CallData(duration_nanos, call_type)); +} + +void ApiCallStatistics::PrintReport() const { + int64_t min_render = std::numeric_limits<int64_t>::max(); + int64_t min_capture = std::numeric_limits<int64_t>::max(); + int64_t max_render = 0; + int64_t max_capture = 0; + int64_t sum_render = 0; + int64_t sum_capture = 0; + int64_t num_render = 0; + int64_t num_capture = 0; + int64_t avg_render = 0; + int64_t avg_capture = 0; + + for (auto v : calls_) { + if (v.call_type == CallType::kRender) { + ++num_render; + min_render = std::min(min_render, v.duration_nanos); + max_render = std::max(max_render, v.duration_nanos); + sum_render += v.duration_nanos; + } else { + ++num_capture; + min_capture = std::min(min_capture, v.duration_nanos); + max_capture = std::max(max_capture, v.duration_nanos); + sum_capture += v.duration_nanos; + } + } + min_render /= rtc::kNumNanosecsPerMicrosec; + max_render /= rtc::kNumNanosecsPerMicrosec; + sum_render /= rtc::kNumNanosecsPerMicrosec; + min_capture /= rtc::kNumNanosecsPerMicrosec; + max_capture /= rtc::kNumNanosecsPerMicrosec; + sum_capture /= rtc::kNumNanosecsPerMicrosec; + avg_render = num_render > 0 ? sum_render / num_render : 0; + avg_capture = num_capture > 0 ? sum_capture / num_capture : 0; + + std::cout << std::endl + << "Total time: " << (sum_capture + sum_render) * 1e-6 << " s" + << std::endl + << " Render API calls:" << std::endl + << " min: " << min_render << " us" << std::endl + << " max: " << max_render << " us" << std::endl + << " avg: " << avg_render << " us" << std::endl + << " Capture API calls:" << std::endl + << " min: " << min_capture << " us" << std::endl + << " max: " << max_capture << " us" << std::endl + << " avg: " << avg_capture << " us" << std::endl; +} + +void ApiCallStatistics::WriteReportToFile(absl::string_view filename) const { + std::unique_ptr<std::ofstream> out = + std::make_unique<std::ofstream>(std::string(filename)); + for (auto v : calls_) { + if (v.call_type == CallType::kRender) { + *out << "render, "; + } else { + *out << "capture, "; + } + *out << (v.duration_nanos / rtc::kNumNanosecsPerMicrosec) << std::endl; + } +} + +ApiCallStatistics::CallData::CallData(int64_t duration_nanos, + CallType call_type) + : duration_nanos(duration_nanos), call_type(call_type) {} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/api_call_statistics.h b/third_party/libwebrtc/modules/audio_processing/test/api_call_statistics.h new file mode 100644 index 0000000000..8fced104f9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/api_call_statistics.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_API_CALL_STATISTICS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_API_CALL_STATISTICS_H_ + +#include <vector> + +#include "absl/strings/string_view.h" + +namespace webrtc { +namespace test { + +// Collects statistics about the API call durations. +class ApiCallStatistics { + public: + enum class CallType { kRender, kCapture }; + + // Adds a new datapoint. + void Add(int64_t duration_nanos, CallType call_type); + + // Prints out a report of the statistics. + void PrintReport() const; + + // Writes the call information to a file. + void WriteReportToFile(absl::string_view filename) const; + + private: + struct CallData { + CallData(int64_t duration_nanos, CallType call_type); + int64_t duration_nanos; + CallType call_type; + }; + std::vector<CallData> calls_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_API_CALL_STATISTICS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/apmtest.m b/third_party/libwebrtc/modules/audio_processing/test/apmtest.m new file mode 100644 index 0000000000..1c8183c3ec --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/apmtest.m @@ -0,0 +1,365 @@ +% +% Copyright (c) 2011 The WebRTC project authors. All Rights Reserved. +% +% Use of this source code is governed by a BSD-style license +% that can be found in the LICENSE file in the root of the source +% tree. An additional intellectual property rights grant can be found +% in the file PATENTS. All contributing project authors may +% be found in the AUTHORS file in the root of the source tree. +% + +function apmtest(task, testname, filepath, casenumber, legacy) +%APMTEST is a tool to process APM file sets and easily display the output. +% APMTEST(TASK, TESTNAME, CASENUMBER) performs one of several TASKs: +% 'test' Processes the files to produce test output. +% 'list' Prints a list of cases in the test set, preceded by their +% CASENUMBERs. +% 'show' Uses spclab to show the test case specified by the +% CASENUMBER parameter. +% +% using a set of test files determined by TESTNAME: +% 'all' All tests. +% 'apm' The standard APM test set (default). +% 'apmm' The mobile APM test set. +% 'aec' The AEC test set. +% 'aecm' The AECM test set. +% 'agc' The AGC test set. +% 'ns' The NS test set. +% 'vad' The VAD test set. +% +% FILEPATH specifies the path to the test data files. +% +% CASENUMBER can be used to select a single test case. Omit CASENUMBER, +% or set to zero, to use all test cases. +% + +if nargin < 5 || isempty(legacy) + % Set to true to run old VQE recordings. + legacy = false; +end + +if nargin < 4 || isempty(casenumber) + casenumber = 0; +end + +if nargin < 3 || isempty(filepath) + filepath = 'data/'; +end + +if nargin < 2 || isempty(testname) + testname = 'all'; +end + +if nargin < 1 || isempty(task) + task = 'test'; +end + +if ~strcmp(task, 'test') && ~strcmp(task, 'list') && ~strcmp(task, 'show') + error(['TASK ' task ' is not recognized']); +end + +if casenumber == 0 && strcmp(task, 'show') + error(['CASENUMBER must be specified for TASK ' task]); +end + +inpath = [filepath 'input/']; +outpath = [filepath 'output/']; +refpath = [filepath 'reference/']; + +if strcmp(testname, 'all') + tests = {'apm','apmm','aec','aecm','agc','ns','vad'}; +else + tests = {testname}; +end + +if legacy + progname = './test'; +else + progname = './process_test'; +end + +global farFile; +global nearFile; +global eventFile; +global delayFile; +global driftFile; + +if legacy + farFile = 'vqeFar.pcm'; + nearFile = 'vqeNear.pcm'; + eventFile = 'vqeEvent.dat'; + delayFile = 'vqeBuf.dat'; + driftFile = 'vqeDrift.dat'; +else + farFile = 'apm_far.pcm'; + nearFile = 'apm_near.pcm'; + eventFile = 'apm_event.dat'; + delayFile = 'apm_delay.dat'; + driftFile = 'apm_drift.dat'; +end + +simulateMode = false; +nErr = 0; +nCases = 0; +for i=1:length(tests) + simulateMode = false; + + if strcmp(tests{i}, 'apm') + testdir = ['apm/']; + outfile = ['out']; + if legacy + opt = ['-ec 1 -agc 2 -nc 2 -vad 3']; + else + opt = ['--no_progress -hpf' ... + ' -aec --drift_compensation -agc --fixed_digital' ... + ' -ns --ns_moderate -vad']; + end + + elseif strcmp(tests{i}, 'apm-swb') + simulateMode = true; + testdir = ['apm-swb/']; + outfile = ['out']; + if legacy + opt = ['-fs 32000 -ec 1 -agc 2 -nc 2']; + else + opt = ['--no_progress -fs 32000 -hpf' ... + ' -aec --drift_compensation -agc --adaptive_digital' ... + ' -ns --ns_moderate -vad']; + end + elseif strcmp(tests{i}, 'apmm') + testdir = ['apmm/']; + outfile = ['out']; + opt = ['-aec --drift_compensation -agc --fixed_digital -hpf -ns ' ... + '--ns_moderate']; + + else + error(['TESTNAME ' tests{i} ' is not recognized']); + end + + inpathtest = [inpath testdir]; + outpathtest = [outpath testdir]; + refpathtest = [refpath testdir]; + + if ~exist(inpathtest,'dir') + error(['Input directory ' inpathtest ' does not exist']); + end + + if ~exist(refpathtest,'dir') + warning(['Reference directory ' refpathtest ' does not exist']); + end + + [status, errMsg] = mkdir(outpathtest); + if (status == 0) + error(errMsg); + end + + [nErr, nCases] = recurseDir(inpathtest, outpathtest, refpathtest, outfile, ... + progname, opt, simulateMode, nErr, nCases, task, casenumber, legacy); + + if strcmp(task, 'test') || strcmp(task, 'show') + system(['rm ' farFile]); + system(['rm ' nearFile]); + if simulateMode == false + system(['rm ' eventFile]); + system(['rm ' delayFile]); + system(['rm ' driftFile]); + end + end +end + +if ~strcmp(task, 'list') + if nErr == 0 + fprintf(1, '\nAll files are bit-exact to reference\n', nErr); + else + fprintf(1, '\n%d files are NOT bit-exact to reference\n', nErr); + end +end + + +function [nErrOut, nCases] = recurseDir(inpath, outpath, refpath, ... + outfile, progname, opt, simulateMode, nErr, nCases, task, casenumber, ... + legacy) + +global farFile; +global nearFile; +global eventFile; +global delayFile; +global driftFile; + +dirs = dir(inpath); +nDirs = 0; +nErrOut = nErr; +for i=3:length(dirs) % skip . and .. + nDirs = nDirs + dirs(i).isdir; +end + + +if nDirs == 0 + nCases = nCases + 1; + + if casenumber == nCases || casenumber == 0 + + if strcmp(task, 'list') + fprintf([num2str(nCases) '. ' outfile '\n']) + else + vadoutfile = ['vad_' outfile '.dat']; + outfile = [outfile '.pcm']; + + % Check for VAD test + vadTest = 0; + if ~isempty(findstr(opt, '-vad')) + vadTest = 1; + if legacy + opt = [opt ' ' outpath vadoutfile]; + else + opt = [opt ' --vad_out_file ' outpath vadoutfile]; + end + end + + if exist([inpath 'vqeFar.pcm']) + system(['ln -s -f ' inpath 'vqeFar.pcm ' farFile]); + elseif exist([inpath 'apm_far.pcm']) + system(['ln -s -f ' inpath 'apm_far.pcm ' farFile]); + end + + if exist([inpath 'vqeNear.pcm']) + system(['ln -s -f ' inpath 'vqeNear.pcm ' nearFile]); + elseif exist([inpath 'apm_near.pcm']) + system(['ln -s -f ' inpath 'apm_near.pcm ' nearFile]); + end + + if exist([inpath 'vqeEvent.dat']) + system(['ln -s -f ' inpath 'vqeEvent.dat ' eventFile]); + elseif exist([inpath 'apm_event.dat']) + system(['ln -s -f ' inpath 'apm_event.dat ' eventFile]); + end + + if exist([inpath 'vqeBuf.dat']) + system(['ln -s -f ' inpath 'vqeBuf.dat ' delayFile]); + elseif exist([inpath 'apm_delay.dat']) + system(['ln -s -f ' inpath 'apm_delay.dat ' delayFile]); + end + + if exist([inpath 'vqeSkew.dat']) + system(['ln -s -f ' inpath 'vqeSkew.dat ' driftFile]); + elseif exist([inpath 'vqeDrift.dat']) + system(['ln -s -f ' inpath 'vqeDrift.dat ' driftFile]); + elseif exist([inpath 'apm_drift.dat']) + system(['ln -s -f ' inpath 'apm_drift.dat ' driftFile]); + end + + if simulateMode == false + command = [progname ' -o ' outpath outfile ' ' opt]; + else + if legacy + inputCmd = [' -in ' nearFile]; + else + inputCmd = [' -i ' nearFile]; + end + + if exist([farFile]) + if legacy + inputCmd = [' -if ' farFile inputCmd]; + else + inputCmd = [' -ir ' farFile inputCmd]; + end + end + command = [progname inputCmd ' -o ' outpath outfile ' ' opt]; + end + % This prevents MATLAB from using its own C libraries. + shellcmd = ['bash -c "unset LD_LIBRARY_PATH;']; + fprintf([command '\n']); + [status, result] = system([shellcmd command '"']); + fprintf(result); + + fprintf(['Reference file: ' refpath outfile '\n']); + + if vadTest == 1 + equal_to_ref = are_files_equal([outpath vadoutfile], ... + [refpath vadoutfile], ... + 'int8'); + if ~equal_to_ref + nErr = nErr + 1; + end + end + + [equal_to_ref, diffvector] = are_files_equal([outpath outfile], ... + [refpath outfile], ... + 'int16'); + if ~equal_to_ref + nErr = nErr + 1; + end + + if strcmp(task, 'show') + % Assume the last init gives the sample rate of interest. + str_idx = strfind(result, 'Sample rate:'); + fs = str2num(result(str_idx(end) + 13:str_idx(end) + 17)); + fprintf('Using %d Hz\n', fs); + + if exist([farFile]) + spclab(fs, farFile, nearFile, [refpath outfile], ... + [outpath outfile], diffvector); + %spclab(fs, diffvector); + else + spclab(fs, nearFile, [refpath outfile], [outpath outfile], ... + diffvector); + %spclab(fs, diffvector); + end + end + end + end +else + + for i=3:length(dirs) + if dirs(i).isdir + [nErr, nCases] = recurseDir([inpath dirs(i).name '/'], outpath, ... + refpath,[outfile '_' dirs(i).name], progname, opt, ... + simulateMode, nErr, nCases, task, casenumber, legacy); + end + end +end +nErrOut = nErr; + +function [are_equal, diffvector] = ... + are_files_equal(newfile, reffile, precision, diffvector) + +are_equal = false; +diffvector = 0; +if ~exist(newfile,'file') + warning(['Output file ' newfile ' does not exist']); + return +end + +if ~exist(reffile,'file') + warning(['Reference file ' reffile ' does not exist']); + return +end + +fid = fopen(newfile,'rb'); +new = fread(fid,inf,precision); +fclose(fid); + +fid = fopen(reffile,'rb'); +ref = fread(fid,inf,precision); +fclose(fid); + +if length(new) ~= length(ref) + warning('Reference is not the same length as output'); + minlength = min(length(new), length(ref)); + new = new(1:minlength); + ref = ref(1:minlength); +end +diffvector = new - ref; + +if isequal(new, ref) + fprintf([newfile ' is bit-exact to reference\n']); + are_equal = true; +else + if isempty(new) + warning([newfile ' is empty']); + return + end + snr = snrseg(new,ref,80); + fprintf('\n'); + are_equal = false; +end diff --git a/third_party/libwebrtc/modules/audio_processing/test/audio_buffer_tools.cc b/third_party/libwebrtc/modules/audio_processing/test/audio_buffer_tools.cc new file mode 100644 index 0000000000..64fb9c7ab1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audio_buffer_tools.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/audio_buffer_tools.h" + +#include <string.h> + +namespace webrtc { +namespace test { + +void SetupFrame(const StreamConfig& stream_config, + std::vector<float*>* frame, + std::vector<float>* frame_samples) { + frame_samples->resize(stream_config.num_channels() * + stream_config.num_frames()); + frame->resize(stream_config.num_channels()); + for (size_t ch = 0; ch < stream_config.num_channels(); ++ch) { + (*frame)[ch] = &(*frame_samples)[ch * stream_config.num_frames()]; + } +} + +void CopyVectorToAudioBuffer(const StreamConfig& stream_config, + rtc::ArrayView<const float> source, + AudioBuffer* destination) { + std::vector<float*> input; + std::vector<float> input_samples; + + SetupFrame(stream_config, &input, &input_samples); + + RTC_CHECK_EQ(input_samples.size(), source.size()); + memcpy(input_samples.data(), source.data(), + source.size() * sizeof(source[0])); + + destination->CopyFrom(&input[0], stream_config); +} + +void ExtractVectorFromAudioBuffer(const StreamConfig& stream_config, + AudioBuffer* source, + std::vector<float>* destination) { + std::vector<float*> output; + + SetupFrame(stream_config, &output, destination); + + source->CopyTo(stream_config, &output[0]); +} + +void FillBuffer(float value, AudioBuffer& audio_buffer) { + for (size_t ch = 0; ch < audio_buffer.num_channels(); ++ch) { + FillBufferChannel(value, ch, audio_buffer); + } +} + +void FillBufferChannel(float value, int channel, AudioBuffer& audio_buffer) { + RTC_CHECK_LT(channel, audio_buffer.num_channels()); + for (size_t i = 0; i < audio_buffer.num_frames(); ++i) { + audio_buffer.channels()[channel][i] = value; + } +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/audio_buffer_tools.h b/third_party/libwebrtc/modules/audio_processing/test/audio_buffer_tools.h new file mode 100644 index 0000000000..faac4bf9ff --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audio_buffer_tools.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_AUDIO_BUFFER_TOOLS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_AUDIO_BUFFER_TOOLS_H_ + +#include <vector> + +#include "api/array_view.h" +#include "modules/audio_processing/audio_buffer.h" +#include "modules/audio_processing/include/audio_processing.h" + +namespace webrtc { +namespace test { + +// Copies a vector into an audiobuffer. +void CopyVectorToAudioBuffer(const StreamConfig& stream_config, + rtc::ArrayView<const float> source, + AudioBuffer* destination); + +// Extracts a vector from an audiobuffer. +void ExtractVectorFromAudioBuffer(const StreamConfig& stream_config, + AudioBuffer* source, + std::vector<float>* destination); + +// Sets all values in `audio_buffer` to `value`. +void FillBuffer(float value, AudioBuffer& audio_buffer); + +// Sets all values channel `channel` for `audio_buffer` to `value`. +void FillBufferChannel(float value, int channel, AudioBuffer& audio_buffer); + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_AUDIO_BUFFER_TOOLS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/audio_processing_builder_for_testing.cc b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_builder_for_testing.cc new file mode 100644 index 0000000000..6bd266dc58 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_builder_for_testing.cc @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/audio_processing_builder_for_testing.h" + +#include <memory> +#include <utility> + +#include "modules/audio_processing/audio_processing_impl.h" + +namespace webrtc { + +AudioProcessingBuilderForTesting::AudioProcessingBuilderForTesting() = default; +AudioProcessingBuilderForTesting::~AudioProcessingBuilderForTesting() = default; + +#ifdef WEBRTC_EXCLUDE_AUDIO_PROCESSING_MODULE + +rtc::scoped_refptr<AudioProcessing> AudioProcessingBuilderForTesting::Create() { + return rtc::make_ref_counted<AudioProcessingImpl>( + config_, std::move(capture_post_processing_), + std::move(render_pre_processing_), std::move(echo_control_factory_), + std::move(echo_detector_), std::move(capture_analyzer_)); +} + +#else + +rtc::scoped_refptr<AudioProcessing> AudioProcessingBuilderForTesting::Create() { + AudioProcessingBuilder builder; + TransferOwnershipsToBuilder(&builder); + return builder.SetConfig(config_).Create(); +} + +#endif + +void AudioProcessingBuilderForTesting::TransferOwnershipsToBuilder( + AudioProcessingBuilder* builder) { + builder->SetCapturePostProcessing(std::move(capture_post_processing_)); + builder->SetRenderPreProcessing(std::move(render_pre_processing_)); + builder->SetEchoControlFactory(std::move(echo_control_factory_)); + builder->SetEchoDetector(std::move(echo_detector_)); + builder->SetCaptureAnalyzer(std::move(capture_analyzer_)); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/audio_processing_builder_for_testing.h b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_builder_for_testing.h new file mode 100644 index 0000000000..e73706c1b6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_builder_for_testing.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_AUDIO_PROCESSING_BUILDER_FOR_TESTING_H_ +#define MODULES_AUDIO_PROCESSING_TEST_AUDIO_PROCESSING_BUILDER_FOR_TESTING_H_ + +#include <list> +#include <memory> +#include <utility> +#include <vector> + +#include "modules/audio_processing/include/audio_processing.h" + +namespace webrtc { + +// Facilitates building of AudioProcessingImp for the tests. +class AudioProcessingBuilderForTesting { + public: + AudioProcessingBuilderForTesting(); + AudioProcessingBuilderForTesting(const AudioProcessingBuilderForTesting&) = + delete; + AudioProcessingBuilderForTesting& operator=( + const AudioProcessingBuilderForTesting&) = delete; + ~AudioProcessingBuilderForTesting(); + + // Sets the APM configuration. + AudioProcessingBuilderForTesting& SetConfig( + const AudioProcessing::Config& config) { + config_ = config; + return *this; + } + + // Sets the echo controller factory to inject when APM is created. + AudioProcessingBuilderForTesting& SetEchoControlFactory( + std::unique_ptr<EchoControlFactory> echo_control_factory) { + echo_control_factory_ = std::move(echo_control_factory); + return *this; + } + + // Sets the capture post-processing sub-module to inject when APM is created. + AudioProcessingBuilderForTesting& SetCapturePostProcessing( + std::unique_ptr<CustomProcessing> capture_post_processing) { + capture_post_processing_ = std::move(capture_post_processing); + return *this; + } + + // Sets the render pre-processing sub-module to inject when APM is created. + AudioProcessingBuilderForTesting& SetRenderPreProcessing( + std::unique_ptr<CustomProcessing> render_pre_processing) { + render_pre_processing_ = std::move(render_pre_processing); + return *this; + } + + // Sets the echo detector to inject when APM is created. + AudioProcessingBuilderForTesting& SetEchoDetector( + rtc::scoped_refptr<EchoDetector> echo_detector) { + echo_detector_ = std::move(echo_detector); + return *this; + } + + // Sets the capture analyzer sub-module to inject when APM is created. + AudioProcessingBuilderForTesting& SetCaptureAnalyzer( + std::unique_ptr<CustomAudioAnalyzer> capture_analyzer) { + capture_analyzer_ = std::move(capture_analyzer); + return *this; + } + + // Creates an APM instance with the specified config or the default one if + // unspecified. Injects the specified components transferring the ownership + // to the newly created APM instance - i.e., except for the config, the + // builder is reset to its initial state. + rtc::scoped_refptr<AudioProcessing> Create(); + + private: + // Transfers the ownership to a non-testing builder. + void TransferOwnershipsToBuilder(AudioProcessingBuilder* builder); + + AudioProcessing::Config config_; + std::unique_ptr<EchoControlFactory> echo_control_factory_; + std::unique_ptr<CustomProcessing> capture_post_processing_; + std::unique_ptr<CustomProcessing> render_pre_processing_; + rtc::scoped_refptr<EchoDetector> echo_detector_; + std::unique_ptr<CustomAudioAnalyzer> capture_analyzer_; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_AUDIO_PROCESSING_BUILDER_FOR_TESTING_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/audio_processing_simulator.cc b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_simulator.cc new file mode 100644 index 0000000000..b29027c35e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_simulator.cc @@ -0,0 +1,609 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/audio_processing_simulator.h" + +#include <algorithm> +#include <fstream> +#include <iostream> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/audio/echo_canceller3_config_json.h" +#include "api/audio/echo_canceller3_factory.h" +#include "api/audio/echo_detector_creator.h" +#include "modules/audio_processing/aec_dump/aec_dump_factory.h" +#include "modules/audio_processing/echo_control_mobile_impl.h" +#include "modules/audio_processing/include/audio_processing.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/fake_recording_device.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/json.h" +#include "rtc_base/strings/string_builder.h" + +namespace webrtc { +namespace test { +namespace { +// Helper for reading JSON from a file and parsing it to an AEC3 configuration. +EchoCanceller3Config ReadAec3ConfigFromJsonFile(absl::string_view filename) { + std::string json_string; + std::string s; + std::ifstream f(std::string(filename).c_str()); + if (f.fail()) { + std::cout << "Failed to open the file " << filename << std::endl; + RTC_CHECK_NOTREACHED(); + } + while (std::getline(f, s)) { + json_string += s; + } + + bool parsing_successful; + EchoCanceller3Config cfg; + Aec3ConfigFromJsonString(json_string, &cfg, &parsing_successful); + if (!parsing_successful) { + std::cout << "Parsing of json string failed: " << std::endl + << json_string << std::endl; + RTC_CHECK_NOTREACHED(); + } + RTC_CHECK(EchoCanceller3Config::Validate(&cfg)); + + return cfg; +} + +std::string GetIndexedOutputWavFilename(absl::string_view wav_name, + int counter) { + rtc::StringBuilder ss; + ss << wav_name.substr(0, wav_name.size() - 4) << "_" << counter + << wav_name.substr(wav_name.size() - 4); + return ss.Release(); +} + +void WriteEchoLikelihoodGraphFileHeader(std::ofstream* output_file) { + (*output_file) << "import numpy as np" << std::endl + << "import matplotlib.pyplot as plt" << std::endl + << "y = np.array(["; +} + +void WriteEchoLikelihoodGraphFileFooter(std::ofstream* output_file) { + (*output_file) << "])" << std::endl + << "if __name__ == '__main__':" << std::endl + << " x = np.arange(len(y))*.01" << std::endl + << " plt.plot(x, y)" << std::endl + << " plt.ylabel('Echo likelihood')" << std::endl + << " plt.xlabel('Time (s)')" << std::endl + << " plt.show()" << std::endl; +} + +// RAII class for execution time measurement. Updates the provided +// ApiCallStatistics based on the time between ScopedTimer creation and +// leaving the enclosing scope. +class ScopedTimer { + public: + ScopedTimer(ApiCallStatistics* api_call_statistics, + ApiCallStatistics::CallType call_type) + : start_time_(rtc::TimeNanos()), + call_type_(call_type), + api_call_statistics_(api_call_statistics) {} + + ~ScopedTimer() { + api_call_statistics_->Add(rtc::TimeNanos() - start_time_, call_type_); + } + + private: + const int64_t start_time_; + const ApiCallStatistics::CallType call_type_; + ApiCallStatistics* const api_call_statistics_; +}; + +} // namespace + +SimulationSettings::SimulationSettings() = default; +SimulationSettings::SimulationSettings(const SimulationSettings&) = default; +SimulationSettings::~SimulationSettings() = default; + +AudioProcessingSimulator::AudioProcessingSimulator( + const SimulationSettings& settings, + rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder) + : settings_(settings), + ap_(std::move(audio_processing)), + analog_mic_level_(settings.initial_mic_level), + fake_recording_device_( + settings.initial_mic_level, + settings_.simulate_mic_gain ? *settings.simulated_mic_kind : 0), + worker_queue_("file_writer_task_queue") { + RTC_CHECK(!settings_.dump_internal_data || WEBRTC_APM_DEBUG_DUMP == 1); + if (settings_.dump_start_frame || settings_.dump_end_frame) { + ApmDataDumper::SetActivated(!settings_.dump_start_frame); + } else { + ApmDataDumper::SetActivated(settings_.dump_internal_data); + } + + if (settings_.dump_set_to_use) { + ApmDataDumper::SetDumpSetToUse(*settings_.dump_set_to_use); + } + + if (settings_.dump_internal_data_output_dir.has_value()) { + ApmDataDumper::SetOutputDirectory( + settings_.dump_internal_data_output_dir.value()); + } + + if (settings_.ed_graph_output_filename && + !settings_.ed_graph_output_filename->empty()) { + residual_echo_likelihood_graph_writer_.open( + *settings_.ed_graph_output_filename); + RTC_CHECK(residual_echo_likelihood_graph_writer_.is_open()); + WriteEchoLikelihoodGraphFileHeader(&residual_echo_likelihood_graph_writer_); + } + + if (settings_.simulate_mic_gain) + RTC_LOG(LS_VERBOSE) << "Simulating analog mic gain"; + + // Create the audio processing object. + RTC_CHECK(!(ap_ && ap_builder)) + << "The AudioProcessing and the AudioProcessingBuilder cannot both be " + "specified at the same time."; + + if (ap_) { + RTC_CHECK(!settings_.aec_settings_filename); + RTC_CHECK(!settings_.print_aec_parameter_values); + } else { + // Use specied builder if such is provided, otherwise create a new builder. + std::unique_ptr<AudioProcessingBuilder> builder = + !!ap_builder ? std::move(ap_builder) + : std::make_unique<AudioProcessingBuilder>(); + + // Create and set an EchoCanceller3Factory if needed. + const bool use_aec = settings_.use_aec && *settings_.use_aec; + if (use_aec) { + EchoCanceller3Config cfg; + if (settings_.aec_settings_filename) { + if (settings_.use_verbose_logging) { + std::cout << "Reading AEC Parameters from JSON input." << std::endl; + } + cfg = ReadAec3ConfigFromJsonFile(*settings_.aec_settings_filename); + } + + if (settings_.linear_aec_output_filename) { + cfg.filter.export_linear_aec_output = true; + } + + if (settings_.print_aec_parameter_values) { + if (!settings_.use_quiet_output) { + std::cout << "AEC settings:" << std::endl; + } + std::cout << Aec3ConfigToJsonString(cfg) << std::endl; + } + + auto echo_control_factory = std::make_unique<EchoCanceller3Factory>(cfg); + builder->SetEchoControlFactory(std::move(echo_control_factory)); + } + + if (settings_.use_ed && *settings.use_ed) { + builder->SetEchoDetector(CreateEchoDetector()); + } + + // Create an audio processing object. + ap_ = builder->Create(); + RTC_CHECK(ap_); + } +} + +AudioProcessingSimulator::~AudioProcessingSimulator() { + if (residual_echo_likelihood_graph_writer_.is_open()) { + WriteEchoLikelihoodGraphFileFooter(&residual_echo_likelihood_graph_writer_); + residual_echo_likelihood_graph_writer_.close(); + } +} + +void AudioProcessingSimulator::ProcessStream(bool fixed_interface) { + // Optionally use the fake recording device to simulate analog gain. + if (settings_.simulate_mic_gain) { + if (settings_.aec_dump_input_filename) { + // When the analog gain is simulated and an AEC dump is used as input, set + // the undo level to `aec_dump_mic_level_` to virtually restore the + // unmodified microphone signal level. + fake_recording_device_.SetUndoMicLevel(aec_dump_mic_level_); + } + + if (fixed_interface) { + fake_recording_device_.SimulateAnalogGain(fwd_frame_.data); + } else { + fake_recording_device_.SimulateAnalogGain(in_buf_.get()); + } + + // Notify the current mic level to AGC. + ap_->set_stream_analog_level(fake_recording_device_.MicLevel()); + } else { + // Notify the current mic level to AGC. + ap_->set_stream_analog_level(settings_.aec_dump_input_filename + ? aec_dump_mic_level_ + : analog_mic_level_); + } + + // Post any scheduled runtime settings. + if (settings_.frame_for_sending_capture_output_used_false && + *settings_.frame_for_sending_capture_output_used_false == + static_cast<int>(num_process_stream_calls_)) { + ap_->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting(false)); + } + if (settings_.frame_for_sending_capture_output_used_true && + *settings_.frame_for_sending_capture_output_used_true == + static_cast<int>(num_process_stream_calls_)) { + ap_->PostRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting(true)); + } + + // Process the current audio frame. + if (fixed_interface) { + { + const auto st = ScopedTimer(&api_call_statistics_, + ApiCallStatistics::CallType::kCapture); + RTC_CHECK_EQ( + AudioProcessing::kNoError, + ap_->ProcessStream(fwd_frame_.data.data(), fwd_frame_.config, + fwd_frame_.config, fwd_frame_.data.data())); + } + fwd_frame_.CopyTo(out_buf_.get()); + } else { + const auto st = ScopedTimer(&api_call_statistics_, + ApiCallStatistics::CallType::kCapture); + RTC_CHECK_EQ(AudioProcessing::kNoError, + ap_->ProcessStream(in_buf_->channels(), in_config_, + out_config_, out_buf_->channels())); + } + + // Store the mic level suggested by AGC. + // Note that when the analog gain is simulated and an AEC dump is used as + // input, `analog_mic_level_` will not be used with set_stream_analog_level(). + analog_mic_level_ = ap_->recommended_stream_analog_level(); + if (settings_.simulate_mic_gain) { + fake_recording_device_.SetMicLevel(analog_mic_level_); + } + if (buffer_memory_writer_) { + RTC_CHECK(!buffer_file_writer_); + buffer_memory_writer_->Write(*out_buf_); + } else if (buffer_file_writer_) { + RTC_CHECK(!buffer_memory_writer_); + buffer_file_writer_->Write(*out_buf_); + } + + if (linear_aec_output_file_writer_) { + bool output_available = ap_->GetLinearAecOutput(linear_aec_output_buf_); + RTC_CHECK(output_available); + RTC_CHECK_GT(linear_aec_output_buf_.size(), 0); + RTC_CHECK_EQ(linear_aec_output_buf_[0].size(), 160); + for (size_t k = 0; k < linear_aec_output_buf_[0].size(); ++k) { + for (size_t ch = 0; ch < linear_aec_output_buf_.size(); ++ch) { + RTC_CHECK_EQ(linear_aec_output_buf_[ch].size(), 160); + float sample = FloatToFloatS16(linear_aec_output_buf_[ch][k]); + linear_aec_output_file_writer_->WriteSamples(&sample, 1); + } + } + } + + if (residual_echo_likelihood_graph_writer_.is_open()) { + auto stats = ap_->GetStatistics(); + residual_echo_likelihood_graph_writer_ + << stats.residual_echo_likelihood.value_or(-1.f) << ", "; + } + + ++num_process_stream_calls_; +} + +void AudioProcessingSimulator::ProcessReverseStream(bool fixed_interface) { + if (fixed_interface) { + { + const auto st = ScopedTimer(&api_call_statistics_, + ApiCallStatistics::CallType::kRender); + RTC_CHECK_EQ( + AudioProcessing::kNoError, + ap_->ProcessReverseStream(rev_frame_.data.data(), rev_frame_.config, + rev_frame_.config, rev_frame_.data.data())); + } + rev_frame_.CopyTo(reverse_out_buf_.get()); + } else { + const auto st = ScopedTimer(&api_call_statistics_, + ApiCallStatistics::CallType::kRender); + RTC_CHECK_EQ(AudioProcessing::kNoError, + ap_->ProcessReverseStream( + reverse_in_buf_->channels(), reverse_in_config_, + reverse_out_config_, reverse_out_buf_->channels())); + } + + if (reverse_buffer_file_writer_) { + reverse_buffer_file_writer_->Write(*reverse_out_buf_); + } + + ++num_reverse_process_stream_calls_; +} + +void AudioProcessingSimulator::SetupBuffersConfigsOutputs( + int input_sample_rate_hz, + int output_sample_rate_hz, + int reverse_input_sample_rate_hz, + int reverse_output_sample_rate_hz, + int input_num_channels, + int output_num_channels, + int reverse_input_num_channels, + int reverse_output_num_channels) { + in_config_ = StreamConfig(input_sample_rate_hz, input_num_channels); + in_buf_.reset(new ChannelBuffer<float>( + rtc::CheckedDivExact(input_sample_rate_hz, kChunksPerSecond), + input_num_channels)); + + reverse_in_config_ = + StreamConfig(reverse_input_sample_rate_hz, reverse_input_num_channels); + reverse_in_buf_.reset(new ChannelBuffer<float>( + rtc::CheckedDivExact(reverse_input_sample_rate_hz, kChunksPerSecond), + reverse_input_num_channels)); + + out_config_ = StreamConfig(output_sample_rate_hz, output_num_channels); + out_buf_.reset(new ChannelBuffer<float>( + rtc::CheckedDivExact(output_sample_rate_hz, kChunksPerSecond), + output_num_channels)); + + reverse_out_config_ = + StreamConfig(reverse_output_sample_rate_hz, reverse_output_num_channels); + reverse_out_buf_.reset(new ChannelBuffer<float>( + rtc::CheckedDivExact(reverse_output_sample_rate_hz, kChunksPerSecond), + reverse_output_num_channels)); + + fwd_frame_.SetFormat(input_sample_rate_hz, input_num_channels); + rev_frame_.SetFormat(reverse_input_sample_rate_hz, + reverse_input_num_channels); + + if (settings_.use_verbose_logging) { + rtc::LogMessage::LogToDebug(rtc::LS_VERBOSE); + + std::cout << "Sample rates:" << std::endl; + std::cout << " Forward input: " << input_sample_rate_hz << std::endl; + std::cout << " Forward output: " << output_sample_rate_hz << std::endl; + std::cout << " Reverse input: " << reverse_input_sample_rate_hz + << std::endl; + std::cout << " Reverse output: " << reverse_output_sample_rate_hz + << std::endl; + std::cout << "Number of channels: " << std::endl; + std::cout << " Forward input: " << input_num_channels << std::endl; + std::cout << " Forward output: " << output_num_channels << std::endl; + std::cout << " Reverse input: " << reverse_input_num_channels << std::endl; + std::cout << " Reverse output: " << reverse_output_num_channels + << std::endl; + } + + SetupOutput(); +} + +void AudioProcessingSimulator::SelectivelyToggleDataDumping( + int init_index, + int capture_frames_since_init) const { + if (!(settings_.dump_start_frame || settings_.dump_end_frame)) { + return; + } + + if (settings_.init_to_process && *settings_.init_to_process != init_index) { + return; + } + + if (settings_.dump_start_frame && + *settings_.dump_start_frame == capture_frames_since_init) { + ApmDataDumper::SetActivated(true); + } + + if (settings_.dump_end_frame && + *settings_.dump_end_frame == capture_frames_since_init) { + ApmDataDumper::SetActivated(false); + } +} + +void AudioProcessingSimulator::SetupOutput() { + if (settings_.output_filename) { + std::string filename; + if (settings_.store_intermediate_output) { + filename = GetIndexedOutputWavFilename(*settings_.output_filename, + output_reset_counter_); + } else { + filename = *settings_.output_filename; + } + + std::unique_ptr<WavWriter> out_file( + new WavWriter(filename, out_config_.sample_rate_hz(), + static_cast<size_t>(out_config_.num_channels()), + settings_.wav_output_format)); + buffer_file_writer_.reset(new ChannelBufferWavWriter(std::move(out_file))); + } else if (settings_.aec_dump_input_string.has_value()) { + buffer_memory_writer_ = std::make_unique<ChannelBufferVectorWriter>( + settings_.processed_capture_samples); + } + + if (settings_.linear_aec_output_filename) { + std::string filename; + if (settings_.store_intermediate_output) { + filename = GetIndexedOutputWavFilename( + *settings_.linear_aec_output_filename, output_reset_counter_); + } else { + filename = *settings_.linear_aec_output_filename; + } + + linear_aec_output_file_writer_.reset( + new WavWriter(filename, 16000, out_config_.num_channels(), + settings_.wav_output_format)); + + linear_aec_output_buf_.resize(out_config_.num_channels()); + } + + if (settings_.reverse_output_filename) { + std::string filename; + if (settings_.store_intermediate_output) { + filename = GetIndexedOutputWavFilename(*settings_.reverse_output_filename, + output_reset_counter_); + } else { + filename = *settings_.reverse_output_filename; + } + + std::unique_ptr<WavWriter> reverse_out_file( + new WavWriter(filename, reverse_out_config_.sample_rate_hz(), + static_cast<size_t>(reverse_out_config_.num_channels()), + settings_.wav_output_format)); + reverse_buffer_file_writer_.reset( + new ChannelBufferWavWriter(std::move(reverse_out_file))); + } + + ++output_reset_counter_; +} + +void AudioProcessingSimulator::DetachAecDump() { + if (settings_.aec_dump_output_filename) { + ap_->DetachAecDump(); + } +} + +void AudioProcessingSimulator::ConfigureAudioProcessor() { + AudioProcessing::Config apm_config; + if (settings_.use_ts) { + apm_config.transient_suppression.enabled = *settings_.use_ts != 0; + } + if (settings_.multi_channel_render) { + apm_config.pipeline.multi_channel_render = *settings_.multi_channel_render; + } + + if (settings_.multi_channel_capture) { + apm_config.pipeline.multi_channel_capture = + *settings_.multi_channel_capture; + } + + if (settings_.use_agc2) { + apm_config.gain_controller2.enabled = *settings_.use_agc2; + if (settings_.agc2_fixed_gain_db) { + apm_config.gain_controller2.fixed_digital.gain_db = + *settings_.agc2_fixed_gain_db; + } + if (settings_.agc2_use_adaptive_gain) { + apm_config.gain_controller2.adaptive_digital.enabled = + *settings_.agc2_use_adaptive_gain; + } + } + if (settings_.use_pre_amplifier) { + apm_config.pre_amplifier.enabled = *settings_.use_pre_amplifier; + if (settings_.pre_amplifier_gain_factor) { + apm_config.pre_amplifier.fixed_gain_factor = + *settings_.pre_amplifier_gain_factor; + } + } + + if (settings_.use_analog_mic_gain_emulation) { + if (*settings_.use_analog_mic_gain_emulation) { + apm_config.capture_level_adjustment.enabled = true; + apm_config.capture_level_adjustment.analog_mic_gain_emulation.enabled = + true; + } else { + apm_config.capture_level_adjustment.analog_mic_gain_emulation.enabled = + false; + } + } + if (settings_.analog_mic_gain_emulation_initial_level) { + apm_config.capture_level_adjustment.analog_mic_gain_emulation + .initial_level = *settings_.analog_mic_gain_emulation_initial_level; + } + + if (settings_.use_capture_level_adjustment) { + apm_config.capture_level_adjustment.enabled = + *settings_.use_capture_level_adjustment; + } + if (settings_.pre_gain_factor) { + apm_config.capture_level_adjustment.pre_gain_factor = + *settings_.pre_gain_factor; + } + if (settings_.post_gain_factor) { + apm_config.capture_level_adjustment.post_gain_factor = + *settings_.post_gain_factor; + } + + const bool use_aec = settings_.use_aec && *settings_.use_aec; + const bool use_aecm = settings_.use_aecm && *settings_.use_aecm; + if (use_aec || use_aecm) { + apm_config.echo_canceller.enabled = true; + apm_config.echo_canceller.mobile_mode = use_aecm; + } + apm_config.echo_canceller.export_linear_aec_output = + !!settings_.linear_aec_output_filename; + + if (settings_.use_hpf) { + apm_config.high_pass_filter.enabled = *settings_.use_hpf; + } + + if (settings_.use_agc) { + apm_config.gain_controller1.enabled = *settings_.use_agc; + } + if (settings_.agc_mode) { + apm_config.gain_controller1.mode = + static_cast<webrtc::AudioProcessing::Config::GainController1::Mode>( + *settings_.agc_mode); + } + if (settings_.use_agc_limiter) { + apm_config.gain_controller1.enable_limiter = *settings_.use_agc_limiter; + } + if (settings_.agc_target_level) { + apm_config.gain_controller1.target_level_dbfs = *settings_.agc_target_level; + } + if (settings_.agc_compression_gain) { + apm_config.gain_controller1.compression_gain_db = + *settings_.agc_compression_gain; + } + if (settings_.use_analog_agc) { + apm_config.gain_controller1.analog_gain_controller.enabled = + *settings_.use_analog_agc; + } + if (settings_.analog_agc_use_digital_adaptive_controller) { + apm_config.gain_controller1.analog_gain_controller.enable_digital_adaptive = + *settings_.analog_agc_use_digital_adaptive_controller; + } + + if (settings_.maximum_internal_processing_rate) { + apm_config.pipeline.maximum_internal_processing_rate = + *settings_.maximum_internal_processing_rate; + } + + if (settings_.use_ns) { + apm_config.noise_suppression.enabled = *settings_.use_ns; + } + if (settings_.ns_level) { + const int level = *settings_.ns_level; + RTC_CHECK_GE(level, 0); + RTC_CHECK_LE(level, 3); + apm_config.noise_suppression.level = + static_cast<AudioProcessing::Config::NoiseSuppression::Level>(level); + } + if (settings_.ns_analysis_on_linear_aec_output) { + apm_config.noise_suppression.analyze_linear_aec_output_when_available = + *settings_.ns_analysis_on_linear_aec_output; + } + + ap_->ApplyConfig(apm_config); + + if (settings_.use_ts) { + // Default to key pressed if activating the transient suppressor with + // continuous key events. + ap_->set_stream_key_pressed(*settings_.use_ts == 2); + } + + if (settings_.aec_dump_output_filename) { + ap_->AttachAecDump(AecDumpFactory::Create( + *settings_.aec_dump_output_filename, -1, &worker_queue_)); + } +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/audio_processing_simulator.h b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_simulator.h new file mode 100644 index 0000000000..b63bc12d6f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audio_processing_simulator.h @@ -0,0 +1,247 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_AUDIO_PROCESSING_SIMULATOR_H_ +#define MODULES_AUDIO_PROCESSING_TEST_AUDIO_PROCESSING_SIMULATOR_H_ + +#include <algorithm> +#include <fstream> +#include <limits> +#include <memory> +#include <string> + +#include "absl/types/optional.h" +#include "common_audio/channel_buffer.h" +#include "common_audio/include/audio_util.h" +#include "modules/audio_processing/include/audio_processing.h" +#include "modules/audio_processing/test/api_call_statistics.h" +#include "modules/audio_processing/test/fake_recording_device.h" +#include "modules/audio_processing/test/test_utils.h" +#include "rtc_base/task_queue_for_test.h" +#include "rtc_base/time_utils.h" + +namespace webrtc { +namespace test { + +static const int kChunksPerSecond = 1000 / AudioProcessing::kChunkSizeMs; + +struct Int16Frame { + void SetFormat(int sample_rate_hz, int num_channels) { + this->sample_rate_hz = sample_rate_hz; + samples_per_channel = + rtc::CheckedDivExact(sample_rate_hz, kChunksPerSecond); + this->num_channels = num_channels; + config = StreamConfig(sample_rate_hz, num_channels); + data.resize(num_channels * samples_per_channel); + } + + void CopyTo(ChannelBuffer<float>* dest) { + RTC_DCHECK(dest); + RTC_CHECK_EQ(num_channels, dest->num_channels()); + RTC_CHECK_EQ(samples_per_channel, dest->num_frames()); + // Copy the data from the input buffer. + std::vector<float> tmp(samples_per_channel * num_channels); + S16ToFloat(data.data(), tmp.size(), tmp.data()); + Deinterleave(tmp.data(), samples_per_channel, num_channels, + dest->channels()); + } + + void CopyFrom(const ChannelBuffer<float>& src) { + RTC_CHECK_EQ(src.num_channels(), num_channels); + RTC_CHECK_EQ(src.num_frames(), samples_per_channel); + data.resize(num_channels * samples_per_channel); + int16_t* dest_data = data.data(); + for (int ch = 0; ch < num_channels; ++ch) { + for (int sample = 0; sample < samples_per_channel; ++sample) { + dest_data[sample * num_channels + ch] = + src.channels()[ch][sample] * 32767; + } + } + } + + int sample_rate_hz; + int samples_per_channel; + int num_channels; + + StreamConfig config; + + std::vector<int16_t> data; +}; + +// Holds all the parameters available for controlling the simulation. +struct SimulationSettings { + SimulationSettings(); + SimulationSettings(const SimulationSettings&); + ~SimulationSettings(); + absl::optional<int> stream_delay; + absl::optional<bool> use_stream_delay; + absl::optional<int> output_sample_rate_hz; + absl::optional<int> output_num_channels; + absl::optional<int> reverse_output_sample_rate_hz; + absl::optional<int> reverse_output_num_channels; + absl::optional<std::string> output_filename; + absl::optional<std::string> reverse_output_filename; + absl::optional<std::string> input_filename; + absl::optional<std::string> reverse_input_filename; + absl::optional<std::string> artificial_nearend_filename; + absl::optional<std::string> linear_aec_output_filename; + absl::optional<bool> use_aec; + absl::optional<bool> use_aecm; + absl::optional<bool> use_ed; // Residual Echo Detector. + absl::optional<std::string> ed_graph_output_filename; + absl::optional<bool> use_agc; + absl::optional<bool> use_agc2; + absl::optional<bool> use_pre_amplifier; + absl::optional<bool> use_capture_level_adjustment; + absl::optional<bool> use_analog_mic_gain_emulation; + absl::optional<bool> use_hpf; + absl::optional<bool> use_ns; + absl::optional<int> use_ts; + absl::optional<bool> use_analog_agc; + absl::optional<bool> use_all; + absl::optional<bool> analog_agc_use_digital_adaptive_controller; + absl::optional<int> agc_mode; + absl::optional<int> agc_target_level; + absl::optional<bool> use_agc_limiter; + absl::optional<int> agc_compression_gain; + absl::optional<bool> agc2_use_adaptive_gain; + absl::optional<float> agc2_fixed_gain_db; + absl::optional<float> pre_amplifier_gain_factor; + absl::optional<float> pre_gain_factor; + absl::optional<float> post_gain_factor; + absl::optional<float> analog_mic_gain_emulation_initial_level; + absl::optional<int> ns_level; + absl::optional<bool> ns_analysis_on_linear_aec_output; + absl::optional<bool> override_key_pressed; + absl::optional<int> maximum_internal_processing_rate; + int initial_mic_level; + bool simulate_mic_gain = false; + absl::optional<bool> multi_channel_render; + absl::optional<bool> multi_channel_capture; + absl::optional<int> simulated_mic_kind; + absl::optional<int> frame_for_sending_capture_output_used_false; + absl::optional<int> frame_for_sending_capture_output_used_true; + bool report_performance = false; + absl::optional<std::string> performance_report_output_filename; + bool report_bitexactness = false; + bool use_verbose_logging = false; + bool use_quiet_output = false; + bool discard_all_settings_in_aecdump = true; + absl::optional<std::string> aec_dump_input_filename; + absl::optional<std::string> aec_dump_output_filename; + bool fixed_interface = false; + bool store_intermediate_output = false; + bool print_aec_parameter_values = false; + bool dump_internal_data = false; + WavFile::SampleFormat wav_output_format = WavFile::SampleFormat::kInt16; + absl::optional<std::string> dump_internal_data_output_dir; + absl::optional<int> dump_set_to_use; + absl::optional<std::string> call_order_input_filename; + absl::optional<std::string> call_order_output_filename; + absl::optional<std::string> aec_settings_filename; + absl::optional<absl::string_view> aec_dump_input_string; + std::vector<float>* processed_capture_samples = nullptr; + bool analysis_only = false; + absl::optional<int> dump_start_frame; + absl::optional<int> dump_end_frame; + absl::optional<int> init_to_process; +}; + +// Provides common functionality for performing audioprocessing simulations. +class AudioProcessingSimulator { + public: + AudioProcessingSimulator(const SimulationSettings& settings, + rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder); + + AudioProcessingSimulator() = delete; + AudioProcessingSimulator(const AudioProcessingSimulator&) = delete; + AudioProcessingSimulator& operator=(const AudioProcessingSimulator&) = delete; + + virtual ~AudioProcessingSimulator(); + + // Processes the data in the input. + virtual void Process() = 0; + + // Returns the execution times of all AudioProcessing calls. + const ApiCallStatistics& GetApiCallStatistics() const { + return api_call_statistics_; + } + + // Analyzes the data in the input and reports the resulting statistics. + virtual void Analyze() = 0; + + // Reports whether the processed recording was bitexact. + bool OutputWasBitexact() { return bitexact_output_; } + + size_t get_num_process_stream_calls() { return num_process_stream_calls_; } + size_t get_num_reverse_process_stream_calls() { + return num_reverse_process_stream_calls_; + } + + protected: + void ProcessStream(bool fixed_interface); + void ProcessReverseStream(bool fixed_interface); + void ConfigureAudioProcessor(); + void DetachAecDump(); + void SetupBuffersConfigsOutputs(int input_sample_rate_hz, + int output_sample_rate_hz, + int reverse_input_sample_rate_hz, + int reverse_output_sample_rate_hz, + int input_num_channels, + int output_num_channels, + int reverse_input_num_channels, + int reverse_output_num_channels); + void SelectivelyToggleDataDumping(int init_index, + int capture_frames_since_init) const; + + const SimulationSettings settings_; + rtc::scoped_refptr<AudioProcessing> ap_; + + std::unique_ptr<ChannelBuffer<float>> in_buf_; + std::unique_ptr<ChannelBuffer<float>> out_buf_; + std::unique_ptr<ChannelBuffer<float>> reverse_in_buf_; + std::unique_ptr<ChannelBuffer<float>> reverse_out_buf_; + std::vector<std::array<float, 160>> linear_aec_output_buf_; + StreamConfig in_config_; + StreamConfig out_config_; + StreamConfig reverse_in_config_; + StreamConfig reverse_out_config_; + std::unique_ptr<ChannelBufferWavReader> buffer_reader_; + std::unique_ptr<ChannelBufferWavReader> reverse_buffer_reader_; + Int16Frame rev_frame_; + Int16Frame fwd_frame_; + bool bitexact_output_ = true; + int aec_dump_mic_level_ = 0; + + protected: + size_t output_reset_counter_ = 0; + + private: + void SetupOutput(); + + size_t num_process_stream_calls_ = 0; + size_t num_reverse_process_stream_calls_ = 0; + std::unique_ptr<ChannelBufferWavWriter> buffer_file_writer_; + std::unique_ptr<ChannelBufferWavWriter> reverse_buffer_file_writer_; + std::unique_ptr<ChannelBufferVectorWriter> buffer_memory_writer_; + std::unique_ptr<WavWriter> linear_aec_output_file_writer_; + ApiCallStatistics api_call_statistics_; + std::ofstream residual_echo_likelihood_graph_writer_; + int analog_mic_level_; + FakeRecordingDevice fake_recording_device_; + + TaskQueueForTest worker_queue_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_AUDIO_PROCESSING_SIMULATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/audioproc_float_impl.cc b/third_party/libwebrtc/modules/audio_processing/test/audioproc_float_impl.cc new file mode 100644 index 0000000000..dd9fc70734 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audioproc_float_impl.cc @@ -0,0 +1,815 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/audioproc_float_impl.h" + +#include <string.h> + +#include <iostream> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/string_view.h" +#include "modules/audio_processing/include/audio_processing.h" +#include "modules/audio_processing/test/aec_dump_based_simulator.h" +#include "modules/audio_processing/test/audio_processing_simulator.h" +#include "modules/audio_processing/test/wav_based_simulator.h" +#include "rtc_base/checks.h" +#include "rtc_base/strings/string_builder.h" +#include "system_wrappers/include/field_trial.h" + +constexpr int kParameterNotSpecifiedValue = -10000; + +ABSL_FLAG(std::string, dump_input, "", "Aec dump input filename"); +ABSL_FLAG(std::string, dump_output, "", "Aec dump output filename"); +ABSL_FLAG(std::string, i, "", "Forward stream input wav filename"); +ABSL_FLAG(std::string, o, "", "Forward stream output wav filename"); +ABSL_FLAG(std::string, ri, "", "Reverse stream input wav filename"); +ABSL_FLAG(std::string, ro, "", "Reverse stream output wav filename"); +ABSL_FLAG(std::string, + artificial_nearend, + "", + "Artificial nearend wav filename"); +ABSL_FLAG(std::string, linear_aec_output, "", "Linear AEC output wav filename"); +ABSL_FLAG(int, + output_num_channels, + kParameterNotSpecifiedValue, + "Number of forward stream output channels"); +ABSL_FLAG(int, + reverse_output_num_channels, + kParameterNotSpecifiedValue, + "Number of Reverse stream output channels"); +ABSL_FLAG(int, + output_sample_rate_hz, + kParameterNotSpecifiedValue, + "Forward stream output sample rate in Hz"); +ABSL_FLAG(int, + reverse_output_sample_rate_hz, + kParameterNotSpecifiedValue, + "Reverse stream output sample rate in Hz"); +ABSL_FLAG(bool, + fixed_interface, + false, + "Use the fixed interface when operating on wav files"); +ABSL_FLAG(int, + aec, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the echo canceller"); +ABSL_FLAG(int, + aecm, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the mobile echo controller"); +ABSL_FLAG(int, + ed, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the residual echo detector"); +ABSL_FLAG(std::string, + ed_graph, + "", + "Output filename for graph of echo likelihood"); +ABSL_FLAG(int, + agc, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the AGC"); +ABSL_FLAG(int, + agc2, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the AGC2"); +ABSL_FLAG(int, + pre_amplifier, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate(0) the pre amplifier"); +ABSL_FLAG( + int, + capture_level_adjustment, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate(0) the capture level adjustment functionality"); +ABSL_FLAG(int, + analog_mic_gain_emulation, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate(0) the analog mic gain emulation in the " + "production (non-test) code."); +ABSL_FLAG(int, + hpf, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the high-pass filter"); +ABSL_FLAG(int, + ns, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the noise suppressor"); +ABSL_FLAG(int, + ts, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the transient suppressor"); +ABSL_FLAG(int, + analog_agc, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the analog AGC"); +ABSL_FLAG(bool, + all_default, + false, + "Activate all of the default components (will be overridden by any " + "other settings)"); +ABSL_FLAG(int, + analog_agc_use_digital_adaptive_controller, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) digital adaptation in AGC1. " + "Digital adaptation is active by default."); +ABSL_FLAG(int, + agc_mode, + kParameterNotSpecifiedValue, + "Specify the AGC mode (0-2)"); +ABSL_FLAG(int, + agc_target_level, + kParameterNotSpecifiedValue, + "Specify the AGC target level (0-31)"); +ABSL_FLAG(int, + agc_limiter, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the level estimator"); +ABSL_FLAG(int, + agc_compression_gain, + kParameterNotSpecifiedValue, + "Specify the AGC compression gain (0-90)"); +ABSL_FLAG(int, + agc2_enable_adaptive_gain, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) the AGC2 adaptive gain"); +ABSL_FLAG(float, + agc2_fixed_gain_db, + kParameterNotSpecifiedValue, + "AGC2 fixed gain (dB) to apply"); +ABSL_FLAG(float, + pre_amplifier_gain_factor, + kParameterNotSpecifiedValue, + "Pre-amplifier gain factor (linear) to apply"); +ABSL_FLAG(float, + pre_gain_factor, + kParameterNotSpecifiedValue, + "Pre-gain factor (linear) to apply in the capture level adjustment"); +ABSL_FLAG(float, + post_gain_factor, + kParameterNotSpecifiedValue, + "Post-gain factor (linear) to apply in the capture level adjustment"); +ABSL_FLAG(float, + analog_mic_gain_emulation_initial_level, + kParameterNotSpecifiedValue, + "Emulated analog mic level to apply initially in the production " + "(non-test) code."); +ABSL_FLAG(int, + ns_level, + kParameterNotSpecifiedValue, + "Specify the NS level (0-3)"); +ABSL_FLAG(int, + ns_analysis_on_linear_aec_output, + kParameterNotSpecifiedValue, + "Specifies whether the noise suppression analysis is done on the " + "linear AEC output"); +ABSL_FLAG(int, + maximum_internal_processing_rate, + kParameterNotSpecifiedValue, + "Set a maximum internal processing rate (32000 or 48000) to override " + "the default rate"); +ABSL_FLAG(int, + stream_delay, + kParameterNotSpecifiedValue, + "Specify the stream delay in ms to use"); +ABSL_FLAG(int, + use_stream_delay, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) reporting the stream delay"); +ABSL_FLAG(int, + stream_drift_samples, + kParameterNotSpecifiedValue, + "Specify the number of stream drift samples to use"); +ABSL_FLAG(int, + initial_mic_level, + 100, + "Initial mic level (0-255) for the analog mic gain simulation in the " + "test code"); +ABSL_FLAG(int, + simulate_mic_gain, + 0, + "Activate (1) or deactivate(0) the analog mic gain simulation in the " + "test code"); +ABSL_FLAG(int, + multi_channel_render, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) multi-channel render processing in " + "APM pipeline"); +ABSL_FLAG(int, + multi_channel_capture, + kParameterNotSpecifiedValue, + "Activate (1) or deactivate (0) multi-channel capture processing in " + "APM pipeline"); +ABSL_FLAG(int, + simulated_mic_kind, + kParameterNotSpecifiedValue, + "Specify which microphone kind to use for microphone simulation"); +ABSL_FLAG(int, + override_key_pressed, + kParameterNotSpecifiedValue, + "Always set to true (1) or to false (0) the key press state. If " + "unspecified, false is set with Wav files or, with AEC dumps, the " + "recorded event is used."); +ABSL_FLAG(int, + frame_for_sending_capture_output_used_false, + kParameterNotSpecifiedValue, + "Capture frame index for sending a runtime setting for that the " + "capture output is not used."); +ABSL_FLAG(int, + frame_for_sending_capture_output_used_true, + kParameterNotSpecifiedValue, + "Capture frame index for sending a runtime setting for that the " + "capture output is used."); +ABSL_FLAG(bool, performance_report, false, "Report the APM performance "); +ABSL_FLAG(std::string, + performance_report_output_file, + "", + "Generate a CSV file with the API call durations"); +ABSL_FLAG(bool, verbose, false, "Produce verbose output"); +ABSL_FLAG(bool, + quiet, + false, + "Avoid producing information about the progress."); +ABSL_FLAG(bool, + bitexactness_report, + false, + "Report bitexactness for aec dump result reproduction"); +ABSL_FLAG(bool, + discard_settings_in_aecdump, + false, + "Discard any config settings specified in the aec dump"); +ABSL_FLAG(bool, + store_intermediate_output, + false, + "Creates new output files after each init"); +ABSL_FLAG(std::string, + custom_call_order_file, + "", + "Custom process API call order file"); +ABSL_FLAG(std::string, + output_custom_call_order_file, + "", + "Generate custom process API call order file from AEC dump"); +ABSL_FLAG(bool, + print_aec_parameter_values, + false, + "Print parameter values used in AEC in JSON-format"); +ABSL_FLAG(std::string, + aec_settings, + "", + "File in JSON-format with custom AEC settings"); +ABSL_FLAG(bool, + dump_data, + false, + "Dump internal data during the call (requires build flag)"); +ABSL_FLAG(std::string, + dump_data_output_dir, + "", + "Internal data dump output directory"); +ABSL_FLAG(int, + dump_set_to_use, + kParameterNotSpecifiedValue, + "Specifies the dump set to use (if not all the dump sets will " + "be used"); +ABSL_FLAG(bool, + analyze, + false, + "Only analyze the call setup behavior (no processing)"); +ABSL_FLAG(float, + dump_start_seconds, + kParameterNotSpecifiedValue, + "Start of when to dump data (seconds)."); +ABSL_FLAG(float, + dump_end_seconds, + kParameterNotSpecifiedValue, + "End of when to dump data (seconds)."); +ABSL_FLAG(int, + dump_start_frame, + kParameterNotSpecifiedValue, + "Start of when to dump data (frames)."); +ABSL_FLAG(int, + dump_end_frame, + kParameterNotSpecifiedValue, + "End of when to dump data (frames)."); +ABSL_FLAG(int, + init_to_process, + kParameterNotSpecifiedValue, + "Init index to process."); + +ABSL_FLAG(bool, + float_wav_output, + false, + "Produce floating point wav output files."); + +ABSL_FLAG(std::string, + force_fieldtrials, + "", + "Field trials control experimental feature code which can be forced. " + "E.g. running with --force_fieldtrials=WebRTC-FooFeature/Enable/" + " will assign the group Enable to field trial WebRTC-FooFeature."); + +namespace webrtc { +namespace test { +namespace { + +const char kUsageDescription[] = + "Usage: audioproc_f [options] -i <input.wav>\n" + " or\n" + " audioproc_f [options] -dump_input <aec_dump>\n" + "\n\n" + "Command-line tool to simulate a call using the audio " + "processing module, either based on wav files or " + "protobuf debug dump recordings.\n"; + +void SetSettingIfSpecified(absl::string_view value, + absl::optional<std::string>* parameter) { + if (value.compare("") != 0) { + *parameter = std::string(value); + } +} + +void SetSettingIfSpecified(int value, absl::optional<int>* parameter) { + if (value != kParameterNotSpecifiedValue) { + *parameter = value; + } +} + +void SetSettingIfSpecified(float value, absl::optional<float>* parameter) { + constexpr float kFloatParameterNotSpecifiedValue = + kParameterNotSpecifiedValue; + if (value != kFloatParameterNotSpecifiedValue) { + *parameter = value; + } +} + +void SetSettingIfFlagSet(int32_t flag, absl::optional<bool>* parameter) { + if (flag == 0) { + *parameter = false; + } else if (flag == 1) { + *parameter = true; + } +} + +SimulationSettings CreateSettings() { + SimulationSettings settings; + if (absl::GetFlag(FLAGS_all_default)) { + settings.use_ts = true; + settings.use_analog_agc = true; + settings.use_ns = true; + settings.use_hpf = true; + settings.use_agc = true; + settings.use_agc2 = false; + settings.use_pre_amplifier = false; + settings.use_aec = true; + settings.use_aecm = false; + settings.use_ed = false; + } + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_input), + &settings.aec_dump_input_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_output), + &settings.aec_dump_output_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_i), &settings.input_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_o), &settings.output_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_ri), + &settings.reverse_input_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_ro), + &settings.reverse_output_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_artificial_nearend), + &settings.artificial_nearend_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_linear_aec_output), + &settings.linear_aec_output_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_output_num_channels), + &settings.output_num_channels); + SetSettingIfSpecified(absl::GetFlag(FLAGS_reverse_output_num_channels), + &settings.reverse_output_num_channels); + SetSettingIfSpecified(absl::GetFlag(FLAGS_output_sample_rate_hz), + &settings.output_sample_rate_hz); + SetSettingIfSpecified(absl::GetFlag(FLAGS_reverse_output_sample_rate_hz), + &settings.reverse_output_sample_rate_hz); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_aec), &settings.use_aec); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_aecm), &settings.use_aecm); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_ed), &settings.use_ed); + SetSettingIfSpecified(absl::GetFlag(FLAGS_ed_graph), + &settings.ed_graph_output_filename); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_agc), &settings.use_agc); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_agc2), &settings.use_agc2); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_pre_amplifier), + &settings.use_pre_amplifier); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_capture_level_adjustment), + &settings.use_capture_level_adjustment); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_analog_mic_gain_emulation), + &settings.use_analog_mic_gain_emulation); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_hpf), &settings.use_hpf); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_ns), &settings.use_ns); + SetSettingIfSpecified(absl::GetFlag(FLAGS_ts), &settings.use_ts); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_analog_agc), + &settings.use_analog_agc); + SetSettingIfFlagSet( + absl::GetFlag(FLAGS_analog_agc_use_digital_adaptive_controller), + &settings.analog_agc_use_digital_adaptive_controller); + SetSettingIfSpecified(absl::GetFlag(FLAGS_agc_mode), &settings.agc_mode); + SetSettingIfSpecified(absl::GetFlag(FLAGS_agc_target_level), + &settings.agc_target_level); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_agc_limiter), + &settings.use_agc_limiter); + SetSettingIfSpecified(absl::GetFlag(FLAGS_agc_compression_gain), + &settings.agc_compression_gain); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_agc2_enable_adaptive_gain), + &settings.agc2_use_adaptive_gain); + + SetSettingIfSpecified(absl::GetFlag(FLAGS_agc2_fixed_gain_db), + &settings.agc2_fixed_gain_db); + SetSettingIfSpecified(absl::GetFlag(FLAGS_pre_amplifier_gain_factor), + &settings.pre_amplifier_gain_factor); + SetSettingIfSpecified(absl::GetFlag(FLAGS_pre_gain_factor), + &settings.pre_gain_factor); + SetSettingIfSpecified(absl::GetFlag(FLAGS_post_gain_factor), + &settings.post_gain_factor); + SetSettingIfSpecified( + absl::GetFlag(FLAGS_analog_mic_gain_emulation_initial_level), + &settings.analog_mic_gain_emulation_initial_level); + SetSettingIfSpecified(absl::GetFlag(FLAGS_ns_level), &settings.ns_level); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_ns_analysis_on_linear_aec_output), + &settings.ns_analysis_on_linear_aec_output); + SetSettingIfSpecified(absl::GetFlag(FLAGS_maximum_internal_processing_rate), + &settings.maximum_internal_processing_rate); + SetSettingIfSpecified(absl::GetFlag(FLAGS_stream_delay), + &settings.stream_delay); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_use_stream_delay), + &settings.use_stream_delay); + SetSettingIfSpecified(absl::GetFlag(FLAGS_custom_call_order_file), + &settings.call_order_input_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_output_custom_call_order_file), + &settings.call_order_output_filename); + SetSettingIfSpecified(absl::GetFlag(FLAGS_aec_settings), + &settings.aec_settings_filename); + settings.initial_mic_level = absl::GetFlag(FLAGS_initial_mic_level); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_multi_channel_render), + &settings.multi_channel_render); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_multi_channel_capture), + &settings.multi_channel_capture); + settings.simulate_mic_gain = absl::GetFlag(FLAGS_simulate_mic_gain); + SetSettingIfSpecified(absl::GetFlag(FLAGS_simulated_mic_kind), + &settings.simulated_mic_kind); + SetSettingIfFlagSet(absl::GetFlag(FLAGS_override_key_pressed), + &settings.override_key_pressed); + SetSettingIfSpecified( + absl::GetFlag(FLAGS_frame_for_sending_capture_output_used_false), + &settings.frame_for_sending_capture_output_used_false); + SetSettingIfSpecified( + absl::GetFlag(FLAGS_frame_for_sending_capture_output_used_true), + &settings.frame_for_sending_capture_output_used_true); + settings.report_performance = absl::GetFlag(FLAGS_performance_report); + SetSettingIfSpecified(absl::GetFlag(FLAGS_performance_report_output_file), + &settings.performance_report_output_filename); + settings.use_verbose_logging = absl::GetFlag(FLAGS_verbose); + settings.use_quiet_output = absl::GetFlag(FLAGS_quiet); + settings.report_bitexactness = absl::GetFlag(FLAGS_bitexactness_report); + settings.discard_all_settings_in_aecdump = + absl::GetFlag(FLAGS_discard_settings_in_aecdump); + settings.fixed_interface = absl::GetFlag(FLAGS_fixed_interface); + settings.store_intermediate_output = + absl::GetFlag(FLAGS_store_intermediate_output); + settings.print_aec_parameter_values = + absl::GetFlag(FLAGS_print_aec_parameter_values); + settings.dump_internal_data = absl::GetFlag(FLAGS_dump_data); + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_data_output_dir), + &settings.dump_internal_data_output_dir); + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_set_to_use), + &settings.dump_set_to_use); + settings.wav_output_format = absl::GetFlag(FLAGS_float_wav_output) + ? WavFile::SampleFormat::kFloat + : WavFile::SampleFormat::kInt16; + + settings.analysis_only = absl::GetFlag(FLAGS_analyze); + + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_start_frame), + &settings.dump_start_frame); + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_end_frame), + &settings.dump_end_frame); + + constexpr int kFramesPerSecond = 100; + absl::optional<float> start_seconds; + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_start_seconds), + &start_seconds); + if (start_seconds) { + settings.dump_start_frame = *start_seconds * kFramesPerSecond; + } + + absl::optional<float> end_seconds; + SetSettingIfSpecified(absl::GetFlag(FLAGS_dump_end_seconds), &end_seconds); + if (end_seconds) { + settings.dump_end_frame = *end_seconds * kFramesPerSecond; + } + + SetSettingIfSpecified(absl::GetFlag(FLAGS_init_to_process), + &settings.init_to_process); + + return settings; +} + +void ReportConditionalErrorAndExit(bool condition, absl::string_view message) { + if (condition) { + std::cerr << message << std::endl; + exit(1); + } +} + +void PerformBasicParameterSanityChecks( + const SimulationSettings& settings, + bool pre_constructed_ap_provided, + bool pre_constructed_ap_builder_provided) { + if (settings.input_filename || settings.reverse_input_filename) { + ReportConditionalErrorAndExit( + !!settings.aec_dump_input_filename, + "Error: The aec dump file cannot be specified " + "together with input wav files!\n"); + + ReportConditionalErrorAndExit( + !!settings.aec_dump_input_string, + "Error: The aec dump input string cannot be specified " + "together with input wav files!\n"); + + ReportConditionalErrorAndExit(!!settings.artificial_nearend_filename, + "Error: The artificial nearend cannot be " + "specified together with input wav files!\n"); + + ReportConditionalErrorAndExit(!settings.input_filename, + "Error: When operating at wav files, the " + "input wav filename must be " + "specified!\n"); + + ReportConditionalErrorAndExit( + settings.reverse_output_filename && !settings.reverse_input_filename, + "Error: When operating at wav files, the reverse input wav filename " + "must be specified if the reverse output wav filename is specified!\n"); + } else { + ReportConditionalErrorAndExit( + !settings.aec_dump_input_filename && !settings.aec_dump_input_string, + "Error: Either the aec dump input file, the wav " + "input file or the aec dump input string must be specified!\n"); + ReportConditionalErrorAndExit( + settings.aec_dump_input_filename && settings.aec_dump_input_string, + "Error: The aec dump input file cannot be specified together with the " + "aec dump input string!\n"); + } + + ReportConditionalErrorAndExit(settings.use_aec && !(*settings.use_aec) && + settings.linear_aec_output_filename, + "Error: The linear AEC ouput filename cannot " + "be specified without the AEC being active"); + + ReportConditionalErrorAndExit( + settings.use_aec && *settings.use_aec && settings.use_aecm && + *settings.use_aecm, + "Error: The AEC and the AECM cannot be activated at the same time!\n"); + + ReportConditionalErrorAndExit( + settings.output_sample_rate_hz && *settings.output_sample_rate_hz <= 0, + "Error: --output_sample_rate_hz must be positive!\n"); + + ReportConditionalErrorAndExit( + settings.reverse_output_sample_rate_hz && + settings.output_sample_rate_hz && + *settings.output_sample_rate_hz <= 0, + "Error: --reverse_output_sample_rate_hz must be positive!\n"); + + ReportConditionalErrorAndExit( + settings.output_num_channels && *settings.output_num_channels <= 0, + "Error: --output_num_channels must be positive!\n"); + + ReportConditionalErrorAndExit( + settings.reverse_output_num_channels && + *settings.reverse_output_num_channels <= 0, + "Error: --reverse_output_num_channels must be positive!\n"); + + ReportConditionalErrorAndExit( + settings.agc_target_level && ((*settings.agc_target_level) < 0 || + (*settings.agc_target_level) > 31), + "Error: --agc_target_level must be specified between 0 and 31.\n"); + + ReportConditionalErrorAndExit( + settings.agc_compression_gain && ((*settings.agc_compression_gain) < 0 || + (*settings.agc_compression_gain) > 90), + "Error: --agc_compression_gain must be specified between 0 and 90.\n"); + + ReportConditionalErrorAndExit( + settings.agc2_fixed_gain_db && ((*settings.agc2_fixed_gain_db) < 0 || + (*settings.agc2_fixed_gain_db) > 90), + "Error: --agc2_fixed_gain_db must be specified between 0 and 90.\n"); + + ReportConditionalErrorAndExit( + settings.ns_level && + ((*settings.ns_level) < 0 || (*settings.ns_level) > 3), + "Error: --ns_level must be specified between 0 and 3.\n"); + + ReportConditionalErrorAndExit( + settings.report_bitexactness && !settings.aec_dump_input_filename, + "Error: --bitexactness_report can only be used when operating on an " + "aecdump\n"); + + ReportConditionalErrorAndExit( + settings.call_order_input_filename && settings.aec_dump_input_filename, + "Error: --custom_call_order_file cannot be used when operating on an " + "aecdump\n"); + + ReportConditionalErrorAndExit( + (settings.initial_mic_level < 0 || settings.initial_mic_level > 255), + "Error: --initial_mic_level must be specified between 0 and 255.\n"); + + ReportConditionalErrorAndExit( + settings.simulated_mic_kind && !settings.simulate_mic_gain, + "Error: --simulated_mic_kind cannot be specified when mic simulation is " + "disabled\n"); + + ReportConditionalErrorAndExit( + !settings.simulated_mic_kind && settings.simulate_mic_gain, + "Error: --simulated_mic_kind must be specified when mic simulation is " + "enabled\n"); + + auto valid_wav_name = [](absl::string_view wav_file_name) { + if (wav_file_name.size() < 5) { + return false; + } + if ((wav_file_name.compare(wav_file_name.size() - 4, 4, ".wav") == 0) || + (wav_file_name.compare(wav_file_name.size() - 4, 4, ".WAV") == 0)) { + return true; + } + return false; + }; + + ReportConditionalErrorAndExit( + settings.input_filename && (!valid_wav_name(*settings.input_filename)), + "Error: --i must be a valid .wav file name.\n"); + + ReportConditionalErrorAndExit( + settings.output_filename && (!valid_wav_name(*settings.output_filename)), + "Error: --o must be a valid .wav file name.\n"); + + ReportConditionalErrorAndExit( + settings.reverse_input_filename && + (!valid_wav_name(*settings.reverse_input_filename)), + "Error: --ri must be a valid .wav file name.\n"); + + ReportConditionalErrorAndExit( + settings.reverse_output_filename && + (!valid_wav_name(*settings.reverse_output_filename)), + "Error: --ro must be a valid .wav file name.\n"); + + ReportConditionalErrorAndExit( + settings.artificial_nearend_filename && + !valid_wav_name(*settings.artificial_nearend_filename), + "Error: --artifical_nearend must be a valid .wav file name.\n"); + + ReportConditionalErrorAndExit( + settings.linear_aec_output_filename && + (!valid_wav_name(*settings.linear_aec_output_filename)), + "Error: --linear_aec_output must be a valid .wav file name.\n"); + + ReportConditionalErrorAndExit( + WEBRTC_APM_DEBUG_DUMP == 0 && settings.dump_internal_data, + "Error: --dump_data cannot be set without proper build support.\n"); + + ReportConditionalErrorAndExit(settings.init_to_process && + *settings.init_to_process != 1 && + !settings.aec_dump_input_filename, + "Error: --init_to_process must be set to 1 for " + "wav-file based simulations.\n"); + + ReportConditionalErrorAndExit( + !settings.init_to_process && + (settings.dump_start_frame || settings.dump_end_frame), + "Error: --init_to_process must be set when specifying a start and/or end " + "frame for when to dump internal data.\n"); + + ReportConditionalErrorAndExit( + !settings.dump_internal_data && + settings.dump_internal_data_output_dir.has_value(), + "Error: --dump_data_output_dir cannot be set without --dump_data.\n"); + + ReportConditionalErrorAndExit( + !settings.aec_dump_input_filename && + settings.call_order_output_filename.has_value(), + "Error: --output_custom_call_order_file needs an AEC dump input file.\n"); + + ReportConditionalErrorAndExit( + (!settings.use_pre_amplifier || !(*settings.use_pre_amplifier)) && + settings.pre_amplifier_gain_factor.has_value(), + "Error: --pre_amplifier_gain_factor needs --pre_amplifier to be " + "specified and set.\n"); + + ReportConditionalErrorAndExit( + pre_constructed_ap_provided && pre_constructed_ap_builder_provided, + "Error: The AudioProcessing and the AudioProcessingBuilder cannot both " + "be specified at the same time.\n"); + + ReportConditionalErrorAndExit( + settings.aec_settings_filename && pre_constructed_ap_provided, + "Error: The aec_settings_filename cannot be specified when a " + "pre-constructed audio processing object is provided.\n"); + + ReportConditionalErrorAndExit( + settings.aec_settings_filename && pre_constructed_ap_provided, + "Error: The print_aec_parameter_values cannot be set when a " + "pre-constructed audio processing object is provided.\n"); + + if (settings.linear_aec_output_filename && pre_constructed_ap_provided) { + std::cout << "Warning: For the linear AEC output to be stored, this must " + "be configured in the AEC that is part of the provided " + "AudioProcessing object." + << std::endl; + } +} + +int RunSimulation(rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder, + int argc, + char* argv[], + absl::string_view input_aecdump, + std::vector<float>* processed_capture_samples) { + std::vector<char*> args = absl::ParseCommandLine(argc, argv); + if (args.size() != 1) { + printf("%s", kUsageDescription); + return 1; + } + // InitFieldTrialsFromString stores the char*, so the char array must + // outlive the application. + const std::string field_trials = absl::GetFlag(FLAGS_force_fieldtrials); + webrtc::field_trial::InitFieldTrialsFromString(field_trials.c_str()); + + SimulationSettings settings = CreateSettings(); + if (!input_aecdump.empty()) { + settings.aec_dump_input_string = input_aecdump; + settings.processed_capture_samples = processed_capture_samples; + RTC_CHECK(settings.processed_capture_samples); + } + PerformBasicParameterSanityChecks(settings, !!audio_processing, !!ap_builder); + std::unique_ptr<AudioProcessingSimulator> processor; + + if (settings.aec_dump_input_filename || settings.aec_dump_input_string) { + processor.reset(new AecDumpBasedSimulator( + settings, std::move(audio_processing), std::move(ap_builder))); + } else { + processor.reset(new WavBasedSimulator(settings, std::move(audio_processing), + std::move(ap_builder))); + } + + if (settings.analysis_only) { + processor->Analyze(); + } else { + processor->Process(); + } + + if (settings.report_performance) { + processor->GetApiCallStatistics().PrintReport(); + } + if (settings.performance_report_output_filename) { + processor->GetApiCallStatistics().WriteReportToFile( + *settings.performance_report_output_filename); + } + + if (settings.report_bitexactness && settings.aec_dump_input_filename) { + if (processor->OutputWasBitexact()) { + std::cout << "The processing was bitexact."; + } else { + std::cout << "The processing was not bitexact."; + } + } + + return 0; +} + +} // namespace + +int AudioprocFloatImpl(rtc::scoped_refptr<AudioProcessing> audio_processing, + int argc, + char* argv[]) { + return RunSimulation( + std::move(audio_processing), /*ap_builder=*/nullptr, argc, argv, + /*input_aecdump=*/"", /*processed_capture_samples=*/nullptr); +} + +int AudioprocFloatImpl(std::unique_ptr<AudioProcessingBuilder> ap_builder, + int argc, + char* argv[], + absl::string_view input_aecdump, + std::vector<float>* processed_capture_samples) { + return RunSimulation(/*audio_processing=*/nullptr, std::move(ap_builder), + argc, argv, input_aecdump, processed_capture_samples); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/audioproc_float_impl.h b/third_party/libwebrtc/modules/audio_processing/test/audioproc_float_impl.h new file mode 100644 index 0000000000..5ed3aefab7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/audioproc_float_impl.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_AUDIOPROC_FLOAT_IMPL_H_ +#define MODULES_AUDIO_PROCESSING_TEST_AUDIOPROC_FLOAT_IMPL_H_ + +#include <memory> + +#include "modules/audio_processing/include/audio_processing.h" + +namespace webrtc { +namespace test { + +// This function implements the audio processing simulation utility. Pass +// `input_aecdump` to provide the content of an AEC dump file as a string; if +// `input_aecdump` is not passed, a WAV or AEC input dump file must be specified +// via the `argv` argument. Pass `processed_capture_samples` to write in it the +// samples processed on the capture side; if `processed_capture_samples` is not +// passed, the output file can optionally be specified via the `argv` argument. +// Any audio_processing object specified in the input is used for the +// simulation. Note that when the audio_processing object is specified all +// functionality that relies on using the internal builder is deactivated, +// since the AudioProcessing object is already created and the builder is not +// used in the simulation. +int AudioprocFloatImpl(rtc::scoped_refptr<AudioProcessing> audio_processing, + int argc, + char* argv[]); + +// This function implements the audio processing simulation utility. Pass +// `input_aecdump` to provide the content of an AEC dump file as a string; if +// `input_aecdump` is not passed, a WAV or AEC input dump file must be specified +// via the `argv` argument. Pass `processed_capture_samples` to write in it the +// samples processed on the capture side; if `processed_capture_samples` is not +// passed, the output file can optionally be specified via the `argv` argument. +int AudioprocFloatImpl(std::unique_ptr<AudioProcessingBuilder> ap_builder, + int argc, + char* argv[], + absl::string_view input_aecdump, + std::vector<float>* processed_capture_samples); + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_AUDIOPROC_FLOAT_IMPL_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/bitexactness_tools.cc b/third_party/libwebrtc/modules/audio_processing/test/bitexactness_tools.cc new file mode 100644 index 0000000000..0464345364 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/bitexactness_tools.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/bitexactness_tools.h" + +#include <math.h> + +#include <algorithm> +#include <string> +#include <vector> + +#include "api/array_view.h" +#include "test/testsupport/file_utils.h" + +namespace webrtc { +namespace test { + +std::string GetApmRenderTestVectorFileName(int sample_rate_hz) { + switch (sample_rate_hz) { + case 8000: + return ResourcePath("far8_stereo", "pcm"); + case 16000: + return ResourcePath("far16_stereo", "pcm"); + case 32000: + return ResourcePath("far32_stereo", "pcm"); + case 48000: + return ResourcePath("far48_stereo", "pcm"); + default: + RTC_DCHECK_NOTREACHED(); + } + return ""; +} + +std::string GetApmCaptureTestVectorFileName(int sample_rate_hz) { + switch (sample_rate_hz) { + case 8000: + return ResourcePath("near8_stereo", "pcm"); + case 16000: + return ResourcePath("near16_stereo", "pcm"); + case 32000: + return ResourcePath("near32_stereo", "pcm"); + case 48000: + return ResourcePath("near48_stereo", "pcm"); + default: + RTC_DCHECK_NOTREACHED(); + } + return ""; +} + +void ReadFloatSamplesFromStereoFile(size_t samples_per_channel, + size_t num_channels, + InputAudioFile* stereo_pcm_file, + rtc::ArrayView<float> data) { + RTC_DCHECK_LE(num_channels, 2); + RTC_DCHECK_EQ(data.size(), samples_per_channel * num_channels); + std::vector<int16_t> read_samples(samples_per_channel * 2); + stereo_pcm_file->Read(samples_per_channel * 2, read_samples.data()); + + // Convert samples to float and discard any channels not needed. + for (size_t sample = 0; sample < samples_per_channel; ++sample) { + for (size_t channel = 0; channel < num_channels; ++channel) { + data[sample * num_channels + channel] = + read_samples[sample * 2 + channel] / 32768.0f; + } + } +} + +::testing::AssertionResult VerifyDeinterleavedArray( + size_t samples_per_channel, + size_t num_channels, + rtc::ArrayView<const float> reference, + rtc::ArrayView<const float> output, + float element_error_bound) { + // Form vectors to compare the reference to. Only the first values of the + // outputs are compared in order not having to specify all preceeding frames + // as testvectors. + const size_t reference_frame_length = + rtc::CheckedDivExact(reference.size(), num_channels); + + std::vector<float> output_to_verify; + for (size_t channel_no = 0; channel_no < num_channels; ++channel_no) { + output_to_verify.insert(output_to_verify.end(), + output.begin() + channel_no * samples_per_channel, + output.begin() + channel_no * samples_per_channel + + reference_frame_length); + } + + return VerifyArray(reference, output_to_verify, element_error_bound); +} + +::testing::AssertionResult VerifyArray(rtc::ArrayView<const float> reference, + rtc::ArrayView<const float> output, + float element_error_bound) { + // The vectors are deemed to be bitexact only if + // a) output have a size at least as long as the reference. + // b) the samples in the reference are bitexact with the corresponding samples + // in the output. + + bool equal = true; + if (output.size() < reference.size()) { + equal = false; + } else { + // Compare the first samples in the vectors. + for (size_t k = 0; k < reference.size(); ++k) { + if (fabs(output[k] - reference[k]) > element_error_bound) { + equal = false; + break; + } + } + } + + if (equal) { + return ::testing::AssertionSuccess(); + } + + // Lambda function that produces a formatted string with the data in the + // vector. + auto print_vector_in_c_format = [](rtc::ArrayView<const float> v, + size_t num_values_to_print) { + std::string s = "{ "; + for (size_t k = 0; k < std::min(num_values_to_print, v.size()); ++k) { + s += std::to_string(v[k]) + "f"; + s += (k < (num_values_to_print - 1)) ? ", " : ""; + } + return s + " }"; + }; + + // If the vectors are deemed not to be similar, return a report of the + // difference. + return ::testing::AssertionFailure() + << std::endl + << " Actual values : " + << print_vector_in_c_format(output, + std::min(output.size(), reference.size())) + << std::endl + << " Expected values: " + << print_vector_in_c_format(reference, reference.size()) << std::endl; +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/bitexactness_tools.h b/third_party/libwebrtc/modules/audio_processing/test/bitexactness_tools.h new file mode 100644 index 0000000000..2d3113276d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/bitexactness_tools.h @@ -0,0 +1,56 @@ + +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_BITEXACTNESS_TOOLS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_BITEXACTNESS_TOOLS_H_ + +#include <string> + +#include "api/array_view.h" +#include "modules/audio_coding/neteq/tools/input_audio_file.h" +#include "test/gtest.h" + +namespace webrtc { +namespace test { + +// Returns test vector to use for the render signal in an +// APM bitexactness test. +std::string GetApmRenderTestVectorFileName(int sample_rate_hz); + +// Returns test vector to use for the capture signal in an +// APM bitexactness test. +std::string GetApmCaptureTestVectorFileName(int sample_rate_hz); + +// Extract float samples of up to two channels from a pcm file. +void ReadFloatSamplesFromStereoFile(size_t samples_per_channel, + size_t num_channels, + InputAudioFile* stereo_pcm_file, + rtc::ArrayView<float> data); + +// Verifies a frame against a reference and returns the results as an +// AssertionResult. +::testing::AssertionResult VerifyDeinterleavedArray( + size_t samples_per_channel, + size_t num_channels, + rtc::ArrayView<const float> reference, + rtc::ArrayView<const float> output, + float element_error_bound); + +// Verifies a vector against a reference and returns the results as an +// AssertionResult. +::testing::AssertionResult VerifyArray(rtc::ArrayView<const float> reference, + rtc::ArrayView<const float> output, + float element_error_bound); + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_BITEXACTNESS_TOOLS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/BUILD.gn b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/BUILD.gn new file mode 100644 index 0000000000..2c3678092e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/BUILD.gn @@ -0,0 +1,81 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../../webrtc.gni") + +if (!build_with_chromium) { + group("conversational_speech") { + testonly = true + deps = [ ":conversational_speech_generator" ] + } + + rtc_executable("conversational_speech_generator") { + testonly = true + sources = [ "generator.cc" ] + deps = [ + ":lib", + "../../../../test:fileutils", + "../../../../test:test_support", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } +} + +rtc_library("lib") { + testonly = true + sources = [ + "config.cc", + "config.h", + "multiend_call.cc", + "multiend_call.h", + "simulator.cc", + "simulator.h", + "timing.cc", + "timing.h", + "wavreader_abstract_factory.h", + "wavreader_factory.cc", + "wavreader_factory.h", + "wavreader_interface.h", + ] + deps = [ + "../../../../api:array_view", + "../../../../common_audio", + "../../../../rtc_base:checks", + "../../../../rtc_base:logging", + "../../../../rtc_base:safe_conversions", + "../../../../rtc_base:stringutils", + "../../../../test:fileutils", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. +} + +rtc_library("unittest") { + testonly = true + sources = [ + "generator_unittest.cc", + "mock_wavreader.cc", + "mock_wavreader.h", + "mock_wavreader_factory.cc", + "mock_wavreader_factory.h", + ] + deps = [ + ":lib", + "../../../../api:array_view", + "../../../../common_audio", + "../../../../rtc_base:logging", + "../../../../test:fileutils", + "../../../../test:test_support", + "//testing/gtest", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/OWNERS b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/OWNERS new file mode 100644 index 0000000000..07cff405e6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/OWNERS @@ -0,0 +1,3 @@ +alessiob@webrtc.org +henrik.lundin@webrtc.org +peah@webrtc.org diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/README.md b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/README.md new file mode 100644 index 0000000000..0fa66669e6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/README.md @@ -0,0 +1,74 @@ +# Conversational Speech generator tool + +Tool to generate multiple-end audio tracks to simulate conversational speech +with two or more participants. + +The input to the tool is a directory containing a number of audio tracks and +a text file indicating how to time the sequence of speech turns (see the Example +section). + +Since the timing of the speaking turns is specified by the user, the generated +tracks may not be suitable for testing scenarios in which there is unpredictable +network delay (e.g., end-to-end RTC assessment). + +Instead, the generated pairs can be used when the delay is constant (obviously +including the case in which there is no delay). +For instance, echo cancellation in the APM module can be evaluated using two-end +audio tracks as input and reverse input. + +By indicating negative and positive time offsets, one can reproduce cross-talk +(aka double-talk) and silence in the conversation. + +### Example + +For each end, there is a set of audio tracks, e.g., a1, a2 and a3 (speaker A) +and b1, b2 (speaker B). +The text file with the timing information may look like this: + +``` +A a1 0 +B b1 0 +A a2 100 +B b2 -200 +A a3 0 +A a4 0 +``` + +The first column indicates the speaker name, the second contains the audio track +file names, and the third the offsets (in milliseconds) used to concatenate the +chunks. An optional fourth column contains positive or negative integral gains +in dB that will be applied to the tracks. It's possible to specify the gain for +some turns but not for others. If the gain is left out, no gain is applied. + +Assume that all the audio tracks in the example above are 1000 ms long. +The tool will then generate two tracks (A and B) that look like this: + +**Track A** +``` + a1 (1000 ms) + silence (1100 ms) + a2 (1000 ms) + silence (800 ms) + a3 (1000 ms) + a4 (1000 ms) +``` + +**Track B** +``` + silence (1000 ms) + b1 (1000 ms) + silence (900 ms) + b2 (1000 ms) + silence (2000 ms) +``` + +The two tracks can be also visualized as follows (one characheter represents +100 ms, "." is silence and "*" is speech). + +``` +t: 0 1 2 3 4 5 6 (s) +A: **********...........**********........******************** +B: ..........**********.........**********.................... + ^ 200 ms cross-talk + 100 ms silence ^ +``` diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/config.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/config.cc new file mode 100644 index 0000000000..76d3de8108 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/config.cc @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/config.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +const std::string& Config::audiotracks_path() const { + return audiotracks_path_; +} + +const std::string& Config::timing_filepath() const { + return timing_filepath_; +} + +const std::string& Config::output_path() const { + return output_path_; +} + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/config.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/config.h new file mode 100644 index 0000000000..5a847e06a2 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/config.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_CONFIG_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_CONFIG_H_ + +#include <string> + +#include "absl/strings/string_view.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +struct Config { + Config(absl::string_view audiotracks_path, + absl::string_view timing_filepath, + absl::string_view output_path) + : audiotracks_path_(audiotracks_path), + timing_filepath_(timing_filepath), + output_path_(output_path) {} + + const std::string& audiotracks_path() const; + const std::string& timing_filepath() const; + const std::string& output_path() const; + + const std::string audiotracks_path_; + const std::string timing_filepath_; + const std::string output_path_; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_CONFIG_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/generator.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/generator.cc new file mode 100644 index 0000000000..d0bc2f2319 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/generator.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <iostream> +#include <vector> + +#include <memory> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "modules/audio_processing/test/conversational_speech/config.h" +#include "modules/audio_processing/test/conversational_speech/multiend_call.h" +#include "modules/audio_processing/test/conversational_speech/simulator.h" +#include "modules/audio_processing/test/conversational_speech/timing.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_factory.h" +#include "test/testsupport/file_utils.h" + +ABSL_FLAG(std::string, i, "", "Directory containing the speech turn wav files"); +ABSL_FLAG(std::string, t, "", "Path to the timing text file"); +ABSL_FLAG(std::string, o, "", "Output wav files destination path"); + +namespace webrtc { +namespace test { +namespace { + +const char kUsageDescription[] = + "Usage: conversational_speech_generator\n" + " -i <path/to/source/audiotracks>\n" + " -t <path/to/timing_file.txt>\n" + " -o <output/path>\n" + "\n\n" + "Command-line tool to generate multiple-end audio tracks to simulate " + "conversational speech with two or more participants.\n"; + +} // namespace + +int main(int argc, char* argv[]) { + std::vector<char*> args = absl::ParseCommandLine(argc, argv); + if (args.size() != 1) { + printf("%s", kUsageDescription); + return 1; + } + RTC_CHECK(DirExists(absl::GetFlag(FLAGS_i))); + RTC_CHECK(FileExists(absl::GetFlag(FLAGS_t))); + RTC_CHECK(DirExists(absl::GetFlag(FLAGS_o))); + + conversational_speech::Config config( + absl::GetFlag(FLAGS_i), absl::GetFlag(FLAGS_t), absl::GetFlag(FLAGS_o)); + + // Load timing. + std::vector<conversational_speech::Turn> timing = + conversational_speech::LoadTiming(config.timing_filepath()); + + // Parse timing and audio tracks. + auto wavreader_factory = + std::make_unique<conversational_speech::WavReaderFactory>(); + conversational_speech::MultiEndCall multiend_call( + timing, config.audiotracks_path(), std::move(wavreader_factory)); + + // Generate output audio tracks. + auto generated_audiotrack_pairs = + conversational_speech::Simulate(multiend_call, config.output_path()); + + // Show paths to created audio tracks. + std::cout << "Output files:" << std::endl; + for (const auto& output_paths_entry : *generated_audiotrack_pairs) { + std::cout << " speaker: " << output_paths_entry.first << std::endl; + std::cout << " near end: " << output_paths_entry.second.near_end + << std::endl; + std::cout << " far end: " << output_paths_entry.second.far_end + << std::endl; + } + + return 0; +} + +} // namespace test +} // namespace webrtc + +int main(int argc, char* argv[]) { + return webrtc::test::main(argc, argv); +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/generator_unittest.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/generator_unittest.cc new file mode 100644 index 0000000000..17714440d4 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/generator_unittest.cc @@ -0,0 +1,675 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// This file consists of unit tests for webrtc::test::conversational_speech +// members. Part of them focus on accepting or rejecting different +// conversational speech setups. A setup is defined by a set of audio tracks and +// timing information). +// The docstring at the beginning of each TEST(ConversationalSpeechTest, +// MultiEndCallSetup*) function looks like the drawing below and indicates which +// setup is tested. +// +// Accept: +// A 0****..... +// B .....1**** +// +// The drawing indicates the following: +// - the illustrated setup should be accepted, +// - there are two speakers (namely, A and B), +// - A is the first speaking, B is the second one, +// - each character after the speaker's letter indicates a time unit (e.g., 100 +// ms), +// - "*" indicates speaking, "." listening, +// - numbers indicate the turn index in std::vector<Turn>. +// +// Note that the same speaker can appear in multiple lines in order to depict +// cases in which there are wrong offsets leading to self cross-talk (which is +// rejected). + +// MSVC++ requires this to be set before any other includes to get M_PI. +#define _USE_MATH_DEFINES + +#include <stdio.h> + +#include <cmath> +#include <map> +#include <memory> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common_audio/wav_file.h" +#include "modules/audio_processing/test/conversational_speech/config.h" +#include "modules/audio_processing/test/conversational_speech/mock_wavreader_factory.h" +#include "modules/audio_processing/test/conversational_speech/multiend_call.h" +#include "modules/audio_processing/test/conversational_speech/simulator.h" +#include "modules/audio_processing/test/conversational_speech/timing.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_factory.h" +#include "rtc_base/logging.h" +#include "test/gmock.h" +#include "test/gtest.h" +#include "test/testsupport/file_utils.h" + +namespace webrtc { +namespace test { +namespace { + +using conversational_speech::LoadTiming; +using conversational_speech::MockWavReaderFactory; +using conversational_speech::MultiEndCall; +using conversational_speech::SaveTiming; +using conversational_speech::Turn; +using conversational_speech::WavReaderFactory; + +const char* const audiotracks_path = "/path/to/audiotracks"; +const char* const timing_filepath = "/path/to/timing_file.txt"; +const char* const output_path = "/path/to/output_dir"; + +const std::vector<Turn> expected_timing = { + {"A", "a1", 0, 0}, {"B", "b1", 0, 0}, {"A", "a2", 100, 0}, + {"B", "b2", -200, 0}, {"A", "a3", 0, 0}, {"A", "a3", 0, 0}, +}; +const std::size_t kNumberOfTurns = expected_timing.size(); + +// Default arguments for MockWavReaderFactory ctor. +// Fake audio track parameters. +constexpr int kDefaultSampleRate = 48000; +const std::map<std::string, const MockWavReaderFactory::Params> + kDefaultMockWavReaderFactoryParamsMap = { + {"t300", {kDefaultSampleRate, 1u, 14400u}}, // Mono, 0.3 seconds. + {"t500", {kDefaultSampleRate, 1u, 24000u}}, // Mono, 0.5 seconds. + {"t1000", {kDefaultSampleRate, 1u, 48000u}}, // Mono, 1.0 seconds. + {"sr8000", {8000, 1u, 8000u}}, // 8kHz sample rate, mono, 1 second. + {"sr16000", {16000, 1u, 16000u}}, // 16kHz sample rate, mono, 1 second. + {"sr16000_stereo", {16000, 2u, 16000u}}, // Like sr16000, but stereo. +}; +const MockWavReaderFactory::Params& kDefaultMockWavReaderFactoryParams = + kDefaultMockWavReaderFactoryParamsMap.at("t500"); + +std::unique_ptr<MockWavReaderFactory> CreateMockWavReaderFactory() { + return std::unique_ptr<MockWavReaderFactory>( + new MockWavReaderFactory(kDefaultMockWavReaderFactoryParams, + kDefaultMockWavReaderFactoryParamsMap)); +} + +void CreateSineWavFile(absl::string_view filepath, + const MockWavReaderFactory::Params& params, + float frequency = 440.0f) { + // Create samples. + constexpr double two_pi = 2.0 * M_PI; + std::vector<int16_t> samples(params.num_samples); + for (std::size_t i = 0; i < params.num_samples; ++i) { + // TODO(alessiob): the produced tone is not pure, improve. + samples[i] = std::lround( + 32767.0f * std::sin(two_pi * i * frequency / params.sample_rate)); + } + + // Write samples. + WavWriter wav_writer(filepath, params.sample_rate, params.num_channels); + wav_writer.WriteSamples(samples.data(), params.num_samples); +} + +// Parameters to generate audio tracks with CreateSineWavFile. +struct SineAudioTrackParams { + MockWavReaderFactory::Params params; + float frequency; +}; + +// Creates a temporary directory in which sine audio tracks are written. +std::string CreateTemporarySineAudioTracks( + const std::map<std::string, SineAudioTrackParams>& sine_tracks_params) { + // Create temporary directory. + std::string temp_directory = + OutputPath() + "TempConversationalSpeechAudioTracks"; + CreateDir(temp_directory); + + // Create sine tracks. + for (const auto& it : sine_tracks_params) { + const std::string temp_filepath = JoinFilename(temp_directory, it.first); + CreateSineWavFile(temp_filepath, it.second.params, it.second.frequency); + } + + return temp_directory; +} + +void CheckAudioTrackParams(const WavReaderFactory& wav_reader_factory, + absl::string_view filepath, + const MockWavReaderFactory::Params& expeted_params) { + auto wav_reader = wav_reader_factory.Create(filepath); + EXPECT_EQ(expeted_params.sample_rate, wav_reader->SampleRate()); + EXPECT_EQ(expeted_params.num_channels, wav_reader->NumChannels()); + EXPECT_EQ(expeted_params.num_samples, wav_reader->NumSamples()); +} + +void DeleteFolderAndContents(absl::string_view dir) { + if (!DirExists(dir)) { + return; + } + absl::optional<std::vector<std::string>> dir_content = ReadDirectory(dir); + EXPECT_TRUE(dir_content); + for (const auto& path : *dir_content) { + if (DirExists(path)) { + DeleteFolderAndContents(path); + } else if (FileExists(path)) { + // TODO(alessiob): Wrap with EXPECT_TRUE() once webrtc:7769 bug fixed. + RemoveFile(path); + } else { + FAIL(); + } + } + // TODO(alessiob): Wrap with EXPECT_TRUE() once webrtc:7769 bug fixed. + RemoveDir(dir); +} + +} // namespace + +using ::testing::_; + +TEST(ConversationalSpeechTest, Settings) { + const conversational_speech::Config config(audiotracks_path, timing_filepath, + output_path); + + // Test getters. + EXPECT_EQ(audiotracks_path, config.audiotracks_path()); + EXPECT_EQ(timing_filepath, config.timing_filepath()); + EXPECT_EQ(output_path, config.output_path()); +} + +TEST(ConversationalSpeechTest, TimingSaveLoad) { + // Save test timing. + const std::string temporary_filepath = + TempFilename(OutputPath(), "TempTimingTestFile"); + SaveTiming(temporary_filepath, expected_timing); + + // Create a std::vector<Turn> instance by loading from file. + std::vector<Turn> actual_timing = LoadTiming(temporary_filepath); + RemoveFile(temporary_filepath); + + // Check size. + EXPECT_EQ(expected_timing.size(), actual_timing.size()); + + // Check Turn instances. + for (size_t index = 0; index < expected_timing.size(); ++index) { + EXPECT_EQ(expected_timing[index], actual_timing[index]) + << "turn #" << index << " not matching"; + } +} + +TEST(ConversationalSpeechTest, MultiEndCallCreate) { + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are 5 unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(5); + + // Inject the mock wav reader factory. + conversational_speech::MultiEndCall multiend_call( + expected_timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(2u, multiend_call.speaker_names().size()); + EXPECT_EQ(5u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(6u, multiend_call.speaking_turns().size()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupDifferentSampleRates) { + const std::vector<Turn> timing = { + {"A", "sr8000", 0, 0}, + {"B", "sr16000", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(::testing::_)).Times(2); + + MultiEndCall multiend_call(timing, audiotracks_path, + std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupMultipleChannels) { + const std::vector<Turn> timing = { + {"A", "sr16000_stereo", 0, 0}, + {"B", "sr16000_stereo", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(::testing::_)).Times(1); + + MultiEndCall multiend_call(timing, audiotracks_path, + std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, + MultiEndCallSetupDifferentSampleRatesAndMultipleNumChannels) { + const std::vector<Turn> timing = { + {"A", "sr8000", 0, 0}, + {"B", "sr16000_stereo", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(::testing::_)).Times(2); + + MultiEndCall multiend_call(timing, audiotracks_path, + std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupFirstOffsetNegative) { + const std::vector<Turn> timing = { + {"A", "t500", -100, 0}, + {"B", "t500", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupSimple) { + // Accept: + // A 0****..... + // B .....1**** + constexpr std::size_t expected_duration = kDefaultSampleRate; + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, + {"B", "t500", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(2u, multiend_call.speaker_names().size()); + EXPECT_EQ(1u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(2u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupPause) { + // Accept: + // A 0****....... + // B .......1**** + constexpr std::size_t expected_duration = kDefaultSampleRate * 1.2; + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, + {"B", "t500", 200, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(2u, multiend_call.speaker_names().size()); + EXPECT_EQ(1u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(2u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupCrossTalk) { + // Accept: + // A 0****.... + // B ....1**** + constexpr std::size_t expected_duration = kDefaultSampleRate * 0.9; + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, + {"B", "t500", -100, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(2u, multiend_call.speaker_names().size()); + EXPECT_EQ(1u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(2u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupInvalidOrder) { + // Reject: + // A ..0**** + // B .1****. The n-th turn cannot start before the (n-1)-th one. + const std::vector<Turn> timing = { + {"A", "t500", 200, 0}, + {"B", "t500", -600, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupCrossTalkThree) { + // Accept: + // A 0****2****... + // B ...1********* + constexpr std::size_t expected_duration = kDefaultSampleRate * 1.3; + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, + {"B", "t1000", -200, 0}, + {"A", "t500", -800, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(2u, multiend_call.speaker_names().size()); + EXPECT_EQ(2u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(3u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupSelfCrossTalkNearInvalid) { + // Reject: + // A 0****...... + // A ...1****... + // B ......2**** + // ^ Turn #1 overlaps with #0 which is from the same speaker. + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, + {"A", "t500", -200, 0}, + {"B", "t500", -200, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupSelfCrossTalkFarInvalid) { + // Reject: + // A 0********* + // B 1**....... + // C ...2**.... + // A ......3**. + // ^ Turn #3 overlaps with #0 which is from the same speaker. + const std::vector<Turn> timing = { + {"A", "t1000", 0, 0}, + {"B", "t300", -1000, 0}, + {"C", "t300", 0, 0}, + {"A", "t300", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupCrossTalkMiddleValid) { + // Accept: + // A 0*********.. + // B ..1****..... + // C .......2**** + constexpr std::size_t expected_duration = kDefaultSampleRate * 1.2; + const std::vector<Turn> timing = { + {"A", "t1000", 0, 0}, + {"B", "t500", -800, 0}, + {"C", "t500", 0, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(3u, multiend_call.speaker_names().size()); + EXPECT_EQ(2u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(3u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupCrossTalkMiddleInvalid) { + // Reject: + // A 0********* + // B ..1****... + // C ....2****. + // ^ Turn #2 overlaps both with #0 and #1 (cross-talk with 3+ speakers + // not permitted). + const std::vector<Turn> timing = { + {"A", "t1000", 0, 0}, + {"B", "t500", -800, 0}, + {"C", "t500", -300, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupCrossTalkMiddleAndPause) { + // Accept: + // A 0*********.. + // B .2****...... + // C .......3**** + constexpr std::size_t expected_duration = kDefaultSampleRate * 1.2; + const std::vector<Turn> timing = { + {"A", "t1000", 0, 0}, + {"B", "t500", -900, 0}, + {"C", "t500", 100, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(3u, multiend_call.speaker_names().size()); + EXPECT_EQ(2u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(3u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupCrossTalkFullOverlapValid) { + // Accept: + // A 0**** + // B 1**** + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, + {"B", "t500", -500, 0}, + }; + auto mock_wavreader_factory = CreateMockWavReaderFactory(); + + // There is one unique audio track to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(1); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(2u, multiend_call.speaker_names().size()); + EXPECT_EQ(1u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(2u, multiend_call.speaking_turns().size()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupLongSequence) { + // Accept: + // A 0****....3****.5**. + // B .....1****...4**... + // C ......2**.......6**.. + constexpr std::size_t expected_duration = kDefaultSampleRate * 1.9; + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, {"B", "t500", 0, 0}, {"C", "t300", -400, 0}, + {"A", "t500", 0, 0}, {"B", "t300", -100, 0}, {"A", "t300", -100, 0}, + {"C", "t300", -200, 0}, + }; + auto mock_wavreader_factory = std::unique_ptr<MockWavReaderFactory>( + new MockWavReaderFactory(kDefaultMockWavReaderFactoryParams, + kDefaultMockWavReaderFactoryParamsMap)); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_TRUE(multiend_call.valid()); + + // Test. + EXPECT_EQ(3u, multiend_call.speaker_names().size()); + EXPECT_EQ(2u, multiend_call.audiotrack_readers().size()); + EXPECT_EQ(7u, multiend_call.speaking_turns().size()); + EXPECT_EQ(expected_duration, multiend_call.total_duration_samples()); +} + +TEST(ConversationalSpeechTest, MultiEndCallSetupLongSequenceInvalid) { + // Reject: + // A 0****....3****.6** + // B .....1****...4**.. + // C ......2**.....5**.. + // ^ Turns #4, #5 and #6 overlapping (cross-talk with 3+ + // speakers not permitted). + const std::vector<Turn> timing = { + {"A", "t500", 0, 0}, {"B", "t500", 0, 0}, {"C", "t300", -400, 0}, + {"A", "t500", 0, 0}, {"B", "t300", -100, 0}, {"A", "t300", -200, 0}, + {"C", "t300", -200, 0}, + }; + auto mock_wavreader_factory = std::unique_ptr<MockWavReaderFactory>( + new MockWavReaderFactory(kDefaultMockWavReaderFactoryParams, + kDefaultMockWavReaderFactoryParamsMap)); + + // There are two unique audio tracks to read. + EXPECT_CALL(*mock_wavreader_factory, Create(_)).Times(2); + + conversational_speech::MultiEndCall multiend_call( + timing, audiotracks_path, std::move(mock_wavreader_factory)); + EXPECT_FALSE(multiend_call.valid()); +} + +TEST(ConversationalSpeechTest, MultiEndCallWavReaderAdaptorSine) { + // Parameters with which wav files are created. + constexpr int duration_seconds = 5; + const int sample_rates[] = {8000, 11025, 16000, 22050, 32000, 44100, 48000}; + + for (int sample_rate : sample_rates) { + const std::string temp_filename = OutputPath() + "TempSineWavFile_" + + std::to_string(sample_rate) + ".wav"; + + // Write wav file. + const std::size_t num_samples = duration_seconds * sample_rate; + MockWavReaderFactory::Params params = {sample_rate, 1u, num_samples}; + CreateSineWavFile(temp_filename, params); + + // Load wav file and check if params match. + WavReaderFactory wav_reader_factory; + MockWavReaderFactory::Params expeted_params = {sample_rate, 1u, + num_samples}; + CheckAudioTrackParams(wav_reader_factory, temp_filename, expeted_params); + + // Clean up. + RemoveFile(temp_filename); + } +} + +TEST(ConversationalSpeechTest, DISABLED_MultiEndCallSimulator) { + // Simulated call (one character corresponding to 500 ms): + // A 0*********...........2*********..... + // B ...........1*********.....3********* + const std::vector<Turn> expected_timing = { + {"A", "t5000_440.wav", 0, 0}, + {"B", "t5000_880.wav", 500, 0}, + {"A", "t5000_440.wav", 0, 0}, + {"B", "t5000_880.wav", -2500, 0}, + }; + const std::size_t expected_duration_seconds = 18; + + // Create temporary audio track files. + const int sample_rate = 16000; + const std::map<std::string, SineAudioTrackParams> sine_tracks_params = { + {"t5000_440.wav", {{sample_rate, 1u, sample_rate * 5}, 440.0}}, + {"t5000_880.wav", {{sample_rate, 1u, sample_rate * 5}, 880.0}}, + }; + const std::string audiotracks_path = + CreateTemporarySineAudioTracks(sine_tracks_params); + + // Set up the multi-end call. + auto wavreader_factory = + std::unique_ptr<WavReaderFactory>(new WavReaderFactory()); + MultiEndCall multiend_call(expected_timing, audiotracks_path, + std::move(wavreader_factory)); + + // Simulate the call. + std::string output_path = JoinFilename(audiotracks_path, "output"); + CreateDir(output_path); + RTC_LOG(LS_VERBOSE) << "simulator output path: " << output_path; + auto generated_audiotrak_pairs = + conversational_speech::Simulate(multiend_call, output_path); + EXPECT_EQ(2u, generated_audiotrak_pairs->size()); + + // Check the output. + WavReaderFactory wav_reader_factory; + const MockWavReaderFactory::Params expeted_params = { + sample_rate, 1u, sample_rate * expected_duration_seconds}; + for (const auto& it : *generated_audiotrak_pairs) { + RTC_LOG(LS_VERBOSE) << "checking far/near-end for <" << it.first << ">"; + CheckAudioTrackParams(wav_reader_factory, it.second.near_end, + expeted_params); + CheckAudioTrackParams(wav_reader_factory, it.second.far_end, + expeted_params); + } + + // Clean. + EXPECT_NO_FATAL_FAILURE(DeleteFolderAndContents(audiotracks_path)); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader.cc new file mode 100644 index 0000000000..1263e938c4 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/mock_wavreader.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +using ::testing::Return; + +MockWavReader::MockWavReader(int sample_rate, + size_t num_channels, + size_t num_samples) + : sample_rate_(sample_rate), + num_channels_(num_channels), + num_samples_(num_samples) { + ON_CALL(*this, SampleRate()).WillByDefault(Return(sample_rate_)); + ON_CALL(*this, NumChannels()).WillByDefault(Return(num_channels_)); + ON_CALL(*this, NumSamples()).WillByDefault(Return(num_samples_)); +} + +MockWavReader::~MockWavReader() = default; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader.h new file mode 100644 index 0000000000..94e20b9ec6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MOCK_WAVREADER_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MOCK_WAVREADER_H_ + +#include <cstddef> +#include <string> + +#include "api/array_view.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_interface.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +class MockWavReader : public WavReaderInterface { + public: + MockWavReader(int sample_rate, size_t num_channels, size_t num_samples); + ~MockWavReader(); + + // TODO(alessiob): use ON_CALL to return random samples if needed. + MOCK_METHOD(size_t, ReadFloatSamples, (rtc::ArrayView<float>), (override)); + MOCK_METHOD(size_t, ReadInt16Samples, (rtc::ArrayView<int16_t>), (override)); + + MOCK_METHOD(int, SampleRate, (), (const, override)); + MOCK_METHOD(size_t, NumChannels, (), (const, override)); + MOCK_METHOD(size_t, NumSamples, (), (const, override)); + + private: + const int sample_rate_; + const size_t num_channels_; + const size_t num_samples_; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MOCK_WAVREADER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader_factory.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader_factory.cc new file mode 100644 index 0000000000..a377cce7e3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader_factory.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/mock_wavreader_factory.h" + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/conversational_speech/mock_wavreader.h" +#include "rtc_base/logging.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +using ::testing::_; +using ::testing::Invoke; + +MockWavReaderFactory::MockWavReaderFactory( + const Params& default_params, + const std::map<std::string, const Params>& params) + : default_params_(default_params), audiotrack_names_params_(params) { + ON_CALL(*this, Create(_)) + .WillByDefault(Invoke(this, &MockWavReaderFactory::CreateMock)); +} + +MockWavReaderFactory::MockWavReaderFactory(const Params& default_params) + : MockWavReaderFactory(default_params, + std::map<std::string, const Params>{}) {} + +MockWavReaderFactory::~MockWavReaderFactory() = default; + +std::unique_ptr<WavReaderInterface> MockWavReaderFactory::CreateMock( + absl::string_view filepath) { + // Search the parameters corresponding to filepath. + size_t delimiter = filepath.find_last_of("/\\"); // Either windows or posix + std::string filename(filepath.substr( + delimiter == absl::string_view::npos ? 0 : delimiter + 1)); + const auto it = audiotrack_names_params_.find(filename); + + // If not found, use default parameters. + if (it == audiotrack_names_params_.end()) { + RTC_LOG(LS_VERBOSE) << "using default parameters for " << filepath; + return std::unique_ptr<WavReaderInterface>(new MockWavReader( + default_params_.sample_rate, default_params_.num_channels, + default_params_.num_samples)); + } + + // Found, use the audiotrack-specific parameters. + RTC_LOG(LS_VERBOSE) << "using ad-hoc parameters for " << filepath; + RTC_LOG(LS_VERBOSE) << "sample_rate " << it->second.sample_rate; + RTC_LOG(LS_VERBOSE) << "num_channels " << it->second.num_channels; + RTC_LOG(LS_VERBOSE) << "num_samples " << it->second.num_samples; + return std::unique_ptr<WavReaderInterface>(new MockWavReader( + it->second.sample_rate, it->second.num_channels, it->second.num_samples)); +} + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader_factory.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader_factory.h new file mode 100644 index 0000000000..bcc7f3069b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/mock_wavreader_factory.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MOCK_WAVREADER_FACTORY_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MOCK_WAVREADER_FACTORY_H_ + +#include <map> +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_abstract_factory.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_interface.h" +#include "test/gmock.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +class MockWavReaderFactory : public WavReaderAbstractFactory { + public: + struct Params { + int sample_rate; + size_t num_channels; + size_t num_samples; + }; + + MockWavReaderFactory(const Params& default_params, + const std::map<std::string, const Params>& params); + explicit MockWavReaderFactory(const Params& default_params); + ~MockWavReaderFactory(); + + MOCK_METHOD(std::unique_ptr<WavReaderInterface>, + Create, + (absl::string_view), + (const, override)); + + private: + // Creates a MockWavReader instance using the parameters in + // audiotrack_names_params_ if the entry corresponding to filepath exists, + // otherwise creates a MockWavReader instance using the default parameters. + std::unique_ptr<WavReaderInterface> CreateMock(absl::string_view filepath); + + const Params& default_params_; + std::map<std::string, const Params> audiotrack_names_params_; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MOCK_WAVREADER_FACTORY_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/multiend_call.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/multiend_call.cc new file mode 100644 index 0000000000..952114a78b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/multiend_call.cc @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/multiend_call.h" + +#include <algorithm> +#include <iterator> + +#include "absl/strings/string_view.h" +#include "rtc_base/logging.h" +#include "test/testsupport/file_utils.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +MultiEndCall::MultiEndCall( + rtc::ArrayView<const Turn> timing, + absl::string_view audiotracks_path, + std::unique_ptr<WavReaderAbstractFactory> wavreader_abstract_factory) + : timing_(timing), + audiotracks_path_(audiotracks_path), + wavreader_abstract_factory_(std::move(wavreader_abstract_factory)), + valid_(false) { + FindSpeakerNames(); + if (CreateAudioTrackReaders()) + valid_ = CheckTiming(); +} + +MultiEndCall::~MultiEndCall() = default; + +void MultiEndCall::FindSpeakerNames() { + RTC_DCHECK(speaker_names_.empty()); + for (const Turn& turn : timing_) { + speaker_names_.emplace(turn.speaker_name); + } +} + +bool MultiEndCall::CreateAudioTrackReaders() { + RTC_DCHECK(audiotrack_readers_.empty()); + sample_rate_hz_ = 0; // Sample rate will be set when reading the first track. + for (const Turn& turn : timing_) { + auto it = audiotrack_readers_.find(turn.audiotrack_file_name); + if (it != audiotrack_readers_.end()) + continue; + + const std::string audiotrack_file_path = + test::JoinFilename(audiotracks_path_, turn.audiotrack_file_name); + + // Map the audiotrack file name to a new instance of WavReaderInterface. + std::unique_ptr<WavReaderInterface> wavreader = + wavreader_abstract_factory_->Create( + test::JoinFilename(audiotracks_path_, turn.audiotrack_file_name)); + + if (sample_rate_hz_ == 0) { + sample_rate_hz_ = wavreader->SampleRate(); + } else if (sample_rate_hz_ != wavreader->SampleRate()) { + RTC_LOG(LS_ERROR) + << "All the audio tracks should have the same sample rate."; + return false; + } + + if (wavreader->NumChannels() != 1) { + RTC_LOG(LS_ERROR) << "Only mono audio tracks supported."; + return false; + } + + audiotrack_readers_.emplace(turn.audiotrack_file_name, + std::move(wavreader)); + } + + return true; +} + +bool MultiEndCall::CheckTiming() { + struct Interval { + size_t begin; + size_t end; + }; + size_t number_of_turns = timing_.size(); + auto millisecond_to_samples = [](int ms, int sr) -> int { + // Truncation may happen if the sampling rate is not an integer multiple + // of 1000 (e.g., 44100). + return ms * sr / 1000; + }; + auto in_interval = [](size_t value, const Interval& interval) { + return interval.begin <= value && value < interval.end; + }; + total_duration_samples_ = 0; + speaking_turns_.clear(); + + // Begin and end timestamps for the last two turns (unit: number of samples). + Interval second_last_turn = {0, 0}; + Interval last_turn = {0, 0}; + + // Initialize map to store speaking turn indices of each speaker (used to + // detect self cross-talk). + std::map<std::string, std::vector<size_t>> speaking_turn_indices; + for (const std::string& speaker_name : speaker_names_) { + speaking_turn_indices.emplace(std::piecewise_construct, + std::forward_as_tuple(speaker_name), + std::forward_as_tuple()); + } + + // Parse turns. + for (size_t turn_index = 0; turn_index < number_of_turns; ++turn_index) { + const Turn& turn = timing_[turn_index]; + auto it = audiotrack_readers_.find(turn.audiotrack_file_name); + RTC_CHECK(it != audiotrack_readers_.end()) + << "Audio track reader not created"; + + // Begin and end timestamps for the current turn. + int offset_samples = + millisecond_to_samples(turn.offset, it->second->SampleRate()); + std::size_t begin_timestamp = last_turn.end + offset_samples; + std::size_t end_timestamp = begin_timestamp + it->second->NumSamples(); + RTC_LOG(LS_INFO) << "turn #" << turn_index << " " << begin_timestamp << "-" + << end_timestamp << " ms"; + + // The order is invalid if the offset is negative and its absolute value is + // larger then the duration of the previous turn. + if (offset_samples < 0 && + -offset_samples > static_cast<int>(last_turn.end - last_turn.begin)) { + RTC_LOG(LS_ERROR) << "invalid order"; + return false; + } + + // Cross-talk with 3 or more speakers occurs when the beginning of the + // current interval falls in the last two turns. + if (turn_index > 1 && in_interval(begin_timestamp, last_turn) && + in_interval(begin_timestamp, second_last_turn)) { + RTC_LOG(LS_ERROR) << "cross-talk with 3+ speakers"; + return false; + } + + // Append turn. + speaking_turns_.emplace_back(turn.speaker_name, turn.audiotrack_file_name, + begin_timestamp, end_timestamp, turn.gain); + + // Save speaking turn index for self cross-talk detection. + RTC_DCHECK_EQ(speaking_turns_.size(), turn_index + 1); + speaking_turn_indices[turn.speaker_name].push_back(turn_index); + + // Update total duration of the consversational speech. + if (total_duration_samples_ < end_timestamp) + total_duration_samples_ = end_timestamp; + + // Update and continue with next turn. + second_last_turn = last_turn; + last_turn.begin = begin_timestamp; + last_turn.end = end_timestamp; + } + + // Detect self cross-talk. + for (const std::string& speaker_name : speaker_names_) { + RTC_LOG(LS_INFO) << "checking self cross-talk for <" << speaker_name << ">"; + + // Copy all turns for this speaker to new vector. + std::vector<SpeakingTurn> speaking_turns_for_name; + std::copy_if(speaking_turns_.begin(), speaking_turns_.end(), + std::back_inserter(speaking_turns_for_name), + [&speaker_name](const SpeakingTurn& st) { + return st.speaker_name == speaker_name; + }); + + // Check for overlap between adjacent elements. + // This is a sufficient condition for self cross-talk since the intervals + // are sorted by begin timestamp. + auto overlap = std::adjacent_find( + speaking_turns_for_name.begin(), speaking_turns_for_name.end(), + [](const SpeakingTurn& a, const SpeakingTurn& b) { + return a.end > b.begin; + }); + + if (overlap != speaking_turns_for_name.end()) { + RTC_LOG(LS_ERROR) << "Self cross-talk detected"; + return false; + } + } + + return true; +} + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/multiend_call.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/multiend_call.h new file mode 100644 index 0000000000..63283465fa --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/multiend_call.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MULTIEND_CALL_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MULTIEND_CALL_H_ + +#include <stddef.h> + +#include <map> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "modules/audio_processing/test/conversational_speech/timing.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_abstract_factory.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_interface.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +class MultiEndCall { + public: + struct SpeakingTurn { + // Constructor required in order to use std::vector::emplace_back(). + SpeakingTurn(absl::string_view new_speaker_name, + absl::string_view new_audiotrack_file_name, + size_t new_begin, + size_t new_end, + int gain) + : speaker_name(new_speaker_name), + audiotrack_file_name(new_audiotrack_file_name), + begin(new_begin), + end(new_end), + gain(gain) {} + std::string speaker_name; + std::string audiotrack_file_name; + size_t begin; + size_t end; + int gain; + }; + + MultiEndCall( + rtc::ArrayView<const Turn> timing, + absl::string_view audiotracks_path, + std::unique_ptr<WavReaderAbstractFactory> wavreader_abstract_factory); + ~MultiEndCall(); + + MultiEndCall(const MultiEndCall&) = delete; + MultiEndCall& operator=(const MultiEndCall&) = delete; + + const std::set<std::string>& speaker_names() const { return speaker_names_; } + const std::map<std::string, std::unique_ptr<WavReaderInterface>>& + audiotrack_readers() const { + return audiotrack_readers_; + } + bool valid() const { return valid_; } + int sample_rate() const { return sample_rate_hz_; } + size_t total_duration_samples() const { return total_duration_samples_; } + const std::vector<SpeakingTurn>& speaking_turns() const { + return speaking_turns_; + } + + private: + // Finds unique speaker names. + void FindSpeakerNames(); + + // Creates one WavReader instance for each unique audiotrack. It returns false + // if the audio tracks do not have the same sample rate or if they are not + // mono. + bool CreateAudioTrackReaders(); + + // Validates the speaking turns timing information. Accepts cross-talk, but + // only up to 2 speakers. Rejects unordered turns and self cross-talk. + bool CheckTiming(); + + rtc::ArrayView<const Turn> timing_; + std::string audiotracks_path_; + std::unique_ptr<WavReaderAbstractFactory> wavreader_abstract_factory_; + std::set<std::string> speaker_names_; + std::map<std::string, std::unique_ptr<WavReaderInterface>> + audiotrack_readers_; + bool valid_; + int sample_rate_hz_; + size_t total_duration_samples_; + std::vector<SpeakingTurn> speaking_turns_; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_MULTIEND_CALL_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/simulator.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/simulator.cc new file mode 100644 index 0000000000..89bcd48d84 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/simulator.cc @@ -0,0 +1,235 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/simulator.h" + +#include <math.h> + +#include <algorithm> +#include <memory> +#include <set> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "common_audio/include/audio_util.h" +#include "common_audio/wav_file.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_interface.h" +#include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "test/testsupport/file_utils.h" + +namespace webrtc { +namespace test { +namespace { + +using conversational_speech::MultiEndCall; +using conversational_speech::SpeakerOutputFilePaths; +using conversational_speech::WavReaderInterface; + +// Combines output path and speaker names to define the output file paths for +// the near-end and far=end audio tracks. +std::unique_ptr<std::map<std::string, SpeakerOutputFilePaths>> +InitSpeakerOutputFilePaths(const std::set<std::string>& speaker_names, + absl::string_view output_path) { + // Create map. + auto speaker_output_file_paths_map = + std::make_unique<std::map<std::string, SpeakerOutputFilePaths>>(); + + // Add near-end and far-end output paths into the map. + for (const auto& speaker_name : speaker_names) { + const std::string near_end_path = + test::JoinFilename(output_path, "s_" + speaker_name + "-near_end.wav"); + RTC_LOG(LS_VERBOSE) << "The near-end audio track will be created in " + << near_end_path << "."; + + const std::string far_end_path = + test::JoinFilename(output_path, "s_" + speaker_name + "-far_end.wav"); + RTC_LOG(LS_VERBOSE) << "The far-end audio track will be created in " + << far_end_path << "."; + + // Add to map. + speaker_output_file_paths_map->emplace( + std::piecewise_construct, std::forward_as_tuple(speaker_name), + std::forward_as_tuple(near_end_path, far_end_path)); + } + + return speaker_output_file_paths_map; +} + +// Class that provides one WavWriter for the near-end and one for the far-end +// output track of a speaker. +class SpeakerWavWriters { + public: + SpeakerWavWriters(const SpeakerOutputFilePaths& output_file_paths, + int sample_rate) + : near_end_wav_writer_(output_file_paths.near_end, sample_rate, 1u), + far_end_wav_writer_(output_file_paths.far_end, sample_rate, 1u) {} + WavWriter* near_end_wav_writer() { return &near_end_wav_writer_; } + WavWriter* far_end_wav_writer() { return &far_end_wav_writer_; } + + private: + WavWriter near_end_wav_writer_; + WavWriter far_end_wav_writer_; +}; + +// Initializes one WavWriter instance for each speaker and both the near-end and +// far-end output tracks. +std::unique_ptr<std::map<std::string, SpeakerWavWriters>> +InitSpeakersWavWriters(const std::map<std::string, SpeakerOutputFilePaths>& + speaker_output_file_paths, + int sample_rate) { + // Create map. + auto speaker_wav_writers_map = + std::make_unique<std::map<std::string, SpeakerWavWriters>>(); + + // Add SpeakerWavWriters instance into the map. + for (auto it = speaker_output_file_paths.begin(); + it != speaker_output_file_paths.end(); ++it) { + speaker_wav_writers_map->emplace( + std::piecewise_construct, std::forward_as_tuple(it->first), + std::forward_as_tuple(it->second, sample_rate)); + } + + return speaker_wav_writers_map; +} + +// Reads all the samples for each audio track. +std::unique_ptr<std::map<std::string, std::vector<int16_t>>> PreloadAudioTracks( + const std::map<std::string, std::unique_ptr<WavReaderInterface>>& + audiotrack_readers) { + // Create map. + auto audiotracks_map = + std::make_unique<std::map<std::string, std::vector<int16_t>>>(); + + // Add audio track vectors. + for (auto it = audiotrack_readers.begin(); it != audiotrack_readers.end(); + ++it) { + // Add map entry. + audiotracks_map->emplace(std::piecewise_construct, + std::forward_as_tuple(it->first), + std::forward_as_tuple(it->second->NumSamples())); + + // Read samples. + it->second->ReadInt16Samples(audiotracks_map->at(it->first)); + } + + return audiotracks_map; +} + +// Writes all the values in `source_samples` via `wav_writer`. If the number of +// previously written samples in `wav_writer` is less than `interval_begin`, it +// adds zeros as left padding. The padding corresponds to intervals during which +// a speaker is not active. +void PadLeftWriteChunk(rtc::ArrayView<const int16_t> source_samples, + size_t interval_begin, + WavWriter* wav_writer) { + // Add left padding. + RTC_CHECK(wav_writer); + RTC_CHECK_GE(interval_begin, wav_writer->num_samples()); + size_t padding_size = interval_begin - wav_writer->num_samples(); + if (padding_size != 0) { + const std::vector<int16_t> padding(padding_size, 0); + wav_writer->WriteSamples(padding.data(), padding_size); + } + + // Write source samples. + wav_writer->WriteSamples(source_samples.data(), source_samples.size()); +} + +// Appends zeros via `wav_writer`. The number of zeros is always non-negative +// and equal to the difference between the previously written samples and +// `pad_samples`. +void PadRightWrite(WavWriter* wav_writer, size_t pad_samples) { + RTC_CHECK(wav_writer); + RTC_CHECK_GE(pad_samples, wav_writer->num_samples()); + size_t padding_size = pad_samples - wav_writer->num_samples(); + if (padding_size != 0) { + const std::vector<int16_t> padding(padding_size, 0); + wav_writer->WriteSamples(padding.data(), padding_size); + } +} + +void ScaleSignal(rtc::ArrayView<const int16_t> source_samples, + int gain, + rtc::ArrayView<int16_t> output_samples) { + const float gain_linear = DbToRatio(gain); + RTC_DCHECK_EQ(source_samples.size(), output_samples.size()); + std::transform(source_samples.begin(), source_samples.end(), + output_samples.begin(), [gain_linear](int16_t x) -> int16_t { + return rtc::saturated_cast<int16_t>(x * gain_linear); + }); +} + +} // namespace + +namespace conversational_speech { + +std::unique_ptr<std::map<std::string, SpeakerOutputFilePaths>> Simulate( + const MultiEndCall& multiend_call, + absl::string_view output_path) { + // Set output file paths and initialize wav writers. + const auto& speaker_names = multiend_call.speaker_names(); + auto speaker_output_file_paths = + InitSpeakerOutputFilePaths(speaker_names, output_path); + auto speakers_wav_writers = InitSpeakersWavWriters( + *speaker_output_file_paths, multiend_call.sample_rate()); + + // Preload all the input audio tracks. + const auto& audiotrack_readers = multiend_call.audiotrack_readers(); + auto audiotracks = PreloadAudioTracks(audiotrack_readers); + + // TODO(alessiob): When speaker_names.size() == 2, near-end and far-end + // across the 2 speakers are symmetric; hence, the code below could be + // replaced by only creating the near-end or the far-end. However, this would + // require to split the unit tests and document the behavior in README.md. + // In practice, it should not be an issue since the files are not expected to + // be signinificant. + + // Write near-end and far-end output tracks. + for (const auto& speaking_turn : multiend_call.speaking_turns()) { + const std::string& active_speaker_name = speaking_turn.speaker_name; + const auto source_audiotrack = + audiotracks->at(speaking_turn.audiotrack_file_name); + std::vector<int16_t> scaled_audiotrack(source_audiotrack.size()); + ScaleSignal(source_audiotrack, speaking_turn.gain, scaled_audiotrack); + + // Write active speaker's chunk to active speaker's near-end. + PadLeftWriteChunk( + scaled_audiotrack, speaking_turn.begin, + speakers_wav_writers->at(active_speaker_name).near_end_wav_writer()); + + // Write active speaker's chunk to other participants' far-ends. + for (const std::string& speaker_name : speaker_names) { + if (speaker_name == active_speaker_name) + continue; + PadLeftWriteChunk( + scaled_audiotrack, speaking_turn.begin, + speakers_wav_writers->at(speaker_name).far_end_wav_writer()); + } + } + + // Finalize all the output tracks with right padding. + // This is required to make all the output tracks duration equal. + size_t duration_samples = multiend_call.total_duration_samples(); + for (const std::string& speaker_name : speaker_names) { + PadRightWrite(speakers_wav_writers->at(speaker_name).near_end_wav_writer(), + duration_samples); + PadRightWrite(speakers_wav_writers->at(speaker_name).far_end_wav_writer(), + duration_samples); + } + + return speaker_output_file_paths; +} + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/simulator.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/simulator.h new file mode 100644 index 0000000000..2f311e16b3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/simulator.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_SIMULATOR_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_SIMULATOR_H_ + +#include <map> +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/conversational_speech/multiend_call.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +struct SpeakerOutputFilePaths { + SpeakerOutputFilePaths(absl::string_view new_near_end, + absl::string_view new_far_end) + : near_end(new_near_end), far_end(new_far_end) {} + // Paths to the near-end and far-end audio track files. + const std::string near_end; + const std::string far_end; +}; + +// Generates the near-end and far-end audio track pairs for each speaker. +std::unique_ptr<std::map<std::string, SpeakerOutputFilePaths>> Simulate( + const MultiEndCall& multiend_call, + absl::string_view output_path); + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_SIMULATOR_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/timing.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/timing.cc new file mode 100644 index 0000000000..95ec9f542e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/timing.cc @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/timing.h" + +#include <fstream> +#include <iostream> +#include <string> + +#include "absl/strings/string_view.h" +#include "rtc_base/string_encode.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +bool Turn::operator==(const Turn& b) const { + return b.speaker_name == speaker_name && + b.audiotrack_file_name == audiotrack_file_name && b.offset == offset && + b.gain == gain; +} + +std::vector<Turn> LoadTiming(absl::string_view timing_filepath) { + // Line parser. + auto parse_line = [](absl::string_view line) { + std::vector<absl::string_view> fields = rtc::split(line, ' '); + RTC_CHECK_GE(fields.size(), 3); + RTC_CHECK_LE(fields.size(), 4); + int gain = 0; + if (fields.size() == 4) { + gain = rtc::StringToNumber<int>(fields[3]).value_or(0); + } + return Turn(fields[0], fields[1], + rtc::StringToNumber<int>(fields[2]).value_or(0), gain); + }; + + // Init. + std::vector<Turn> timing; + + // Parse lines. + std::string line; + std::ifstream infile(std::string{timing_filepath}); + while (std::getline(infile, line)) { + if (line.empty()) + continue; + timing.push_back(parse_line(line)); + } + infile.close(); + + return timing; +} + +void SaveTiming(absl::string_view timing_filepath, + rtc::ArrayView<const Turn> timing) { + std::ofstream outfile(std::string{timing_filepath}); + RTC_CHECK(outfile.is_open()); + for (const Turn& turn : timing) { + outfile << turn.speaker_name << " " << turn.audiotrack_file_name << " " + << turn.offset << " " << turn.gain << std::endl; + } + outfile.close(); +} + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/timing.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/timing.h new file mode 100644 index 0000000000..9314f6fc43 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/timing.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_TIMING_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_TIMING_H_ + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +struct Turn { + Turn(absl::string_view new_speaker_name, + absl::string_view new_audiotrack_file_name, + int new_offset, + int gain) + : speaker_name(new_speaker_name), + audiotrack_file_name(new_audiotrack_file_name), + offset(new_offset), + gain(gain) {} + bool operator==(const Turn& b) const; + std::string speaker_name; + std::string audiotrack_file_name; + int offset; + int gain; +}; + +// Loads a list of turns from a file. +std::vector<Turn> LoadTiming(absl::string_view timing_filepath); + +// Writes a list of turns into a file. +void SaveTiming(absl::string_view timing_filepath, + rtc::ArrayView<const Turn> timing); + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_TIMING_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_abstract_factory.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_abstract_factory.h new file mode 100644 index 0000000000..14ddfc7539 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_abstract_factory.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_ABSTRACT_FACTORY_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_ABSTRACT_FACTORY_H_ + +#include <memory> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_interface.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +class WavReaderAbstractFactory { + public: + virtual ~WavReaderAbstractFactory() = default; + virtual std::unique_ptr<WavReaderInterface> Create( + absl::string_view filepath) const = 0; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_ABSTRACT_FACTORY_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_factory.cc b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_factory.cc new file mode 100644 index 0000000000..99b1686484 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_factory.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/conversational_speech/wavreader_factory.h" + +#include <cstddef> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "common_audio/wav_file.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { +namespace { + +using conversational_speech::WavReaderInterface; + +class WavReaderAdaptor final : public WavReaderInterface { + public: + explicit WavReaderAdaptor(absl::string_view filepath) + : wav_reader_(filepath) {} + ~WavReaderAdaptor() override = default; + + size_t ReadFloatSamples(rtc::ArrayView<float> samples) override { + return wav_reader_.ReadSamples(samples.size(), samples.begin()); + } + + size_t ReadInt16Samples(rtc::ArrayView<int16_t> samples) override { + return wav_reader_.ReadSamples(samples.size(), samples.begin()); + } + + int SampleRate() const override { return wav_reader_.sample_rate(); } + + size_t NumChannels() const override { return wav_reader_.num_channels(); } + + size_t NumSamples() const override { return wav_reader_.num_samples(); } + + private: + WavReader wav_reader_; +}; + +} // namespace + +namespace conversational_speech { + +WavReaderFactory::WavReaderFactory() = default; + +WavReaderFactory::~WavReaderFactory() = default; + +std::unique_ptr<WavReaderInterface> WavReaderFactory::Create( + absl::string_view filepath) const { + return std::unique_ptr<WavReaderAdaptor>(new WavReaderAdaptor(filepath)); +} + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_factory.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_factory.h new file mode 100644 index 0000000000..f2e5b61055 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_factory.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_FACTORY_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_FACTORY_H_ + +#include <memory> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_abstract_factory.h" +#include "modules/audio_processing/test/conversational_speech/wavreader_interface.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +class WavReaderFactory : public WavReaderAbstractFactory { + public: + WavReaderFactory(); + ~WavReaderFactory() override; + std::unique_ptr<WavReaderInterface> Create( + absl::string_view filepath) const override; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_FACTORY_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_interface.h b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_interface.h new file mode 100644 index 0000000000..c74f639461 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/conversational_speech/wavreader_interface.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_INTERFACE_H_ +#define MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_INTERFACE_H_ + +#include <stddef.h> + +#include "api/array_view.h" + +namespace webrtc { +namespace test { +namespace conversational_speech { + +class WavReaderInterface { + public: + virtual ~WavReaderInterface() = default; + + // Returns the number of samples read. + virtual size_t ReadFloatSamples(rtc::ArrayView<float> samples) = 0; + virtual size_t ReadInt16Samples(rtc::ArrayView<int16_t> samples) = 0; + + // Getters. + virtual int SampleRate() const = 0; + virtual size_t NumChannels() const = 0; + virtual size_t NumSamples() const = 0; +}; + +} // namespace conversational_speech +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_CONVERSATIONAL_SPEECH_WAVREADER_INTERFACE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/debug_dump_replayer.cc b/third_party/libwebrtc/modules/audio_processing/test/debug_dump_replayer.cc new file mode 100644 index 0000000000..2419313e9d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/debug_dump_replayer.cc @@ -0,0 +1,248 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/debug_dump_replayer.h" + +#include <string> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/audio_processing_builder_for_testing.h" +#include "modules/audio_processing/test/protobuf_utils.h" +#include "modules/audio_processing/test/runtime_setting_util.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +namespace { + +void MaybeResetBuffer(std::unique_ptr<ChannelBuffer<float>>* buffer, + const StreamConfig& config) { + auto& buffer_ref = *buffer; + if (!buffer_ref.get() || buffer_ref->num_frames() != config.num_frames() || + buffer_ref->num_channels() != config.num_channels()) { + buffer_ref.reset( + new ChannelBuffer<float>(config.num_frames(), config.num_channels())); + } +} + +} // namespace + +DebugDumpReplayer::DebugDumpReplayer() + : input_(nullptr), // will be created upon usage. + reverse_(nullptr), + output_(nullptr), + apm_(nullptr), + debug_file_(nullptr) {} + +DebugDumpReplayer::~DebugDumpReplayer() { + if (debug_file_) + fclose(debug_file_); +} + +bool DebugDumpReplayer::SetDumpFile(absl::string_view filename) { + debug_file_ = fopen(std::string(filename).c_str(), "rb"); + LoadNextMessage(); + return debug_file_; +} + +// Get next event that has not run. +absl::optional<audioproc::Event> DebugDumpReplayer::GetNextEvent() const { + if (!has_next_event_) + return absl::nullopt; + else + return next_event_; +} + +// Run the next event. Returns the event type. +bool DebugDumpReplayer::RunNextEvent() { + if (!has_next_event_) + return false; + switch (next_event_.type()) { + case audioproc::Event::INIT: + OnInitEvent(next_event_.init()); + break; + case audioproc::Event::STREAM: + OnStreamEvent(next_event_.stream()); + break; + case audioproc::Event::REVERSE_STREAM: + OnReverseStreamEvent(next_event_.reverse_stream()); + break; + case audioproc::Event::CONFIG: + OnConfigEvent(next_event_.config()); + break; + case audioproc::Event::RUNTIME_SETTING: + OnRuntimeSettingEvent(next_event_.runtime_setting()); + break; + case audioproc::Event::UNKNOWN_EVENT: + // We do not expect to receive UNKNOWN event. + RTC_CHECK_NOTREACHED(); + } + LoadNextMessage(); + return true; +} + +const ChannelBuffer<float>* DebugDumpReplayer::GetOutput() const { + return output_.get(); +} + +StreamConfig DebugDumpReplayer::GetOutputConfig() const { + return output_config_; +} + +// OnInitEvent reset the input/output/reserve channel format. +void DebugDumpReplayer::OnInitEvent(const audioproc::Init& msg) { + RTC_CHECK(msg.has_num_input_channels()); + RTC_CHECK(msg.has_output_sample_rate()); + RTC_CHECK(msg.has_num_output_channels()); + RTC_CHECK(msg.has_reverse_sample_rate()); + RTC_CHECK(msg.has_num_reverse_channels()); + + input_config_ = StreamConfig(msg.sample_rate(), msg.num_input_channels()); + output_config_ = + StreamConfig(msg.output_sample_rate(), msg.num_output_channels()); + reverse_config_ = + StreamConfig(msg.reverse_sample_rate(), msg.num_reverse_channels()); + + MaybeResetBuffer(&input_, input_config_); + MaybeResetBuffer(&output_, output_config_); + MaybeResetBuffer(&reverse_, reverse_config_); +} + +// OnStreamEvent replays an input signal and verifies the output. +void DebugDumpReplayer::OnStreamEvent(const audioproc::Stream& msg) { + // APM should have been created. + RTC_CHECK(apm_.get()); + + apm_->set_stream_analog_level(msg.level()); + RTC_CHECK_EQ(AudioProcessing::kNoError, + apm_->set_stream_delay_ms(msg.delay())); + + if (msg.has_keypress()) { + apm_->set_stream_key_pressed(msg.keypress()); + } else { + apm_->set_stream_key_pressed(true); + } + + RTC_CHECK_EQ(input_config_.num_channels(), + static_cast<size_t>(msg.input_channel_size())); + RTC_CHECK_EQ(input_config_.num_frames() * sizeof(float), + msg.input_channel(0).size()); + + for (int i = 0; i < msg.input_channel_size(); ++i) { + memcpy(input_->channels()[i], msg.input_channel(i).data(), + msg.input_channel(i).size()); + } + + RTC_CHECK_EQ(AudioProcessing::kNoError, + apm_->ProcessStream(input_->channels(), input_config_, + output_config_, output_->channels())); +} + +void DebugDumpReplayer::OnReverseStreamEvent( + const audioproc::ReverseStream& msg) { + // APM should have been created. + RTC_CHECK(apm_.get()); + + RTC_CHECK_GT(msg.channel_size(), 0); + RTC_CHECK_EQ(reverse_config_.num_channels(), + static_cast<size_t>(msg.channel_size())); + RTC_CHECK_EQ(reverse_config_.num_frames() * sizeof(float), + msg.channel(0).size()); + + for (int i = 0; i < msg.channel_size(); ++i) { + memcpy(reverse_->channels()[i], msg.channel(i).data(), + msg.channel(i).size()); + } + + RTC_CHECK_EQ( + AudioProcessing::kNoError, + apm_->ProcessReverseStream(reverse_->channels(), reverse_config_, + reverse_config_, reverse_->channels())); +} + +void DebugDumpReplayer::OnConfigEvent(const audioproc::Config& msg) { + MaybeRecreateApm(msg); + ConfigureApm(msg); +} + +void DebugDumpReplayer::OnRuntimeSettingEvent( + const audioproc::RuntimeSetting& msg) { + RTC_CHECK(apm_.get()); + ReplayRuntimeSetting(apm_.get(), msg); +} + +void DebugDumpReplayer::MaybeRecreateApm(const audioproc::Config& msg) { + // These configurations cannot be changed on the fly. + RTC_CHECK(msg.has_aec_delay_agnostic_enabled()); + RTC_CHECK(msg.has_aec_extended_filter_enabled()); + + // We only create APM once, since changes on these fields should not + // happen in current implementation. + if (!apm_.get()) { + apm_ = AudioProcessingBuilderForTesting().Create(); + } +} + +void DebugDumpReplayer::ConfigureApm(const audioproc::Config& msg) { + AudioProcessing::Config apm_config; + + // AEC2/AECM configs. + RTC_CHECK(msg.has_aec_enabled()); + RTC_CHECK(msg.has_aecm_enabled()); + apm_config.echo_canceller.enabled = msg.aec_enabled() || msg.aecm_enabled(); + apm_config.echo_canceller.mobile_mode = msg.aecm_enabled(); + + // HPF configs. + RTC_CHECK(msg.has_hpf_enabled()); + apm_config.high_pass_filter.enabled = msg.hpf_enabled(); + + // Preamp configs. + RTC_CHECK(msg.has_pre_amplifier_enabled()); + apm_config.pre_amplifier.enabled = msg.pre_amplifier_enabled(); + apm_config.pre_amplifier.fixed_gain_factor = + msg.pre_amplifier_fixed_gain_factor(); + + // NS configs. + RTC_CHECK(msg.has_ns_enabled()); + RTC_CHECK(msg.has_ns_level()); + apm_config.noise_suppression.enabled = msg.ns_enabled(); + apm_config.noise_suppression.level = + static_cast<AudioProcessing::Config::NoiseSuppression::Level>( + msg.ns_level()); + + // TS configs. + RTC_CHECK(msg.has_transient_suppression_enabled()); + apm_config.transient_suppression.enabled = + msg.transient_suppression_enabled(); + + // AGC configs. + RTC_CHECK(msg.has_agc_enabled()); + RTC_CHECK(msg.has_agc_mode()); + RTC_CHECK(msg.has_agc_limiter_enabled()); + apm_config.gain_controller1.enabled = msg.agc_enabled(); + apm_config.gain_controller1.mode = + static_cast<AudioProcessing::Config::GainController1::Mode>( + msg.agc_mode()); + apm_config.gain_controller1.enable_limiter = msg.agc_limiter_enabled(); + RTC_CHECK(msg.has_noise_robust_agc_enabled()); + apm_config.gain_controller1.analog_gain_controller.enabled = + msg.noise_robust_agc_enabled(); + + apm_->ApplyConfig(apm_config); +} + +void DebugDumpReplayer::LoadNextMessage() { + has_next_event_ = + debug_file_ && ReadMessageFromFile(debug_file_, &next_event_); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/debug_dump_replayer.h b/third_party/libwebrtc/modules/audio_processing/test/debug_dump_replayer.h new file mode 100644 index 0000000000..be21c68663 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/debug_dump_replayer.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_DEBUG_DUMP_REPLAYER_H_ +#define MODULES_AUDIO_PROCESSING_TEST_DEBUG_DUMP_REPLAYER_H_ + +#include <memory> + +#include "absl/strings/string_view.h" +#include "common_audio/channel_buffer.h" +#include "modules/audio_processing/include/audio_processing.h" +#include "rtc_base/ignore_wundef.h" + +RTC_PUSH_IGNORING_WUNDEF() +#include "modules/audio_processing/debug.pb.h" +RTC_POP_IGNORING_WUNDEF() + +namespace webrtc { +namespace test { + +class DebugDumpReplayer { + public: + DebugDumpReplayer(); + ~DebugDumpReplayer(); + + // Set dump file + bool SetDumpFile(absl::string_view filename); + + // Return next event. + absl::optional<audioproc::Event> GetNextEvent() const; + + // Run the next event. Returns true if succeeded. + bool RunNextEvent(); + + const ChannelBuffer<float>* GetOutput() const; + StreamConfig GetOutputConfig() const; + + private: + // Following functions are facilities for replaying debug dumps. + void OnInitEvent(const audioproc::Init& msg); + void OnStreamEvent(const audioproc::Stream& msg); + void OnReverseStreamEvent(const audioproc::ReverseStream& msg); + void OnConfigEvent(const audioproc::Config& msg); + void OnRuntimeSettingEvent(const audioproc::RuntimeSetting& msg); + + void MaybeRecreateApm(const audioproc::Config& msg); + void ConfigureApm(const audioproc::Config& msg); + + void LoadNextMessage(); + + // Buffer for APM input/output. + std::unique_ptr<ChannelBuffer<float>> input_; + std::unique_ptr<ChannelBuffer<float>> reverse_; + std::unique_ptr<ChannelBuffer<float>> output_; + + rtc::scoped_refptr<AudioProcessing> apm_; + + FILE* debug_file_; + + StreamConfig input_config_; + StreamConfig reverse_config_; + StreamConfig output_config_; + + bool has_next_event_; + audioproc::Event next_event_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_DEBUG_DUMP_REPLAYER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/debug_dump_test.cc b/third_party/libwebrtc/modules/audio_processing/test/debug_dump_test.cc new file mode 100644 index 0000000000..d69d3a4eea --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/debug_dump_test.cc @@ -0,0 +1,535 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <stddef.h> // size_t + +#include <memory> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/audio/echo_canceller3_factory.h" +#include "modules/audio_coding/neteq/tools/resample_input_audio_file.h" +#include "modules/audio_processing/aec_dump/aec_dump_factory.h" +#include "modules/audio_processing/test/audio_processing_builder_for_testing.h" +#include "modules/audio_processing/test/debug_dump_replayer.h" +#include "modules/audio_processing/test/test_utils.h" +#include "rtc_base/task_queue_for_test.h" +#include "test/gtest.h" +#include "test/testsupport/file_utils.h" + +namespace webrtc { +namespace test { + +namespace { + +void MaybeResetBuffer(std::unique_ptr<ChannelBuffer<float>>* buffer, + const StreamConfig& config) { + auto& buffer_ref = *buffer; + if (!buffer_ref.get() || buffer_ref->num_frames() != config.num_frames() || + buffer_ref->num_channels() != config.num_channels()) { + buffer_ref.reset( + new ChannelBuffer<float>(config.num_frames(), config.num_channels())); + } +} + +class DebugDumpGenerator { + public: + DebugDumpGenerator(absl::string_view input_file_name, + int input_rate_hz, + int input_channels, + absl::string_view reverse_file_name, + int reverse_rate_hz, + int reverse_channels, + absl::string_view dump_file_name, + bool enable_pre_amplifier); + + // Constructor that uses default input files. + explicit DebugDumpGenerator(const AudioProcessing::Config& apm_config); + + ~DebugDumpGenerator(); + + // Changes the sample rate of the input audio to the APM. + void SetInputRate(int rate_hz); + + // Sets if converts stereo input signal to mono by discarding other channels. + void ForceInputMono(bool mono); + + // Changes the sample rate of the reverse audio to the APM. + void SetReverseRate(int rate_hz); + + // Sets if converts stereo reverse signal to mono by discarding other + // channels. + void ForceReverseMono(bool mono); + + // Sets the required sample rate of the APM output. + void SetOutputRate(int rate_hz); + + // Sets the required channels of the APM output. + void SetOutputChannels(int channels); + + std::string dump_file_name() const { return dump_file_name_; } + + void StartRecording(); + void Process(size_t num_blocks); + void StopRecording(); + AudioProcessing* apm() const { return apm_.get(); } + + private: + static void ReadAndDeinterleave(ResampleInputAudioFile* audio, + int channels, + const StreamConfig& config, + float* const* buffer); + + // APM input/output settings. + StreamConfig input_config_; + StreamConfig reverse_config_; + StreamConfig output_config_; + + // Input file format. + const std::string input_file_name_; + ResampleInputAudioFile input_audio_; + const int input_file_channels_; + + // Reverse file format. + const std::string reverse_file_name_; + ResampleInputAudioFile reverse_audio_; + const int reverse_file_channels_; + + // Buffer for APM input/output. + std::unique_ptr<ChannelBuffer<float>> input_; + std::unique_ptr<ChannelBuffer<float>> reverse_; + std::unique_ptr<ChannelBuffer<float>> output_; + + bool enable_pre_amplifier_; + + TaskQueueForTest worker_queue_; + rtc::scoped_refptr<AudioProcessing> apm_; + + const std::string dump_file_name_; +}; + +DebugDumpGenerator::DebugDumpGenerator(absl::string_view input_file_name, + int input_rate_hz, + int input_channels, + absl::string_view reverse_file_name, + int reverse_rate_hz, + int reverse_channels, + absl::string_view dump_file_name, + bool enable_pre_amplifier) + : input_config_(input_rate_hz, input_channels), + reverse_config_(reverse_rate_hz, reverse_channels), + output_config_(input_rate_hz, input_channels), + input_audio_(input_file_name, input_rate_hz, input_rate_hz), + input_file_channels_(input_channels), + reverse_audio_(reverse_file_name, reverse_rate_hz, reverse_rate_hz), + reverse_file_channels_(reverse_channels), + input_(new ChannelBuffer<float>(input_config_.num_frames(), + input_config_.num_channels())), + reverse_(new ChannelBuffer<float>(reverse_config_.num_frames(), + reverse_config_.num_channels())), + output_(new ChannelBuffer<float>(output_config_.num_frames(), + output_config_.num_channels())), + enable_pre_amplifier_(enable_pre_amplifier), + worker_queue_("debug_dump_generator_worker_queue"), + dump_file_name_(dump_file_name) { + AudioProcessingBuilderForTesting apm_builder; + apm_ = apm_builder.Create(); +} + +DebugDumpGenerator::DebugDumpGenerator( + const AudioProcessing::Config& apm_config) + : DebugDumpGenerator(ResourcePath("near32_stereo", "pcm"), + 32000, + 2, + ResourcePath("far32_stereo", "pcm"), + 32000, + 2, + TempFilename(OutputPath(), "debug_aec"), + apm_config.pre_amplifier.enabled) { + apm_->ApplyConfig(apm_config); +} + +DebugDumpGenerator::~DebugDumpGenerator() { + remove(dump_file_name_.c_str()); +} + +void DebugDumpGenerator::SetInputRate(int rate_hz) { + input_audio_.set_output_rate_hz(rate_hz); + input_config_.set_sample_rate_hz(rate_hz); + MaybeResetBuffer(&input_, input_config_); +} + +void DebugDumpGenerator::ForceInputMono(bool mono) { + const int channels = mono ? 1 : input_file_channels_; + input_config_.set_num_channels(channels); + MaybeResetBuffer(&input_, input_config_); +} + +void DebugDumpGenerator::SetReverseRate(int rate_hz) { + reverse_audio_.set_output_rate_hz(rate_hz); + reverse_config_.set_sample_rate_hz(rate_hz); + MaybeResetBuffer(&reverse_, reverse_config_); +} + +void DebugDumpGenerator::ForceReverseMono(bool mono) { + const int channels = mono ? 1 : reverse_file_channels_; + reverse_config_.set_num_channels(channels); + MaybeResetBuffer(&reverse_, reverse_config_); +} + +void DebugDumpGenerator::SetOutputRate(int rate_hz) { + output_config_.set_sample_rate_hz(rate_hz); + MaybeResetBuffer(&output_, output_config_); +} + +void DebugDumpGenerator::SetOutputChannels(int channels) { + output_config_.set_num_channels(channels); + MaybeResetBuffer(&output_, output_config_); +} + +void DebugDumpGenerator::StartRecording() { + apm_->AttachAecDump( + AecDumpFactory::Create(dump_file_name_.c_str(), -1, &worker_queue_)); +} + +void DebugDumpGenerator::Process(size_t num_blocks) { + for (size_t i = 0; i < num_blocks; ++i) { + ReadAndDeinterleave(&reverse_audio_, reverse_file_channels_, + reverse_config_, reverse_->channels()); + ReadAndDeinterleave(&input_audio_, input_file_channels_, input_config_, + input_->channels()); + RTC_CHECK_EQ(AudioProcessing::kNoError, apm_->set_stream_delay_ms(100)); + apm_->set_stream_analog_level(100); + if (enable_pre_amplifier_) { + apm_->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain(1 + i % 10)); + } + apm_->set_stream_key_pressed(i % 10 == 9); + RTC_CHECK_EQ(AudioProcessing::kNoError, + apm_->ProcessStream(input_->channels(), input_config_, + output_config_, output_->channels())); + + RTC_CHECK_EQ( + AudioProcessing::kNoError, + apm_->ProcessReverseStream(reverse_->channels(), reverse_config_, + reverse_config_, reverse_->channels())); + } +} + +void DebugDumpGenerator::StopRecording() { + apm_->DetachAecDump(); +} + +void DebugDumpGenerator::ReadAndDeinterleave(ResampleInputAudioFile* audio, + int channels, + const StreamConfig& config, + float* const* buffer) { + const size_t num_frames = config.num_frames(); + const int out_channels = config.num_channels(); + + std::vector<int16_t> signal(channels * num_frames); + + audio->Read(num_frames * channels, &signal[0]); + + // We only allow reducing number of channels by discarding some channels. + RTC_CHECK_LE(out_channels, channels); + for (int channel = 0; channel < out_channels; ++channel) { + for (size_t i = 0; i < num_frames; ++i) { + buffer[channel][i] = S16ToFloat(signal[i * channels + channel]); + } + } +} + +} // namespace + +class DebugDumpTest : public ::testing::Test { + public: + // VerifyDebugDump replays a debug dump using APM and verifies that the result + // is bit-exact-identical to the output channel in the dump. This is only + // guaranteed if the debug dump is started on the first frame. + void VerifyDebugDump(absl::string_view in_filename); + + private: + DebugDumpReplayer debug_dump_replayer_; +}; + +void DebugDumpTest::VerifyDebugDump(absl::string_view in_filename) { + ASSERT_TRUE(debug_dump_replayer_.SetDumpFile(in_filename)); + + while (const absl::optional<audioproc::Event> event = + debug_dump_replayer_.GetNextEvent()) { + debug_dump_replayer_.RunNextEvent(); + if (event->type() == audioproc::Event::STREAM) { + const audioproc::Stream* msg = &event->stream(); + const StreamConfig output_config = debug_dump_replayer_.GetOutputConfig(); + const ChannelBuffer<float>* output = debug_dump_replayer_.GetOutput(); + // Check that output of APM is bit-exact to the output in the dump. + ASSERT_EQ(output_config.num_channels(), + static_cast<size_t>(msg->output_channel_size())); + ASSERT_EQ(output_config.num_frames() * sizeof(float), + msg->output_channel(0).size()); + for (int i = 0; i < msg->output_channel_size(); ++i) { + ASSERT_EQ(0, + memcmp(output->channels()[i], msg->output_channel(i).data(), + msg->output_channel(i).size())); + } + } + } +} + +TEST_F(DebugDumpTest, SimpleCase) { + DebugDumpGenerator generator(/*apm_config=*/{}); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, ChangeInputFormat) { + DebugDumpGenerator generator(/*apm_config=*/{}); + + generator.StartRecording(); + generator.Process(100); + generator.SetInputRate(48000); + + generator.ForceInputMono(true); + // Number of output channel should not be larger than that of input. APM will + // fail otherwise. + generator.SetOutputChannels(1); + + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, ChangeReverseFormat) { + DebugDumpGenerator generator(/*apm_config=*/{}); + generator.StartRecording(); + generator.Process(100); + generator.SetReverseRate(48000); + generator.ForceReverseMono(true); + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, ChangeOutputFormat) { + DebugDumpGenerator generator(/*apm_config=*/{}); + generator.StartRecording(); + generator.Process(100); + generator.SetOutputRate(48000); + generator.SetOutputChannels(1); + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, ToggleAec) { + AudioProcessing::Config apm_config; + apm_config.echo_canceller.enabled = true; + DebugDumpGenerator generator(apm_config); + generator.StartRecording(); + generator.Process(100); + + apm_config.echo_canceller.enabled = false; + generator.apm()->ApplyConfig(apm_config); + + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, VerifyCombinedExperimentalStringInclusive) { + AudioProcessing::Config apm_config; + apm_config.echo_canceller.enabled = true; + apm_config.gain_controller1.analog_gain_controller.enabled = true; + apm_config.gain_controller1.analog_gain_controller.startup_min_volume = 0; + // Arbitrarily set clipping gain to 17, which will never be the default. + apm_config.gain_controller1.analog_gain_controller.clipped_level_min = 17; + DebugDumpGenerator generator(apm_config); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + + DebugDumpReplayer debug_dump_replayer_; + + ASSERT_TRUE(debug_dump_replayer_.SetDumpFile(generator.dump_file_name())); + + while (const absl::optional<audioproc::Event> event = + debug_dump_replayer_.GetNextEvent()) { + debug_dump_replayer_.RunNextEvent(); + if (event->type() == audioproc::Event::CONFIG) { + const audioproc::Config* msg = &event->config(); + ASSERT_TRUE(msg->has_experiments_description()); + EXPECT_PRED_FORMAT2(::testing::IsSubstring, "EchoController", + msg->experiments_description().c_str()); + EXPECT_PRED_FORMAT2(::testing::IsSubstring, "AgcClippingLevelExperiment", + msg->experiments_description().c_str()); + } + } +} + +TEST_F(DebugDumpTest, VerifyCombinedExperimentalStringExclusive) { + AudioProcessing::Config apm_config; + apm_config.echo_canceller.enabled = true; + DebugDumpGenerator generator(apm_config); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + + DebugDumpReplayer debug_dump_replayer_; + + ASSERT_TRUE(debug_dump_replayer_.SetDumpFile(generator.dump_file_name())); + + while (const absl::optional<audioproc::Event> event = + debug_dump_replayer_.GetNextEvent()) { + debug_dump_replayer_.RunNextEvent(); + if (event->type() == audioproc::Event::CONFIG) { + const audioproc::Config* msg = &event->config(); + ASSERT_TRUE(msg->has_experiments_description()); + EXPECT_PRED_FORMAT2(::testing::IsNotSubstring, + "AgcClippingLevelExperiment", + msg->experiments_description().c_str()); + } + } +} + +TEST_F(DebugDumpTest, VerifyAec3ExperimentalString) { + AudioProcessing::Config apm_config; + apm_config.echo_canceller.enabled = true; + DebugDumpGenerator generator(apm_config); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + + DebugDumpReplayer debug_dump_replayer_; + + ASSERT_TRUE(debug_dump_replayer_.SetDumpFile(generator.dump_file_name())); + + while (const absl::optional<audioproc::Event> event = + debug_dump_replayer_.GetNextEvent()) { + debug_dump_replayer_.RunNextEvent(); + if (event->type() == audioproc::Event::CONFIG) { + const audioproc::Config* msg = &event->config(); + ASSERT_TRUE(msg->has_experiments_description()); + EXPECT_PRED_FORMAT2(::testing::IsSubstring, "EchoController", + msg->experiments_description().c_str()); + } + } +} + +TEST_F(DebugDumpTest, VerifyAgcClippingLevelExperimentalString) { + AudioProcessing::Config apm_config; + apm_config.gain_controller1.analog_gain_controller.enabled = true; + apm_config.gain_controller1.analog_gain_controller.startup_min_volume = 0; + // Arbitrarily set clipping gain to 17, which will never be the default. + apm_config.gain_controller1.analog_gain_controller.clipped_level_min = 17; + DebugDumpGenerator generator(apm_config); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + + DebugDumpReplayer debug_dump_replayer_; + + ASSERT_TRUE(debug_dump_replayer_.SetDumpFile(generator.dump_file_name())); + + while (const absl::optional<audioproc::Event> event = + debug_dump_replayer_.GetNextEvent()) { + debug_dump_replayer_.RunNextEvent(); + if (event->type() == audioproc::Event::CONFIG) { + const audioproc::Config* msg = &event->config(); + ASSERT_TRUE(msg->has_experiments_description()); + EXPECT_PRED_FORMAT2(::testing::IsSubstring, "AgcClippingLevelExperiment", + msg->experiments_description().c_str()); + } + } +} + +TEST_F(DebugDumpTest, VerifyEmptyExperimentalString) { + DebugDumpGenerator generator(/*apm_config=*/{}); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + + DebugDumpReplayer debug_dump_replayer_; + + ASSERT_TRUE(debug_dump_replayer_.SetDumpFile(generator.dump_file_name())); + + while (const absl::optional<audioproc::Event> event = + debug_dump_replayer_.GetNextEvent()) { + debug_dump_replayer_.RunNextEvent(); + if (event->type() == audioproc::Event::CONFIG) { + const audioproc::Config* msg = &event->config(); + ASSERT_TRUE(msg->has_experiments_description()); + EXPECT_EQ(0u, msg->experiments_description().size()); + } + } +} + +// AGC is not supported on Android or iOS. +#if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS) +#define MAYBE_ToggleAgc DISABLED_ToggleAgc +#else +#define MAYBE_ToggleAgc ToggleAgc +#endif +TEST_F(DebugDumpTest, MAYBE_ToggleAgc) { + DebugDumpGenerator generator(/*apm_config=*/{}); + generator.StartRecording(); + generator.Process(100); + + AudioProcessing::Config apm_config = generator.apm()->GetConfig(); + apm_config.gain_controller1.enabled = !apm_config.gain_controller1.enabled; + generator.apm()->ApplyConfig(apm_config); + + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, ToggleNs) { + DebugDumpGenerator generator(/*apm_config=*/{}); + generator.StartRecording(); + generator.Process(100); + + AudioProcessing::Config apm_config = generator.apm()->GetConfig(); + apm_config.noise_suppression.enabled = !apm_config.noise_suppression.enabled; + generator.apm()->ApplyConfig(apm_config); + + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, TransientSuppressionOn) { + DebugDumpGenerator generator(/*apm_config=*/{}); + + AudioProcessing::Config apm_config = generator.apm()->GetConfig(); + apm_config.transient_suppression.enabled = true; + generator.apm()->ApplyConfig(apm_config); + + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +TEST_F(DebugDumpTest, PreAmplifierIsOn) { + AudioProcessing::Config apm_config; + apm_config.pre_amplifier.enabled = true; + DebugDumpGenerator generator(apm_config); + generator.StartRecording(); + generator.Process(100); + generator.StopRecording(); + VerifyDebugDump(generator.dump_file_name()); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools.cc b/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools.cc new file mode 100644 index 0000000000..1d36b954f9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/echo_canceller_test_tools.h" + +#include "rtc_base/checks.h" + +namespace webrtc { + +void RandomizeSampleVector(Random* random_generator, rtc::ArrayView<float> v) { + RandomizeSampleVector(random_generator, v, + /*amplitude=*/32767.f); +} + +void RandomizeSampleVector(Random* random_generator, + rtc::ArrayView<float> v, + float amplitude) { + for (auto& v_k : v) { + v_k = 2 * amplitude * random_generator->Rand<float>() - amplitude; + } +} + +template <typename T> +void DelayBuffer<T>::Delay(rtc::ArrayView<const T> x, + rtc::ArrayView<T> x_delayed) { + RTC_DCHECK_EQ(x.size(), x_delayed.size()); + if (buffer_.empty()) { + std::copy(x.begin(), x.end(), x_delayed.begin()); + } else { + for (size_t k = 0; k < x.size(); ++k) { + x_delayed[k] = buffer_[next_insert_index_]; + buffer_[next_insert_index_] = x[k]; + next_insert_index_ = (next_insert_index_ + 1) % buffer_.size(); + } + } +} + +template class DelayBuffer<float>; +template class DelayBuffer<int>; +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools.h b/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools.h new file mode 100644 index 0000000000..0d70cd39c6 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_ECHO_CANCELLER_TEST_TOOLS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_ECHO_CANCELLER_TEST_TOOLS_H_ + +#include <algorithm> +#include <vector> + +#include "api/array_view.h" +#include "rtc_base/random.h" + +namespace webrtc { + +// Randomizes the elements in a vector with values -32767.f:32767.f. +void RandomizeSampleVector(Random* random_generator, rtc::ArrayView<float> v); + +// Randomizes the elements in a vector with values -amplitude:amplitude. +void RandomizeSampleVector(Random* random_generator, + rtc::ArrayView<float> v, + float amplitude); + +// Class for delaying a signal a fixed number of samples. +template <typename T> +class DelayBuffer { + public: + explicit DelayBuffer(size_t delay) : buffer_(delay) {} + ~DelayBuffer() = default; + + // Produces a delayed signal copy of x. + void Delay(rtc::ArrayView<const T> x, rtc::ArrayView<T> x_delayed); + + private: + std::vector<T> buffer_; + size_t next_insert_index_ = 0; +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_ECHO_CANCELLER_TEST_TOOLS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools_unittest.cc b/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools_unittest.cc new file mode 100644 index 0000000000..164d28fa16 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/echo_canceller_test_tools_unittest.cc @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/echo_canceller_test_tools.h" + +#include <vector> + +#include "api/array_view.h" +#include "rtc_base/checks.h" +#include "rtc_base/random.h" +#include "test/gtest.h" + +namespace webrtc { + +TEST(EchoCancellerTestTools, FloatDelayBuffer) { + constexpr size_t kDelay = 10; + DelayBuffer<float> delay_buffer(kDelay); + std::vector<float> v(1000, 0.f); + for (size_t k = 0; k < v.size(); ++k) { + v[k] = k; + } + std::vector<float> v_delayed = v; + constexpr size_t kBlockSize = 50; + for (size_t k = 0; k < rtc::CheckedDivExact(v.size(), kBlockSize); ++k) { + delay_buffer.Delay( + rtc::ArrayView<const float>(&v[k * kBlockSize], kBlockSize), + rtc::ArrayView<float>(&v_delayed[k * kBlockSize], kBlockSize)); + } + for (size_t k = kDelay; k < v.size(); ++k) { + EXPECT_EQ(v[k - kDelay], v_delayed[k]); + } +} + +TEST(EchoCancellerTestTools, IntDelayBuffer) { + constexpr size_t kDelay = 10; + DelayBuffer<int> delay_buffer(kDelay); + std::vector<int> v(1000, 0); + for (size_t k = 0; k < v.size(); ++k) { + v[k] = k; + } + std::vector<int> v_delayed = v; + const size_t kBlockSize = 50; + for (size_t k = 0; k < rtc::CheckedDivExact(v.size(), kBlockSize); ++k) { + delay_buffer.Delay( + rtc::ArrayView<const int>(&v[k * kBlockSize], kBlockSize), + rtc::ArrayView<int>(&v_delayed[k * kBlockSize], kBlockSize)); + } + for (size_t k = kDelay; k < v.size(); ++k) { + EXPECT_EQ(v[k - kDelay], v_delayed[k]); + } +} + +TEST(EchoCancellerTestTools, RandomizeSampleVector) { + Random random_generator(42U); + std::vector<float> v(50, 0.f); + std::vector<float> v_ref = v; + RandomizeSampleVector(&random_generator, v); + EXPECT_NE(v, v_ref); + v_ref = v; + RandomizeSampleVector(&random_generator, v); + EXPECT_NE(v, v_ref); +} + +TEST(EchoCancellerTestTools, RandomizeSampleVectorWithAmplitude) { + Random random_generator(42U); + std::vector<float> v(50, 0.f); + RandomizeSampleVector(&random_generator, v, 1000.f); + EXPECT_GE(1000.f, *std::max_element(v.begin(), v.end())); + EXPECT_LE(-1000.f, *std::min_element(v.begin(), v.end())); + RandomizeSampleVector(&random_generator, v, 100.f); + EXPECT_GE(100.f, *std::max_element(v.begin(), v.end())); + EXPECT_LE(-100.f, *std::min_element(v.begin(), v.end())); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/echo_control_mock.h b/third_party/libwebrtc/modules/audio_processing/test/echo_control_mock.h new file mode 100644 index 0000000000..763d6e4f0b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/echo_control_mock.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_ECHO_CONTROL_MOCK_H_ +#define MODULES_AUDIO_PROCESSING_TEST_ECHO_CONTROL_MOCK_H_ + +#include "api/audio/echo_control.h" +#include "test/gmock.h" + +namespace webrtc { + +class AudioBuffer; + +class MockEchoControl : public EchoControl { + public: + MOCK_METHOD(void, AnalyzeRender, (AudioBuffer * render), (override)); + MOCK_METHOD(void, AnalyzeCapture, (AudioBuffer * capture), (override)); + MOCK_METHOD(void, + ProcessCapture, + (AudioBuffer * capture, bool echo_path_change), + (override)); + MOCK_METHOD(void, + ProcessCapture, + (AudioBuffer * capture, + AudioBuffer* linear_output, + bool echo_path_change), + (override)); + MOCK_METHOD(EchoControl::Metrics, GetMetrics, (), (const, override)); + MOCK_METHOD(void, SetAudioBufferDelay, (int delay_ms), (override)); + MOCK_METHOD(void, + SetCaptureOutputUsage, + (bool capture_output_used), + (override)); + MOCK_METHOD(bool, ActiveProcessing, (), (const, override)); +}; + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_ECHO_CONTROL_MOCK_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device.cc b/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device.cc new file mode 100644 index 0000000000..3a35ee9d74 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device.cc @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/fake_recording_device.h" + +#include <algorithm> +#include <memory> + +#include "absl/types/optional.h" +#include "modules/audio_processing/agc/gain_map_internal.h" +#include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/numerics/safe_minmax.h" + +namespace webrtc { +namespace test { + +namespace { + +constexpr float kFloatSampleMin = -32768.f; +constexpr float kFloatSampleMax = 32767.0f; + +} // namespace + +// Abstract class for the different fake recording devices. +class FakeRecordingDeviceWorker { + public: + explicit FakeRecordingDeviceWorker(const int initial_mic_level) + : mic_level_(initial_mic_level) {} + int mic_level() const { return mic_level_; } + void set_mic_level(const int level) { mic_level_ = level; } + void set_undo_mic_level(const int level) { undo_mic_level_ = level; } + virtual ~FakeRecordingDeviceWorker() = default; + virtual void ModifyBufferInt16(rtc::ArrayView<int16_t> buffer) = 0; + virtual void ModifyBufferFloat(ChannelBuffer<float>* buffer) = 0; + + protected: + // Mic level to simulate. + int mic_level_; + // Optional mic level to undo. + absl::optional<int> undo_mic_level_; +}; + +namespace { + +// Identity fake recording device. The samples are not modified, which is +// equivalent to a constant gain curve at 1.0 - only used for testing. +class FakeRecordingDeviceIdentity final : public FakeRecordingDeviceWorker { + public: + explicit FakeRecordingDeviceIdentity(const int initial_mic_level) + : FakeRecordingDeviceWorker(initial_mic_level) {} + ~FakeRecordingDeviceIdentity() override = default; + void ModifyBufferInt16(rtc::ArrayView<int16_t> buffer) override {} + void ModifyBufferFloat(ChannelBuffer<float>* buffer) override {} +}; + +// Linear fake recording device. The gain curve is a linear function mapping the +// mic levels range [0, 255] to [0.0, 1.0]. +class FakeRecordingDeviceLinear final : public FakeRecordingDeviceWorker { + public: + explicit FakeRecordingDeviceLinear(const int initial_mic_level) + : FakeRecordingDeviceWorker(initial_mic_level) {} + ~FakeRecordingDeviceLinear() override = default; + void ModifyBufferInt16(rtc::ArrayView<int16_t> buffer) override { + const size_t number_of_samples = buffer.size(); + int16_t* data = buffer.data(); + // If an undo level is specified, virtually restore the unmodified + // microphone level; otherwise simulate the mic gain only. + const float divisor = + (undo_mic_level_ && *undo_mic_level_ > 0) ? *undo_mic_level_ : 255.f; + for (size_t i = 0; i < number_of_samples; ++i) { + data[i] = rtc::saturated_cast<int16_t>(data[i] * mic_level_ / divisor); + } + } + void ModifyBufferFloat(ChannelBuffer<float>* buffer) override { + // If an undo level is specified, virtually restore the unmodified + // microphone level; otherwise simulate the mic gain only. + const float divisor = + (undo_mic_level_ && *undo_mic_level_ > 0) ? *undo_mic_level_ : 255.f; + for (size_t c = 0; c < buffer->num_channels(); ++c) { + for (size_t i = 0; i < buffer->num_frames(); ++i) { + buffer->channels()[c][i] = + rtc::SafeClamp(buffer->channels()[c][i] * mic_level_ / divisor, + kFloatSampleMin, kFloatSampleMax); + } + } + } +}; + +float ComputeAgc1LinearFactor(const absl::optional<int>& undo_mic_level, + int mic_level) { + // If an undo level is specified, virtually restore the unmodified + // microphone level; otherwise simulate the mic gain only. + const int undo_level = + (undo_mic_level && *undo_mic_level > 0) ? *undo_mic_level : 100; + return DbToRatio(kGainMap[mic_level] - kGainMap[undo_level]); +} + +// Roughly dB-scale fake recording device. Valid levels are [0, 255]. The mic +// applies a gain from kGainMap in agc/gain_map_internal.h. +class FakeRecordingDeviceAgc1 final : public FakeRecordingDeviceWorker { + public: + explicit FakeRecordingDeviceAgc1(const int initial_mic_level) + : FakeRecordingDeviceWorker(initial_mic_level) {} + ~FakeRecordingDeviceAgc1() override = default; + void ModifyBufferInt16(rtc::ArrayView<int16_t> buffer) override { + const float scaling_factor = + ComputeAgc1LinearFactor(undo_mic_level_, mic_level_); + const size_t number_of_samples = buffer.size(); + int16_t* data = buffer.data(); + for (size_t i = 0; i < number_of_samples; ++i) { + data[i] = rtc::saturated_cast<int16_t>(data[i] * scaling_factor); + } + } + void ModifyBufferFloat(ChannelBuffer<float>* buffer) override { + const float scaling_factor = + ComputeAgc1LinearFactor(undo_mic_level_, mic_level_); + for (size_t c = 0; c < buffer->num_channels(); ++c) { + for (size_t i = 0; i < buffer->num_frames(); ++i) { + buffer->channels()[c][i] = + rtc::SafeClamp(buffer->channels()[c][i] * scaling_factor, + kFloatSampleMin, kFloatSampleMax); + } + } + } +}; + +} // namespace + +FakeRecordingDevice::FakeRecordingDevice(int initial_mic_level, + int device_kind) { + switch (device_kind) { + case 0: + worker_ = + std::make_unique<FakeRecordingDeviceIdentity>(initial_mic_level); + break; + case 1: + worker_ = std::make_unique<FakeRecordingDeviceLinear>(initial_mic_level); + break; + case 2: + worker_ = std::make_unique<FakeRecordingDeviceAgc1>(initial_mic_level); + break; + default: + RTC_DCHECK_NOTREACHED(); + break; + } +} + +FakeRecordingDevice::~FakeRecordingDevice() = default; + +int FakeRecordingDevice::MicLevel() const { + RTC_CHECK(worker_); + return worker_->mic_level(); +} + +void FakeRecordingDevice::SetMicLevel(const int level) { + RTC_CHECK(worker_); + if (level != worker_->mic_level()) + RTC_LOG(LS_INFO) << "Simulate mic level update: " << level; + worker_->set_mic_level(level); +} + +void FakeRecordingDevice::SetUndoMicLevel(const int level) { + RTC_DCHECK(worker_); + // TODO(alessiob): The behavior with undo level equal to zero is not clear yet + // and will be defined in future CLs once more FakeRecordingDeviceWorker + // implementations need to be added. + RTC_CHECK(level > 0) << "Zero undo mic level is unsupported"; + worker_->set_undo_mic_level(level); +} + +void FakeRecordingDevice::SimulateAnalogGain(rtc::ArrayView<int16_t> buffer) { + RTC_DCHECK(worker_); + worker_->ModifyBufferInt16(buffer); +} + +void FakeRecordingDevice::SimulateAnalogGain(ChannelBuffer<float>* buffer) { + RTC_DCHECK(worker_); + worker_->ModifyBufferFloat(buffer); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device.h b/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device.h new file mode 100644 index 0000000000..da3c0cf794 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_FAKE_RECORDING_DEVICE_H_ +#define MODULES_AUDIO_PROCESSING_TEST_FAKE_RECORDING_DEVICE_H_ + +#include <algorithm> +#include <memory> +#include <vector> + +#include "api/array_view.h" +#include "common_audio/channel_buffer.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +class FakeRecordingDeviceWorker; + +// Class for simulating a microphone with analog gain. +// +// The intended modes of operation are the following: +// +// FakeRecordingDevice fake_mic(255, 1); +// +// fake_mic.SetMicLevel(170); +// fake_mic.SimulateAnalogGain(buffer); +// +// When the mic level to undo is known: +// +// fake_mic.SetMicLevel(170); +// fake_mic.SetUndoMicLevel(30); +// fake_mic.SimulateAnalogGain(buffer); +// +// The second option virtually restores the unmodified microphone level. Calling +// SimulateAnalogGain() will first "undo" the gain applied by the real +// microphone (e.g., 30). +class FakeRecordingDevice final { + public: + FakeRecordingDevice(int initial_mic_level, int device_kind); + ~FakeRecordingDevice(); + + int MicLevel() const; + void SetMicLevel(int level); + void SetUndoMicLevel(int level); + + // Simulates the analog gain. + // If `real_device_level` is a valid level, the unmodified mic signal is + // virtually restored. To skip the latter step set `real_device_level` to + // an empty value. + void SimulateAnalogGain(rtc::ArrayView<int16_t> buffer); + + // Simulates the analog gain. + // If `real_device_level` is a valid level, the unmodified mic signal is + // virtually restored. To skip the latter step set `real_device_level` to + // an empty value. + void SimulateAnalogGain(ChannelBuffer<float>* buffer); + + private: + // Fake recording device worker. + std::unique_ptr<FakeRecordingDeviceWorker> worker_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_FAKE_RECORDING_DEVICE_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device_unittest.cc b/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device_unittest.cc new file mode 100644 index 0000000000..2ac8b1dc48 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/fake_recording_device_unittest.cc @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/fake_recording_device.h" + +#include <cmath> +#include <memory> +#include <string> +#include <vector> + +#include "api/array_view.h" +#include "rtc_base/strings/string_builder.h" +#include "test/gtest.h" + +namespace webrtc { +namespace test { +namespace { + +constexpr int kInitialMicLevel = 100; + +// TODO(alessiob): Add new fake recording device kind values here as they are +// added in FakeRecordingDevice::FakeRecordingDevice. +const std::vector<int> kFakeRecDeviceKinds = {0, 1, 2}; + +const std::vector<std::vector<float>> kTestMultiChannelSamples{ + std::vector<float>{-10.f, -1.f, -0.1f, 0.f, 0.1f, 1.f, 10.f}}; + +// Writes samples into ChannelBuffer<float>. +void WritesDataIntoChannelBuffer(const std::vector<std::vector<float>>& data, + ChannelBuffer<float>* buff) { + EXPECT_EQ(data.size(), buff->num_channels()); + EXPECT_EQ(data[0].size(), buff->num_frames()); + for (size_t c = 0; c < buff->num_channels(); ++c) { + for (size_t f = 0; f < buff->num_frames(); ++f) { + buff->channels()[c][f] = data[c][f]; + } + } +} + +std::unique_ptr<ChannelBuffer<float>> CreateChannelBufferWithData( + const std::vector<std::vector<float>>& data) { + auto buff = + std::make_unique<ChannelBuffer<float>>(data[0].size(), data.size()); + WritesDataIntoChannelBuffer(data, buff.get()); + return buff; +} + +// Checks that the samples modified using monotonic level values are also +// monotonic. +void CheckIfMonotoneSamplesModules(const ChannelBuffer<float>* prev, + const ChannelBuffer<float>* curr) { + RTC_DCHECK_EQ(prev->num_channels(), curr->num_channels()); + RTC_DCHECK_EQ(prev->num_frames(), curr->num_frames()); + bool valid = true; + for (size_t i = 0; i < prev->num_channels(); ++i) { + for (size_t j = 0; j < prev->num_frames(); ++j) { + valid = std::fabs(prev->channels()[i][j]) <= + std::fabs(curr->channels()[i][j]); + if (!valid) { + break; + } + } + if (!valid) { + break; + } + } + EXPECT_TRUE(valid); +} + +// Checks that the samples in each pair have the same sign unless the sample in +// `dst` is zero (because of zero gain). +void CheckSameSign(const ChannelBuffer<float>* src, + const ChannelBuffer<float>* dst) { + RTC_DCHECK_EQ(src->num_channels(), dst->num_channels()); + RTC_DCHECK_EQ(src->num_frames(), dst->num_frames()); + const auto fsgn = [](float x) { return ((x < 0) ? -1 : (x > 0) ? 1 : 0); }; + bool valid = true; + for (size_t i = 0; i < src->num_channels(); ++i) { + for (size_t j = 0; j < src->num_frames(); ++j) { + valid = dst->channels()[i][j] == 0.0f || + fsgn(src->channels()[i][j]) == fsgn(dst->channels()[i][j]); + if (!valid) { + break; + } + } + if (!valid) { + break; + } + } + EXPECT_TRUE(valid); +} + +std::string FakeRecordingDeviceKindToString(int fake_rec_device_kind) { + rtc::StringBuilder ss; + ss << "fake recording device: " << fake_rec_device_kind; + return ss.Release(); +} + +std::string AnalogLevelToString(int level) { + rtc::StringBuilder ss; + ss << "analog level: " << level; + return ss.Release(); +} + +} // namespace + +TEST(FakeRecordingDevice, CheckHelperFunctions) { + constexpr size_t kC = 0; // Channel index. + constexpr size_t kS = 1; // Sample index. + + // Check read. + auto buff = CreateChannelBufferWithData(kTestMultiChannelSamples); + for (size_t c = 0; c < kTestMultiChannelSamples.size(); ++c) { + for (size_t s = 0; s < kTestMultiChannelSamples[0].size(); ++s) { + EXPECT_EQ(kTestMultiChannelSamples[c][s], buff->channels()[c][s]); + } + } + + // Check write. + buff->channels()[kC][kS] = -5.0f; + RTC_DCHECK_NE(buff->channels()[kC][kS], kTestMultiChannelSamples[kC][kS]); + + // Check reset. + WritesDataIntoChannelBuffer(kTestMultiChannelSamples, buff.get()); + EXPECT_EQ(buff->channels()[kC][kS], kTestMultiChannelSamples[kC][kS]); +} + +// Implicitly checks that changes to the mic and undo levels are visible to the +// FakeRecordingDeviceWorker implementation are injected in FakeRecordingDevice. +TEST(FakeRecordingDevice, TestWorkerAbstractClass) { + FakeRecordingDevice fake_recording_device(kInitialMicLevel, 1); + + auto buff1 = CreateChannelBufferWithData(kTestMultiChannelSamples); + fake_recording_device.SetMicLevel(100); + fake_recording_device.SimulateAnalogGain(buff1.get()); + + auto buff2 = CreateChannelBufferWithData(kTestMultiChannelSamples); + fake_recording_device.SetMicLevel(200); + fake_recording_device.SimulateAnalogGain(buff2.get()); + + for (size_t c = 0; c < kTestMultiChannelSamples.size(); ++c) { + for (size_t s = 0; s < kTestMultiChannelSamples[0].size(); ++s) { + EXPECT_LE(std::abs(buff1->channels()[c][s]), + std::abs(buff2->channels()[c][s])); + } + } + + auto buff3 = CreateChannelBufferWithData(kTestMultiChannelSamples); + fake_recording_device.SetMicLevel(200); + fake_recording_device.SetUndoMicLevel(100); + fake_recording_device.SimulateAnalogGain(buff3.get()); + + for (size_t c = 0; c < kTestMultiChannelSamples.size(); ++c) { + for (size_t s = 0; s < kTestMultiChannelSamples[0].size(); ++s) { + EXPECT_LE(std::abs(buff1->channels()[c][s]), + std::abs(buff3->channels()[c][s])); + EXPECT_LE(std::abs(buff2->channels()[c][s]), + std::abs(buff3->channels()[c][s])); + } + } +} + +TEST(FakeRecordingDevice, GainCurveShouldBeMonotone) { + // Create input-output buffers. + auto buff_prev = CreateChannelBufferWithData(kTestMultiChannelSamples); + auto buff_curr = CreateChannelBufferWithData(kTestMultiChannelSamples); + + // Test different mappings. + for (auto fake_rec_device_kind : kFakeRecDeviceKinds) { + SCOPED_TRACE(FakeRecordingDeviceKindToString(fake_rec_device_kind)); + FakeRecordingDevice fake_recording_device(kInitialMicLevel, + fake_rec_device_kind); + // TODO(alessiob): The test below is designed for state-less recording + // devices. If, for instance, a device has memory, the test might need + // to be redesigned (e.g., re-initialize fake recording device). + + // Apply lowest analog level. + WritesDataIntoChannelBuffer(kTestMultiChannelSamples, buff_prev.get()); + fake_recording_device.SetMicLevel(0); + fake_recording_device.SimulateAnalogGain(buff_prev.get()); + + // Increment analog level to check monotonicity. + for (int i = 1; i <= 255; ++i) { + SCOPED_TRACE(AnalogLevelToString(i)); + WritesDataIntoChannelBuffer(kTestMultiChannelSamples, buff_curr.get()); + fake_recording_device.SetMicLevel(i); + fake_recording_device.SimulateAnalogGain(buff_curr.get()); + CheckIfMonotoneSamplesModules(buff_prev.get(), buff_curr.get()); + + // Update prev. + buff_prev.swap(buff_curr); + } + } +} + +TEST(FakeRecordingDevice, GainCurveShouldNotChangeSign) { + // Create view on original samples. + std::unique_ptr<const ChannelBuffer<float>> buff_orig = + CreateChannelBufferWithData(kTestMultiChannelSamples); + + // Create output buffer. + auto buff = CreateChannelBufferWithData(kTestMultiChannelSamples); + + // Test different mappings. + for (auto fake_rec_device_kind : kFakeRecDeviceKinds) { + SCOPED_TRACE(FakeRecordingDeviceKindToString(fake_rec_device_kind)); + FakeRecordingDevice fake_recording_device(kInitialMicLevel, + fake_rec_device_kind); + + // TODO(alessiob): The test below is designed for state-less recording + // devices. If, for instance, a device has memory, the test might need + // to be redesigned (e.g., re-initialize fake recording device). + for (int i = 0; i <= 255; ++i) { + SCOPED_TRACE(AnalogLevelToString(i)); + WritesDataIntoChannelBuffer(kTestMultiChannelSamples, buff.get()); + fake_recording_device.SetMicLevel(i); + fake_recording_device.SimulateAnalogGain(buff.get()); + CheckSameSign(buff_orig.get(), buff.get()); + } + } +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/performance_timer.cc b/third_party/libwebrtc/modules/audio_processing/test/performance_timer.cc new file mode 100644 index 0000000000..1a82258903 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/performance_timer.cc @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/performance_timer.h" + +#include <math.h> + +#include <numeric> + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +PerformanceTimer::PerformanceTimer(int num_frames_to_process) + : clock_(webrtc::Clock::GetRealTimeClock()) { + timestamps_us_.reserve(num_frames_to_process); +} + +PerformanceTimer::~PerformanceTimer() = default; + +void PerformanceTimer::StartTimer() { + start_timestamp_us_ = clock_->TimeInMicroseconds(); +} + +void PerformanceTimer::StopTimer() { + RTC_DCHECK(start_timestamp_us_); + timestamps_us_.push_back(clock_->TimeInMicroseconds() - *start_timestamp_us_); +} + +double PerformanceTimer::GetDurationAverage() const { + return GetDurationAverage(0); +} + +double PerformanceTimer::GetDurationStandardDeviation() const { + return GetDurationStandardDeviation(0); +} + +double PerformanceTimer::GetDurationAverage( + size_t number_of_warmup_samples) const { + RTC_DCHECK_GT(timestamps_us_.size(), number_of_warmup_samples); + const size_t number_of_samples = + timestamps_us_.size() - number_of_warmup_samples; + return static_cast<double>( + std::accumulate(timestamps_us_.begin() + number_of_warmup_samples, + timestamps_us_.end(), static_cast<int64_t>(0))) / + number_of_samples; +} + +double PerformanceTimer::GetDurationStandardDeviation( + size_t number_of_warmup_samples) const { + RTC_DCHECK_GT(timestamps_us_.size(), number_of_warmup_samples); + const size_t number_of_samples = + timestamps_us_.size() - number_of_warmup_samples; + RTC_DCHECK_GT(number_of_samples, 0); + double average_duration = GetDurationAverage(number_of_warmup_samples); + + double variance = std::accumulate( + timestamps_us_.begin() + number_of_warmup_samples, timestamps_us_.end(), + 0.0, [average_duration](const double& a, const int64_t& b) { + return a + (b - average_duration) * (b - average_duration); + }); + + return sqrt(variance / number_of_samples); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/performance_timer.h b/third_party/libwebrtc/modules/audio_processing/test/performance_timer.h new file mode 100644 index 0000000000..5375ba74e8 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/performance_timer.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_PERFORMANCE_TIMER_H_ +#define MODULES_AUDIO_PROCESSING_TEST_PERFORMANCE_TIMER_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "system_wrappers/include/clock.h" + +namespace webrtc { +namespace test { + +class PerformanceTimer { + public: + explicit PerformanceTimer(int num_frames_to_process); + ~PerformanceTimer(); + + void StartTimer(); + void StopTimer(); + + double GetDurationAverage() const; + double GetDurationStandardDeviation() const; + + // These methods are the same as those above, but they ignore the first + // `number_of_warmup_samples` measurements. + double GetDurationAverage(size_t number_of_warmup_samples) const; + double GetDurationStandardDeviation(size_t number_of_warmup_samples) const; + + private: + webrtc::Clock* clock_; + absl::optional<int64_t> start_timestamp_us_; + std::vector<int64_t> timestamps_us_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_PERFORMANCE_TIMER_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/protobuf_utils.cc b/third_party/libwebrtc/modules/audio_processing/test/protobuf_utils.cc new file mode 100644 index 0000000000..75574961b0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/protobuf_utils.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/protobuf_utils.h" + +#include <memory> + +#include "rtc_base/system/arch.h" + +namespace { +// Allocates new memory in the memory owned by the unique_ptr to fit the raw +// message and returns the number of bytes read when having a string stream as +// input. +size_t ReadMessageBytesFromString(std::stringstream* input, + std::unique_ptr<uint8_t[]>* bytes) { + int32_t size = 0; + input->read(reinterpret_cast<char*>(&size), sizeof(int32_t)); + int32_t size_read = input->gcount(); + if (size_read != sizeof(int32_t)) + return 0; + if (size <= 0) + return 0; + + *bytes = std::make_unique<uint8_t[]>(size); + input->read(reinterpret_cast<char*>(bytes->get()), + size * sizeof((*bytes)[0])); + size_read = input->gcount(); + return size_read == size ? size : 0; +} +} // namespace + +namespace webrtc { + +size_t ReadMessageBytesFromFile(FILE* file, std::unique_ptr<uint8_t[]>* bytes) { +// The "wire format" for the size is little-endian. Assume we're running on +// a little-endian machine. +#ifndef WEBRTC_ARCH_LITTLE_ENDIAN +#error "Need to convert messsage from little-endian." +#endif + int32_t size = 0; + if (fread(&size, sizeof(size), 1, file) != 1) + return 0; + if (size <= 0) + return 0; + + *bytes = std::make_unique<uint8_t[]>(size); + return fread(bytes->get(), sizeof((*bytes)[0]), size, file); +} + +// Returns true on success, false on error or end-of-file. +bool ReadMessageFromFile(FILE* file, MessageLite* msg) { + std::unique_ptr<uint8_t[]> bytes; + size_t size = ReadMessageBytesFromFile(file, &bytes); + if (!size) + return false; + + msg->Clear(); + return msg->ParseFromArray(bytes.get(), size); +} + +// Returns true on success, false on error or end of string stream. +bool ReadMessageFromString(std::stringstream* input, MessageLite* msg) { + std::unique_ptr<uint8_t[]> bytes; + size_t size = ReadMessageBytesFromString(input, &bytes); + if (!size) + return false; + + msg->Clear(); + return msg->ParseFromArray(bytes.get(), size); +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/protobuf_utils.h b/third_party/libwebrtc/modules/audio_processing/test/protobuf_utils.h new file mode 100644 index 0000000000..b9c2e819f9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/protobuf_utils.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_PROTOBUF_UTILS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_PROTOBUF_UTILS_H_ + +#include <memory> +#include <sstream> // no-presubmit-check TODO(webrtc:8982) + +#include "rtc_base/ignore_wundef.h" +#include "rtc_base/protobuf_utils.h" + +RTC_PUSH_IGNORING_WUNDEF() +#include "modules/audio_processing/debug.pb.h" +RTC_POP_IGNORING_WUNDEF() + +namespace webrtc { + +// Allocates new memory in the unique_ptr to fit the raw message and returns the +// number of bytes read. +size_t ReadMessageBytesFromFile(FILE* file, std::unique_ptr<uint8_t[]>* bytes); + +// Returns true on success, false on error or end-of-file. +bool ReadMessageFromFile(FILE* file, MessageLite* msg); + +// Returns true on success, false on error or end of string stream. +bool ReadMessageFromString( + std::stringstream* input, // no-presubmit-check TODO(webrtc:8982) + MessageLite* msg); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_PROTOBUF_UTILS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn new file mode 100644 index 0000000000..e53a829623 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -0,0 +1,170 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../../webrtc.gni") + +if (!build_with_chromium) { + group("py_quality_assessment") { + testonly = true + deps = [ + ":scripts", + ":unit_tests", + ] + } + + copy("scripts") { + testonly = true + sources = [ + "README.md", + "apm_quality_assessment.py", + "apm_quality_assessment.sh", + "apm_quality_assessment_boxplot.py", + "apm_quality_assessment_export.py", + "apm_quality_assessment_gencfgs.py", + "apm_quality_assessment_optimize.py", + ] + outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ] + deps = [ + ":apm_configs", + ":lib", + ":output", + "../../../../resources/audio_processing/test/py_quality_assessment:probing_signals", + "../../../../rtc_tools:audioproc_f", + ] + } + + copy("apm_configs") { + testonly = true + sources = [ "apm_configs/default.json" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ + "$root_build_dir/py_quality_assessment/apm_configs/{{source_file_part}}", + ] + } # apm_configs + + copy("lib") { + testonly = true + sources = [ + "quality_assessment/__init__.py", + "quality_assessment/annotations.py", + "quality_assessment/audioproc_wrapper.py", + "quality_assessment/collect_data.py", + "quality_assessment/data_access.py", + "quality_assessment/echo_path_simulation.py", + "quality_assessment/echo_path_simulation_factory.py", + "quality_assessment/eval_scores.py", + "quality_assessment/eval_scores_factory.py", + "quality_assessment/evaluation.py", + "quality_assessment/exceptions.py", + "quality_assessment/export.py", + "quality_assessment/export_unittest.py", + "quality_assessment/external_vad.py", + "quality_assessment/input_mixer.py", + "quality_assessment/input_signal_creator.py", + "quality_assessment/results.css", + "quality_assessment/results.js", + "quality_assessment/signal_processing.py", + "quality_assessment/simulation.py", + "quality_assessment/test_data_generation.py", + "quality_assessment/test_data_generation_factory.py", + ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ] + deps = [ "../../../../resources/audio_processing/test/py_quality_assessment:noise_tracks" ] + } + + copy("output") { + testonly = true + sources = [ "output/README.md" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = + [ "$root_build_dir/py_quality_assessment/output/{{source_file_part}}" ] + } + + group("unit_tests") { + testonly = true + visibility = [ ":*" ] # Only targets in this file can depend on this. + deps = [ + ":apm_vad", + ":fake_polqa", + ":lib_unit_tests", + ":scripts_unit_tests", + ":vad", + ] + } + + rtc_executable("fake_polqa") { + testonly = true + sources = [ "quality_assessment/fake_polqa.cc" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + output_dir = "${root_out_dir}/py_quality_assessment/quality_assessment" + deps = [ + "../../../../rtc_base:checks", + "//third_party/abseil-cpp/absl/strings", + ] + } + + rtc_executable("vad") { + testonly = true + sources = [ "quality_assessment/vad.cc" ] + deps = [ + "../../../../common_audio", + "../../../../rtc_base:logging", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } + + rtc_executable("apm_vad") { + testonly = true + sources = [ "quality_assessment/apm_vad.cc" ] + deps = [ + "../..", + "../../../../common_audio", + "../../../../rtc_base:logging", + "../../vad", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } + + rtc_executable("sound_level") { + testonly = true + sources = [ "quality_assessment/sound_level.cc" ] + deps = [ + "../..", + "../../../../common_audio", + "../../../../rtc_base:logging", + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/flags:parse", + ] + } + + copy("lib_unit_tests") { + testonly = true + sources = [ + "quality_assessment/annotations_unittest.py", + "quality_assessment/echo_path_simulation_unittest.py", + "quality_assessment/eval_scores_unittest.py", + "quality_assessment/fake_external_vad.py", + "quality_assessment/input_mixer_unittest.py", + "quality_assessment/signal_processing_unittest.py", + "quality_assessment/simulation_unittest.py", + "quality_assessment/test_data_generation_unittest.py", + ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ] + } + + copy("scripts_unit_tests") { + testonly = true + sources = [ "apm_quality_assessment_unittest.py" ] + visibility = [ ":*" ] # Only targets in this file can depend on this. + outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ] + } +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/OWNERS b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/OWNERS new file mode 100644 index 0000000000..9f56bb830d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/OWNERS @@ -0,0 +1,5 @@ +aleloi@webrtc.org +alessiob@webrtc.org +henrik.lundin@webrtc.org +ivoc@webrtc.org +peah@webrtc.org diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/README.md b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/README.md new file mode 100644 index 0000000000..4156112df2 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/README.md @@ -0,0 +1,125 @@ +# APM Quality Assessment tool + +Python wrapper of APM simulators (e.g., `audioproc_f`) with which quality +assessment can be automatized. The tool allows to simulate different noise +conditions, input signals, APM configurations and it computes different scores. +Once the scores are computed, the results can be easily exported to an HTML page +which allows to listen to the APM input and output signals and also the +reference one used for evaluation. + +## Dependencies + - OS: Linux + - Python 2.7 + - Python libraries: enum34, numpy, scipy, pydub (0.17.0+), pandas (0.20.1+), + pyquery (1.2+), jsmin (2.2+), csscompressor (0.9.4) + - It is recommended that a dedicated Python environment is used + - install `virtualenv` + - `$ sudo apt-get install python-virtualenv` + - setup a new Python environment (e.g., `my_env`) + - `$ cd ~ && virtualenv my_env` + - activate the new Python environment + - `$ source ~/my_env/bin/activate` + - add dependcies via `pip` + - `(my_env)$ pip install enum34 numpy pydub scipy pandas pyquery jsmin \` + `csscompressor` + - PolqaOem64 (see http://www.polqa.info/) + - Tested with POLQA Library v1.180 / P863 v2.400 + - Aachen Impulse Response (AIR) Database + - Download https://www2.iks.rwth-aachen.de/air/air_database_release_1_4.zip + - Input probing signals and noise tracks (you can make your own dataset - *1) + +## Build + - Compile WebRTC + - Go to `out/Default/py_quality_assessment` and check that + `apm_quality_assessment.py` exists + +## Unit tests + - Compile WebRTC + - Go to `out/Default/py_quality_assessment` + - Run `python -m unittest discover -p "*_unittest.py"` + +## First time setup + - Deploy PolqaOem64 and set the `POLQA_PATH` environment variable + - e.g., `$ export POLQA_PATH=/var/opt/PolqaOem64` + - Deploy the AIR Database and set the `AECHEN_IR_DATABASE_PATH` environment + variable + - e.g., `$ export AECHEN_IR_DATABASE_PATH=/var/opt/AIR_1_4` + - Deploy probing signal tracks into + - `out/Default/py_quality_assessment/probing_signals` (*1) + - Deploy noise tracks into + - `out/Default/py_quality_assessment/noise_tracks` (*1, *2) + +(*1) You can use custom files as long as they are mono tracks sampled at 48kHz +encoded in the 16 bit signed format (it is recommended that the tracks are +converted and exported with Audacity). + +## Usage (scores computation) + - Go to `out/Default/py_quality_assessment` + - Check the `apm_quality_assessment.sh` as an example script to parallelize the + experiments + - Adjust the script according to your preferences (e.g., output path) + - Run `apm_quality_assessment.sh` + - The script will end by opening the browser and showing ALL the computed + scores + +## Usage (export reports) +Showing all the results at once can be confusing. You therefore may want to +export separate reports. In this case, you can use the +`apm_quality_assessment_export.py` script as follows: + + - Set `--output_dir, -o` to the same value used in `apm_quality_assessment.sh` + - Use regular expressions to select/filter out scores by + - APM configurations: `--config_names, -c` + - capture signals: `--capture_names, -i` + - render signals: `--render_names, -r` + - echo simulator: `--echo_simulator_names, -e` + - test data generators: `--test_data_generators, -t` + - scores: `--eval_scores, -s` + - Assign a suffix to the report name using `-f <suffix>` + +For instance: + +``` +$ ./apm_quality_assessment_export.py \ + -o output/ \ + -c "(^default$)|(.*AE.*)" \ + -t \(white_noise\) \ + -s \(polqa\) \ + -f echo +``` + +## Usage (boxplot) +After generating stats, it can help to visualize how a score depends on a +certain APM simulator parameter. The `apm_quality_assessment_boxplot.py` script +helps with that, producing plots similar to [this +one](https://matplotlib.org/mpl_examples/pylab_examples/boxplot_demo_06.png). + +Suppose some scores come from running the APM simulator `audioproc_f` with +or without the level controller: `--lc=1` or `--lc=0`. Then two boxplots +side by side can be generated with + +``` +$ ./apm_quality_assessment_boxplot.py \ + -o /path/to/output + -v <score_name> + -n /path/to/dir/with/apm_configs + -z lc +``` + +## Troubleshooting +The input wav file must be: + - sampled at a sample rate that is a multiple of 100 (required by POLQA) + - in the 16 bit format (required by `audioproc_f`) + - encoded in the Microsoft WAV signed 16 bit PCM format (Audacity default + when exporting) + +Depending on the license, the POLQA tool may take “breaks” as a way to limit the +throughput. When this happens, the APM Quality Assessment tool is slowed down. +For more details about this limitation, check Section 10.9.1 in the POLQA manual +v.1.18. + +In case of issues with the POLQA score computation, check +`py_quality_assessment/eval_scores.py` and adapt +`PolqaScore._parse_output_file()`. +The code can be also fixed directly into the build directory (namely, +`out/Default/py_quality_assessment/eval_scores.py`). diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_configs/default.json b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_configs/default.json new file mode 100644 index 0000000000..5c3277bac0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_configs/default.json @@ -0,0 +1 @@ +{"-all_default": null} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py new file mode 100755 index 0000000000..e067ecb692 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Perform APM module quality assessment on one or more input files using one or + more APM simulator configuration files and one or more test data generators. + +Usage: apm_quality_assessment.py -i audio1.wav [audio2.wav ...] + -c cfg1.json [cfg2.json ...] + -n white [echo ...] + -e audio_level [polqa ...] + -o /path/to/output +""" + +import argparse +import logging +import os +import sys + +import quality_assessment.audioproc_wrapper as audioproc_wrapper +import quality_assessment.echo_path_simulation as echo_path_simulation +import quality_assessment.eval_scores as eval_scores +import quality_assessment.evaluation as evaluation +import quality_assessment.eval_scores_factory as eval_scores_factory +import quality_assessment.external_vad as external_vad +import quality_assessment.test_data_generation as test_data_generation +import quality_assessment.test_data_generation_factory as \ + test_data_generation_factory +import quality_assessment.simulation as simulation + +_ECHO_PATH_SIMULATOR_NAMES = ( + echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES) +_TEST_DATA_GENERATOR_CLASSES = ( + test_data_generation.TestDataGenerator.REGISTERED_CLASSES) +_TEST_DATA_GENERATORS_NAMES = _TEST_DATA_GENERATOR_CLASSES.keys() +_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES +_EVAL_SCORE_WORKER_NAMES = _EVAL_SCORE_WORKER_CLASSES.keys() + +_DEFAULT_CONFIG_FILE = 'apm_configs/default.json' + +_POLQA_BIN_NAME = 'PolqaOem64' + + +def _InstanceArgumentsParser(): + """Arguments parser factory. + """ + parser = argparse.ArgumentParser(description=( + 'Perform APM module quality assessment on one or more input files using ' + 'one or more APM simulator configuration files and one or more ' + 'test data generators.')) + + parser.add_argument('-c', + '--config_files', + nargs='+', + required=False, + help=('path to the configuration files defining the ' + 'arguments with which the APM simulator tool is ' + 'called'), + default=[_DEFAULT_CONFIG_FILE]) + + parser.add_argument( + '-i', + '--capture_input_files', + nargs='+', + required=True, + help='path to the capture input wav files (one or more)') + + parser.add_argument('-r', + '--render_input_files', + nargs='+', + required=False, + help=('path to the render input wav files; either ' + 'omitted or one file for each file in ' + '--capture_input_files (files will be paired by ' + 'index)'), + default=None) + + parser.add_argument('-p', + '--echo_path_simulator', + required=False, + help=('custom echo path simulator name; required if ' + '--render_input_files is specified'), + choices=_ECHO_PATH_SIMULATOR_NAMES, + default=echo_path_simulation.NoEchoPathSimulator.NAME) + + parser.add_argument('-t', + '--test_data_generators', + nargs='+', + required=False, + help='custom list of test data generators to use', + choices=_TEST_DATA_GENERATORS_NAMES, + default=_TEST_DATA_GENERATORS_NAMES) + + parser.add_argument('--additive_noise_tracks_path', required=False, + help='path to the wav files for the additive', + default=test_data_generation. \ + AdditiveNoiseTestDataGenerator. \ + DEFAULT_NOISE_TRACKS_PATH) + + parser.add_argument('-e', + '--eval_scores', + nargs='+', + required=False, + help='custom list of evaluation scores to use', + choices=_EVAL_SCORE_WORKER_NAMES, + default=_EVAL_SCORE_WORKER_NAMES) + + parser.add_argument('-o', + '--output_dir', + required=False, + help=('base path to the output directory in which the ' + 'output wav files and the evaluation outcomes ' + 'are saved'), + default='output') + + parser.add_argument('--polqa_path', + required=True, + help='path to the POLQA tool') + + parser.add_argument('--air_db_path', + required=True, + help='path to the Aechen IR database') + + parser.add_argument('--apm_sim_path', required=False, + help='path to the APM simulator tool', + default=audioproc_wrapper. \ + AudioProcWrapper. \ + DEFAULT_APM_SIMULATOR_BIN_PATH) + + parser.add_argument('--echo_metric_tool_bin_path', + required=False, + help=('path to the echo metric binary ' + '(required for the echo eval score)'), + default=None) + + parser.add_argument( + '--copy_with_identity_generator', + required=False, + help=('If true, the identity test data generator makes a ' + 'copy of the clean speech input file.'), + default=False) + + parser.add_argument('--external_vad_paths', + nargs='+', + required=False, + help=('Paths to external VAD programs. Each must take' + '\'-i <wav file> -o <output>\' inputs'), + default=[]) + + parser.add_argument('--external_vad_names', + nargs='+', + required=False, + help=('Keys to the vad paths. Must be different and ' + 'as many as the paths.'), + default=[]) + + return parser + + +def _ValidateArguments(args, parser): + if args.capture_input_files and args.render_input_files and (len( + args.capture_input_files) != len(args.render_input_files)): + parser.error( + '--render_input_files and --capture_input_files must be lists ' + 'having the same length') + sys.exit(1) + + if args.render_input_files and not args.echo_path_simulator: + parser.error( + 'when --render_input_files is set, --echo_path_simulator is ' + 'also required') + sys.exit(1) + + if len(args.external_vad_names) != len(args.external_vad_paths): + parser.error('If provided, --external_vad_paths and ' + '--external_vad_names must ' + 'have the same number of arguments.') + sys.exit(1) + + +def main(): + # TODO(alessiob): level = logging.INFO once debugged. + logging.basicConfig(level=logging.DEBUG) + parser = _InstanceArgumentsParser() + args = parser.parse_args() + _ValidateArguments(args, parser) + + simulator = simulation.ApmModuleSimulator( + test_data_generator_factory=( + test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path=args.air_db_path, + noise_tracks_path=args.additive_noise_tracks_path, + copy_with_identity=args.copy_with_identity_generator)), + evaluation_score_factory=eval_scores_factory. + EvaluationScoreWorkerFactory( + polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME), + echo_metric_tool_bin_path=args.echo_metric_tool_bin_path), + ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path), + evaluator=evaluation.ApmModuleEvaluator(), + external_vads=external_vad.ExternalVad.ConstructVadDict( + args.external_vad_paths, args.external_vad_names)) + simulator.Run(config_filepaths=args.config_files, + capture_input_filepaths=args.capture_input_files, + render_input_filepaths=args.render_input_files, + echo_path_simulator_name=args.echo_path_simulator, + test_data_generator_names=args.test_data_generators, + eval_score_names=args.eval_scores, + output_dir=args.output_dir) + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh new file mode 100755 index 0000000000..aa563ee26b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +# Path to the POLQA tool. +if [ -z ${POLQA_PATH} ]; then # Check if defined. + # Default location. + export POLQA_PATH='/var/opt/PolqaOem64' +fi +if [ -d "${POLQA_PATH}" ]; then + echo "POLQA found in ${POLQA_PATH}" +else + echo "POLQA not found in ${POLQA_PATH}" + exit 1 +fi + +# Path to the Aechen IR database. +if [ -z ${AECHEN_IR_DATABASE_PATH} ]; then # Check if defined. + # Default location. + export AECHEN_IR_DATABASE_PATH='/var/opt/AIR_1_4' +fi +if [ -d "${AECHEN_IR_DATABASE_PATH}" ]; then + echo "AIR database found in ${AECHEN_IR_DATABASE_PATH}" +else + echo "AIR database not found in ${AECHEN_IR_DATABASE_PATH}" + exit 1 +fi + +# Customize probing signals, test data generators and scores if needed. +CAPTURE_SIGNALS=(probing_signals/*.wav) +TEST_DATA_GENERATORS=( \ + "identity" \ + "white_noise" \ + # "environmental_noise" \ + # "reverberation" \ +) +SCORES=( \ + # "polqa" \ + "audio_level_peak" \ + "audio_level_mean" \ +) +OUTPUT_PATH=output + +# Generate standard APM config files. +chmod +x apm_quality_assessment_gencfgs.py +./apm_quality_assessment_gencfgs.py + +# Customize APM configurations if needed. +APM_CONFIGS=(apm_configs/*.json) + +# Add output path if missing. +if [ ! -d ${OUTPUT_PATH} ]; then + mkdir ${OUTPUT_PATH} +fi + +# Start one process for each "probing signal"-"test data source" pair. +chmod +x apm_quality_assessment.py +for capture_signal_filepath in "${CAPTURE_SIGNALS[@]}" ; do + probing_signal_name="$(basename $capture_signal_filepath)" + probing_signal_name="${probing_signal_name%.*}" + for test_data_gen_name in "${TEST_DATA_GENERATORS[@]}" ; do + LOG_FILE="${OUTPUT_PATH}/apm_qa-${probing_signal_name}-"` + `"${test_data_gen_name}.log" + echo "Starting ${probing_signal_name} ${test_data_gen_name} "` + `"(see ${LOG_FILE})" + ./apm_quality_assessment.py \ + --polqa_path ${POLQA_PATH}\ + --air_db_path ${AECHEN_IR_DATABASE_PATH}\ + -i ${capture_signal_filepath} \ + -o ${OUTPUT_PATH} \ + -t ${test_data_gen_name} \ + -c "${APM_CONFIGS[@]}" \ + -e "${SCORES[@]}" > $LOG_FILE 2>&1 & + done +done + +# Join Python processes running apm_quality_assessment.py. +wait + +# Export results. +chmod +x ./apm_quality_assessment_export.py +./apm_quality_assessment_export.py -o ${OUTPUT_PATH} + +# Show results in the browser. +RESULTS_FILE="$(realpath ${OUTPUT_PATH}/results.html)" +sensible-browser "file://${RESULTS_FILE}" > /dev/null 2>&1 & diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_boxplot.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_boxplot.py new file mode 100644 index 0000000000..c425885b95 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_boxplot.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Shows boxplots of given score for different values of selected +parameters. Can be used to compare scores by audioproc_f flag. + +Usage: apm_quality_assessment_boxplot.py -o /path/to/output + -v polqa + -n /path/to/dir/with/apm_configs + -z audioproc_f_arg1 [arg2 ...] + +Arguments --config_names, --render_names, --echo_simulator_names, +--test_data_generators, --eval_scores can be used to filter the data +used for plotting. +""" + +import collections +import logging +import matplotlib.pyplot as plt +import os + +import quality_assessment.data_access as data_access +import quality_assessment.collect_data as collect_data + + +def InstanceArgumentsParser(): + """Arguments parser factory. + """ + parser = collect_data.InstanceArgumentsParser() + parser.description = ( + 'Shows boxplot of given score for different values of selected' + 'parameters. Can be used to compare scores by audioproc_f flag') + + parser.add_argument('-v', + '--eval_score', + required=True, + help=('Score name for constructing boxplots')) + + parser.add_argument( + '-n', + '--config_dir', + required=False, + help=('path to the folder with the configuration files'), + default='apm_configs') + + parser.add_argument('-z', + '--params_to_plot', + required=True, + nargs='+', + help=('audioproc_f parameter values' + 'by which to group scores (no leading dash)')) + + return parser + + +def FilterScoresByParams(data_frame, filter_params, score_name, config_dir): + """Filters data on the values of one or more parameters. + + Args: + data_frame: pandas.DataFrame of all used input data. + + filter_params: each config of the input data is assumed to have + exactly one parameter from `filter_params` defined. Every value + of the parameters in `filter_params` is a key in the returned + dict; the associated value is all cells of the data with that + value of the parameter. + + score_name: Name of score which value is boxplotted. Currently cannot do + more than one value. + + config_dir: path to dir with APM configs. + + Returns: dictionary, key is a param value, result is all scores for + that param value (see `filter_params` for explanation). + """ + results = collections.defaultdict(dict) + config_names = data_frame['apm_config'].drop_duplicates().values.tolist() + + for config_name in config_names: + config_json = data_access.AudioProcConfigFile.Load( + os.path.join(config_dir, config_name + '.json')) + data_with_config = data_frame[data_frame.apm_config == config_name] + data_cell_scores = data_with_config[data_with_config.eval_score_name == + score_name] + + # Exactly one of `params_to_plot` must match: + (matching_param, ) = [ + x for x in filter_params if '-' + x in config_json + ] + + # Add scores for every track to the result. + for capture_name in data_cell_scores.capture: + result_score = float(data_cell_scores[data_cell_scores.capture == + capture_name].score) + config_dict = results[config_json['-' + matching_param]] + if capture_name not in config_dict: + config_dict[capture_name] = {} + + config_dict[capture_name][matching_param] = result_score + + return results + + +def _FlattenToScoresList(config_param_score_dict): + """Extracts a list of scores from input data structure. + + Args: + config_param_score_dict: of the form {'capture_name': + {'param_name' : score_value,.. } ..} + + Returns: Plain list of all score value present in input data + structure + """ + result = [] + for capture_name in config_param_score_dict: + result += list(config_param_score_dict[capture_name].values()) + return result + + +def main(): + # Init. + # TODO(alessiob): INFO once debugged. + logging.basicConfig(level=logging.DEBUG) + parser = InstanceArgumentsParser() + args = parser.parse_args() + + # Get the scores. + src_path = collect_data.ConstructSrcPath(args) + logging.debug(src_path) + scores_data_frame = collect_data.FindScores(src_path, args) + + # Filter the data by `args.params_to_plot` + scores_filtered = FilterScoresByParams(scores_data_frame, + args.params_to_plot, + args.eval_score, args.config_dir) + + data_list = sorted(scores_filtered.items()) + data_values = [_FlattenToScoresList(x) for (_, x) in data_list] + data_labels = [x for (x, _) in data_list] + + _, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6)) + axes.boxplot(data_values, labels=data_labels) + axes.set_ylabel(args.eval_score) + axes.set_xlabel('/'.join(args.params_to_plot)) + plt.show() + + +if __name__ == "__main__": + main() diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_export.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_export.py new file mode 100755 index 0000000000..c20accb9dc --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_export.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Export the scores computed by the apm_quality_assessment.py script into an + HTML file. +""" + +import logging +import os +import sys + +import quality_assessment.collect_data as collect_data +import quality_assessment.export as export + + +def _BuildOutputFilename(filename_suffix): + """Builds the filename for the exported file. + + Args: + filename_suffix: suffix for the output file name. + + Returns: + A string. + """ + if filename_suffix is None: + return 'results.html' + return 'results-{}.html'.format(filename_suffix) + + +def main(): + # Init. + logging.basicConfig( + level=logging.DEBUG) # TODO(alessio): INFO once debugged. + parser = collect_data.InstanceArgumentsParser() + parser.add_argument('-f', + '--filename_suffix', + help=('suffix of the exported file')) + parser.description = ('Exports pre-computed APM module quality assessment ' + 'results into HTML tables') + args = parser.parse_args() + + # Get the scores. + src_path = collect_data.ConstructSrcPath(args) + logging.debug(src_path) + scores_data_frame = collect_data.FindScores(src_path, args) + + # Export. + output_filepath = os.path.join(args.output_dir, + _BuildOutputFilename(args.filename_suffix)) + exporter = export.HtmlExport(output_filepath) + exporter.Export(scores_data_frame) + + logging.info('output file successfully written in %s', output_filepath) + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_gencfgs.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_gencfgs.py new file mode 100755 index 0000000000..ca80f85bd1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_gencfgs.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Generate .json files with which the APM module can be tested using the + apm_quality_assessment.py script and audioproc_f as APM simulator. +""" + +import logging +import os + +import quality_assessment.data_access as data_access + +OUTPUT_PATH = os.path.abspath('apm_configs') + + +def _GenerateDefaultOverridden(config_override): + """Generates one or more APM overriden configurations. + + For each item in config_override, it overrides the default configuration and + writes a new APM configuration file. + + The default settings are loaded via "-all_default". + Check "src/modules/audio_processing/test/audioproc_float.cc" and search + for "if (FLAG_all_default) {". + + For instance, in 55eb6d621489730084927868fed195d3645a9ec9 the default is this: + settings.use_aec = rtc::Optional<bool>(true); + settings.use_aecm = rtc::Optional<bool>(false); + settings.use_agc = rtc::Optional<bool>(true); + settings.use_bf = rtc::Optional<bool>(false); + settings.use_ed = rtc::Optional<bool>(false); + settings.use_hpf = rtc::Optional<bool>(true); + settings.use_le = rtc::Optional<bool>(true); + settings.use_ns = rtc::Optional<bool>(true); + settings.use_ts = rtc::Optional<bool>(true); + settings.use_vad = rtc::Optional<bool>(true); + + Args: + config_override: dict of APM configuration file names as keys; the values + are dict instances encoding the audioproc_f flags. + """ + for config_filename in config_override: + config = config_override[config_filename] + config['-all_default'] = None + + config_filepath = os.path.join( + OUTPUT_PATH, 'default-{}.json'.format(config_filename)) + logging.debug('config file <%s> | %s', config_filepath, config) + + data_access.AudioProcConfigFile.Save(config_filepath, config) + logging.info('config file created: <%s>', config_filepath) + + +def _GenerateAllDefaultButOne(): + """Disables the flags enabled by default one-by-one. + """ + config_sets = { + 'no_AEC': { + '-aec': 0, + }, + 'no_AGC': { + '-agc': 0, + }, + 'no_HP_filter': { + '-hpf': 0, + }, + 'no_level_estimator': { + '-le': 0, + }, + 'no_noise_suppressor': { + '-ns': 0, + }, + 'no_transient_suppressor': { + '-ts': 0, + }, + 'no_vad': { + '-vad': 0, + }, + } + _GenerateDefaultOverridden(config_sets) + + +def _GenerateAllDefaultPlusOne(): + """Enables the flags disabled by default one-by-one. + """ + config_sets = { + 'with_AECM': { + '-aec': 0, + '-aecm': 1, + }, # AEC and AECM are exclusive. + 'with_AGC_limiter': { + '-agc_limiter': 1, + }, + 'with_AEC_delay_agnostic': { + '-delay_agnostic': 1, + }, + 'with_drift_compensation': { + '-drift_compensation': 1, + }, + 'with_residual_echo_detector': { + '-ed': 1, + }, + 'with_AEC_extended_filter': { + '-extended_filter': 1, + }, + 'with_LC': { + '-lc': 1, + }, + 'with_refined_adaptive_filter': { + '-refined_adaptive_filter': 1, + }, + } + _GenerateDefaultOverridden(config_sets) + + +def main(): + logging.basicConfig(level=logging.INFO) + _GenerateAllDefaultPlusOne() + _GenerateAllDefaultButOne() + + +if __name__ == '__main__': + main() diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_optimize.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_optimize.py new file mode 100644 index 0000000000..ecae2ed995 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_optimize.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Finds the APM configuration that maximizes a provided metric by +parsing the output generated apm_quality_assessment.py. +""" + +from __future__ import division + +import collections +import logging +import os + +import quality_assessment.data_access as data_access +import quality_assessment.collect_data as collect_data + + +def _InstanceArgumentsParser(): + """Arguments parser factory. Extends the arguments from 'collect_data' + with a few extra for selecting what parameters to optimize for. + """ + parser = collect_data.InstanceArgumentsParser() + parser.description = ( + 'Rudimentary optimization of a function over different parameter' + 'combinations.') + + parser.add_argument( + '-n', + '--config_dir', + required=False, + help=('path to the folder with the configuration files'), + default='apm_configs') + + parser.add_argument('-p', + '--params', + required=True, + nargs='+', + help=('parameters to parse from the config files in' + 'config_dir')) + + parser.add_argument( + '-z', + '--params_not_to_optimize', + required=False, + nargs='+', + default=[], + help=('parameters from `params` not to be optimized for')) + + return parser + + +def _ConfigurationAndScores(data_frame, params, params_not_to_optimize, + config_dir): + """Returns a list of all configurations and scores. + + Args: + data_frame: A pandas data frame with the scores and config name + returned by _FindScores. + params: The parameter names to parse from configs the config + directory + + params_not_to_optimize: The parameter names which shouldn't affect + the optimal parameter + selection. E.g., fixed settings and not + tunable parameters. + + config_dir: Path to folder with config files. + + Returns: + Dictionary of the form + {param_combination: [{params: {param1: value1, ...}, + scores: {score1: value1, ...}}]}. + + The key `param_combination` runs over all parameter combinations + of the parameters in `params` and not in + `params_not_to_optimize`. A corresponding value is a list of all + param combinations for params in `params_not_to_optimize` and + their scores. + """ + results = collections.defaultdict(list) + config_names = data_frame['apm_config'].drop_duplicates().values.tolist() + score_names = data_frame['eval_score_name'].drop_duplicates( + ).values.tolist() + + # Normalize the scores + normalization_constants = {} + for score_name in score_names: + scores = data_frame[data_frame.eval_score_name == score_name].score + normalization_constants[score_name] = max(scores) + + params_to_optimize = [p for p in params if p not in params_not_to_optimize] + param_combination = collections.namedtuple("ParamCombination", + params_to_optimize) + + for config_name in config_names: + config_json = data_access.AudioProcConfigFile.Load( + os.path.join(config_dir, config_name + ".json")) + scores = {} + data_cell = data_frame[data_frame.apm_config == config_name] + for score_name in score_names: + data_cell_scores = data_cell[data_cell.eval_score_name == + score_name].score + scores[score_name] = sum(data_cell_scores) / len(data_cell_scores) + scores[score_name] /= normalization_constants[score_name] + + result = {'scores': scores, 'params': {}} + config_optimize_params = {} + for param in params: + if param in params_to_optimize: + config_optimize_params[param] = config_json['-' + param] + else: + result['params'][param] = config_json['-' + param] + + current_param_combination = param_combination(**config_optimize_params) + results[current_param_combination].append(result) + return results + + +def _FindOptimalParameter(configs_and_scores, score_weighting): + """Finds the config producing the maximal score. + + Args: + configs_and_scores: structure of the form returned by + _ConfigurationAndScores + + score_weighting: a function to weight together all score values of + the form [{params: {param1: value1, ...}, scores: + {score1: value1, ...}}] into a numeric + value + Returns: + the config that has the largest values of `score_weighting` applied + to its scores. + """ + + min_score = float('+inf') + best_params = None + for config in configs_and_scores: + scores_and_params = configs_and_scores[config] + current_score = score_weighting(scores_and_params) + if current_score < min_score: + min_score = current_score + best_params = config + logging.debug("Score: %f", current_score) + logging.debug("Config: %s", str(config)) + return best_params + + +def _ExampleWeighting(scores_and_configs): + """Example argument to `_FindOptimalParameter` + Args: + scores_and_configs: a list of configs and scores, in the form + described in _FindOptimalParameter + Returns: + numeric value, the sum of all scores + """ + res = 0 + for score_config in scores_and_configs: + res += sum(score_config['scores'].values()) + return res + + +def main(): + # Init. + # TODO(alessiob): INFO once debugged. + logging.basicConfig(level=logging.DEBUG) + parser = _InstanceArgumentsParser() + args = parser.parse_args() + + # Get the scores. + src_path = collect_data.ConstructSrcPath(args) + logging.debug('Src path <%s>', src_path) + scores_data_frame = collect_data.FindScores(src_path, args) + all_scores = _ConfigurationAndScores(scores_data_frame, args.params, + args.params_not_to_optimize, + args.config_dir) + + opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting) + + logging.info('Optimal parameter combination: <%s>', opt_param) + logging.info('It\'s score values: <%s>', all_scores[opt_param]) + + +if __name__ == "__main__": + main() diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py new file mode 100644 index 0000000000..80338c1373 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py @@ -0,0 +1,28 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the apm_quality_assessment module. +""" + +import sys +import unittest + +import mock + +import apm_quality_assessment + + +class TestSimulationScript(unittest.TestCase): + """Unit tests for the apm_quality_assessment module. + """ + + def testMain(self): + # Exit with error code if no arguments are passed. + with self.assertRaises(SystemExit) as cm, mock.patch.object( + sys, 'argv', ['apm_quality_assessment.py']): + apm_quality_assessment.main() + self.assertGreater(cm.exception.code, 0) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/output/README.md b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/output/README.md new file mode 100644 index 0000000000..66e2a1c848 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/output/README.md @@ -0,0 +1 @@ +You can use this folder for the output generated by the apm_quality_assessment scripts. diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/__init__.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/__init__.py new file mode 100644 index 0000000000..b870dfaef3 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py new file mode 100644 index 0000000000..93a8248397 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py @@ -0,0 +1,296 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Extraction of annotations from audio files. +""" + +from __future__ import division +import logging +import os +import shutil +import struct +import subprocess +import sys +import tempfile + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +from . import external_vad +from . import exceptions +from . import signal_processing + + +class AudioAnnotationsExtractor(object): + """Extracts annotations from audio files. + """ + + class VadType(object): + ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard. + WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h + WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h + + def __init__(self, value): + if (not isinstance(value, int)) or not 0 <= value <= 7: + raise exceptions.InitializationException('Invalid vad type: ' + + value) + self._value = value + + def Contains(self, vad_type): + return self._value | vad_type == self._value + + def __str__(self): + vads = [] + if self.Contains(self.ENERGY_THRESHOLD): + vads.append("energy") + if self.Contains(self.WEBRTC_COMMON_AUDIO): + vads.append("common_audio") + if self.Contains(self.WEBRTC_APM): + vads.append("apm") + return "VadType({})".format(", ".join(vads)) + + _OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz' + + # Level estimation params. + _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0) + _LEVEL_FRAME_SIZE_MS = 1.0 + # The time constants in ms indicate the time it takes for the level estimate + # to go down/up by 1 db if the signal is zero. + _LEVEL_ATTACK_MS = 5.0 + _LEVEL_DECAY_MS = 20.0 + + # VAD params. + _VAD_THRESHOLD = 1 + _VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), + os.pardir, os.pardir) + _VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad') + + _VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad') + + def __init__(self, vad_type, external_vads=None): + self._signal = None + self._level = None + self._level_frame_size = None + self._common_audio_vad = None + self._energy_vad = None + self._apm_vad_probs = None + self._apm_vad_rms = None + self._vad_frame_size = None + self._vad_frame_size_ms = None + self._c_attack = None + self._c_decay = None + + self._vad_type = self.VadType(vad_type) + logging.info('VADs used for annotations: ' + str(self._vad_type)) + + if external_vads is None: + external_vads = {} + self._external_vads = external_vads + + assert len(self._external_vads) == len(external_vads), ( + 'The external VAD names must be unique.') + for vad in external_vads.values(): + if not isinstance(vad, external_vad.ExternalVad): + raise exceptions.InitializationException('Invalid vad type: ' + + str(type(vad))) + logging.info('External VAD used for annotation: ' + str(vad.name)) + + assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \ + self._VAD_WEBRTC_COMMON_AUDIO_PATH + assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \ + self._VAD_WEBRTC_APM_PATH + + @classmethod + def GetOutputFileNameTemplate(cls): + return cls._OUTPUT_FILENAME_TEMPLATE + + def GetLevel(self): + return self._level + + def GetLevelFrameSize(self): + return self._level_frame_size + + @classmethod + def GetLevelFrameSizeMs(cls): + return cls._LEVEL_FRAME_SIZE_MS + + def GetVadOutput(self, vad_type): + if vad_type == self.VadType.ENERGY_THRESHOLD: + return self._energy_vad + elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO: + return self._common_audio_vad + elif vad_type == self.VadType.WEBRTC_APM: + return (self._apm_vad_probs, self._apm_vad_rms) + else: + raise exceptions.InitializationException('Invalid vad type: ' + + vad_type) + + def GetVadFrameSize(self): + return self._vad_frame_size + + def GetVadFrameSizeMs(self): + return self._vad_frame_size_ms + + def Extract(self, filepath): + # Load signal. + self._signal = signal_processing.SignalProcessingUtils.LoadWav( + filepath) + if self._signal.channels != 1: + raise NotImplementedError( + 'Multiple-channel annotations not implemented') + + # Level estimation params. + self._level_frame_size = int(self._signal.frame_rate / 1000 * + (self._LEVEL_FRAME_SIZE_MS)) + self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else ( + self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS / + self._LEVEL_ATTACK_MS)) + self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else ( + self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS / + self._LEVEL_DECAY_MS)) + + # Compute level. + self._LevelEstimation() + + # Ideal VAD output, it requires clean speech with high SNR as input. + if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD): + # Naive VAD based on level thresholding. + vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) + self._energy_vad = np.uint8(self._level > vad_threshold) + self._vad_frame_size = self._level_frame_size + self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS + if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO): + # WebRTC common_audio/ VAD. + self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate) + if self._vad_type.Contains(self.VadType.WEBRTC_APM): + # WebRTC modules/audio_processing/ VAD. + self._RunWebRtcApmVad(filepath) + for extvad_name in self._external_vads: + self._external_vads[extvad_name].Run(filepath) + + def Save(self, output_path, annotation_name=""): + ext_kwargs = { + 'extvad_conf-' + ext_vad: + self._external_vads[ext_vad].GetVadOutput() + for ext_vad in self._external_vads + } + np.savez_compressed(file=os.path.join( + output_path, + self.GetOutputFileNameTemplate().format(annotation_name)), + level=self._level, + level_frame_size=self._level_frame_size, + level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS, + vad_output=self._common_audio_vad, + vad_energy_output=self._energy_vad, + vad_frame_size=self._vad_frame_size, + vad_frame_size_ms=self._vad_frame_size_ms, + vad_probs=self._apm_vad_probs, + vad_rms=self._apm_vad_rms, + **ext_kwargs) + + def _LevelEstimation(self): + # Read samples. + samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( + self._signal).astype(np.float32) / 32768.0 + num_frames = len(samples) // self._level_frame_size + num_samples = num_frames * self._level_frame_size + + # Envelope. + self._level = np.max(np.reshape(np.abs(samples[:num_samples]), + (num_frames, self._level_frame_size)), + axis=1) + assert len(self._level) == num_frames + + # Envelope smoothing. + smooth = lambda curr, prev, k: (1 - k) * curr + k * prev + self._level[0] = smooth(self._level[0], 0.0, self._c_attack) + for i in range(1, num_frames): + self._level[i] = smooth( + self._level[i], self._level[i - 1], self._c_attack if + (self._level[i] > self._level[i - 1]) else self._c_decay) + + def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate): + self._common_audio_vad = None + self._vad_frame_size = None + + # Create temporary output path. + tmp_path = tempfile.mkdtemp() + output_file_path = os.path.join( + tmp_path, + os.path.split(wav_file_path)[1] + '_vad.tmp') + + # Call WebRTC VAD. + try: + subprocess.call([ + self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o', + output_file_path + ], + cwd=self._VAD_WEBRTC_PATH) + + # Read bytes. + with open(output_file_path, 'rb') as f: + raw_data = f.read() + + # Parse side information. + self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0] + self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000 + assert self._vad_frame_size_ms in [10, 20, 30] + extra_bits = struct.unpack('B', raw_data[-1])[0] + assert 0 <= extra_bits <= 8 + + # Init VAD vector. + num_bytes = len(raw_data) + num_frames = 8 * (num_bytes - + 2) - extra_bits # 8 frames for each byte. + self._common_audio_vad = np.zeros(num_frames, np.uint8) + + # Read VAD decisions. + for i, byte in enumerate(raw_data[1:-1]): + byte = struct.unpack('B', byte)[0] + for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)): + self._common_audio_vad[i * 8 + j] = int(byte & 1) + byte = byte >> 1 + except Exception as e: + logging.error('Error while running the WebRTC VAD (' + e.message + + ')') + finally: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + + def _RunWebRtcApmVad(self, wav_file_path): + # Create temporary output path. + tmp_path = tempfile.mkdtemp() + output_file_path_probs = os.path.join( + tmp_path, + os.path.split(wav_file_path)[1] + '_vad_probs.tmp') + output_file_path_rms = os.path.join( + tmp_path, + os.path.split(wav_file_path)[1] + '_vad_rms.tmp') + + # Call WebRTC VAD. + try: + subprocess.call([ + self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs', + output_file_path_probs, '-o_rms', output_file_path_rms + ], + cwd=self._VAD_WEBRTC_PATH) + + # Parse annotations. + self._apm_vad_probs = np.fromfile(output_file_path_probs, + np.double) + self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double) + assert len(self._apm_vad_rms) == len(self._apm_vad_probs) + + except Exception as e: + logging.error('Error while running the WebRTC APM VAD (' + + e.message + ')') + finally: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py new file mode 100644 index 0000000000..8230208808 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py @@ -0,0 +1,160 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the annotations module. +""" + +from __future__ import division +import logging +import os +import shutil +import tempfile +import unittest + +import numpy as np + +from . import annotations +from . import external_vad +from . import input_signal_creator +from . import signal_processing + + +class TestAnnotationsExtraction(unittest.TestCase): + """Unit tests for the annotations module. + """ + + _CLEAN_TMP_OUTPUT = True + _DEBUG_PLOT_VAD = False + _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType + _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD + | _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO + | _VAD_TYPE_CLASS.WEBRTC_APM) + + def setUp(self): + """Create temporary folder.""" + self._tmp_path = tempfile.mkdtemp() + self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav') + pure_tone, _ = input_signal_creator.InputSignalCreator.Create( + 'pure_tone', [440, 1000]) + signal_processing.SignalProcessingUtils.SaveWav( + self._wav_file_path, pure_tone) + self._sample_rate = pure_tone.frame_rate + + def tearDown(self): + """Recursively delete temporary folder.""" + if self._CLEAN_TMP_OUTPUT: + shutil.rmtree(self._tmp_path) + else: + logging.warning(self.id() + ' did not clean the temporary path ' + + (self._tmp_path)) + + def testFrameSizes(self): + e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) + e.Extract(self._wav_file_path) + samples_to_ms = lambda n, sr: 1000 * n // sr + self.assertEqual( + samples_to_ms(e.GetLevelFrameSize(), self._sample_rate), + e.GetLevelFrameSizeMs()) + self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate), + e.GetVadFrameSizeMs()) + + def testVoiceActivityDetectors(self): + for vad_type_value in range(0, self._ALL_VAD_TYPES + 1): + vad_type = self._VAD_TYPE_CLASS(vad_type_value) + e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value) + e.Extract(self._wav_file_path) + if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD): + # pylint: disable=unpacking-non-sequence + vad_output = e.GetVadOutput( + self._VAD_TYPE_CLASS.ENERGY_THRESHOLD) + self.assertGreater(len(vad_output), 0) + self.assertGreaterEqual( + float(np.sum(vad_output)) / len(vad_output), 0.95) + + if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO): + # pylint: disable=unpacking-non-sequence + vad_output = e.GetVadOutput( + self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO) + self.assertGreater(len(vad_output), 0) + self.assertGreaterEqual( + float(np.sum(vad_output)) / len(vad_output), 0.95) + + if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM): + # pylint: disable=unpacking-non-sequence + (vad_probs, + vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM) + self.assertGreater(len(vad_probs), 0) + self.assertGreater(len(vad_rms), 0) + self.assertGreaterEqual( + float(np.sum(vad_probs)) / len(vad_probs), 0.5) + self.assertGreaterEqual( + float(np.sum(vad_rms)) / len(vad_rms), 20000) + + if self._DEBUG_PLOT_VAD: + frame_times_s = lambda num_frames, frame_size_ms: np.arange( + num_frames).astype(np.float32) * frame_size_ms / 1000.0 + level = e.GetLevel() + t_level = frame_times_s(num_frames=len(level), + frame_size_ms=e.GetLevelFrameSizeMs()) + t_vad = frame_times_s(num_frames=len(vad_output), + frame_size_ms=e.GetVadFrameSizeMs()) + import matplotlib.pyplot as plt + plt.figure() + plt.hold(True) + plt.plot(t_level, level) + plt.plot(t_vad, vad_output * np.max(level), '.') + plt.show() + + def testSaveLoad(self): + e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) + e.Extract(self._wav_file_path) + e.Save(self._tmp_path, "fake-annotation") + + data = np.load( + os.path.join( + self._tmp_path, + e.GetOutputFileNameTemplate().format("fake-annotation"))) + np.testing.assert_array_equal(e.GetLevel(), data['level']) + self.assertEqual(np.float32, data['level'].dtype) + np.testing.assert_array_equal( + e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD), + data['vad_energy_output']) + np.testing.assert_array_equal( + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO), + data['vad_output']) + np.testing.assert_array_equal( + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], + data['vad_probs']) + np.testing.assert_array_equal( + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], + data['vad_rms']) + self.assertEqual(np.uint8, data['vad_energy_output'].dtype) + self.assertEqual(np.float64, data['vad_probs'].dtype) + self.assertEqual(np.float64, data['vad_rms'].dtype) + + def testEmptyExternalShouldNotCrash(self): + for vad_type_value in range(0, self._ALL_VAD_TYPES + 1): + annotations.AudioAnnotationsExtractor(vad_type_value, {}) + + def testFakeExternalSaveLoad(self): + def FakeExternalFactory(): + return external_vad.ExternalVad( + os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'fake_external_vad.py'), 'fake') + + for vad_type_value in range(0, self._ALL_VAD_TYPES + 1): + e = annotations.AudioAnnotationsExtractor( + vad_type_value, {'fake': FakeExternalFactory()}) + e.Extract(self._wav_file_path) + e.Save(self._tmp_path, annotation_name="fake-annotation") + data = np.load( + os.path.join( + self._tmp_path, + e.GetOutputFileNameTemplate().format("fake-annotation"))) + self.assertEqual(np.float32, data['extvad_conf-fake'].dtype) + np.testing.assert_almost_equal(np.arange(100, dtype=np.float32), + data['extvad_conf-fake']) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/apm_configs/default.json b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/apm_configs/default.json new file mode 100644 index 0000000000..5c3277bac0 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/apm_configs/default.json @@ -0,0 +1 @@ +{"-all_default": null} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/apm_vad.cc b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/apm_vad.cc new file mode 100644 index 0000000000..73ce4ed3f7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/apm_vad.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. An additional intellectual property rights grant can be found +// in the file PATENTS. All contributing project authors may +// be found in the AUTHORS file in the root of the source tree. + +#include <array> +#include <fstream> +#include <memory> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "common_audio/wav_file.h" +#include "modules/audio_processing/vad/voice_activity_detector.h" +#include "rtc_base/logging.h" + +ABSL_FLAG(std::string, i, "", "Input wav file"); +ABSL_FLAG(std::string, o_probs, "", "VAD probabilities output file"); +ABSL_FLAG(std::string, o_rms, "", "VAD output file"); + +namespace webrtc { +namespace test { +namespace { + +constexpr uint8_t kAudioFrameLengthMilliseconds = 10; +constexpr int kMaxSampleRate = 48000; +constexpr size_t kMaxFrameLen = + kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000; + +int main(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); + const std::string input_file = absl::GetFlag(FLAGS_i); + const std::string output_probs_file = absl::GetFlag(FLAGS_o_probs); + const std::string output_file = absl::GetFlag(FLAGS_o_rms); + // Open wav input file and check properties. + WavReader wav_reader(input_file); + if (wav_reader.num_channels() != 1) { + RTC_LOG(LS_ERROR) << "Only mono wav files supported"; + return 1; + } + if (wav_reader.sample_rate() > kMaxSampleRate) { + RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate + << ")"; + return 1; + } + const size_t audio_frame_len = rtc::CheckedDivExact( + kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000); + if (audio_frame_len > kMaxFrameLen) { + RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large."; + return 1; + } + + // Create output file and write header. + std::ofstream out_probs_file(output_probs_file, std::ofstream::binary); + std::ofstream out_rms_file(output_file, std::ofstream::binary); + + // Run VAD and write decisions. + VoiceActivityDetector vad; + std::array<int16_t, kMaxFrameLen> samples; + + while (true) { + // Process frame. + const auto read_samples = + wav_reader.ReadSamples(audio_frame_len, samples.data()); + if (read_samples < audio_frame_len) { + break; + } + vad.ProcessChunk(samples.data(), audio_frame_len, wav_reader.sample_rate()); + // Write output. + auto probs = vad.chunkwise_voice_probabilities(); + auto rms = vad.chunkwise_rms(); + RTC_CHECK_EQ(probs.size(), rms.size()); + RTC_CHECK_EQ(sizeof(double), 8); + + for (const auto& p : probs) { + out_probs_file.write(reinterpret_cast<const char*>(&p), 8); + } + for (const auto& r : rms) { + out_rms_file.write(reinterpret_cast<const char*>(&r), 8); + } + } + + out_probs_file.close(); + out_rms_file.close(); + return 0; +} + +} // namespace +} // namespace test +} // namespace webrtc + +int main(int argc, char* argv[]) { + return webrtc::test::main(argc, argv); +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py new file mode 100644 index 0000000000..04aeaa95b9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/audioproc_wrapper.py @@ -0,0 +1,100 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Class implementing a wrapper for APM simulators. +""" + +import cProfile +import logging +import os +import subprocess + +from . import data_access +from . import exceptions + + +class AudioProcWrapper(object): + """Wrapper for APM simulators. + """ + + DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath( + os.path.join(os.pardir, 'audioproc_f')) + OUTPUT_FILENAME = 'output.wav' + + def __init__(self, simulator_bin_path): + """Ctor. + + Args: + simulator_bin_path: path to the APM simulator binary. + """ + self._simulator_bin_path = simulator_bin_path + self._config = None + self._output_signal_filepath = None + + # Profiler instance to measure running time. + self._profiler = cProfile.Profile() + + @property + def output_filepath(self): + return self._output_signal_filepath + + def Run(self, + config_filepath, + capture_input_filepath, + output_path, + render_input_filepath=None): + """Runs APM simulator. + + Args: + config_filepath: path to the configuration file specifying the arguments + for the APM simulator. + capture_input_filepath: path to the capture audio track input file (aka + forward or near-end). + output_path: path of the audio track output file. + render_input_filepath: path to the render audio track input file (aka + reverse or far-end). + """ + # Init. + self._output_signal_filepath = os.path.join(output_path, + self.OUTPUT_FILENAME) + profiling_stats_filepath = os.path.join(output_path, 'profiling.stats') + + # Skip if the output has already been generated. + if os.path.exists(self._output_signal_filepath) and os.path.exists( + profiling_stats_filepath): + return + + # Load configuration. + self._config = data_access.AudioProcConfigFile.Load(config_filepath) + + # Set remaining parameters. + if not os.path.exists(capture_input_filepath): + raise exceptions.FileNotFoundError( + 'cannot find capture input file') + self._config['-i'] = capture_input_filepath + self._config['-o'] = self._output_signal_filepath + if render_input_filepath is not None: + if not os.path.exists(render_input_filepath): + raise exceptions.FileNotFoundError( + 'cannot find render input file') + self._config['-ri'] = render_input_filepath + + # Build arguments list. + args = [self._simulator_bin_path] + for param_name in self._config: + args.append(param_name) + if self._config[param_name] is not None: + args.append(str(self._config[param_name])) + logging.debug(' '.join(args)) + + # Run. + self._profiler.enable() + subprocess.call(args) + self._profiler.disable() + + # Save profiling stats. + self._profiler.dump_stats(profiling_stats_filepath) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/collect_data.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/collect_data.py new file mode 100644 index 0000000000..38aac0cbe2 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/collect_data.py @@ -0,0 +1,243 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Imports a filtered subset of the scores and configurations computed +by apm_quality_assessment.py into a pandas data frame. +""" + +import argparse +import glob +import logging +import os +import re +import sys + +try: + import pandas as pd +except ImportError: + logging.critical('Cannot import the third-party Python package pandas') + sys.exit(1) + +from . import data_access as data_access +from . import simulation as sim + +# Compiled regular expressions used to extract score descriptors. +RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() + + r'(.+)') +RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() + + r'(.+)') +RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)') +RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() + + r'(.+)') +RE_TEST_DATA_GEN_NAME = re.compile( + sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)') +RE_TEST_DATA_GEN_PARAMS = re.compile( + sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)') +RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() + + r'(.+)(\..+)') + + +def InstanceArgumentsParser(): + """Arguments parser factory. + """ + parser = argparse.ArgumentParser( + description=('Override this description in a user script by changing' + ' `parser.description` of the returned parser.')) + + parser.add_argument('-o', + '--output_dir', + required=True, + help=('the same base path used with the ' + 'apm_quality_assessment tool')) + + parser.add_argument( + '-c', + '--config_names', + type=re.compile, + help=('regular expression to filter the APM configuration' + ' names')) + + parser.add_argument( + '-i', + '--capture_names', + type=re.compile, + help=('regular expression to filter the capture signal ' + 'names')) + + parser.add_argument('-r', + '--render_names', + type=re.compile, + help=('regular expression to filter the render signal ' + 'names')) + + parser.add_argument( + '-e', + '--echo_simulator_names', + type=re.compile, + help=('regular expression to filter the echo simulator ' + 'names')) + + parser.add_argument('-t', + '--test_data_generators', + type=re.compile, + help=('regular expression to filter the test data ' + 'generator names')) + + parser.add_argument( + '-s', + '--eval_scores', + type=re.compile, + help=('regular expression to filter the evaluation score ' + 'names')) + + return parser + + +def _GetScoreDescriptors(score_filepath): + """Extracts a score descriptor from the given score file path. + + Args: + score_filepath: path to the score file. + + Returns: + A tuple of strings (APM configuration name, capture audio track name, + render audio track name, echo simulator name, test data generator name, + test data generator parameters as string, evaluation score name). + """ + fields = score_filepath.split(os.sep)[-7:] + extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]). + groups(0)[0]) + return ( + extract_name(0, RE_CONFIG_NAME), + extract_name(1, RE_CAPTURE_NAME), + extract_name(2, RE_RENDER_NAME), + extract_name(3, RE_ECHO_SIM_NAME), + extract_name(4, RE_TEST_DATA_GEN_NAME), + extract_name(5, RE_TEST_DATA_GEN_PARAMS), + extract_name(6, RE_SCORE_NAME), + ) + + +def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name, + test_data_gen_name, score_name, args): + """Decides whether excluding a score. + + A set of optional regular expressions in args is used to determine if the + score should be excluded (depending on its |*_name| descriptors). + + Args: + config_name: APM configuration name. + capture_name: capture audio track name. + render_name: render audio track name. + echo_simulator_name: echo simulator name. + test_data_gen_name: test data generator name. + score_name: evaluation score name. + args: parsed arguments. + + Returns: + A boolean. + """ + value_regexpr_pairs = [ + (config_name, args.config_names), + (capture_name, args.capture_names), + (render_name, args.render_names), + (echo_simulator_name, args.echo_simulator_names), + (test_data_gen_name, args.test_data_generators), + (score_name, args.eval_scores), + ] + + # Score accepted if each value matches the corresponding regular expression. + for value, regexpr in value_regexpr_pairs: + if regexpr is None: + continue + if not regexpr.match(value): + return True + + return False + + +def FindScores(src_path, args): + """Given a search path, find scores and return a DataFrame object. + + Args: + src_path: Search path pattern. + args: parsed arguments. + + Returns: + A DataFrame object. + """ + # Get scores. + scores = [] + for score_filepath in glob.iglob(src_path): + # Extract score descriptor fields from the path. + (config_name, capture_name, render_name, echo_simulator_name, + test_data_gen_name, test_data_gen_params, + score_name) = _GetScoreDescriptors(score_filepath) + + # Ignore the score if required. + if _ExcludeScore(config_name, capture_name, render_name, + echo_simulator_name, test_data_gen_name, score_name, + args): + logging.info('ignored score: %s %s %s %s %s %s', config_name, + capture_name, render_name, echo_simulator_name, + test_data_gen_name, score_name) + continue + + # Read metadata and score. + metadata = data_access.Metadata.LoadAudioTestDataPaths( + os.path.split(score_filepath)[0]) + score = data_access.ScoreFile.Load(score_filepath) + + # Add a score with its descriptor fields. + scores.append(( + metadata['clean_capture_input_filepath'], + metadata['echo_free_capture_filepath'], + metadata['echo_filepath'], + metadata['render_filepath'], + metadata['capture_filepath'], + metadata['apm_output_filepath'], + metadata['apm_reference_filepath'], + config_name, + capture_name, + render_name, + echo_simulator_name, + test_data_gen_name, + test_data_gen_params, + score_name, + score, + )) + + return pd.DataFrame(data=scores, + columns=( + 'clean_capture_input_filepath', + 'echo_free_capture_filepath', + 'echo_filepath', + 'render_filepath', + 'capture_filepath', + 'apm_output_filepath', + 'apm_reference_filepath', + 'apm_config', + 'capture', + 'render', + 'echo_simulator', + 'test_data_gen', + 'test_data_gen_params', + 'eval_score_name', + 'score', + )) + + +def ConstructSrcPath(args): + return os.path.join( + args.output_dir, + sim.ApmModuleSimulator.GetPrefixApmConfig() + '*', + sim.ApmModuleSimulator.GetPrefixCapture() + '*', + sim.ApmModuleSimulator.GetPrefixRender() + '*', + sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*', + sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*', + sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*', + sim.ApmModuleSimulator.GetPrefixScore() + '*') diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py new file mode 100644 index 0000000000..c1aebb67f1 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py @@ -0,0 +1,154 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Data access utility functions and classes. +""" + +import json +import os + + +def MakeDirectory(path): + """Makes a directory recursively without rising exceptions if existing. + + Args: + path: path to the directory to be created. + """ + if os.path.exists(path): + return + os.makedirs(path) + + +class Metadata(object): + """Data access class to save and load metadata. + """ + + def __init__(self): + pass + + _GENERIC_METADATA_SUFFIX = '.mdata' + _AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json' + + @classmethod + def LoadFileMetadata(cls, filepath): + """Loads generic metadata linked to a file. + + Args: + filepath: path to the metadata file to read. + + Returns: + A dict. + """ + with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f: + return json.load(f) + + @classmethod + def SaveFileMetadata(cls, filepath, metadata): + """Saves generic metadata linked to a file. + + Args: + filepath: path to the metadata file to write. + metadata: a dict. + """ + with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f: + json.dump(metadata, f) + + @classmethod + def LoadAudioTestDataPaths(cls, metadata_path): + """Loads the input and the reference audio track paths. + + Args: + metadata_path: path to the directory containing the metadata file. + + Returns: + Tuple with the paths to the input and output audio tracks. + """ + metadata_filepath = os.path.join(metadata_path, + cls._AUDIO_TEST_DATA_FILENAME) + with open(metadata_filepath) as f: + return json.load(f) + + @classmethod + def SaveAudioTestDataPaths(cls, output_path, **filepaths): + """Saves the input and the reference audio track paths. + + Args: + output_path: path to the directory containing the metadata file. + + Keyword Args: + filepaths: collection of audio track file paths to save. + """ + output_filepath = os.path.join(output_path, + cls._AUDIO_TEST_DATA_FILENAME) + with open(output_filepath, 'w') as f: + json.dump(filepaths, f) + + +class AudioProcConfigFile(object): + """Data access to load/save APM simulator argument lists. + + The arguments stored in the config files are used to control the APM flags. + """ + + def __init__(self): + pass + + @classmethod + def Load(cls, filepath): + """Loads a configuration file for an APM simulator. + + Args: + filepath: path to the configuration file. + + Returns: + A dict containing the configuration. + """ + with open(filepath) as f: + return json.load(f) + + @classmethod + def Save(cls, filepath, config): + """Saves a configuration file for an APM simulator. + + Args: + filepath: path to the configuration file. + config: a dict containing the configuration. + """ + with open(filepath, 'w') as f: + json.dump(config, f) + + +class ScoreFile(object): + """Data access class to save and load float scalar scores. + """ + + def __init__(self): + pass + + @classmethod + def Load(cls, filepath): + """Loads a score from file. + + Args: + filepath: path to the score file. + + Returns: + A float encoding the score. + """ + with open(filepath) as f: + return float(f.readline().strip()) + + @classmethod + def Save(cls, filepath, score): + """Saves a score into a file. + + Args: + filepath: path to the score file. + score: float encoding the score. + """ + with open(filepath, 'w') as f: + f.write('{0:f}\n'.format(score)) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation.py new file mode 100644 index 0000000000..65903ea32d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation.py @@ -0,0 +1,136 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Echo path simulation module. +""" + +import hashlib +import os + +from . import signal_processing + + +class EchoPathSimulator(object): + """Abstract class for the echo path simulators. + + In general, an echo path simulator is a function of the render signal and + simulates the propagation of the latter into the microphone (e.g., due to + mechanical or electrical paths). + """ + + NAME = None + REGISTERED_CLASSES = {} + + def __init__(self): + pass + + def Simulate(self, output_path): + """Creates the echo signal and stores it in an audio file (abstract method). + + Args: + output_path: Path in which any output can be saved. + + Returns: + Path to the generated audio track file or None if no echo is present. + """ + raise NotImplementedError() + + @classmethod + def RegisterClass(cls, class_to_register): + """Registers an EchoPathSimulator implementation. + + Decorator to automatically register the classes that extend + EchoPathSimulator. + Example usage: + + @EchoPathSimulator.RegisterClass + class NoEchoPathSimulator(EchoPathSimulator): + pass + """ + cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register + return class_to_register + + +@EchoPathSimulator.RegisterClass +class NoEchoPathSimulator(EchoPathSimulator): + """Simulates absence of echo.""" + + NAME = 'noecho' + + def __init__(self): + EchoPathSimulator.__init__(self) + + def Simulate(self, output_path): + return None + + +@EchoPathSimulator.RegisterClass +class LinearEchoPathSimulator(EchoPathSimulator): + """Simulates linear echo path. + + This class applies a given impulse response to the render input and then it + sums the signal to the capture input signal. + """ + + NAME = 'linear' + + def __init__(self, render_input_filepath, impulse_response): + """ + Args: + render_input_filepath: Render audio track file. + impulse_response: list or numpy vector of float values. + """ + EchoPathSimulator.__init__(self) + self._render_input_filepath = render_input_filepath + self._impulse_response = impulse_response + + def Simulate(self, output_path): + """Simulates linear echo path.""" + # Form the file name with a hash of the impulse response. + impulse_response_hash = hashlib.sha256( + str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest() + echo_filepath = os.path.join( + output_path, 'linear_echo_{}.wav'.format(impulse_response_hash)) + + # If the simulated echo audio track file does not exists, create it. + if not os.path.exists(echo_filepath): + render = signal_processing.SignalProcessingUtils.LoadWav( + self._render_input_filepath) + echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse( + render, self._impulse_response) + signal_processing.SignalProcessingUtils.SaveWav( + echo_filepath, echo) + + return echo_filepath + + +@EchoPathSimulator.RegisterClass +class RecordedEchoPathSimulator(EchoPathSimulator): + """Uses recorded echo. + + This class uses the clean capture input file name to build the file name of + the corresponding recording containing echo (a predefined suffix is used). + Such a file is expected to be already existing. + """ + + NAME = 'recorded' + + _FILE_NAME_SUFFIX = '_echo' + + def __init__(self, render_input_filepath): + EchoPathSimulator.__init__(self) + self._render_input_filepath = render_input_filepath + + def Simulate(self, output_path): + """Uses recorded echo path.""" + path, file_name_ext = os.path.split(self._render_input_filepath) + file_name, file_ext = os.path.splitext(file_name_ext) + echo_filepath = os.path.join( + path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext)) + assert os.path.exists(echo_filepath), ( + 'cannot find the echo audio track file {}'.format(echo_filepath)) + return echo_filepath diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation_factory.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation_factory.py new file mode 100644 index 0000000000..4b46b36b47 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation_factory.py @@ -0,0 +1,48 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Echo path simulation factory module. +""" + +import numpy as np + +from . import echo_path_simulation + + +class EchoPathSimulatorFactory(object): + + # TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more + # realistic impulse response. + _LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15]) + + def __init__(self): + pass + + @classmethod + def GetInstance(cls, echo_path_simulator_class, render_input_filepath): + """Creates an EchoPathSimulator instance given a class object. + + Args: + echo_path_simulator_class: EchoPathSimulator class object (not an + instance). + render_input_filepath: Path to the render audio track file. + + Returns: + An EchoPathSimulator instance. + """ + assert render_input_filepath is not None or ( + echo_path_simulator_class == + echo_path_simulation.NoEchoPathSimulator) + + if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator: + return echo_path_simulation.NoEchoPathSimulator() + elif echo_path_simulator_class == ( + echo_path_simulation.LinearEchoPathSimulator): + return echo_path_simulation.LinearEchoPathSimulator( + render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE) + else: + return echo_path_simulator_class(render_input_filepath) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation_unittest.py new file mode 100644 index 0000000000..b6cc8abdde --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/echo_path_simulation_unittest.py @@ -0,0 +1,82 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the echo path simulation module. +""" + +import shutil +import os +import tempfile +import unittest + +import pydub + +from . import echo_path_simulation +from . import echo_path_simulation_factory +from . import signal_processing + + +class TestEchoPathSimulators(unittest.TestCase): + """Unit tests for the eval_scores module. + """ + + def setUp(self): + """Creates temporary data.""" + self._tmp_path = tempfile.mkdtemp() + + # Create and save white noise. + silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + silence) + self._audio_track_num_samples = ( + signal_processing.SignalProcessingUtils.CountSamples(white_noise)) + self._audio_track_filepath = os.path.join(self._tmp_path, + 'white_noise.wav') + signal_processing.SignalProcessingUtils.SaveWav( + self._audio_track_filepath, white_noise) + + # Make a copy the white noise audio track file; it will be used by + # echo_path_simulation.RecordedEchoPathSimulator. + shutil.copy(self._audio_track_filepath, + os.path.join(self._tmp_path, 'white_noise_echo.wav')) + + def tearDown(self): + """Recursively deletes temporary folders.""" + shutil.rmtree(self._tmp_path) + + def testRegisteredClasses(self): + # Check that there is at least one registered echo path simulator. + registered_classes = ( + echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES) + self.assertIsInstance(registered_classes, dict) + self.assertGreater(len(registered_classes), 0) + + # Instance factory. + factory = echo_path_simulation_factory.EchoPathSimulatorFactory() + + # Try each registered echo path simulator. + for echo_path_simulator_name in registered_classes: + simulator = factory.GetInstance( + echo_path_simulator_class=registered_classes[ + echo_path_simulator_name], + render_input_filepath=self._audio_track_filepath) + + echo_filepath = simulator.Simulate(self._tmp_path) + if echo_filepath is None: + self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME, + echo_path_simulator_name) + # No other tests in this case. + continue + + # Check that the echo audio track file exists and its length is greater or + # equal to that of the render audio track. + self.assertTrue(os.path.exists(echo_filepath)) + echo = signal_processing.SignalProcessingUtils.LoadWav( + echo_filepath) + self.assertGreaterEqual( + signal_processing.SignalProcessingUtils.CountSamples(echo), + self._audio_track_num_samples) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py new file mode 100644 index 0000000000..59c5f74be4 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py @@ -0,0 +1,427 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Evaluation score abstract class and implementations. +""" + +from __future__ import division +import logging +import os +import re +import subprocess +import sys + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +from . import data_access +from . import exceptions +from . import signal_processing + + +class EvaluationScore(object): + + NAME = None + REGISTERED_CLASSES = {} + + def __init__(self, score_filename_prefix): + self._score_filename_prefix = score_filename_prefix + self._input_signal_metadata = None + self._reference_signal = None + self._reference_signal_filepath = None + self._tested_signal = None + self._tested_signal_filepath = None + self._output_filepath = None + self._score = None + self._render_signal_filepath = None + + @classmethod + def RegisterClass(cls, class_to_register): + """Registers an EvaluationScore implementation. + + Decorator to automatically register the classes that extend EvaluationScore. + Example usage: + + @EvaluationScore.RegisterClass + class AudioLevelScore(EvaluationScore): + pass + """ + cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register + return class_to_register + + @property + def output_filepath(self): + return self._output_filepath + + @property + def score(self): + return self._score + + def SetInputSignalMetadata(self, metadata): + """Sets input signal metadata. + + Args: + metadata: dict instance. + """ + self._input_signal_metadata = metadata + + def SetReferenceSignalFilepath(self, filepath): + """Sets the path to the audio track used as reference signal. + + Args: + filepath: path to the reference audio track. + """ + self._reference_signal_filepath = filepath + + def SetTestedSignalFilepath(self, filepath): + """Sets the path to the audio track used as test signal. + + Args: + filepath: path to the test audio track. + """ + self._tested_signal_filepath = filepath + + def SetRenderSignalFilepath(self, filepath): + """Sets the path to the audio track used as render signal. + + Args: + filepath: path to the test audio track. + """ + self._render_signal_filepath = filepath + + def Run(self, output_path): + """Extracts the score for the set test data pair. + + Args: + output_path: path to the directory where the output is written. + """ + self._output_filepath = os.path.join( + output_path, self._score_filename_prefix + self.NAME + '.txt') + try: + # If the score has already been computed, load. + self._LoadScore() + logging.debug('score found and loaded') + except IOError: + # Compute the score. + logging.debug('score not found, compute') + self._Run(output_path) + + def _Run(self, output_path): + # Abstract method. + raise NotImplementedError() + + def _LoadReferenceSignal(self): + assert self._reference_signal_filepath is not None + self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav( + self._reference_signal_filepath) + + def _LoadTestedSignal(self): + assert self._tested_signal_filepath is not None + self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav( + self._tested_signal_filepath) + + def _LoadScore(self): + return data_access.ScoreFile.Load(self._output_filepath) + + def _SaveScore(self): + return data_access.ScoreFile.Save(self._output_filepath, self._score) + + +@EvaluationScore.RegisterClass +class AudioLevelPeakScore(EvaluationScore): + """Peak audio level score. + + Defined as the difference between the peak audio level of the tested and + the reference signals. + + Unit: dB + Ideal: 0 dB + Worst case: +/-inf dB + """ + + NAME = 'audio_level_peak' + + def __init__(self, score_filename_prefix): + EvaluationScore.__init__(self, score_filename_prefix) + + def _Run(self, output_path): + self._LoadReferenceSignal() + self._LoadTestedSignal() + self._score = self._tested_signal.dBFS - self._reference_signal.dBFS + self._SaveScore() + + +@EvaluationScore.RegisterClass +class MeanAudioLevelScore(EvaluationScore): + """Mean audio level score. + + Defined as the difference between the mean audio level of the tested and + the reference signals. + + Unit: dB + Ideal: 0 dB + Worst case: +/-inf dB + """ + + NAME = 'audio_level_mean' + + def __init__(self, score_filename_prefix): + EvaluationScore.__init__(self, score_filename_prefix) + + def _Run(self, output_path): + self._LoadReferenceSignal() + self._LoadTestedSignal() + + dbfs_diffs_sum = 0.0 + seconds = min(len(self._tested_signal), len( + self._reference_signal)) // 1000 + for t in range(seconds): + t0 = t * seconds + t1 = t0 + seconds + dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS - + self._reference_signal[t0:t1].dBFS) + self._score = dbfs_diffs_sum / float(seconds) + self._SaveScore() + + +@EvaluationScore.RegisterClass +class EchoMetric(EvaluationScore): + """Echo score. + + Proportion of detected echo. + + Unit: ratio + Ideal: 0 + Worst case: 1 + """ + + NAME = 'echo_metric' + + def __init__(self, score_filename_prefix, echo_detector_bin_filepath): + EvaluationScore.__init__(self, score_filename_prefix) + + # POLQA binary file path. + self._echo_detector_bin_filepath = echo_detector_bin_filepath + if not os.path.exists(self._echo_detector_bin_filepath): + logging.error('cannot find EchoMetric tool binary file') + raise exceptions.FileNotFoundError() + + self._echo_detector_bin_path, _ = os.path.split( + self._echo_detector_bin_filepath) + + def _Run(self, output_path): + echo_detector_out_filepath = os.path.join(output_path, + 'echo_detector.out') + if os.path.exists(echo_detector_out_filepath): + os.unlink(echo_detector_out_filepath) + + logging.debug("Render signal filepath: %s", + self._render_signal_filepath) + if not os.path.exists(self._render_signal_filepath): + logging.error( + "Render input required for evaluating the echo metric.") + + args = [ + self._echo_detector_bin_filepath, '--output_file', + echo_detector_out_filepath, '--', '-i', + self._tested_signal_filepath, '-ri', self._render_signal_filepath + ] + logging.debug(' '.join(args)) + subprocess.call(args, cwd=self._echo_detector_bin_path) + + # Parse Echo detector tool output and extract the score. + self._score = self._ParseOutputFile(echo_detector_out_filepath) + self._SaveScore() + + @classmethod + def _ParseOutputFile(cls, echo_metric_file_path): + """ + Parses the POLQA tool output formatted as a table ('-t' option). + + Args: + polqa_out_filepath: path to the POLQA tool output file. + + Returns: + The score as a number in [0, 1]. + """ + with open(echo_metric_file_path) as f: + return float(f.read()) + + +@EvaluationScore.RegisterClass +class PolqaScore(EvaluationScore): + """POLQA score. + + See http://www.polqa.info/. + + Unit: MOS + Ideal: 4.5 + Worst case: 1.0 + """ + + NAME = 'polqa' + + def __init__(self, score_filename_prefix, polqa_bin_filepath): + EvaluationScore.__init__(self, score_filename_prefix) + + # POLQA binary file path. + self._polqa_bin_filepath = polqa_bin_filepath + if not os.path.exists(self._polqa_bin_filepath): + logging.error('cannot find POLQA tool binary file') + raise exceptions.FileNotFoundError() + + # Path to the POLQA directory with binary and license files. + self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath) + + def _Run(self, output_path): + polqa_out_filepath = os.path.join(output_path, 'polqa.out') + if os.path.exists(polqa_out_filepath): + os.unlink(polqa_out_filepath) + + args = [ + self._polqa_bin_filepath, + '-t', + '-q', + '-Overwrite', + '-Ref', + self._reference_signal_filepath, + '-Test', + self._tested_signal_filepath, + '-LC', + 'NB', + '-Out', + polqa_out_filepath, + ] + logging.debug(' '.join(args)) + subprocess.call(args, cwd=self._polqa_tool_path) + + # Parse POLQA tool output and extract the score. + polqa_output = self._ParseOutputFile(polqa_out_filepath) + self._score = float(polqa_output['PolqaScore']) + + self._SaveScore() + + @classmethod + def _ParseOutputFile(cls, polqa_out_filepath): + """ + Parses the POLQA tool output formatted as a table ('-t' option). + + Args: + polqa_out_filepath: path to the POLQA tool output file. + + Returns: + A dict. + """ + data = [] + with open(polqa_out_filepath) as f: + for line in f: + line = line.strip() + if len(line) == 0 or line.startswith('*'): + # Ignore comments. + continue + # Read fields. + data.append(re.split(r'\t+', line)) + + # Two rows expected (header and values). + assert len(data) == 2, 'Cannot parse POLQA output' + number_of_fields = len(data[0]) + assert number_of_fields == len(data[1]) + + # Build and return a dictionary with field names (header) as keys and the + # corresponding field values as values. + return { + data[0][index]: data[1][index] + for index in range(number_of_fields) + } + + +@EvaluationScore.RegisterClass +class TotalHarmonicDistorsionScore(EvaluationScore): + """Total harmonic distorsion plus noise score. + + Total harmonic distorsion plus noise score. + See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN". + + Unit: -. + Ideal: 0. + Worst case: +inf + """ + + NAME = 'thd' + + def __init__(self, score_filename_prefix): + EvaluationScore.__init__(self, score_filename_prefix) + self._input_frequency = None + + def _Run(self, output_path): + self._CheckInputSignal() + + self._LoadTestedSignal() + if self._tested_signal.channels != 1: + raise exceptions.EvaluationScoreException( + 'unsupported number of channels') + samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( + self._tested_signal) + + # Init. + num_samples = len(samples) + duration = len(self._tested_signal) / 1000.0 + scaling = 2.0 / num_samples + max_freq = self._tested_signal.frame_rate / 2 + f0_freq = float(self._input_frequency) + t = np.linspace(0, duration, num_samples) + + # Analyze harmonics. + b_terms = [] + n = 1 + while f0_freq * n < max_freq: + x_n = np.sum( + samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling + y_n = np.sum( + samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling + b_terms.append(np.sqrt(x_n**2 + y_n**2)) + n += 1 + + output_without_fundamental = samples - b_terms[0] * np.sin( + 2.0 * np.pi * f0_freq * t) + distortion_and_noise = np.sqrt( + np.sum(output_without_fundamental**2) * np.pi * scaling) + + # TODO(alessiob): Fix or remove if not needed. + # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0] + + # TODO(alessiob): Check the range of `thd_plus_noise` and update the class + # docstring above if accordingly. + thd_plus_noise = distortion_and_noise / b_terms[0] + + self._score = thd_plus_noise + self._SaveScore() + + def _CheckInputSignal(self): + # Check input signal and get properties. + try: + if self._input_signal_metadata['signal'] != 'pure_tone': + raise exceptions.EvaluationScoreException( + 'The THD score requires a pure tone as input signal') + self._input_frequency = self._input_signal_metadata['frequency'] + if self._input_signal_metadata[ + 'test_data_gen_name'] != 'identity' or ( + self._input_signal_metadata['test_data_gen_config'] != + 'default'): + raise exceptions.EvaluationScoreException( + 'The THD score cannot be used with any test data generator other ' + 'than "identity"') + except TypeError: + raise exceptions.EvaluationScoreException( + 'The THD score requires an input signal with associated metadata' + ) + except KeyError: + raise exceptions.EvaluationScoreException( + 'Invalid input signal metadata to compute the THD score') diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py new file mode 100644 index 0000000000..5749a8924b --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py @@ -0,0 +1,55 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""EvaluationScore factory class. +""" + +import logging + +from . import exceptions +from . import eval_scores + + +class EvaluationScoreWorkerFactory(object): + """Factory class used to instantiate evaluation score workers. + + The ctor gets the parametrs that are used to instatiate the evaluation score + workers. + """ + + def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path): + self._score_filename_prefix = None + self._polqa_tool_bin_path = polqa_tool_bin_path + self._echo_metric_tool_bin_path = echo_metric_tool_bin_path + + def SetScoreFilenamePrefix(self, prefix): + self._score_filename_prefix = prefix + + def GetInstance(self, evaluation_score_class): + """Creates an EvaluationScore instance given a class object. + + Args: + evaluation_score_class: EvaluationScore class object (not an instance). + + Returns: + An EvaluationScore instance. + """ + if self._score_filename_prefix is None: + raise exceptions.InitializationException( + 'The score file name prefix for evaluation score workers is not set' + ) + logging.debug('factory producing a %s evaluation score', + evaluation_score_class) + + if evaluation_score_class == eval_scores.PolqaScore: + return eval_scores.PolqaScore(self._score_filename_prefix, + self._polqa_tool_bin_path) + elif evaluation_score_class == eval_scores.EchoMetric: + return eval_scores.EchoMetric(self._score_filename_prefix, + self._echo_metric_tool_bin_path) + else: + return evaluation_score_class(self._score_filename_prefix) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py new file mode 100644 index 0000000000..12e043320e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py @@ -0,0 +1,137 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the eval_scores module. +""" + +import os +import shutil +import tempfile +import unittest + +import pydub + +from . import data_access +from . import eval_scores +from . import eval_scores_factory +from . import signal_processing + + +class TestEvalScores(unittest.TestCase): + """Unit tests for the eval_scores module. + """ + + def setUp(self): + """Create temporary output folder and two audio track files.""" + self._output_path = tempfile.mkdtemp() + + # Create fake reference and tested (i.e., APM output) audio track files. + silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + fake_reference_signal = (signal_processing.SignalProcessingUtils. + GenerateWhiteNoise(silence)) + fake_tested_signal = (signal_processing.SignalProcessingUtils. + GenerateWhiteNoise(silence)) + + # Save fake audio tracks. + self._fake_reference_signal_filepath = os.path.join( + self._output_path, 'fake_ref.wav') + signal_processing.SignalProcessingUtils.SaveWav( + self._fake_reference_signal_filepath, fake_reference_signal) + self._fake_tested_signal_filepath = os.path.join( + self._output_path, 'fake_test.wav') + signal_processing.SignalProcessingUtils.SaveWav( + self._fake_tested_signal_filepath, fake_tested_signal) + + def tearDown(self): + """Recursively delete temporary folder.""" + shutil.rmtree(self._output_path) + + def testRegisteredClasses(self): + # Evaluation score names to exclude (tested separately). + exceptions = ['thd', 'echo_metric'] + + # Preliminary check. + self.assertTrue(os.path.exists(self._output_path)) + + # Check that there is at least one registered evaluation score worker. + registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES + self.assertIsInstance(registered_classes, dict) + self.assertGreater(len(registered_classes), 0) + + # Instance evaluation score workers factory with fake dependencies. + eval_score_workers_factory = ( + eval_scores_factory.EvaluationScoreWorkerFactory( + polqa_tool_bin_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'), + echo_metric_tool_bin_path=None)) + eval_score_workers_factory.SetScoreFilenamePrefix('scores-') + + # Try each registered evaluation score worker. + for eval_score_name in registered_classes: + if eval_score_name in exceptions: + continue + + # Instance evaluation score worker. + eval_score_worker = eval_score_workers_factory.GetInstance( + registered_classes[eval_score_name]) + + # Set fake input metadata and reference and test file paths, then run. + eval_score_worker.SetReferenceSignalFilepath( + self._fake_reference_signal_filepath) + eval_score_worker.SetTestedSignalFilepath( + self._fake_tested_signal_filepath) + eval_score_worker.Run(self._output_path) + + # Check output. + score = data_access.ScoreFile.Load( + eval_score_worker.output_filepath) + self.assertTrue(isinstance(score, float)) + + def testTotalHarmonicDistorsionScore(self): + # Init. + pure_tone_freq = 5000.0 + eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-') + eval_score_worker.SetInputSignalMetadata({ + 'signal': + 'pure_tone', + 'frequency': + pure_tone_freq, + 'test_data_gen_name': + 'identity', + 'test_data_gen_config': + 'default', + }) + template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + + # Create 3 test signals: pure tone, pure tone + white noise, white noise + # only. + pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone( + template, pure_tone_freq) + white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + template) + noisy_tone = signal_processing.SignalProcessingUtils.MixSignals( + pure_tone, white_noise) + + # Compute scores for increasingly distorted pure tone signals. + scores = [None, None, None] + for index, tested_signal in enumerate( + [pure_tone, noisy_tone, white_noise]): + # Save signal. + tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav') + signal_processing.SignalProcessingUtils.SaveWav( + tmp_filepath, tested_signal) + + # Compute score. + eval_score_worker.SetTestedSignalFilepath(tmp_filepath) + eval_score_worker.Run(self._output_path) + scores[index] = eval_score_worker.score + + # Remove output file to avoid caching. + os.remove(eval_score_worker.output_filepath) + + # Validate scores (lowest score with a pure tone). + self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)])) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py new file mode 100644 index 0000000000..2599085329 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py @@ -0,0 +1,57 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Evaluator of the APM module. +""" + +import logging + + +class ApmModuleEvaluator(object): + """APM evaluator class. + """ + + def __init__(self): + pass + + @classmethod + def Run(cls, evaluation_score_workers, apm_input_metadata, + apm_output_filepath, reference_input_filepath, + render_input_filepath, output_path): + """Runs the evaluation. + + Iterates over the given evaluation score workers. + + Args: + evaluation_score_workers: list of EvaluationScore instances. + apm_input_metadata: dictionary with metadata of the APM input. + apm_output_filepath: path to the audio track file with the APM output. + reference_input_filepath: path to the reference audio track file. + output_path: output path. + + Returns: + A dict of evaluation score name and score pairs. + """ + # Init. + scores = {} + + for evaluation_score_worker in evaluation_score_workers: + logging.info(' computing <%s> score', + evaluation_score_worker.NAME) + evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata) + evaluation_score_worker.SetReferenceSignalFilepath( + reference_input_filepath) + evaluation_score_worker.SetTestedSignalFilepath( + apm_output_filepath) + evaluation_score_worker.SetRenderSignalFilepath( + render_input_filepath) + + evaluation_score_worker.Run(output_path) + scores[ + evaluation_score_worker.NAME] = evaluation_score_worker.score + + return scores diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py new file mode 100644 index 0000000000..893901d359 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py @@ -0,0 +1,45 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Exception classes. +""" + + +class FileNotFoundError(Exception): + """File not found exception. + """ + pass + + +class SignalProcessingException(Exception): + """Signal processing exception. + """ + pass + + +class InputMixerException(Exception): + """Input mixer exception. + """ + pass + + +class InputSignalCreatorException(Exception): + """Input signal creator exception. + """ + pass + + +class EvaluationScoreException(Exception): + """Evaluation score exception. + """ + pass + + +class InitializationException(Exception): + """Initialization exception. + """ + pass diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/export.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/export.py new file mode 100644 index 0000000000..fe3a6c7cb9 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/export.py @@ -0,0 +1,426 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import functools +import hashlib +import logging +import os +import re +import sys + +try: + import csscompressor +except ImportError: + logging.critical( + 'Cannot import the third-party Python package csscompressor') + sys.exit(1) + +try: + import jsmin +except ImportError: + logging.critical('Cannot import the third-party Python package jsmin') + sys.exit(1) + + +class HtmlExport(object): + """HTML exporter class for APM quality scores.""" + + _NEW_LINE = '\n' + + # CSS and JS file paths. + _PATH = os.path.dirname(os.path.realpath(__file__)) + _CSS_FILEPATH = os.path.join(_PATH, 'results.css') + _CSS_MINIFIED = True + _JS_FILEPATH = os.path.join(_PATH, 'results.js') + _JS_MINIFIED = True + + def __init__(self, output_filepath): + self._scores_data_frame = None + self._output_filepath = output_filepath + + def Export(self, scores_data_frame): + """Exports scores into an HTML file. + + Args: + scores_data_frame: DataFrame instance. + """ + self._scores_data_frame = scores_data_frame + html = [ + '<html>', + self._BuildHeader(), + ('<script type="text/javascript">' + '(function () {' + 'window.addEventListener(\'load\', function () {' + 'var inspector = new AudioInspector();' + '});' + '})();' + '</script>'), '<body>', + self._BuildBody(), '</body>', '</html>' + ] + self._Save(self._output_filepath, self._NEW_LINE.join(html)) + + def _BuildHeader(self): + """Builds the <head> section of the HTML file. + + The header contains the page title and either embedded or linked CSS and JS + files. + + Returns: + A string with <head>...</head> HTML. + """ + html = ['<head>', '<title>Results</title>'] + + # Add Material Design hosted libs. + html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/' + 'css?family=Roboto:300,400,500,700" type="text/css">') + html.append( + '<link rel="stylesheet" href="https://fonts.googleapis.com/' + 'icon?family=Material+Icons">') + html.append( + '<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/' + 'material.indigo-pink.min.css">') + html.append('<script defer src="https://code.getmdl.io/1.3.0/' + 'material.min.js"></script>') + + # Embed custom JavaScript and CSS files. + html.append('<script>') + with open(self._JS_FILEPATH) as f: + html.append( + jsmin.jsmin(f.read()) if self._JS_MINIFIED else ( + f.read().rstrip())) + html.append('</script>') + html.append('<style>') + with open(self._CSS_FILEPATH) as f: + html.append( + csscompressor.compress(f.read()) if self._CSS_MINIFIED else ( + f.read().rstrip())) + html.append('</style>') + + html.append('</head>') + + return self._NEW_LINE.join(html) + + def _BuildBody(self): + """Builds the content of the <body> section.""" + score_names = self._scores_data_frame[ + 'eval_score_name'].drop_duplicates().values.tolist() + + html = [ + ('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header ' + 'mdl-layout--fixed-tabs">'), + '<header class="mdl-layout__header">', + '<div class="mdl-layout__header-row">', + '<span class="mdl-layout-title">APM QA results ({})</span>'.format( + self._output_filepath), + '</div>', + ] + + # Tab selectors. + html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">') + for tab_index, score_name in enumerate(score_names): + is_active = tab_index == 0 + html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">' + '{}</a>'.format(tab_index, + ' is-active' if is_active else '', + self._FormatName(score_name))) + html.append('</div>') + + html.append('</header>') + html.append( + '<main class="mdl-layout__content" style="overflow-x: auto;">') + + # Tabs content. + for tab_index, score_name in enumerate(score_names): + html.append('<section class="mdl-layout__tab-panel{}" ' + 'id="score-tab-{}">'.format( + ' is-active' if is_active else '', tab_index)) + html.append('<div class="page-content">') + html.append( + self._BuildScoreTab(score_name, ('s{}'.format(tab_index), ))) + html.append('</div>') + html.append('</section>') + + html.append('</main>') + html.append('</div>') + + # Add snackbar for notifications. + html.append( + '<div id="snackbar" aria-live="assertive" aria-atomic="true"' + ' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">' + '<div class="mdl-snackbar__text"></div>' + '<button type="button" class="mdl-snackbar__action"></button>' + '</div>') + + return self._NEW_LINE.join(html) + + def _BuildScoreTab(self, score_name, anchor_data): + """Builds the content of a tab.""" + # Find unique values. + scores = self._scores_data_frame[ + self._scores_data_frame.eval_score_name == score_name] + apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config'])) + test_data_gen_configs = sorted( + self._FindUniqueTuples(scores, + ['test_data_gen', 'test_data_gen_params'])) + + html = [ + '<div class="mdl-grid">', + '<div class="mdl-layout-spacer"></div>', + '<div class="mdl-cell mdl-cell--10-col">', + ('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" ' + 'style="width: 100%;">'), + ] + + # Header. + html.append('<thead><tr><th>APM config / Test data generator</th>') + for test_data_gen_info in test_data_gen_configs: + html.append('<th>{} {}</th>'.format( + self._FormatName(test_data_gen_info[0]), + test_data_gen_info[1])) + html.append('</tr></thead>') + + # Body. + html.append('<tbody>') + for apm_config in apm_configs: + html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>') + for test_data_gen_info in test_data_gen_configs: + dialog_id = self._ScoreStatsInspectorDialogId( + score_name, apm_config[0], test_data_gen_info[0], + test_data_gen_info[1]) + html.append( + '<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'. + format( + dialog_id, + self._BuildScoreTableCell(score_name, + test_data_gen_info[0], + test_data_gen_info[1], + apm_config[0]))) + html.append('</tr>') + html.append('</tbody>') + + html.append( + '</table></div><div class="mdl-layout-spacer"></div></div>') + + html.append( + self._BuildScoreStatsInspectorDialogs(score_name, apm_configs, + test_data_gen_configs, + anchor_data)) + + return self._NEW_LINE.join(html) + + def _BuildScoreTableCell(self, score_name, test_data_gen, + test_data_gen_params, apm_config): + """Builds the content of a table cell for a score table.""" + scores = self._SliceDataForScoreTableCell(score_name, apm_config, + test_data_gen, + test_data_gen_params) + stats = self._ComputeScoreStats(scores) + + html = [] + items_id_prefix = (score_name + test_data_gen + test_data_gen_params + + apm_config) + if stats['count'] == 1: + # Show the only available score. + item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest() + html.append('<div id="single-value-{0}">{1:f}</div>'.format( + item_id, scores['score'].mean())) + html.append( + '<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}' + '</div>'.format(item_id, 'single value')) + else: + # Show stats. + for stat_name in ['min', 'max', 'mean', 'std dev']: + item_id = hashlib.md5( + (items_id_prefix + stat_name).encode('utf-8')).hexdigest() + html.append('<div id="stats-{0}">{1:f}</div>'.format( + item_id, stats[stat_name])) + html.append( + '<div class="mdl-tooltip" data-mdl-for="stats-{}">{}' + '</div>'.format(item_id, stat_name)) + + return self._NEW_LINE.join(html) + + def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs, + test_data_gen_configs, anchor_data): + """Builds a set of score stats inspector dialogs.""" + html = [] + for apm_config in apm_configs: + for test_data_gen_info in test_data_gen_configs: + dialog_id = self._ScoreStatsInspectorDialogId( + score_name, apm_config[0], test_data_gen_info[0], + test_data_gen_info[1]) + + html.append('<dialog class="mdl-dialog" id="{}" ' + 'style="width: 40%;">'.format(dialog_id)) + + # Content. + html.append('<div class="mdl-dialog__content">') + html.append( + '<h6><strong>APM config preset</strong>: {}<br/>' + '<strong>Test data generator</strong>: {} ({})</h6>'. + format(self._FormatName(apm_config[0]), + self._FormatName(test_data_gen_info[0]), + test_data_gen_info[1])) + html.append( + self._BuildScoreStatsInspectorDialog( + score_name, apm_config[0], test_data_gen_info[0], + test_data_gen_info[1], anchor_data + (dialog_id, ))) + html.append('</div>') + + # Actions. + html.append('<div class="mdl-dialog__actions">') + html.append('<button type="button" class="mdl-button" ' + 'onclick="closeScoreStatsInspector()">' + 'Close</button>') + html.append('</div>') + + html.append('</dialog>') + + return self._NEW_LINE.join(html) + + def _BuildScoreStatsInspectorDialog(self, score_name, apm_config, + test_data_gen, test_data_gen_params, + anchor_data): + """Builds one score stats inspector dialog.""" + scores = self._SliceDataForScoreTableCell(score_name, apm_config, + test_data_gen, + test_data_gen_params) + + capture_render_pairs = sorted( + self._FindUniqueTuples(scores, ['capture', 'render'])) + echo_simulators = sorted( + self._FindUniqueTuples(scores, ['echo_simulator'])) + + html = [ + '<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">' + ] + + # Header. + html.append('<thead><tr><th>Capture-Render / Echo simulator</th>') + for echo_simulator in echo_simulators: + html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>') + html.append('</tr></thead>') + + # Body. + html.append('<tbody>') + for row, (capture, render) in enumerate(capture_render_pairs): + html.append('<tr><td><div>{}</div><div>{}</div></td>'.format( + capture, render)) + for col, echo_simulator in enumerate(echo_simulators): + score_tuple = self._SliceDataForScoreStatsTableCell( + scores, capture, render, echo_simulator[0]) + cell_class = 'r{}c{}'.format(row, col) + html.append('<td class="single-score-cell {}">{}</td>'.format( + cell_class, + self._BuildScoreStatsInspectorTableCell( + score_tuple, anchor_data + (cell_class, )))) + html.append('</tr>') + html.append('</tbody>') + + html.append('</table>') + + # Placeholder for the audio inspector. + html.append('<div class="audio-inspector-placeholder"></div>') + + return self._NEW_LINE.join(html) + + def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data): + """Builds the content of a cell of a score stats inspector.""" + anchor = '&'.join(anchor_data) + html = [('<div class="v">{}</div>' + '<button class="mdl-button mdl-js-button mdl-button--icon"' + ' data-anchor="{}">' + '<i class="material-icons mdl-color-text--blue-grey">link</i>' + '</button>').format(score_tuple.score, anchor)] + + # Add all the available file paths as hidden data. + for field_name in score_tuple.keys(): + if field_name.endswith('_filepath'): + html.append( + '<input type="hidden" name="{}" value="{}">'.format( + field_name, score_tuple[field_name])) + + return self._NEW_LINE.join(html) + + def _SliceDataForScoreTableCell(self, score_name, apm_config, + test_data_gen, test_data_gen_params): + """Slices `self._scores_data_frame` to extract the data for a tab.""" + masks = [] + masks.append(self._scores_data_frame.eval_score_name == score_name) + masks.append(self._scores_data_frame.apm_config == apm_config) + masks.append(self._scores_data_frame.test_data_gen == test_data_gen) + masks.append(self._scores_data_frame.test_data_gen_params == + test_data_gen_params) + mask = functools.reduce((lambda i1, i2: i1 & i2), masks) + del masks + return self._scores_data_frame[mask] + + @classmethod + def _SliceDataForScoreStatsTableCell(cls, scores, capture, render, + echo_simulator): + """Slices `scores` to extract the data for a tab.""" + masks = [] + + masks.append(scores.capture == capture) + masks.append(scores.render == render) + masks.append(scores.echo_simulator == echo_simulator) + mask = functools.reduce((lambda i1, i2: i1 & i2), masks) + del masks + + sliced_data = scores[mask] + assert len(sliced_data) == 1, 'single score is expected' + return sliced_data.iloc[0] + + @classmethod + def _FindUniqueTuples(cls, data_frame, fields): + """Slices `data_frame` to a list of fields and finds unique tuples.""" + return data_frame[fields].drop_duplicates().values.tolist() + + @classmethod + def _ComputeScoreStats(cls, data_frame): + """Computes score stats.""" + scores = data_frame['score'] + return { + 'count': scores.count(), + 'min': scores.min(), + 'max': scores.max(), + 'mean': scores.mean(), + 'std dev': scores.std(), + } + + @classmethod + def _ScoreStatsInspectorDialogId(cls, score_name, apm_config, + test_data_gen, test_data_gen_params): + """Assigns a unique name to a dialog.""" + return 'score-stats-dialog-' + hashlib.md5( + 'score-stats-inspector-{}-{}-{}-{}'.format( + score_name, apm_config, test_data_gen, + test_data_gen_params).encode('utf-8')).hexdigest() + + @classmethod + def _Save(cls, output_filepath, html): + """Writes the HTML file. + + Args: + output_filepath: output file path. + html: string with the HTML content. + """ + with open(output_filepath, 'w') as f: + f.write(html) + + @classmethod + def _FormatName(cls, name): + """Formats a name. + + Args: + name: a string. + + Returns: + A copy of name in which underscores and dashes are replaced with a space. + """ + return re.sub(r'[_\-]', ' ', name) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/export_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/export_unittest.py new file mode 100644 index 0000000000..412aa7c4e7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/export_unittest.py @@ -0,0 +1,86 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the export module. +""" + +import logging +import os +import shutil +import tempfile +import unittest + +import pyquery as pq + +from . import audioproc_wrapper +from . import collect_data +from . import eval_scores_factory +from . import evaluation +from . import export +from . import simulation +from . import test_data_generation_factory + + +class TestExport(unittest.TestCase): + """Unit tests for the export module. + """ + + _CLEAN_TMP_OUTPUT = True + + def setUp(self): + """Creates temporary data to export.""" + self._tmp_path = tempfile.mkdtemp() + + # Run a fake experiment to produce data to export. + simulator = simulation.ApmModuleSimulator( + test_data_generator_factory=( + test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path='', + noise_tracks_path='', + copy_with_identity=False)), + evaluation_score_factory=( + eval_scores_factory.EvaluationScoreWorkerFactory( + polqa_tool_bin_path=os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'fake_polqa'), + echo_metric_tool_bin_path=None)), + ap_wrapper=audioproc_wrapper.AudioProcWrapper( + audioproc_wrapper.AudioProcWrapper. + DEFAULT_APM_SIMULATOR_BIN_PATH), + evaluator=evaluation.ApmModuleEvaluator()) + simulator.Run( + config_filepaths=['apm_configs/default.json'], + capture_input_filepaths=[ + os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'), + os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'), + ], + test_data_generator_names=['identity', 'white_noise'], + eval_score_names=['audio_level_peak', 'audio_level_mean'], + output_dir=self._tmp_path) + + # Export results. + p = collect_data.InstanceArgumentsParser() + args = p.parse_args(['--output_dir', self._tmp_path]) + src_path = collect_data.ConstructSrcPath(args) + self._data_to_export = collect_data.FindScores(src_path, args) + + def tearDown(self): + """Recursively deletes temporary folders.""" + if self._CLEAN_TMP_OUTPUT: + shutil.rmtree(self._tmp_path) + else: + logging.warning(self.id() + ' did not clean the temporary path ' + + (self._tmp_path)) + + def testCreateHtmlReport(self): + fn_out = os.path.join(self._tmp_path, 'results.html') + exporter = export.HtmlExport(fn_out) + exporter.Export(self._data_to_export) + + document = pq.PyQuery(filename=fn_out) + self.assertIsInstance(document, pq.PyQuery) + # TODO(alessiob): Use PyQuery API to check the HTML file. diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py new file mode 100644 index 0000000000..a7db7b4840 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py @@ -0,0 +1,75 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +from __future__ import division + +import logging +import os +import subprocess +import shutil +import sys +import tempfile + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +from . import signal_processing + + +class ExternalVad(object): + def __init__(self, path_to_binary, name): + """Args: + path_to_binary: path to binary that accepts '-i <wav>', '-o + <float probabilities>'. There must be one float value per + 10ms audio + name: a name to identify the external VAD. Used for saving + the output as extvad_output-<name>. + """ + self._path_to_binary = path_to_binary + self.name = name + assert os.path.exists(self._path_to_binary), (self._path_to_binary) + self._vad_output = None + + def Run(self, wav_file_path): + _signal = signal_processing.SignalProcessingUtils.LoadWav( + wav_file_path) + if _signal.channels != 1: + raise NotImplementedError('Multiple-channel' + ' annotations not implemented') + if _signal.frame_rate != 48000: + raise NotImplementedError('Frame rates ' + 'other than 48000 not implemented') + + tmp_path = tempfile.mkdtemp() + try: + output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp') + subprocess.call([ + self._path_to_binary, '-i', wav_file_path, '-o', + output_file_path + ]) + self._vad_output = np.fromfile(output_file_path, np.float32) + except Exception as e: + logging.error('Error while running the ' + self.name + ' VAD (' + + e.message + ')') + finally: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + + def GetVadOutput(self): + assert self._vad_output is not None + return self._vad_output + + @classmethod + def ConstructVadDict(cls, vad_paths, vad_names): + external_vads = {} + for path, name in zip(vad_paths, vad_names): + external_vads[name] = ExternalVad(path, name) + return external_vads diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py new file mode 100755 index 0000000000..f679f8c94a --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py @@ -0,0 +1,25 @@ +#!/usr/bin/python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +import argparse +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', required=True) + parser.add_argument('-o', required=True) + + args = parser.parse_args() + + array = np.arange(100, dtype=np.float32) + array.tofile(open(args.o, 'w')) + + +if __name__ == '__main__': + main() diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_polqa.cc b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_polqa.cc new file mode 100644 index 0000000000..6f3b2d1dd7 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_polqa.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <fstream> +#include <iostream> +#include <string> + +#include "absl/strings/string_view.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { +namespace { + +const char* const kErrorMessage = "-Out /path/to/output/file is mandatory"; + +// Writes fake output intended to be parsed by +// quality_assessment.eval_scores.PolqaScore. +void WriteOutputFile(absl::string_view output_file_path) { + RTC_CHECK_NE(output_file_path, ""); + std::ofstream out(std::string{output_file_path}); + RTC_CHECK(!out.bad()); + out << "* Fake Polqa output" << std::endl; + out << "FakeField1\tPolqaScore\tFakeField2" << std::endl; + out << "FakeValue1\t3.25\tFakeValue2" << std::endl; + out.close(); +} + +} // namespace + +int main(int argc, char* argv[]) { + // Find "-Out" and use its next argument as output file path. + RTC_CHECK_GE(argc, 3) << kErrorMessage; + const std::string kSoughtFlagName = "-Out"; + for (int i = 1; i < argc - 1; ++i) { + if (kSoughtFlagName.compare(argv[i]) == 0) { + WriteOutputFile(argv[i + 1]); + return 0; + } + } + RTC_FATAL() << kErrorMessage; +} + +} // namespace test +} // namespace webrtc + +int main(int argc, char* argv[]) { + return webrtc::test::main(argc, argv); +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py new file mode 100644 index 0000000000..af022bd461 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py @@ -0,0 +1,97 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Input mixer module. +""" + +import logging +import os + +from . import exceptions +from . import signal_processing + + +class ApmInputMixer(object): + """Class to mix a set of audio segments down to the APM input.""" + + _HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal' + + def __init__(self): + pass + + @classmethod + def HardClippingLogMessage(cls): + """Returns the log message used when hard clipping is detected in the mix. + + This method is mainly intended to be used by the unit tests. + """ + return cls._HARD_CLIPPING_LOG_MSG + + @classmethod + def Mix(cls, output_path, capture_input_filepath, echo_filepath): + """Mixes capture and echo. + + Creates the overall capture input for APM by mixing the "echo-free" capture + signal with the echo signal (e.g., echo simulated via the + echo_path_simulation module). + + The echo signal cannot be shorter than the capture signal and the generated + mix will have the same duration of the capture signal. The latter property + is enforced in order to let the input of APM and the reference signal + created by TestDataGenerator have the same length (required for the + evaluation step). + + Hard-clipping may occur in the mix; a warning is raised when this happens. + + If `echo_filepath` is None, nothing is done and `capture_input_filepath` is + returned. + + Args: + speech: AudioSegment instance. + echo_path: AudioSegment instance or None. + + Returns: + Path to the mix audio track file. + """ + if echo_filepath is None: + return capture_input_filepath + + # Build the mix output file name as a function of the echo file name. + # This ensures that if the internal parameters of the echo path simulator + # change, no erroneous cache hit occurs. + echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1]) + capture_input_file_name, _ = os.path.splitext( + os.path.split(capture_input_filepath)[1]) + mix_filepath = os.path.join( + output_path, + 'mix_capture_{}_{}.wav'.format(capture_input_file_name, + echo_file_name)) + + # Create the mix if not done yet. + mix = None + if not os.path.exists(mix_filepath): + echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav( + capture_input_filepath) + echo = signal_processing.SignalProcessingUtils.LoadWav( + echo_filepath) + + if signal_processing.SignalProcessingUtils.CountSamples(echo) < ( + signal_processing.SignalProcessingUtils.CountSamples( + echo_free_capture)): + raise exceptions.InputMixerException( + 'echo cannot be shorter than capture') + + mix = echo_free_capture.overlay(echo) + signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix) + + # Check if hard clipping occurs. + if mix is None: + mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) + if signal_processing.SignalProcessingUtils.DetectHardClipping(mix): + logging.warning(cls._HARD_CLIPPING_LOG_MSG) + + return mix_filepath diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer_unittest.py new file mode 100644 index 0000000000..4fd5e4f1ee --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer_unittest.py @@ -0,0 +1,140 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the input mixer module. +""" + +import logging +import os +import shutil +import tempfile +import unittest + +import mock + +from . import exceptions +from . import input_mixer +from . import signal_processing + + +class TestApmInputMixer(unittest.TestCase): + """Unit tests for the ApmInputMixer class. + """ + + # Audio track file names created in setUp(). + _FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer'] + + # Target peak power level (dBFS) of each audio track file created in setUp(). + # These values are hand-crafted in order to make saturation happen when + # capture and echo_2 are mixed and the contrary for capture and echo_1. + # None means that the power is not changed. + _MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None] + + # Audio track file durations in milliseconds. + _DURATIONS = [1000, 1000, 1000, 800, 1200] + + _SAMPLE_RATE = 48000 + + def setUp(self): + """Creates temporary data.""" + self._tmp_path = tempfile.mkdtemp() + + # Create audio track files. + self._audio_tracks = {} + for filename, peak_power, duration in zip(self._FILENAMES, + self._MAX_PEAK_POWER_LEVELS, + self._DURATIONS): + audio_track_filepath = os.path.join(self._tmp_path, + '{}.wav'.format(filename)) + + # Create a pure tone with the target peak power level. + template = signal_processing.SignalProcessingUtils.GenerateSilence( + duration=duration, sample_rate=self._SAMPLE_RATE) + signal = signal_processing.SignalProcessingUtils.GeneratePureTone( + template) + if peak_power is not None: + signal = signal.apply_gain(-signal.max_dBFS + peak_power) + + signal_processing.SignalProcessingUtils.SaveWav( + audio_track_filepath, signal) + self._audio_tracks[filename] = { + 'filepath': + audio_track_filepath, + 'num_samples': + signal_processing.SignalProcessingUtils.CountSamples(signal) + } + + def tearDown(self): + """Recursively deletes temporary folders.""" + shutil.rmtree(self._tmp_path) + + def testCheckMixSameDuration(self): + """Checks the duration when mixing capture and echo with same duration.""" + mix_filepath = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['echo_1']['filepath']) + self.assertTrue(os.path.exists(mix_filepath)) + + mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) + self.assertEqual( + self._audio_tracks['capture']['num_samples'], + signal_processing.SignalProcessingUtils.CountSamples(mix)) + + def testRejectShorterEcho(self): + """Rejects echo signals that are shorter than the capture signal.""" + try: + _ = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['shorter']['filepath']) + self.fail('no exception raised') + except exceptions.InputMixerException: + pass + + def testCheckMixDurationWithLongerEcho(self): + """Checks the duration when mixing an echo longer than the capture.""" + mix_filepath = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['longer']['filepath']) + self.assertTrue(os.path.exists(mix_filepath)) + + mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath) + self.assertEqual( + self._audio_tracks['capture']['num_samples'], + signal_processing.SignalProcessingUtils.CountSamples(mix)) + + def testCheckOutputFileNamesConflict(self): + """Checks that different echo files lead to different output file names.""" + mix1_filepath = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['echo_1']['filepath']) + self.assertTrue(os.path.exists(mix1_filepath)) + + mix2_filepath = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['echo_2']['filepath']) + self.assertTrue(os.path.exists(mix2_filepath)) + + self.assertNotEqual(mix1_filepath, mix2_filepath) + + def testHardClippingLogExpected(self): + """Checks that hard clipping warning is raised when occurring.""" + logging.warning = mock.MagicMock(name='warning') + _ = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['echo_2']['filepath']) + logging.warning.assert_called_once_with( + input_mixer.ApmInputMixer.HardClippingLogMessage()) + + def testHardClippingLogNotExpected(self): + """Checks that hard clipping warning is not raised when not occurring.""" + logging.warning = mock.MagicMock(name='warning') + _ = input_mixer.ApmInputMixer.Mix( + self._tmp_path, self._audio_tracks['capture']['filepath'], + self._audio_tracks['echo_1']['filepath']) + self.assertNotIn( + mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()), + logging.warning.call_args_list) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py new file mode 100644 index 0000000000..b64fdcca89 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py @@ -0,0 +1,68 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Input signal creator module. +""" + +from . import exceptions +from . import signal_processing + + +class InputSignalCreator(object): + """Input signal creator class. + """ + + @classmethod + def Create(cls, name, raw_params): + """Creates a input signal and its metadata. + + Args: + name: Input signal creator name. + raw_params: Tuple of parameters to pass to the specific signal creator. + + Returns: + (AudioSegment, dict) tuple. + """ + try: + signal = {} + params = {} + + if name == 'pure_tone': + params['frequency'] = float(raw_params[0]) + params['duration'] = int(raw_params[1]) + signal = cls._CreatePureTone(params['frequency'], + params['duration']) + else: + raise exceptions.InputSignalCreatorException( + 'Invalid input signal creator name') + + # Complete metadata. + params['signal'] = name + + return signal, params + except (TypeError, AssertionError) as e: + raise exceptions.InputSignalCreatorException( + 'Invalid signal creator parameters: {}'.format(e)) + + @classmethod + def _CreatePureTone(cls, frequency, duration): + """ + Generates a pure tone at 48000 Hz. + + Args: + frequency: Float in (0-24000] (Hz). + duration: Integer (milliseconds). + + Returns: + AudioSegment instance. + """ + assert 0 < frequency <= 24000 + assert duration > 0 + template = signal_processing.SignalProcessingUtils.GenerateSilence( + duration) + return signal_processing.SignalProcessingUtils.GeneratePureTone( + template, frequency) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/results.css b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/results.css new file mode 100644 index 0000000000..2f406bb002 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/results.css @@ -0,0 +1,32 @@ +/* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +td.selected-score { + background-color: #DDD; +} + +td.single-score-cell{ + text-align: center; +} + +.audio-inspector { + text-align: center; +} + +.audio-inspector div{ + margin-bottom: 0; + padding-bottom: 0; + padding-top: 0; +} + +.audio-inspector div div{ + margin-bottom: 0; + padding-bottom: 0; + padding-top: 0; +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/results.js b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/results.js new file mode 100644 index 0000000000..8e47411058 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/results.js @@ -0,0 +1,376 @@ +// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. An additional intellectual property rights grant can be found +// in the file PATENTS. All contributing project authors may +// be found in the AUTHORS file in the root of the source tree. + +/** + * Opens the score stats inspector dialog. + * @param {String} dialogId: identifier of the dialog to show. + * @return {DOMElement} The dialog element that has been opened. + */ +function openScoreStatsInspector(dialogId) { + var dialog = document.getElementById(dialogId); + dialog.showModal(); + return dialog; +} + +/** + * Closes the score stats inspector dialog. + */ +function closeScoreStatsInspector() { + var dialog = document.querySelector('dialog[open]'); + if (dialog == null) + return; + dialog.close(); +} + +/** + * Audio inspector class. + * @constructor + */ +function AudioInspector() { + console.debug('Creating an AudioInspector instance.'); + this.audioPlayer_ = new Audio(); + this.metadata_ = {}; + this.currentScore_ = null; + this.audioInspector_ = null; + this.snackbarContainer_ = document.querySelector('#snackbar'); + + // Get base URL without anchors. + this.baseUrl_ = window.location.href; + var index = this.baseUrl_.indexOf('#'); + if (index > 0) + this.baseUrl_ = this.baseUrl_.substr(0, index) + console.info('Base URL set to "' + window.location.href + '".'); + + window.event.stopPropagation(); + this.createTextAreasForCopy_(); + this.createAudioInspector_(); + this.initializeEventHandlers_(); + + // When MDL is ready, parse the anchor (if any) to show the requested + // experiment. + var self = this; + document.querySelectorAll('header a')[0].addEventListener( + 'mdl-componentupgraded', function() { + if (!self.parseWindowAnchor()) { + // If not experiment is requested, open the first section. + console.info('No anchor parsing, opening the first section.'); + document.querySelectorAll('header a > span')[0].click(); + } + }); +} + +/** + * Parse the anchor in the window URL. + * @return {bool} True if the parsing succeeded. + */ +AudioInspector.prototype.parseWindowAnchor = function() { + var index = location.href.indexOf('#'); + if (index == -1) { + console.debug('No # found in the URL.'); + return false; + } + + var anchor = location.href.substr(index - location.href.length + 1); + console.info('Anchor changed: "' + anchor + '".'); + + var parts = anchor.split('&'); + if (parts.length != 3) { + console.info('Ignoring anchor with invalid number of fields.'); + return false; + } + + var openDialog = document.querySelector('dialog[open]'); + try { + // Open the requested dialog if not already open. + if (!openDialog || openDialog.id != parts[1]) { + !openDialog || openDialog.close(); + document.querySelectorAll('header a > span')[ + parseInt(parts[0].substr(1))].click(); + openDialog = openScoreStatsInspector(parts[1]); + } + + // Trigger click on cell. + var cell = openDialog.querySelector('td.' + parts[2]); + cell.focus(); + cell.click(); + + this.showNotification_('Experiment selected.'); + return true; + } catch (e) { + this.showNotification_('Cannot select experiment :('); + console.error('Exception caught while selecting experiment: "' + e + '".'); + } + + return false; +} + +/** + * Set up the inspector for a new score. + * @param {DOMElement} element: Element linked to the selected score. + */ +AudioInspector.prototype.selectedScoreChange = function(element) { + if (this.currentScore_ == element) { return; } + if (this.currentScore_ != null) { + this.currentScore_.classList.remove('selected-score'); + } + this.currentScore_ = element; + this.currentScore_.classList.add('selected-score'); + this.stopAudio(); + + // Read metadata. + var matches = element.querySelectorAll('input[type=hidden]'); + this.metadata_ = {}; + for (var index = 0; index < matches.length; ++index) { + this.metadata_[matches[index].name] = matches[index].value; + } + + // Show the audio inspector interface. + var container = element.parentNode.parentNode.parentNode.parentNode; + var audioInspectorPlaceholder = container.querySelector( + '.audio-inspector-placeholder'); + this.moveInspector_(audioInspectorPlaceholder); +}; + +/** + * Stop playing audio. + */ +AudioInspector.prototype.stopAudio = function() { + console.info('Pausing audio play out.'); + this.audioPlayer_.pause(); +}; + +/** + * Show a text message using the snackbar. + */ +AudioInspector.prototype.showNotification_ = function(text) { + try { + this.snackbarContainer_.MaterialSnackbar.showSnackbar({ + message: text, timeout: 2000}); + } catch (e) { + // Fallback to an alert. + alert(text); + console.warn('Cannot use snackbar: "' + e + '"'); + } +} + +/** + * Move the audio inspector DOM node into the given parent. + * @param {DOMElement} newParentNode: New parent for the inspector. + */ +AudioInspector.prototype.moveInspector_ = function(newParentNode) { + newParentNode.appendChild(this.audioInspector_); +}; + +/** + * Play audio file from url. + * @param {string} metadataFieldName: Metadata field name. + */ +AudioInspector.prototype.playAudio = function(metadataFieldName) { + if (this.metadata_[metadataFieldName] == undefined) { return; } + if (this.metadata_[metadataFieldName] == 'None') { + alert('The selected stream was not used during the experiment.'); + return; + } + this.stopAudio(); + this.audioPlayer_.src = this.metadata_[metadataFieldName]; + console.debug('Audio source URL: "' + this.audioPlayer_.src + '"'); + this.audioPlayer_.play(); + console.info('Playing out audio.'); +}; + +/** + * Create hidden text areas to copy URLs. + * + * For each dialog, one text area is created since it is not possible to select + * text on a text area outside of the active dialog. + */ +AudioInspector.prototype.createTextAreasForCopy_ = function() { + var self = this; + document.querySelectorAll('dialog.mdl-dialog').forEach(function(element) { + var textArea = document.createElement("textarea"); + textArea.classList.add('url-copy'); + textArea.style.position = 'fixed'; + textArea.style.bottom = 0; + textArea.style.left = 0; + textArea.style.width = '2em'; + textArea.style.height = '2em'; + textArea.style.border = 'none'; + textArea.style.outline = 'none'; + textArea.style.boxShadow = 'none'; + textArea.style.background = 'transparent'; + textArea.style.fontSize = '6px'; + element.appendChild(textArea); + }); +} + +/** + * Create audio inspector. + */ +AudioInspector.prototype.createAudioInspector_ = function() { + var buttonIndex = 0; + function getButtonHtml(icon, toolTipText, caption, metadataFieldName) { + var buttonId = 'audioInspectorButton' + buttonIndex++; + html = caption == null ? '' : caption; + html += '<button class="mdl-button mdl-js-button mdl-button--icon ' + + 'mdl-js-ripple-effect" id="' + buttonId + '">' + + '<i class="material-icons">' + icon + '</i>' + + '<div class="mdl-tooltip" data-mdl-for="' + buttonId + '">' + + toolTipText + + '</div>'; + if (metadataFieldName != null) { + html += '<input type="hidden" value="' + metadataFieldName + '">' + } + html += '</button>' + + return html; + } + + // TODO(alessiob): Add timeline and highlight current track by changing icon + // color. + + this.audioInspector_ = document.createElement('div'); + this.audioInspector_.classList.add('audio-inspector'); + this.audioInspector_.innerHTML = + '<div class="mdl-grid">' + + '<div class="mdl-layout-spacer"></div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'Simulated echo', 'E<sub>in</sub>', + 'echo_filepath') + + '</div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('stop', 'Stop playing [S]', null, '__stop__') + + '</div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'Render stream', 'R<sub>in</sub>', + 'render_filepath') + + '</div>' + + '<div class="mdl-layout-spacer"></div>' + + '</div>' + + '<div class="mdl-grid">' + + '<div class="mdl-layout-spacer"></div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'Capture stream (APM input) [1]', + 'Y\'<sub>in</sub>', 'capture_filepath') + + '</div>' + + '<div class="mdl-cell mdl-cell--2-col"><strong>APM</strong></div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'APM output [2]', 'Y<sub>out</sub>', + 'apm_output_filepath') + + '</div>' + + '<div class="mdl-layout-spacer"></div>' + + '</div>' + + '<div class="mdl-grid">' + + '<div class="mdl-layout-spacer"></div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'Echo-free capture stream', + 'Y<sub>in</sub>', 'echo_free_capture_filepath') + + '</div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'Clean capture stream', + 'Y<sub>clean</sub>', 'clean_capture_input_filepath') + + '</div>' + + '<div class="mdl-cell mdl-cell--2-col">' + + getButtonHtml('play_arrow', 'APM reference [3]', 'Y<sub>ref</sub>', + 'apm_reference_filepath') + + '</div>' + + '<div class="mdl-layout-spacer"></div>' + + '</div>'; + + // Add an invisible node as initial container for the audio inspector. + var parent = document.createElement('div'); + parent.style.display = 'none'; + this.moveInspector_(parent); + document.body.appendChild(parent); +}; + +/** + * Initialize event handlers. + */ +AudioInspector.prototype.initializeEventHandlers_ = function() { + var self = this; + + // Score cells. + document.querySelectorAll('td.single-score-cell').forEach(function(element) { + element.onclick = function() { + self.selectedScoreChange(this); + } + }); + + // Copy anchor URLs icons. + if (document.queryCommandSupported('copy')) { + document.querySelectorAll('td.single-score-cell button').forEach( + function(element) { + element.onclick = function() { + // Find the text area in the dialog. + var textArea = element.closest('dialog').querySelector( + 'textarea.url-copy'); + + // Copy. + textArea.value = self.baseUrl_ + '#' + element.getAttribute( + 'data-anchor'); + textArea.select(); + try { + if (!document.execCommand('copy')) + throw 'Copy returned false'; + self.showNotification_('Experiment URL copied.'); + } catch (e) { + self.showNotification_('Cannot copy experiment URL :('); + console.error(e); + } + } + }); + } else { + self.showNotification_( + 'The copy command is disabled. URL copy is not enabled.'); + } + + // Audio inspector buttons. + this.audioInspector_.querySelectorAll('button').forEach(function(element) { + var target = element.querySelector('input[type=hidden]'); + if (target == null) { return; } + element.onclick = function() { + if (target.value == '__stop__') { + self.stopAudio(); + } else { + self.playAudio(target.value); + } + }; + }); + + // Dialog close handlers. + var dialogs = document.querySelectorAll('dialog').forEach(function(element) { + element.onclose = function() { + self.stopAudio(); + } + }); + + // Keyboard shortcuts. + window.onkeyup = function(e) { + var key = e.keyCode ? e.keyCode : e.which; + switch (key) { + case 49: // 1. + self.playAudio('capture_filepath'); + break; + case 50: // 2. + self.playAudio('apm_output_filepath'); + break; + case 51: // 3. + self.playAudio('apm_reference_filepath'); + break; + case 83: // S. + case 115: // s. + self.stopAudio(); + break; + } + }; + + // Hash change. + window.onhashchange = function(e) { + self.parseWindowAnchor(); + } +}; diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py new file mode 100644 index 0000000000..95e801903d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py @@ -0,0 +1,359 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Signal processing utility module. +""" + +import array +import logging +import os +import sys +import enum + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +try: + import pydub + import pydub.generators +except ImportError: + logging.critical('Cannot import the third-party Python package pydub') + sys.exit(1) + +try: + import scipy.signal + import scipy.fftpack +except ImportError: + logging.critical('Cannot import the third-party Python package scipy') + sys.exit(1) + +from . import exceptions + + +class SignalProcessingUtils(object): + """Collection of signal processing utilities. + """ + + @enum.unique + class MixPadding(enum.Enum): + NO_PADDING = 0 + ZERO_PADDING = 1 + LOOP = 2 + + def __init__(self): + pass + + @classmethod + def LoadWav(cls, filepath, channels=1): + """Loads wav file. + + Args: + filepath: path to the wav audio track file to load. + channels: number of channels (downmixing to mono by default). + + Returns: + AudioSegment instance. + """ + if not os.path.exists(filepath): + logging.error('cannot find the <%s> audio track file', filepath) + raise exceptions.FileNotFoundError() + return pydub.AudioSegment.from_file(filepath, + format='wav', + channels=channels) + + @classmethod + def SaveWav(cls, output_filepath, signal): + """Saves wav file. + + Args: + output_filepath: path to the wav audio track file to save. + signal: AudioSegment instance. + """ + return signal.export(output_filepath, format='wav') + + @classmethod + def CountSamples(cls, signal): + """Number of samples per channel. + + Args: + signal: AudioSegment instance. + + Returns: + An integer. + """ + number_of_samples = len(signal.get_array_of_samples()) + assert signal.channels > 0 + assert number_of_samples % signal.channels == 0 + return number_of_samples / signal.channels + + @classmethod + def GenerateSilence(cls, duration=1000, sample_rate=48000): + """Generates silence. + + This method can also be used to create a template AudioSegment instance. + A template can then be used with other Generate*() methods accepting an + AudioSegment instance as argument. + + Args: + duration: duration in ms. + sample_rate: sample rate. + + Returns: + AudioSegment instance. + """ + return pydub.AudioSegment.silent(duration, sample_rate) + + @classmethod + def GeneratePureTone(cls, template, frequency=440.0): + """Generates a pure tone. + + The pure tone is generated with the same duration and in the same format of + the given template signal. + + Args: + template: AudioSegment instance. + frequency: Frequency of the pure tone in Hz. + + Return: + AudioSegment instance. + """ + if frequency > template.frame_rate >> 1: + raise exceptions.SignalProcessingException('Invalid frequency') + + generator = pydub.generators.Sine(sample_rate=template.frame_rate, + bit_depth=template.sample_width * 8, + freq=frequency) + + return generator.to_audio_segment(duration=len(template), volume=0.0) + + @classmethod + def GenerateWhiteNoise(cls, template): + """Generates white noise. + + The white noise is generated with the same duration and in the same format + of the given template signal. + + Args: + template: AudioSegment instance. + + Return: + AudioSegment instance. + """ + generator = pydub.generators.WhiteNoise( + sample_rate=template.frame_rate, + bit_depth=template.sample_width * 8) + return generator.to_audio_segment(duration=len(template), volume=0.0) + + @classmethod + def AudioSegmentToRawData(cls, signal): + samples = signal.get_array_of_samples() + if samples.typecode != 'h': + raise exceptions.SignalProcessingException( + 'Unsupported samples type') + return np.array(signal.get_array_of_samples(), np.int16) + + @classmethod + def Fft(cls, signal, normalize=True): + if signal.channels != 1: + raise NotImplementedError('multiple-channel FFT not implemented') + x = cls.AudioSegmentToRawData(signal).astype(np.float32) + if normalize: + x /= max(abs(np.max(x)), 1.0) + y = scipy.fftpack.fft(x) + return y[:len(y) / 2] + + @classmethod + def DetectHardClipping(cls, signal, threshold=2): + """Detects hard clipping. + + Hard clipping is simply detected by counting samples that touch either the + lower or upper bound too many times in a row (according to `threshold`). + The presence of a single sequence of samples meeting such property is enough + to label the signal as hard clipped. + + Args: + signal: AudioSegment instance. + threshold: minimum number of samples at full-scale in a row. + + Returns: + True if hard clipping is detect, False otherwise. + """ + if signal.channels != 1: + raise NotImplementedError( + 'multiple-channel clipping not implemented') + if signal.sample_width != 2: # Note that signal.sample_width is in bytes. + raise exceptions.SignalProcessingException( + 'hard-clipping detection only supported for 16 bit samples') + samples = cls.AudioSegmentToRawData(signal) + + # Detect adjacent clipped samples. + samples_type_info = np.iinfo(samples.dtype) + mask_min = samples == samples_type_info.min + mask_max = samples == samples_type_info.max + + def HasLongSequence(vector, min_legth=threshold): + """Returns True if there are one or more long sequences of True flags.""" + seq_length = 0 + for b in vector: + seq_length = seq_length + 1 if b else 0 + if seq_length >= min_legth: + return True + return False + + return HasLongSequence(mask_min) or HasLongSequence(mask_max) + + @classmethod + def ApplyImpulseResponse(cls, signal, impulse_response): + """Applies an impulse response to a signal. + + Args: + signal: AudioSegment instance. + impulse_response: list or numpy vector of float values. + + Returns: + AudioSegment instance. + """ + # Get samples. + assert signal.channels == 1, ( + 'multiple-channel recordings not supported') + samples = signal.get_array_of_samples() + + # Convolve. + logging.info( + 'applying %d order impulse response to a signal lasting %d ms', + len(impulse_response), len(signal)) + convolved_samples = scipy.signal.fftconvolve(in1=samples, + in2=impulse_response, + mode='full').astype( + np.int16) + logging.info('convolution computed') + + # Cast. + convolved_samples = array.array(signal.array_type, convolved_samples) + + # Verify. + logging.debug('signal length: %d samples', len(samples)) + logging.debug('convolved signal length: %d samples', + len(convolved_samples)) + assert len(convolved_samples) > len(samples) + + # Generate convolved signal AudioSegment instance. + convolved_signal = pydub.AudioSegment(data=convolved_samples, + metadata={ + 'sample_width': + signal.sample_width, + 'frame_rate': + signal.frame_rate, + 'frame_width': + signal.frame_width, + 'channels': signal.channels, + }) + assert len(convolved_signal) > len(signal) + + return convolved_signal + + @classmethod + def Normalize(cls, signal): + """Normalizes a signal. + + Args: + signal: AudioSegment instance. + + Returns: + An AudioSegment instance. + """ + return signal.apply_gain(-signal.max_dBFS) + + @classmethod + def Copy(cls, signal): + """Makes a copy os a signal. + + Args: + signal: AudioSegment instance. + + Returns: + An AudioSegment instance. + """ + return pydub.AudioSegment(data=signal.get_array_of_samples(), + metadata={ + 'sample_width': signal.sample_width, + 'frame_rate': signal.frame_rate, + 'frame_width': signal.frame_width, + 'channels': signal.channels, + }) + + @classmethod + def MixSignals(cls, + signal, + noise, + target_snr=0.0, + pad_noise=MixPadding.NO_PADDING): + """Mixes `signal` and `noise` with a target SNR. + + Mix `signal` and `noise` with a desired SNR by scaling `noise`. + If the target SNR is +/- infinite, a copy of signal/noise is returned. + If `signal` is shorter than `noise`, the length of the mix equals that of + `signal`. Otherwise, the mix length depends on whether padding is applied. + When padding is not applied, that is `pad_noise` is set to NO_PADDING + (default), the mix length equals that of `noise` - i.e., `signal` is + truncated. Otherwise, `noise` is extended and the resulting mix has the same + length of `signal`. + + Args: + signal: AudioSegment instance (signal). + noise: AudioSegment instance (noise). + target_snr: float, numpy.Inf or -numpy.Inf (dB). + pad_noise: SignalProcessingUtils.MixPadding, default: NO_PADDING. + + Returns: + An AudioSegment instance. + """ + # Handle infinite target SNR. + if target_snr == -np.Inf: + # Return a copy of noise. + logging.warning('SNR = -Inf, returning noise') + return cls.Copy(noise) + elif target_snr == np.Inf: + # Return a copy of signal. + logging.warning('SNR = +Inf, returning signal') + return cls.Copy(signal) + + # Check signal and noise power. + signal_power = float(signal.dBFS) + noise_power = float(noise.dBFS) + if signal_power == -np.Inf: + logging.error('signal has -Inf power, cannot mix') + raise exceptions.SignalProcessingException( + 'cannot mix a signal with -Inf power') + if noise_power == -np.Inf: + logging.error('noise has -Inf power, cannot mix') + raise exceptions.SignalProcessingException( + 'cannot mix a signal with -Inf power') + + # Mix. + gain_db = signal_power - noise_power - target_snr + signal_duration = len(signal) + noise_duration = len(noise) + if signal_duration <= noise_duration: + # Ignore `pad_noise`, `noise` is truncated if longer that `signal`, the + # mix will have the same length of `signal`. + return signal.overlay(noise.apply_gain(gain_db)) + elif pad_noise == cls.MixPadding.NO_PADDING: + # `signal` is longer than `noise`, but no padding is applied to `noise`. + # Truncate `signal`. + return noise.overlay(signal, gain_during_overlay=gain_db) + elif pad_noise == cls.MixPadding.ZERO_PADDING: + # TODO(alessiob): Check that this works as expected. + return signal.overlay(noise.apply_gain(gain_db)) + elif pad_noise == cls.MixPadding.LOOP: + # `signal` is longer than `noise`, extend `noise` by looping. + return signal.overlay(noise.apply_gain(gain_db), loop=True) + else: + raise exceptions.SignalProcessingException('invalid padding type') diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py new file mode 100644 index 0000000000..881fb66800 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing_unittest.py @@ -0,0 +1,183 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the signal_processing module. +""" + +import unittest + +import numpy as np +import pydub + +from . import exceptions +from . import signal_processing + + +class TestSignalProcessing(unittest.TestCase): + """Unit tests for the signal_processing module. + """ + + def testMixSignals(self): + # Generate a template signal with which white noise can be generated. + silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + + # Generate two distinct AudioSegment instances with 1 second of white noise. + signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + silence) + noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + silence) + + # Extract samples. + signal_samples = signal.get_array_of_samples() + noise_samples = noise.get_array_of_samples() + + # Test target SNR -Inf (noise expected). + mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals( + signal, noise, -np.Inf) + self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration. + mix_neg_inf_samples = mix_neg_inf.get_array_of_samples() + self.assertTrue( # Check samples. + all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)])) + + # Test target SNR 0.0 (different data expected). + mix_0 = signal_processing.SignalProcessingUtils.MixSignals( + signal, noise, 0.0) + self.assertTrue(len(signal), len(mix_0)) # Check duration. + self.assertTrue(len(noise), len(mix_0)) + mix_0_samples = mix_0.get_array_of_samples() + self.assertTrue( + any([x != y for x, y in zip(signal_samples, mix_0_samples)])) + self.assertTrue( + any([x != y for x, y in zip(noise_samples, mix_0_samples)])) + + # Test target SNR +Inf (signal expected). + mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals( + signal, noise, np.Inf) + self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration. + mix_pos_inf_samples = mix_pos_inf.get_array_of_samples() + self.assertTrue( # Check samples. + all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)])) + + def testMixSignalsMinInfPower(self): + silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + silence) + + with self.assertRaises(exceptions.SignalProcessingException): + _ = signal_processing.SignalProcessingUtils.MixSignals( + signal, silence, 0.0) + + with self.assertRaises(exceptions.SignalProcessingException): + _ = signal_processing.SignalProcessingUtils.MixSignals( + silence, signal, 0.0) + + def testMixSignalNoiseDifferentLengths(self): + # Test signals. + shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + pydub.AudioSegment.silent(duration=1000, frame_rate=8000)) + longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + pydub.AudioSegment.silent(duration=2000, frame_rate=8000)) + + # When the signal is shorter than the noise, the mix length always equals + # that of the signal regardless of whether padding is applied. + # No noise padding, length of signal less than that of noise. + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=shorter, + noise=longer, + pad_noise=signal_processing.SignalProcessingUtils.MixPadding. + NO_PADDING) + self.assertEqual(len(shorter), len(mix)) + # With noise padding, length of signal less than that of noise. + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=shorter, + noise=longer, + pad_noise=signal_processing.SignalProcessingUtils.MixPadding. + ZERO_PADDING) + self.assertEqual(len(shorter), len(mix)) + + # When the signal is longer than the noise, the mix length depends on + # whether padding is applied. + # No noise padding, length of signal greater than that of noise. + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=longer, + noise=shorter, + pad_noise=signal_processing.SignalProcessingUtils.MixPadding. + NO_PADDING) + self.assertEqual(len(shorter), len(mix)) + # With noise padding, length of signal greater than that of noise. + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=longer, + noise=shorter, + pad_noise=signal_processing.SignalProcessingUtils.MixPadding. + ZERO_PADDING) + self.assertEqual(len(longer), len(mix)) + + def testMixSignalNoisePaddingTypes(self): + # Test signals. + shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + pydub.AudioSegment.silent(duration=1000, frame_rate=8000)) + longer = signal_processing.SignalProcessingUtils.GeneratePureTone( + pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0) + + # Zero padding: expect pure tone only in 1-2s. + mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals( + signal=longer, + noise=shorter, + target_snr=-6, + pad_noise=signal_processing.SignalProcessingUtils.MixPadding. + ZERO_PADDING) + + # Loop: expect pure tone plus noise in 1-2s. + mix_loop = signal_processing.SignalProcessingUtils.MixSignals( + signal=longer, + noise=shorter, + target_snr=-6, + pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP) + + def Energy(signal): + samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( + signal).astype(np.float32) + return np.sum(samples * samples) + + e_mix_zero_pad = Energy(mix_zero_pad[-1000:]) + e_mix_loop = Energy(mix_loop[-1000:]) + self.assertLess(0, e_mix_zero_pad) + self.assertLess(e_mix_zero_pad, e_mix_loop) + + def testMixSignalSnr(self): + # Test signals. + tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone( + pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0) + tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone( + pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0) + + def ToneAmplitudes(mix): + """Returns the amplitude of the coefficients #16 and #192, which + correspond to the tones at 250 and 3k Hz respectively.""" + mix_fft = np.absolute( + signal_processing.SignalProcessingUtils.Fft(mix)) + return mix_fft[16], mix_fft[192] + + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=tone_low, noise=tone_high, target_snr=-6) + ampl_low, ampl_high = ToneAmplitudes(mix) + self.assertLess(ampl_low, ampl_high) + + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=tone_high, noise=tone_low, target_snr=-6) + ampl_low, ampl_high = ToneAmplitudes(mix) + self.assertLess(ampl_high, ampl_low) + + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=tone_low, noise=tone_high, target_snr=6) + ampl_low, ampl_high = ToneAmplitudes(mix) + self.assertLess(ampl_high, ampl_low) + + mix = signal_processing.SignalProcessingUtils.MixSignals( + signal=tone_high, noise=tone_low, target_snr=6) + ampl_low, ampl_high = ToneAmplitudes(mix) + self.assertLess(ampl_low, ampl_high) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py new file mode 100644 index 0000000000..69b3a1624e --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py @@ -0,0 +1,446 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""APM module simulator. +""" + +import logging +import os + +from . import annotations +from . import data_access +from . import echo_path_simulation +from . import echo_path_simulation_factory +from . import eval_scores +from . import exceptions +from . import input_mixer +from . import input_signal_creator +from . import signal_processing +from . import test_data_generation + + +class ApmModuleSimulator(object): + """Audio processing module (APM) simulator class. + """ + + _TEST_DATA_GENERATOR_CLASSES = ( + test_data_generation.TestDataGenerator.REGISTERED_CLASSES) + _EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES + + _PREFIX_APM_CONFIG = 'apmcfg-' + _PREFIX_CAPTURE = 'capture-' + _PREFIX_RENDER = 'render-' + _PREFIX_ECHO_SIMULATOR = 'echosim-' + _PREFIX_TEST_DATA_GEN = 'datagen-' + _PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-' + _PREFIX_SCORE = 'score-' + + def __init__(self, + test_data_generator_factory, + evaluation_score_factory, + ap_wrapper, + evaluator, + external_vads=None): + if external_vads is None: + external_vads = {} + self._test_data_generator_factory = test_data_generator_factory + self._evaluation_score_factory = evaluation_score_factory + self._audioproc_wrapper = ap_wrapper + self._evaluator = evaluator + self._annotator = annotations.AudioAnnotationsExtractor( + annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD + | annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO + | annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM, + external_vads) + + # Init. + self._test_data_generator_factory.SetOutputDirectoryPrefix( + self._PREFIX_TEST_DATA_GEN_PARAMS) + self._evaluation_score_factory.SetScoreFilenamePrefix( + self._PREFIX_SCORE) + + # Properties for each run. + self._base_output_path = None + self._output_cache_path = None + self._test_data_generators = None + self._evaluation_score_workers = None + self._config_filepaths = None + self._capture_input_filepaths = None + self._render_input_filepaths = None + self._echo_path_simulator_class = None + + @classmethod + def GetPrefixApmConfig(cls): + return cls._PREFIX_APM_CONFIG + + @classmethod + def GetPrefixCapture(cls): + return cls._PREFIX_CAPTURE + + @classmethod + def GetPrefixRender(cls): + return cls._PREFIX_RENDER + + @classmethod + def GetPrefixEchoSimulator(cls): + return cls._PREFIX_ECHO_SIMULATOR + + @classmethod + def GetPrefixTestDataGenerator(cls): + return cls._PREFIX_TEST_DATA_GEN + + @classmethod + def GetPrefixTestDataGeneratorParameters(cls): + return cls._PREFIX_TEST_DATA_GEN_PARAMS + + @classmethod + def GetPrefixScore(cls): + return cls._PREFIX_SCORE + + def Run(self, + config_filepaths, + capture_input_filepaths, + test_data_generator_names, + eval_score_names, + output_dir, + render_input_filepaths=None, + echo_path_simulator_name=( + echo_path_simulation.NoEchoPathSimulator.NAME)): + """Runs the APM simulation. + + Initializes paths and required instances, then runs all the simulations. + The render input can be optionally added. If added, the number of capture + input audio tracks and the number of render input audio tracks have to be + equal. The two lists are used to form pairs of capture and render input. + + Args: + config_filepaths: set of APM configuration files to test. + capture_input_filepaths: set of capture input audio track files to test. + test_data_generator_names: set of test data generator names to test. + eval_score_names: set of evaluation score names to test. + output_dir: base path to the output directory for wav files and outcomes. + render_input_filepaths: set of render input audio track files to test. + echo_path_simulator_name: name of the echo path simulator to use when + render input is provided. + """ + assert render_input_filepaths is None or ( + len(capture_input_filepaths) == len(render_input_filepaths)), ( + 'render input set size not matching input set size') + assert render_input_filepaths is None or echo_path_simulator_name in ( + echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), ( + 'invalid echo path simulator') + self._base_output_path = os.path.abspath(output_dir) + + # Output path used to cache the data shared across simulations. + self._output_cache_path = os.path.join(self._base_output_path, + '_cache') + + # Instance test data generators. + self._test_data_generators = [ + self._test_data_generator_factory.GetInstance( + test_data_generators_class=( + self._TEST_DATA_GENERATOR_CLASSES[name])) + for name in (test_data_generator_names) + ] + + # Instance evaluation score workers. + self._evaluation_score_workers = [ + self._evaluation_score_factory.GetInstance( + evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name]) + for (name) in eval_score_names + ] + + # Set APM configuration file paths. + self._config_filepaths = self._CreatePathsCollection(config_filepaths) + + # Set probing signal file paths. + if render_input_filepaths is None: + # Capture input only. + self._capture_input_filepaths = self._CreatePathsCollection( + capture_input_filepaths) + self._render_input_filepaths = None + else: + # Set both capture and render input signals. + self._SetTestInputSignalFilePaths(capture_input_filepaths, + render_input_filepaths) + + # Set the echo path simulator class. + self._echo_path_simulator_class = ( + echo_path_simulation.EchoPathSimulator. + REGISTERED_CLASSES[echo_path_simulator_name]) + + self._SimulateAll() + + def _SimulateAll(self): + """Runs all the simulations. + + Iterates over the combinations of APM configurations, probing signals, and + test data generators. This method is mainly responsible for the creation of + the cache and output directories required in order to call _Simulate(). + """ + without_render_input = self._render_input_filepaths is None + + # Try different APM config files. + for config_name in self._config_filepaths: + config_filepath = self._config_filepaths[config_name] + + # Try different capture-render pairs. + for capture_input_name in self._capture_input_filepaths: + # Output path for the capture signal annotations. + capture_annotations_cache_path = os.path.join( + self._output_cache_path, + self._PREFIX_CAPTURE + capture_input_name) + data_access.MakeDirectory(capture_annotations_cache_path) + + # Capture. + capture_input_filepath = self._capture_input_filepaths[ + capture_input_name] + if not os.path.exists(capture_input_filepath): + # If the input signal file does not exist, try to create using the + # available input signal creators. + self._CreateInputSignal(capture_input_filepath) + assert os.path.exists(capture_input_filepath) + self._ExtractCaptureAnnotations( + capture_input_filepath, capture_annotations_cache_path) + + # Render and simulated echo path (optional). + render_input_filepath = None if without_render_input else ( + self._render_input_filepaths[capture_input_name]) + render_input_name = '(none)' if without_render_input else ( + self._ExtractFileName(render_input_filepath)) + echo_path_simulator = (echo_path_simulation_factory. + EchoPathSimulatorFactory.GetInstance( + self._echo_path_simulator_class, + render_input_filepath)) + + # Try different test data generators. + for test_data_generators in self._test_data_generators: + logging.info( + 'APM config preset: <%s>, capture: <%s>, render: <%s>,' + 'test data generator: <%s>, echo simulator: <%s>', + config_name, capture_input_name, render_input_name, + test_data_generators.NAME, echo_path_simulator.NAME) + + # Output path for the generated test data. + test_data_cache_path = os.path.join( + capture_annotations_cache_path, + self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME) + data_access.MakeDirectory(test_data_cache_path) + logging.debug('test data cache path: <%s>', + test_data_cache_path) + + # Output path for the echo simulator and APM input mixer output. + echo_test_data_cache_path = os.path.join( + test_data_cache_path, + 'echosim-{}'.format(echo_path_simulator.NAME)) + data_access.MakeDirectory(echo_test_data_cache_path) + logging.debug('echo test data cache path: <%s>', + echo_test_data_cache_path) + + # Full output path. + output_path = os.path.join( + self._base_output_path, + self._PREFIX_APM_CONFIG + config_name, + self._PREFIX_CAPTURE + capture_input_name, + self._PREFIX_RENDER + render_input_name, + self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME, + self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME) + data_access.MakeDirectory(output_path) + logging.debug('output path: <%s>', output_path) + + self._Simulate(test_data_generators, + capture_input_filepath, + render_input_filepath, test_data_cache_path, + echo_test_data_cache_path, output_path, + config_filepath, echo_path_simulator) + + @staticmethod + def _CreateInputSignal(input_signal_filepath): + """Creates a missing input signal file. + + The file name is parsed to extract input signal creator and params. If a + creator is matched and the parameters are valid, a new signal is generated + and written in `input_signal_filepath`. + + Args: + input_signal_filepath: Path to the input signal audio file to write. + + Raises: + InputSignalCreatorException + """ + filename = os.path.splitext( + os.path.split(input_signal_filepath)[-1])[0] + filename_parts = filename.split('-') + + if len(filename_parts) < 2: + raise exceptions.InputSignalCreatorException( + 'Cannot parse input signal file name') + + signal, metadata = input_signal_creator.InputSignalCreator.Create( + filename_parts[0], filename_parts[1].split('_')) + + signal_processing.SignalProcessingUtils.SaveWav( + input_signal_filepath, signal) + data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata) + + def _ExtractCaptureAnnotations(self, + input_filepath, + output_path, + annotation_name=""): + self._annotator.Extract(input_filepath) + self._annotator.Save(output_path, annotation_name) + + def _Simulate(self, test_data_generators, clean_capture_input_filepath, + render_input_filepath, test_data_cache_path, + echo_test_data_cache_path, output_path, config_filepath, + echo_path_simulator): + """Runs a single set of simulation. + + Simulates a given combination of APM configuration, probing signal, and + test data generator. It iterates over the test data generator + internal configurations. + + Args: + test_data_generators: TestDataGenerator instance. + clean_capture_input_filepath: capture input audio track file to be + processed by a test data generator and + not affected by echo. + render_input_filepath: render input audio track file to test. + test_data_cache_path: path for the generated test audio track files. + echo_test_data_cache_path: path for the echo simulator. + output_path: base output path for the test data generator. + config_filepath: APM configuration file to test. + echo_path_simulator: EchoPathSimulator instance. + """ + # Generate pairs of noisy input and reference signal files. + test_data_generators.Generate( + input_signal_filepath=clean_capture_input_filepath, + test_data_cache_path=test_data_cache_path, + base_output_path=output_path) + + # Extract metadata linked to the clean input file (if any). + apm_input_metadata = None + try: + apm_input_metadata = data_access.Metadata.LoadFileMetadata( + clean_capture_input_filepath) + except IOError as e: + apm_input_metadata = {} + apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME + apm_input_metadata['test_data_gen_config'] = None + + # For each test data pair, simulate a call and evaluate. + for config_name in test_data_generators.config_names: + logging.info(' - test data generator config: <%s>', config_name) + apm_input_metadata['test_data_gen_config'] = config_name + + # Paths to the test data generator output. + # Note that the reference signal does not depend on the render input + # which is optional. + noisy_capture_input_filepath = ( + test_data_generators.noisy_signal_filepaths[config_name]) + reference_signal_filepath = ( + test_data_generators.reference_signal_filepaths[config_name]) + + # Output path for the evaluation (e.g., APM output file). + evaluation_output_path = test_data_generators.apm_output_paths[ + config_name] + + # Paths to the APM input signals. + echo_path_filepath = echo_path_simulator.Simulate( + echo_test_data_cache_path) + apm_input_filepath = input_mixer.ApmInputMixer.Mix( + echo_test_data_cache_path, noisy_capture_input_filepath, + echo_path_filepath) + + # Extract annotations for the APM input mix. + apm_input_basepath, apm_input_filename = os.path.split( + apm_input_filepath) + self._ExtractCaptureAnnotations( + apm_input_filepath, apm_input_basepath, + os.path.splitext(apm_input_filename)[0] + '-') + + # Simulate a call using APM. + self._audioproc_wrapper.Run( + config_filepath=config_filepath, + capture_input_filepath=apm_input_filepath, + render_input_filepath=render_input_filepath, + output_path=evaluation_output_path) + + try: + # Evaluate. + self._evaluator.Run( + evaluation_score_workers=self._evaluation_score_workers, + apm_input_metadata=apm_input_metadata, + apm_output_filepath=self._audioproc_wrapper. + output_filepath, + reference_input_filepath=reference_signal_filepath, + render_input_filepath=render_input_filepath, + output_path=evaluation_output_path, + ) + + # Save simulation metadata. + data_access.Metadata.SaveAudioTestDataPaths( + output_path=evaluation_output_path, + clean_capture_input_filepath=clean_capture_input_filepath, + echo_free_capture_filepath=noisy_capture_input_filepath, + echo_filepath=echo_path_filepath, + render_filepath=render_input_filepath, + capture_filepath=apm_input_filepath, + apm_output_filepath=self._audioproc_wrapper. + output_filepath, + apm_reference_filepath=reference_signal_filepath, + apm_config_filepath=config_filepath, + ) + except exceptions.EvaluationScoreException as e: + logging.warning('the evaluation failed: %s', e.message) + continue + + def _SetTestInputSignalFilePaths(self, capture_input_filepaths, + render_input_filepaths): + """Sets input and render input file paths collections. + + Pairs the input and render input files by storing the file paths into two + collections. The key is the file name of the input file. + + Args: + capture_input_filepaths: list of file paths. + render_input_filepaths: list of file paths. + """ + self._capture_input_filepaths = {} + self._render_input_filepaths = {} + assert len(capture_input_filepaths) == len(render_input_filepaths) + for capture_input_filepath, render_input_filepath in zip( + capture_input_filepaths, render_input_filepaths): + name = self._ExtractFileName(capture_input_filepath) + self._capture_input_filepaths[name] = os.path.abspath( + capture_input_filepath) + self._render_input_filepaths[name] = os.path.abspath( + render_input_filepath) + + @classmethod + def _CreatePathsCollection(cls, filepaths): + """Creates a collection of file paths. + + Given a list of file paths, makes a collection with one item for each file + path. The value is absolute path, the key is the file name without + extenstion. + + Args: + filepaths: list of file paths. + + Returns: + A dict. + """ + filepaths_collection = {} + for filepath in filepaths: + name = cls._ExtractFileName(filepath) + filepaths_collection[name] = os.path.abspath(filepath) + return filepaths_collection + + @classmethod + def _ExtractFileName(cls, filepath): + return os.path.splitext(os.path.split(filepath)[-1])[0] diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py new file mode 100644 index 0000000000..78ca17f589 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py @@ -0,0 +1,203 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the simulation module. +""" + +import logging +import os +import shutil +import tempfile +import unittest + +import mock +import pydub + +from . import audioproc_wrapper +from . import eval_scores_factory +from . import evaluation +from . import external_vad +from . import signal_processing +from . import simulation +from . import test_data_generation_factory + + +class TestApmModuleSimulator(unittest.TestCase): + """Unit tests for the ApmModuleSimulator class. + """ + + def setUp(self): + """Create temporary folders and fake audio track.""" + self._output_path = tempfile.mkdtemp() + self._tmp_path = tempfile.mkdtemp() + + silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + silence) + self._fake_audio_track_path = os.path.join(self._output_path, + 'fake.wav') + signal_processing.SignalProcessingUtils.SaveWav( + self._fake_audio_track_path, fake_signal) + + def tearDown(self): + """Recursively delete temporary folders.""" + shutil.rmtree(self._output_path) + shutil.rmtree(self._tmp_path) + + def testSimulation(self): + # Instance dependencies to mock and inject. + ap_wrapper = audioproc_wrapper.AudioProcWrapper( + audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH) + evaluator = evaluation.ApmModuleEvaluator() + ap_wrapper.Run = mock.MagicMock(name='Run') + evaluator.Run = mock.MagicMock(name='Run') + + # Instance non-mocked dependencies. + test_data_generator_factory = ( + test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path='', + noise_tracks_path='', + copy_with_identity=False)) + evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory( + polqa_tool_bin_path=os.path.join(os.path.dirname(__file__), + 'fake_polqa'), + echo_metric_tool_bin_path=None) + + # Instance simulator. + simulator = simulation.ApmModuleSimulator( + test_data_generator_factory=test_data_generator_factory, + evaluation_score_factory=evaluation_score_factory, + ap_wrapper=ap_wrapper, + evaluator=evaluator, + external_vads={ + 'fake': + external_vad.ExternalVad( + os.path.join(os.path.dirname(__file__), + 'fake_external_vad.py'), 'fake') + }) + + # What to simulate. + config_files = ['apm_configs/default.json'] + input_files = [self._fake_audio_track_path] + test_data_generators = ['identity', 'white_noise'] + eval_scores = ['audio_level_mean', 'polqa'] + + # Run all simulations. + simulator.Run(config_filepaths=config_files, + capture_input_filepaths=input_files, + test_data_generator_names=test_data_generators, + eval_score_names=eval_scores, + output_dir=self._output_path) + + # Check. + # TODO(alessiob): Once the TestDataGenerator classes can be configured by + # the client code (e.g., number of SNR pairs for the white noise test data + # generator), the exact number of calls to ap_wrapper.Run and evaluator.Run + # is known; use that with assertEqual. + min_number_of_simulations = len(config_files) * len(input_files) * len( + test_data_generators) + self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list), + min_number_of_simulations) + self.assertGreaterEqual(len(evaluator.Run.call_args_list), + min_number_of_simulations) + + def testInputSignalCreation(self): + # Instance simulator. + simulator = simulation.ApmModuleSimulator( + test_data_generator_factory=( + test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path='', + noise_tracks_path='', + copy_with_identity=False)), + evaluation_score_factory=( + eval_scores_factory.EvaluationScoreWorkerFactory( + polqa_tool_bin_path=os.path.join(os.path.dirname(__file__), + 'fake_polqa'), + echo_metric_tool_bin_path=None)), + ap_wrapper=audioproc_wrapper.AudioProcWrapper( + audioproc_wrapper.AudioProcWrapper. + DEFAULT_APM_SIMULATOR_BIN_PATH), + evaluator=evaluation.ApmModuleEvaluator()) + + # Inexistent input files to be silently created. + input_files = [ + os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'), + os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'), + ] + self.assertFalse( + any([os.path.exists(input_file) for input_file in (input_files)])) + + # The input files are created during the simulation. + simulator.Run(config_filepaths=['apm_configs/default.json'], + capture_input_filepaths=input_files, + test_data_generator_names=['identity'], + eval_score_names=['audio_level_peak'], + output_dir=self._output_path) + self.assertTrue( + all([os.path.exists(input_file) for input_file in (input_files)])) + + def testPureToneGenerationWithTotalHarmonicDistorsion(self): + logging.warning = mock.MagicMock(name='warning') + + # Instance simulator. + simulator = simulation.ApmModuleSimulator( + test_data_generator_factory=( + test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path='', + noise_tracks_path='', + copy_with_identity=False)), + evaluation_score_factory=( + eval_scores_factory.EvaluationScoreWorkerFactory( + polqa_tool_bin_path=os.path.join(os.path.dirname(__file__), + 'fake_polqa'), + echo_metric_tool_bin_path=None)), + ap_wrapper=audioproc_wrapper.AudioProcWrapper( + audioproc_wrapper.AudioProcWrapper. + DEFAULT_APM_SIMULATOR_BIN_PATH), + evaluator=evaluation.ApmModuleEvaluator()) + + # What to simulate. + config_files = ['apm_configs/default.json'] + input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')] + eval_scores = ['thd'] + + # Should work. + simulator.Run(config_filepaths=config_files, + capture_input_filepaths=input_files, + test_data_generator_names=['identity'], + eval_score_names=eval_scores, + output_dir=self._output_path) + self.assertFalse(logging.warning.called) + + # Warning expected. + simulator.Run( + config_filepaths=config_files, + capture_input_filepaths=input_files, + test_data_generator_names=['white_noise'], # Not allowed with THD. + eval_score_names=eval_scores, + output_dir=self._output_path) + logging.warning.assert_called_with('the evaluation failed: %s', ( + 'The THD score cannot be used with any test data generator other than ' + '"identity"')) + + # # Init. + # generator = test_data_generation.IdentityTestDataGenerator('tmp') + # input_signal_filepath = os.path.join( + # self._test_data_cache_path, 'pure_tone-440_1000.wav') + + # # Check that the input signal is generated. + # self.assertFalse(os.path.exists(input_signal_filepath)) + # generator.Generate( + # input_signal_filepath=input_signal_filepath, + # test_data_cache_path=self._test_data_cache_path, + # base_output_path=self._base_output_path) + # self.assertTrue(os.path.exists(input_signal_filepath)) + + # # Check input signal properties. + # input_signal = signal_processing.SignalProcessingUtils.LoadWav( + # input_signal_filepath) + # self.assertEqual(1000, len(input_signal)) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/sound_level.cc b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/sound_level.cc new file mode 100644 index 0000000000..1f24d9d370 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/sound_level.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. An additional intellectual property rights grant can be found +// in the file PATENTS. All contributing project authors may +// be found in the AUTHORS file in the root of the source tree. + +#include <algorithm> +#include <array> +#include <cmath> +#include <fstream> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "common_audio/include/audio_util.h" +#include "common_audio/wav_file.h" +#include "rtc_base/logging.h" + +ABSL_FLAG(std::string, i, "", "Input wav file"); +ABSL_FLAG(std::string, oc, "", "Config output file"); +ABSL_FLAG(std::string, ol, "", "Levels output file"); +ABSL_FLAG(float, a, 5.f, "Attack (ms)"); +ABSL_FLAG(float, d, 20.f, "Decay (ms)"); +ABSL_FLAG(int, f, 10, "Frame length (ms)"); + +namespace webrtc { +namespace test { +namespace { + +constexpr int kMaxSampleRate = 48000; +constexpr uint8_t kMaxFrameLenMs = 30; +constexpr size_t kMaxFrameLen = kMaxFrameLenMs * kMaxSampleRate / 1000; + +const double kOneDbReduction = DbToRatio(-1.0); + +int main(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); + // Check parameters. + if (absl::GetFlag(FLAGS_f) < 1 || absl::GetFlag(FLAGS_f) > kMaxFrameLenMs) { + RTC_LOG(LS_ERROR) << "Invalid frame length (min: 1, max: " << kMaxFrameLenMs + << ")"; + return 1; + } + if (absl::GetFlag(FLAGS_a) < 0 || absl::GetFlag(FLAGS_d) < 0) { + RTC_LOG(LS_ERROR) << "Attack and decay must be non-negative"; + return 1; + } + + // Open wav input file and check properties. + const std::string input_file = absl::GetFlag(FLAGS_i); + const std::string config_output_file = absl::GetFlag(FLAGS_oc); + const std::string levels_output_file = absl::GetFlag(FLAGS_ol); + WavReader wav_reader(input_file); + if (wav_reader.num_channels() != 1) { + RTC_LOG(LS_ERROR) << "Only mono wav files supported"; + return 1; + } + if (wav_reader.sample_rate() > kMaxSampleRate) { + RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate + << ")"; + return 1; + } + + // Map from milliseconds to samples. + const size_t audio_frame_length = rtc::CheckedDivExact( + absl::GetFlag(FLAGS_f) * wav_reader.sample_rate(), 1000); + auto time_const = [](double c) { + return std::pow(kOneDbReduction, absl::GetFlag(FLAGS_f) / c); + }; + const float attack = + absl::GetFlag(FLAGS_a) == 0.0 ? 0.0 : time_const(absl::GetFlag(FLAGS_a)); + const float decay = + absl::GetFlag(FLAGS_d) == 0.0 ? 0.0 : time_const(absl::GetFlag(FLAGS_d)); + + // Write config to file. + std::ofstream out_config(config_output_file); + out_config << "{" + "'frame_len_ms': " + << absl::GetFlag(FLAGS_f) + << ", " + "'attack_ms': " + << absl::GetFlag(FLAGS_a) + << ", " + "'decay_ms': " + << absl::GetFlag(FLAGS_d) << "}\n"; + out_config.close(); + + // Measure level frame-by-frame. + std::ofstream out_levels(levels_output_file, std::ofstream::binary); + std::array<int16_t, kMaxFrameLen> samples; + float level_prev = 0.f; + while (true) { + // Process frame. + const auto read_samples = + wav_reader.ReadSamples(audio_frame_length, samples.data()); + if (read_samples < audio_frame_length) + break; // EOF. + + // Frame peak level. + std::transform(samples.begin(), samples.begin() + audio_frame_length, + samples.begin(), [](int16_t s) { return std::abs(s); }); + const int16_t peak_level = *std::max_element( + samples.cbegin(), samples.cbegin() + audio_frame_length); + const float level_curr = static_cast<float>(peak_level) / 32768.f; + + // Temporal smoothing. + auto smooth = [&level_prev, &level_curr](float c) { + return (1.0 - c) * level_curr + c * level_prev; + }; + level_prev = smooth(level_curr > level_prev ? attack : decay); + + // Write output. + out_levels.write(reinterpret_cast<const char*>(&level_prev), sizeof(float)); + } + out_levels.close(); + + return 0; +} + +} // namespace +} // namespace test +} // namespace webrtc + +int main(int argc, char* argv[]) { + return webrtc::test::main(argc, argv); +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py new file mode 100644 index 0000000000..7e86faccec --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py @@ -0,0 +1,526 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Test data generators producing signals pairs intended to be used to +test the APM module. Each pair consists of a noisy input and a reference signal. +The former is used as APM input and it is generated by adding noise to a +clean audio track. The reference is the expected APM output. + +Throughout this file, the following naming convention is used: + - input signal: the clean signal (e.g., speech), + - noise signal: the noise to be summed up to the input signal (e.g., white + noise, Gaussian noise), + - noisy signal: input + noise. +The noise signal may or may not be a function of the clean signal. For +instance, white noise is independently generated, whereas reverberation is +obtained by convolving the input signal with an impulse response. +""" + +import logging +import os +import shutil +import sys + +try: + import scipy.io +except ImportError: + logging.critical('Cannot import the third-party Python package scipy') + sys.exit(1) + +from . import data_access +from . import exceptions +from . import signal_processing + + +class TestDataGenerator(object): + """Abstract class responsible for the generation of noisy signals. + + Given a clean signal, it generates two streams named noisy signal and + reference. The former is the clean signal deteriorated by the noise source, + the latter goes through the same deterioration process, but more "gently". + Noisy signal and reference are produced so that the reference is the signal + expected at the output of the APM module when the latter is fed with the noisy + signal. + + An test data generator generates one or more pairs. + """ + + NAME = None + REGISTERED_CLASSES = {} + + def __init__(self, output_directory_prefix): + self._output_directory_prefix = output_directory_prefix + # Init dictionaries with one entry for each test data generator + # configuration (e.g., different SNRs). + # Noisy audio track files (stored separately in a cache folder). + self._noisy_signal_filepaths = None + # Path to be used for the APM simulation output files. + self._apm_output_paths = None + # Reference audio track files (stored separately in a cache folder). + self._reference_signal_filepaths = None + self.Clear() + + @classmethod + def RegisterClass(cls, class_to_register): + """Registers a TestDataGenerator implementation. + + Decorator to automatically register the classes that extend + TestDataGenerator. + Example usage: + + @TestDataGenerator.RegisterClass + class IdentityGenerator(TestDataGenerator): + pass + """ + cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register + return class_to_register + + @property + def config_names(self): + return self._noisy_signal_filepaths.keys() + + @property + def noisy_signal_filepaths(self): + return self._noisy_signal_filepaths + + @property + def apm_output_paths(self): + return self._apm_output_paths + + @property + def reference_signal_filepaths(self): + return self._reference_signal_filepaths + + def Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + """Generates a set of noisy input and reference audiotrack file pairs. + + This method initializes an empty set of pairs and calls the _Generate() + method implemented in a concrete class. + + Args: + input_signal_filepath: path to the clean input audio track file. + test_data_cache_path: path to the cache of the generated audio track + files. + base_output_path: base path where output is written. + """ + self.Clear() + self._Generate(input_signal_filepath, test_data_cache_path, + base_output_path) + + def Clear(self): + """Clears the generated output path dictionaries. + """ + self._noisy_signal_filepaths = {} + self._apm_output_paths = {} + self._reference_signal_filepaths = {} + + def _Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + """Abstract method to be implemented in each concrete class. + """ + raise NotImplementedError() + + def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths, + snr_value_pairs): + """Adds noisy-reference signal pairs. + + Args: + base_output_path: noisy tracks base output path. + noisy_mix_filepaths: nested dictionary of noisy signal paths organized + by noisy track name and SNR level. + snr_value_pairs: list of SNR pairs. + """ + for noise_track_name in noisy_mix_filepaths: + for snr_noisy, snr_refence in snr_value_pairs: + config_name = '{0}_{1:d}_{2:d}_SNR'.format( + noise_track_name, snr_noisy, snr_refence) + output_path = self._MakeDir(base_output_path, config_name) + self._AddNoiseReferenceFilesPair( + config_name=config_name, + noisy_signal_filepath=noisy_mix_filepaths[noise_track_name] + [snr_noisy], + reference_signal_filepath=noisy_mix_filepaths[ + noise_track_name][snr_refence], + output_path=output_path) + + def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath, + reference_signal_filepath, output_path): + """Adds one noisy-reference signal pair. + + Args: + config_name: name of the APM configuration. + noisy_signal_filepath: path to noisy audio track file. + reference_signal_filepath: path to reference audio track file. + output_path: APM output path. + """ + assert config_name not in self._noisy_signal_filepaths + self._noisy_signal_filepaths[config_name] = os.path.abspath( + noisy_signal_filepath) + self._apm_output_paths[config_name] = os.path.abspath(output_path) + self._reference_signal_filepaths[config_name] = os.path.abspath( + reference_signal_filepath) + + def _MakeDir(self, base_output_path, test_data_generator_config_name): + output_path = os.path.join( + base_output_path, + self._output_directory_prefix + test_data_generator_config_name) + data_access.MakeDirectory(output_path) + return output_path + + +@TestDataGenerator.RegisterClass +class IdentityTestDataGenerator(TestDataGenerator): + """Generator that adds no noise. + + Both the noisy and the reference signals are the input signal. + """ + + NAME = 'identity' + + def __init__(self, output_directory_prefix, copy_with_identity): + TestDataGenerator.__init__(self, output_directory_prefix) + self._copy_with_identity = copy_with_identity + + @property + def copy_with_identity(self): + return self._copy_with_identity + + def _Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + config_name = 'default' + output_path = self._MakeDir(base_output_path, config_name) + + if self._copy_with_identity: + input_signal_filepath_new = os.path.join( + test_data_cache_path, + os.path.split(input_signal_filepath)[1]) + logging.info('copying ' + input_signal_filepath + ' to ' + + (input_signal_filepath_new)) + shutil.copy(input_signal_filepath, input_signal_filepath_new) + input_signal_filepath = input_signal_filepath_new + + self._AddNoiseReferenceFilesPair( + config_name=config_name, + noisy_signal_filepath=input_signal_filepath, + reference_signal_filepath=input_signal_filepath, + output_path=output_path) + + +@TestDataGenerator.RegisterClass +class WhiteNoiseTestDataGenerator(TestDataGenerator): + """Generator that adds white noise. + """ + + NAME = 'white_noise' + + # Each pair indicates the clean vs. noisy and reference vs. noisy SNRs. + # The reference (second value of each pair) always has a lower amount of noise + # - i.e., the SNR is 10 dB higher. + _SNR_VALUE_PAIRS = [ + [20, 30], # Smallest noise. + [10, 20], + [5, 15], + [0, 10], # Largest noise. + ] + + _NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav' + + def __init__(self, output_directory_prefix): + TestDataGenerator.__init__(self, output_directory_prefix) + + def _Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + # Load the input signal. + input_signal = signal_processing.SignalProcessingUtils.LoadWav( + input_signal_filepath) + + # Create the noise track. + noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + input_signal) + + # Create the noisy mixes (once for each unique SNR value). + noisy_mix_filepaths = {} + snr_values = set( + [snr for pair in self._SNR_VALUE_PAIRS for snr in pair]) + for snr in snr_values: + noisy_signal_filepath = os.path.join( + test_data_cache_path, + self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr)) + + # Create and save if not done. + if not os.path.exists(noisy_signal_filepath): + # Create noisy signal. + noisy_signal = signal_processing.SignalProcessingUtils.MixSignals( + input_signal, noise_signal, snr) + + # Save. + signal_processing.SignalProcessingUtils.SaveWav( + noisy_signal_filepath, noisy_signal) + + # Add file to the collection of mixes. + noisy_mix_filepaths[snr] = noisy_signal_filepath + + # Add all the noisy-reference signal pairs. + for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS: + config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence) + output_path = self._MakeDir(base_output_path, config_name) + self._AddNoiseReferenceFilesPair( + config_name=config_name, + noisy_signal_filepath=noisy_mix_filepaths[snr_noisy], + reference_signal_filepath=noisy_mix_filepaths[snr_refence], + output_path=output_path) + + +# TODO(alessiob): remove comment when class implemented. +# @TestDataGenerator.RegisterClass +class NarrowBandNoiseTestDataGenerator(TestDataGenerator): + """Generator that adds narrow-band noise. + """ + + NAME = 'narrow_band_noise' + + def __init__(self, output_directory_prefix): + TestDataGenerator.__init__(self, output_directory_prefix) + + def _Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + # TODO(alessiob): implement. + pass + + +@TestDataGenerator.RegisterClass +class AdditiveNoiseTestDataGenerator(TestDataGenerator): + """Generator that adds noise loops. + + This generator uses all the wav files in a given path (default: noise_tracks/) + and mixes them to the clean speech with different target SNRs (hard-coded). + """ + + NAME = 'additive_noise' + _NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav' + + DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__), + os.pardir, 'noise_tracks') + + # TODO(alessiob): Make the list of SNR pairs customizable. + # Each pair indicates the clean vs. noisy and reference vs. noisy SNRs. + # The reference (second value of each pair) always has a lower amount of noise + # - i.e., the SNR is 10 dB higher. + _SNR_VALUE_PAIRS = [ + [20, 30], # Smallest noise. + [10, 20], + [5, 15], + [0, 10], # Largest noise. + ] + + def __init__(self, output_directory_prefix, noise_tracks_path): + TestDataGenerator.__init__(self, output_directory_prefix) + self._noise_tracks_path = noise_tracks_path + self._noise_tracks_file_names = [ + n for n in os.listdir(self._noise_tracks_path) + if n.lower().endswith('.wav') + ] + if len(self._noise_tracks_file_names) == 0: + raise exceptions.InitializationException( + 'No wav files found in the noise tracks path %s' % + (self._noise_tracks_path)) + + def _Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + """Generates test data pairs using environmental noise. + + For each noise track and pair of SNR values, the following two audio tracks + are created: the noisy signal and the reference signal. The former is + obtained by mixing the (clean) input signal to the corresponding noise + track enforcing the target SNR. + """ + # Init. + snr_values = set( + [snr for pair in self._SNR_VALUE_PAIRS for snr in pair]) + + # Load the input signal. + input_signal = signal_processing.SignalProcessingUtils.LoadWav( + input_signal_filepath) + + noisy_mix_filepaths = {} + for noise_track_filename in self._noise_tracks_file_names: + # Load the noise track. + noise_track_name, _ = os.path.splitext(noise_track_filename) + noise_track_filepath = os.path.join(self._noise_tracks_path, + noise_track_filename) + if not os.path.exists(noise_track_filepath): + logging.error('cannot find the <%s> noise track', + noise_track_filename) + raise exceptions.FileNotFoundError() + + noise_signal = signal_processing.SignalProcessingUtils.LoadWav( + noise_track_filepath) + + # Create the noisy mixes (once for each unique SNR value). + noisy_mix_filepaths[noise_track_name] = {} + for snr in snr_values: + noisy_signal_filepath = os.path.join( + test_data_cache_path, + self._NOISY_SIGNAL_FILENAME_TEMPLATE.format( + noise_track_name, snr)) + + # Create and save if not done. + if not os.path.exists(noisy_signal_filepath): + # Create noisy signal. + noisy_signal = signal_processing.SignalProcessingUtils.MixSignals( + input_signal, + noise_signal, + snr, + pad_noise=signal_processing.SignalProcessingUtils. + MixPadding.LOOP) + + # Save. + signal_processing.SignalProcessingUtils.SaveWav( + noisy_signal_filepath, noisy_signal) + + # Add file to the collection of mixes. + noisy_mix_filepaths[noise_track_name][ + snr] = noisy_signal_filepath + + # Add all the noise-SNR pairs. + self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths, + self._SNR_VALUE_PAIRS) + + +@TestDataGenerator.RegisterClass +class ReverberationTestDataGenerator(TestDataGenerator): + """Generator that adds reverberation noise. + + TODO(alessiob): Make this class more generic since the impulse response can be + anything (not just reverberation); call it e.g., + ConvolutionalNoiseTestDataGenerator. + """ + + NAME = 'reverberation' + + _IMPULSE_RESPONSES = { + 'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo. + 'booth': 'air_binaural_booth_0_0_1.mat', # Short echo. + } + _MAX_IMPULSE_RESPONSE_LENGTH = None + + # Each pair indicates the clean vs. noisy and reference vs. noisy SNRs. + # The reference (second value of each pair) always has a lower amount of noise + # - i.e., the SNR is 5 dB higher. + _SNR_VALUE_PAIRS = [ + [3, 8], # Smallest noise. + [-3, 2], # Largest noise. + ] + + _NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav' + _NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav' + + def __init__(self, output_directory_prefix, aechen_ir_database_path): + TestDataGenerator.__init__(self, output_directory_prefix) + self._aechen_ir_database_path = aechen_ir_database_path + + def _Generate(self, input_signal_filepath, test_data_cache_path, + base_output_path): + """Generates test data pairs using reverberation noise. + + For each impulse response, one noise track is created. For each impulse + response and pair of SNR values, the following 2 audio tracks are + created: the noisy signal and the reference signal. The former is + obtained by mixing the (clean) input signal to the corresponding noise + track enforcing the target SNR. + """ + # Init. + snr_values = set( + [snr for pair in self._SNR_VALUE_PAIRS for snr in pair]) + + # Load the input signal. + input_signal = signal_processing.SignalProcessingUtils.LoadWav( + input_signal_filepath) + + noisy_mix_filepaths = {} + for impulse_response_name in self._IMPULSE_RESPONSES: + noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format( + impulse_response_name) + noise_track_filepath = os.path.join(test_data_cache_path, + noise_track_filename) + noise_signal = None + try: + # Load noise track. + noise_signal = signal_processing.SignalProcessingUtils.LoadWav( + noise_track_filepath) + except exceptions.FileNotFoundError: + # Generate noise track by applying the impulse response. + impulse_response_filepath = os.path.join( + self._aechen_ir_database_path, + self._IMPULSE_RESPONSES[impulse_response_name]) + noise_signal = self._GenerateNoiseTrack( + noise_track_filepath, input_signal, + impulse_response_filepath) + assert noise_signal is not None + + # Create the noisy mixes (once for each unique SNR value). + noisy_mix_filepaths[impulse_response_name] = {} + for snr in snr_values: + noisy_signal_filepath = os.path.join( + test_data_cache_path, + self._NOISY_SIGNAL_FILENAME_TEMPLATE.format( + impulse_response_name, snr)) + + # Create and save if not done. + if not os.path.exists(noisy_signal_filepath): + # Create noisy signal. + noisy_signal = signal_processing.SignalProcessingUtils.MixSignals( + input_signal, noise_signal, snr) + + # Save. + signal_processing.SignalProcessingUtils.SaveWav( + noisy_signal_filepath, noisy_signal) + + # Add file to the collection of mixes. + noisy_mix_filepaths[impulse_response_name][ + snr] = noisy_signal_filepath + + # Add all the noise-SNR pairs. + self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths, + self._SNR_VALUE_PAIRS) + + def _GenerateNoiseTrack(self, noise_track_filepath, input_signal, + impulse_response_filepath): + """Generates noise track. + + Generate a signal by convolving input_signal with the impulse response in + impulse_response_filepath; then save to noise_track_filepath. + + Args: + noise_track_filepath: output file path for the noise track. + input_signal: (clean) input signal samples. + impulse_response_filepath: impulse response file path. + + Returns: + AudioSegment instance. + """ + # Load impulse response. + data = scipy.io.loadmat(impulse_response_filepath) + impulse_response = data['h_air'].flatten() + if self._MAX_IMPULSE_RESPONSE_LENGTH is not None: + logging.info('truncating impulse response from %d to %d samples', + len(impulse_response), + self._MAX_IMPULSE_RESPONSE_LENGTH) + impulse_response = impulse_response[:self. + _MAX_IMPULSE_RESPONSE_LENGTH] + + # Apply impulse response. + processed_signal = ( + signal_processing.SignalProcessingUtils.ApplyImpulseResponse( + input_signal, impulse_response)) + + # Save. + signal_processing.SignalProcessingUtils.SaveWav( + noise_track_filepath, processed_signal) + + return processed_signal diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py new file mode 100644 index 0000000000..948888e775 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py @@ -0,0 +1,71 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""TestDataGenerator factory class. +""" + +import logging + +from . import exceptions +from . import test_data_generation + + +class TestDataGeneratorFactory(object): + """Factory class used to create test data generators. + + Usage: Create a factory passing parameters to the ctor with which the + generators will be produced. + """ + + def __init__(self, aechen_ir_database_path, noise_tracks_path, + copy_with_identity): + """Ctor. + + Args: + aechen_ir_database_path: Path to the Aechen Impulse Response database. + noise_tracks_path: Path to the noise tracks to add. + copy_with_identity: Flag indicating whether the identity generator has to + make copies of the clean speech input files. + """ + self._output_directory_prefix = None + self._aechen_ir_database_path = aechen_ir_database_path + self._noise_tracks_path = noise_tracks_path + self._copy_with_identity = copy_with_identity + + def SetOutputDirectoryPrefix(self, prefix): + self._output_directory_prefix = prefix + + def GetInstance(self, test_data_generators_class): + """Creates an TestDataGenerator instance given a class object. + + Args: + test_data_generators_class: TestDataGenerator class object (not an + instance). + + Returns: + TestDataGenerator instance. + """ + if self._output_directory_prefix is None: + raise exceptions.InitializationException( + 'The output directory prefix for test data generators is not set' + ) + logging.debug('factory producing %s', test_data_generators_class) + + if test_data_generators_class == ( + test_data_generation.IdentityTestDataGenerator): + return test_data_generation.IdentityTestDataGenerator( + self._output_directory_prefix, self._copy_with_identity) + elif test_data_generators_class == ( + test_data_generation.ReverberationTestDataGenerator): + return test_data_generation.ReverberationTestDataGenerator( + self._output_directory_prefix, self._aechen_ir_database_path) + elif test_data_generators_class == ( + test_data_generation.AdditiveNoiseTestDataGenerator): + return test_data_generation.AdditiveNoiseTestDataGenerator( + self._output_directory_prefix, self._noise_tracks_path) + else: + return test_data_generators_class(self._output_directory_prefix) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py new file mode 100644 index 0000000000..f75098ae2c --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py @@ -0,0 +1,207 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +"""Unit tests for the test_data_generation module. +""" + +import os +import shutil +import tempfile +import unittest + +import numpy as np +import scipy.io + +from . import test_data_generation +from . import test_data_generation_factory +from . import signal_processing + + +class TestTestDataGenerators(unittest.TestCase): + """Unit tests for the test_data_generation module. + """ + + def setUp(self): + """Create temporary folders.""" + self._base_output_path = tempfile.mkdtemp() + self._test_data_cache_path = tempfile.mkdtemp() + self._fake_air_db_path = tempfile.mkdtemp() + + # Fake AIR DB impulse responses. + # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom + # impulse responses. When changed, the coupling below between + # impulse_response_mat_file_names and + # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed. + impulse_response_mat_file_names = [ + 'air_binaural_lecture_0_0_1.mat', + 'air_binaural_booth_0_0_1.mat', + ] + for impulse_response_mat_file_name in impulse_response_mat_file_names: + data = {'h_air': np.random.rand(1, 1000).astype('<f8')} + scipy.io.savemat( + os.path.join(self._fake_air_db_path, + impulse_response_mat_file_name), data) + + def tearDown(self): + """Recursively delete temporary folders.""" + shutil.rmtree(self._base_output_path) + shutil.rmtree(self._test_data_cache_path) + shutil.rmtree(self._fake_air_db_path) + + def testTestDataGenerators(self): + # Preliminary check. + self.assertTrue(os.path.exists(self._base_output_path)) + self.assertTrue(os.path.exists(self._test_data_cache_path)) + + # Check that there is at least one registered test data generator. + registered_classes = ( + test_data_generation.TestDataGenerator.REGISTERED_CLASSES) + self.assertIsInstance(registered_classes, dict) + self.assertGreater(len(registered_classes), 0) + + # Instance generators factory. + generators_factory = test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path=self._fake_air_db_path, + noise_tracks_path=test_data_generation. \ + AdditiveNoiseTestDataGenerator. \ + DEFAULT_NOISE_TRACKS_PATH, + copy_with_identity=False) + generators_factory.SetOutputDirectoryPrefix('datagen-') + + # Use a simple input file as clean input signal. + input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals', + 'tone-880.wav') + self.assertTrue(os.path.exists(input_signal_filepath)) + + # Load input signal. + input_signal = signal_processing.SignalProcessingUtils.LoadWav( + input_signal_filepath) + + # Try each registered test data generator. + for generator_name in registered_classes: + # Instance test data generator. + generator = generators_factory.GetInstance( + registered_classes[generator_name]) + + # Generate the noisy input - reference pairs. + generator.Generate(input_signal_filepath=input_signal_filepath, + test_data_cache_path=self._test_data_cache_path, + base_output_path=self._base_output_path) + + # Perform checks. + self._CheckGeneratedPairsListSizes(generator) + self._CheckGeneratedPairsSignalDurations(generator, input_signal) + self._CheckGeneratedPairsOutputPaths(generator) + + def testTestidentityDataGenerator(self): + # Preliminary check. + self.assertTrue(os.path.exists(self._base_output_path)) + self.assertTrue(os.path.exists(self._test_data_cache_path)) + + # Use a simple input file as clean input signal. + input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals', + 'tone-880.wav') + self.assertTrue(os.path.exists(input_signal_filepath)) + + def GetNoiseReferenceFilePaths(identity_generator): + noisy_signal_filepaths = identity_generator.noisy_signal_filepaths + reference_signal_filepaths = identity_generator.reference_signal_filepaths + assert noisy_signal_filepaths.keys( + ) == reference_signal_filepaths.keys() + assert len(noisy_signal_filepaths.keys()) == 1 + key = noisy_signal_filepaths.keys()[0] + return noisy_signal_filepaths[key], reference_signal_filepaths[key] + + # Test the `copy_with_identity` flag. + for copy_with_identity in [False, True]: + # Instance the generator through the factory. + factory = test_data_generation_factory.TestDataGeneratorFactory( + aechen_ir_database_path='', + noise_tracks_path='', + copy_with_identity=copy_with_identity) + factory.SetOutputDirectoryPrefix('datagen-') + generator = factory.GetInstance( + test_data_generation.IdentityTestDataGenerator) + # Check `copy_with_identity` is set correctly. + self.assertEqual(copy_with_identity, generator.copy_with_identity) + + # Generate test data and extract the paths to the noise and the reference + # files. + generator.Generate(input_signal_filepath=input_signal_filepath, + test_data_cache_path=self._test_data_cache_path, + base_output_path=self._base_output_path) + noisy_signal_filepath, reference_signal_filepath = ( + GetNoiseReferenceFilePaths(generator)) + + # Check that a copy is made if and only if `copy_with_identity` is True. + if copy_with_identity: + self.assertNotEqual(noisy_signal_filepath, + input_signal_filepath) + self.assertNotEqual(reference_signal_filepath, + input_signal_filepath) + else: + self.assertEqual(noisy_signal_filepath, input_signal_filepath) + self.assertEqual(reference_signal_filepath, + input_signal_filepath) + + def _CheckGeneratedPairsListSizes(self, generator): + config_names = generator.config_names + number_of_pairs = len(config_names) + self.assertEqual(number_of_pairs, + len(generator.noisy_signal_filepaths)) + self.assertEqual(number_of_pairs, len(generator.apm_output_paths)) + self.assertEqual(number_of_pairs, + len(generator.reference_signal_filepaths)) + + def _CheckGeneratedPairsSignalDurations(self, generator, input_signal): + """Checks duration of the generated signals. + + Checks that the noisy input and the reference tracks are audio files + with duration equal to or greater than that of the input signal. + + Args: + generator: TestDataGenerator instance. + input_signal: AudioSegment instance. + """ + input_signal_length = ( + signal_processing.SignalProcessingUtils.CountSamples(input_signal)) + + # Iterate over the noisy signal - reference pairs. + for config_name in generator.config_names: + # Load the noisy input file. + noisy_signal_filepath = generator.noisy_signal_filepaths[ + config_name] + noisy_signal = signal_processing.SignalProcessingUtils.LoadWav( + noisy_signal_filepath) + + # Check noisy input signal length. + noisy_signal_length = (signal_processing.SignalProcessingUtils. + CountSamples(noisy_signal)) + self.assertGreaterEqual(noisy_signal_length, input_signal_length) + + # Load the reference file. + reference_signal_filepath = generator.reference_signal_filepaths[ + config_name] + reference_signal = signal_processing.SignalProcessingUtils.LoadWav( + reference_signal_filepath) + + # Check noisy input signal length. + reference_signal_length = (signal_processing.SignalProcessingUtils. + CountSamples(reference_signal)) + self.assertGreaterEqual(reference_signal_length, + input_signal_length) + + def _CheckGeneratedPairsOutputPaths(self, generator): + """Checks that the output path created by the generator exists. + + Args: + generator: TestDataGenerator instance. + """ + # Iterate over the noisy signal - reference pairs. + for config_name in generator.config_names: + output_path = generator.apm_output_paths[config_name] + self.assertTrue(os.path.exists(output_path)) diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/vad.cc b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/vad.cc new file mode 100644 index 0000000000..b47f6221cb --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/vad.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. An additional intellectual property rights grant can be found +// in the file PATENTS. All contributing project authors may +// be found in the AUTHORS file in the root of the source tree. + +#include "common_audio/vad/include/vad.h" + +#include <array> +#include <fstream> +#include <memory> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "common_audio/wav_file.h" +#include "rtc_base/logging.h" + +ABSL_FLAG(std::string, i, "", "Input wav file"); +ABSL_FLAG(std::string, o, "", "VAD output file"); + +namespace webrtc { +namespace test { +namespace { + +// The allowed values are 10, 20 or 30 ms. +constexpr uint8_t kAudioFrameLengthMilliseconds = 30; +constexpr int kMaxSampleRate = 48000; +constexpr size_t kMaxFrameLen = + kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000; + +constexpr uint8_t kBitmaskBuffSize = 8; + +int main(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); + const std::string input_file = absl::GetFlag(FLAGS_i); + const std::string output_file = absl::GetFlag(FLAGS_o); + // Open wav input file and check properties. + WavReader wav_reader(input_file); + if (wav_reader.num_channels() != 1) { + RTC_LOG(LS_ERROR) << "Only mono wav files supported"; + return 1; + } + if (wav_reader.sample_rate() > kMaxSampleRate) { + RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate + << ")"; + return 1; + } + const size_t audio_frame_length = rtc::CheckedDivExact( + kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000); + if (audio_frame_length > kMaxFrameLen) { + RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large."; + return 1; + } + + // Create output file and write header. + std::ofstream out_file(output_file, std::ofstream::binary); + const char audio_frame_length_ms = kAudioFrameLengthMilliseconds; + out_file.write(&audio_frame_length_ms, 1); // Header. + + // Run VAD and write decisions. + std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal); + std::array<int16_t, kMaxFrameLen> samples; + char buff = 0; // Buffer to write one bit per frame. + uint8_t next = 0; // Points to the next bit to write in `buff`. + while (true) { + // Process frame. + const auto read_samples = + wav_reader.ReadSamples(audio_frame_length, samples.data()); + if (read_samples < audio_frame_length) + break; + const auto is_speech = vad->VoiceActivity( + samples.data(), audio_frame_length, wav_reader.sample_rate()); + + // Write output. + buff = is_speech ? buff | (1 << next) : buff & ~(1 << next); + if (++next == kBitmaskBuffSize) { + out_file.write(&buff, 1); // Flush. + buff = 0; // Reset. + next = 0; + } + } + + // Finalize. + char extra_bits = 0; + if (next > 0) { + extra_bits = kBitmaskBuffSize - next; + out_file.write(&buff, 1); // Flush. + } + out_file.write(&extra_bits, 1); + out_file.close(); + + return 0; +} + +} // namespace +} // namespace test +} // namespace webrtc + +int main(int argc, char* argv[]) { + return webrtc::test::main(argc, argv); +} diff --git a/third_party/libwebrtc/modules/audio_processing/test/runtime_setting_util.cc b/third_party/libwebrtc/modules/audio_processing/test/runtime_setting_util.cc new file mode 100644 index 0000000000..4899d2d459 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/runtime_setting_util.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/runtime_setting_util.h" + +#include "rtc_base/checks.h" + +namespace webrtc { + +void ReplayRuntimeSetting(AudioProcessing* apm, + const webrtc::audioproc::RuntimeSetting& setting) { + RTC_CHECK(apm); + // TODO(bugs.webrtc.org/9138): Add ability to handle different types + // of settings. Currently CapturePreGain, CaptureFixedPostGain and + // PlayoutVolumeChange are supported. + RTC_CHECK(setting.has_capture_pre_gain() || + setting.has_capture_fixed_post_gain() || + setting.has_playout_volume_change()); + + if (setting.has_capture_pre_gain()) { + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCapturePreGain( + setting.capture_pre_gain())); + } else if (setting.has_capture_fixed_post_gain()) { + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureFixedPostGain( + setting.capture_fixed_post_gain())); + } else if (setting.has_playout_volume_change()) { + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreatePlayoutVolumeChange( + setting.playout_volume_change())); + } else if (setting.has_playout_audio_device_change()) { + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreatePlayoutAudioDeviceChange( + {setting.playout_audio_device_change().id(), + setting.playout_audio_device_change().max_volume()})); + } else if (setting.has_capture_output_used()) { + apm->SetRuntimeSetting( + AudioProcessing::RuntimeSetting::CreateCaptureOutputUsedSetting( + setting.capture_output_used())); + } +} +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/runtime_setting_util.h b/third_party/libwebrtc/modules/audio_processing/test/runtime_setting_util.h new file mode 100644 index 0000000000..d8cbe82076 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/runtime_setting_util.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_RUNTIME_SETTING_UTIL_H_ +#define MODULES_AUDIO_PROCESSING_TEST_RUNTIME_SETTING_UTIL_H_ + +#include "modules/audio_processing/include/audio_processing.h" +#include "modules/audio_processing/test/protobuf_utils.h" + +namespace webrtc { + +void ReplayRuntimeSetting(AudioProcessing* apm, + const webrtc::audioproc::RuntimeSetting& setting); +} + +#endif // MODULES_AUDIO_PROCESSING_TEST_RUNTIME_SETTING_UTIL_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/simulator_buffers.cc b/third_party/libwebrtc/modules/audio_processing/test/simulator_buffers.cc new file mode 100644 index 0000000000..458f6ced76 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/simulator_buffers.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/simulator_buffers.h" + +#include "modules/audio_processing/test/audio_buffer_tools.h" +#include "rtc_base/checks.h" + +namespace webrtc { +namespace test { + +SimulatorBuffers::SimulatorBuffers(int render_input_sample_rate_hz, + int capture_input_sample_rate_hz, + int render_output_sample_rate_hz, + int capture_output_sample_rate_hz, + size_t num_render_input_channels, + size_t num_capture_input_channels, + size_t num_render_output_channels, + size_t num_capture_output_channels) { + Random rand_gen(42); + CreateConfigAndBuffer(render_input_sample_rate_hz, num_render_input_channels, + &rand_gen, &render_input_buffer, &render_input_config, + &render_input, &render_input_samples); + + CreateConfigAndBuffer(render_output_sample_rate_hz, + num_render_output_channels, &rand_gen, + &render_output_buffer, &render_output_config, + &render_output, &render_output_samples); + + CreateConfigAndBuffer(capture_input_sample_rate_hz, + num_capture_input_channels, &rand_gen, + &capture_input_buffer, &capture_input_config, + &capture_input, &capture_input_samples); + + CreateConfigAndBuffer(capture_output_sample_rate_hz, + num_capture_output_channels, &rand_gen, + &capture_output_buffer, &capture_output_config, + &capture_output, &capture_output_samples); + + UpdateInputBuffers(); +} + +SimulatorBuffers::~SimulatorBuffers() = default; + +void SimulatorBuffers::CreateConfigAndBuffer( + int sample_rate_hz, + size_t num_channels, + Random* rand_gen, + std::unique_ptr<AudioBuffer>* buffer, + StreamConfig* config, + std::vector<float*>* buffer_data, + std::vector<float>* buffer_data_samples) { + int samples_per_channel = rtc::CheckedDivExact(sample_rate_hz, 100); + *config = StreamConfig(sample_rate_hz, num_channels); + buffer->reset( + new AudioBuffer(config->sample_rate_hz(), config->num_channels(), + config->sample_rate_hz(), config->num_channels(), + config->sample_rate_hz(), config->num_channels())); + + buffer_data_samples->resize(samples_per_channel * num_channels); + for (auto& v : *buffer_data_samples) { + v = rand_gen->Rand<float>(); + } + + buffer_data->resize(num_channels); + for (size_t ch = 0; ch < num_channels; ++ch) { + (*buffer_data)[ch] = &(*buffer_data_samples)[ch * samples_per_channel]; + } +} + +void SimulatorBuffers::UpdateInputBuffers() { + test::CopyVectorToAudioBuffer(capture_input_config, capture_input_samples, + capture_input_buffer.get()); + test::CopyVectorToAudioBuffer(render_input_config, render_input_samples, + render_input_buffer.get()); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/simulator_buffers.h b/third_party/libwebrtc/modules/audio_processing/test/simulator_buffers.h new file mode 100644 index 0000000000..36dcf301a2 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/simulator_buffers.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_SIMULATOR_BUFFERS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_SIMULATOR_BUFFERS_H_ + +#include <memory> +#include <vector> + +#include "modules/audio_processing/audio_buffer.h" +#include "modules/audio_processing/include/audio_processing.h" +#include "rtc_base/random.h" + +namespace webrtc { +namespace test { + +struct SimulatorBuffers { + SimulatorBuffers(int render_input_sample_rate_hz, + int capture_input_sample_rate_hz, + int render_output_sample_rate_hz, + int capture_output_sample_rate_hz, + size_t num_render_input_channels, + size_t num_capture_input_channels, + size_t num_render_output_channels, + size_t num_capture_output_channels); + ~SimulatorBuffers(); + + void CreateConfigAndBuffer(int sample_rate_hz, + size_t num_channels, + Random* rand_gen, + std::unique_ptr<AudioBuffer>* buffer, + StreamConfig* config, + std::vector<float*>* buffer_data, + std::vector<float>* buffer_data_samples); + + void UpdateInputBuffers(); + + std::unique_ptr<AudioBuffer> render_input_buffer; + std::unique_ptr<AudioBuffer> capture_input_buffer; + std::unique_ptr<AudioBuffer> render_output_buffer; + std::unique_ptr<AudioBuffer> capture_output_buffer; + StreamConfig render_input_config; + StreamConfig capture_input_config; + StreamConfig render_output_config; + StreamConfig capture_output_config; + std::vector<float*> render_input; + std::vector<float> render_input_samples; + std::vector<float*> capture_input; + std::vector<float> capture_input_samples; + std::vector<float*> render_output; + std::vector<float> render_output_samples; + std::vector<float*> capture_output; + std::vector<float> capture_output_samples; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_SIMULATOR_BUFFERS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/test_utils.cc b/third_party/libwebrtc/modules/audio_processing/test/test_utils.cc new file mode 100644 index 0000000000..9aeebe5155 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/test_utils.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2015 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/test_utils.h" + +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "rtc_base/checks.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { + +ChannelBufferWavReader::ChannelBufferWavReader(std::unique_ptr<WavReader> file) + : file_(std::move(file)) {} + +ChannelBufferWavReader::~ChannelBufferWavReader() = default; + +bool ChannelBufferWavReader::Read(ChannelBuffer<float>* buffer) { + RTC_CHECK_EQ(file_->num_channels(), buffer->num_channels()); + interleaved_.resize(buffer->size()); + if (file_->ReadSamples(interleaved_.size(), &interleaved_[0]) != + interleaved_.size()) { + return false; + } + + FloatS16ToFloat(&interleaved_[0], interleaved_.size(), &interleaved_[0]); + Deinterleave(&interleaved_[0], buffer->num_frames(), buffer->num_channels(), + buffer->channels()); + return true; +} + +ChannelBufferWavWriter::ChannelBufferWavWriter(std::unique_ptr<WavWriter> file) + : file_(std::move(file)) {} + +ChannelBufferWavWriter::~ChannelBufferWavWriter() = default; + +void ChannelBufferWavWriter::Write(const ChannelBuffer<float>& buffer) { + RTC_CHECK_EQ(file_->num_channels(), buffer.num_channels()); + interleaved_.resize(buffer.size()); + Interleave(buffer.channels(), buffer.num_frames(), buffer.num_channels(), + &interleaved_[0]); + FloatToFloatS16(&interleaved_[0], interleaved_.size(), &interleaved_[0]); + file_->WriteSamples(&interleaved_[0], interleaved_.size()); +} + +ChannelBufferVectorWriter::ChannelBufferVectorWriter(std::vector<float>* output) + : output_(output) { + RTC_DCHECK(output_); +} + +ChannelBufferVectorWriter::~ChannelBufferVectorWriter() = default; + +void ChannelBufferVectorWriter::Write(const ChannelBuffer<float>& buffer) { + // Account for sample rate changes throughout a simulation. + interleaved_buffer_.resize(buffer.size()); + Interleave(buffer.channels(), buffer.num_frames(), buffer.num_channels(), + interleaved_buffer_.data()); + size_t old_size = output_->size(); + output_->resize(old_size + interleaved_buffer_.size()); + FloatToFloatS16(interleaved_buffer_.data(), interleaved_buffer_.size(), + output_->data() + old_size); +} + +FILE* OpenFile(absl::string_view filename, absl::string_view mode) { + std::string filename_str(filename); + FILE* file = fopen(filename_str.c_str(), std::string(mode).c_str()); + if (!file) { + printf("Unable to open file %s\n", filename_str.c_str()); + exit(1); + } + return file; +} + +void SetFrameSampleRate(Int16FrameData* frame, int sample_rate_hz) { + frame->sample_rate_hz = sample_rate_hz; + frame->samples_per_channel = + AudioProcessing::kChunkSizeMs * sample_rate_hz / 1000; +} + +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/test_utils.h b/third_party/libwebrtc/modules/audio_processing/test/test_utils.h new file mode 100644 index 0000000000..bf82f9d66d --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/test_utils.h @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_TEST_UTILS_H_ +#define MODULES_AUDIO_PROCESSING_TEST_TEST_UTILS_H_ + +#include <math.h> + +#include <iterator> +#include <limits> +#include <memory> +#include <sstream> // no-presubmit-check TODO(webrtc:8982) +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "common_audio/channel_buffer.h" +#include "common_audio/wav_file.h" +#include "modules/audio_processing/include/audio_processing.h" + +namespace webrtc { + +static const AudioProcessing::Error kNoErr = AudioProcessing::kNoError; +#define EXPECT_NOERR(expr) EXPECT_EQ(kNoErr, (expr)) + +// Encapsulates samples and metadata for an integer frame. +struct Int16FrameData { + // Max data size that matches the data size of the AudioFrame class, providing + // storage for 8 channels of 96 kHz data. + static const int kMaxDataSizeSamples = 7680; + + Int16FrameData() { + sample_rate_hz = 0; + num_channels = 0; + samples_per_channel = 0; + data.fill(0); + } + + void CopyFrom(const Int16FrameData& src) { + samples_per_channel = src.samples_per_channel; + sample_rate_hz = src.sample_rate_hz; + num_channels = src.num_channels; + + const size_t length = samples_per_channel * num_channels; + RTC_CHECK_LE(length, kMaxDataSizeSamples); + memcpy(data.data(), src.data.data(), sizeof(int16_t) * length); + } + std::array<int16_t, kMaxDataSizeSamples> data; + int32_t sample_rate_hz; + size_t num_channels; + size_t samples_per_channel; +}; + +// Reads ChannelBuffers from a provided WavReader. +class ChannelBufferWavReader final { + public: + explicit ChannelBufferWavReader(std::unique_ptr<WavReader> file); + ~ChannelBufferWavReader(); + + ChannelBufferWavReader(const ChannelBufferWavReader&) = delete; + ChannelBufferWavReader& operator=(const ChannelBufferWavReader&) = delete; + + // Reads data from the file according to the `buffer` format. Returns false if + // a full buffer can't be read from the file. + bool Read(ChannelBuffer<float>* buffer); + + private: + std::unique_ptr<WavReader> file_; + std::vector<float> interleaved_; +}; + +// Writes ChannelBuffers to a provided WavWriter. +class ChannelBufferWavWriter final { + public: + explicit ChannelBufferWavWriter(std::unique_ptr<WavWriter> file); + ~ChannelBufferWavWriter(); + + ChannelBufferWavWriter(const ChannelBufferWavWriter&) = delete; + ChannelBufferWavWriter& operator=(const ChannelBufferWavWriter&) = delete; + + void Write(const ChannelBuffer<float>& buffer); + + private: + std::unique_ptr<WavWriter> file_; + std::vector<float> interleaved_; +}; + +// Takes a pointer to a vector. Allows appending the samples of channel buffers +// to the given vector, by interleaving the samples and converting them to float +// S16. +class ChannelBufferVectorWriter final { + public: + explicit ChannelBufferVectorWriter(std::vector<float>* output); + ChannelBufferVectorWriter(const ChannelBufferVectorWriter&) = delete; + ChannelBufferVectorWriter& operator=(const ChannelBufferVectorWriter&) = + delete; + ~ChannelBufferVectorWriter(); + + // Creates an interleaved copy of `buffer`, converts the samples to float S16 + // and appends the result to output_. + void Write(const ChannelBuffer<float>& buffer); + + private: + std::vector<float> interleaved_buffer_; + std::vector<float>* output_; +}; + +// Exits on failure; do not use in unit tests. +FILE* OpenFile(absl::string_view filename, absl::string_view mode); + +void SetFrameSampleRate(Int16FrameData* frame, int sample_rate_hz); + +template <typename T> +void SetContainerFormat(int sample_rate_hz, + size_t num_channels, + Int16FrameData* frame, + std::unique_ptr<ChannelBuffer<T> >* cb) { + SetFrameSampleRate(frame, sample_rate_hz); + frame->num_channels = num_channels; + cb->reset(new ChannelBuffer<T>(frame->samples_per_channel, num_channels)); +} + +template <typename T> +float ComputeSNR(const T* ref, const T* test, size_t length, float* variance) { + float mse = 0; + float mean = 0; + *variance = 0; + for (size_t i = 0; i < length; ++i) { + T error = ref[i] - test[i]; + mse += error * error; + *variance += ref[i] * ref[i]; + mean += ref[i]; + } + mse /= length; + *variance /= length; + mean /= length; + *variance -= mean * mean; + + float snr = 100; // We assign 100 dB to the zero-error case. + if (mse > 0) + snr = 10 * log10(*variance / mse); + return snr; +} + +// Returns a vector<T> parsed from whitespace delimited values in to_parse, +// or an empty vector if the string could not be parsed. +template <typename T> +std::vector<T> ParseList(absl::string_view to_parse) { + std::vector<T> values; + + std::istringstream str( // no-presubmit-check TODO(webrtc:8982) + std::string{to_parse}); + std::copy( + std::istream_iterator<T>(str), // no-presubmit-check TODO(webrtc:8982) + std::istream_iterator<T>(), // no-presubmit-check TODO(webrtc:8982) + std::back_inserter(values)); + + return values; +} + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_TEST_UTILS_H_ diff --git a/third_party/libwebrtc/modules/audio_processing/test/unittest.proto b/third_party/libwebrtc/modules/audio_processing/test/unittest.proto new file mode 100644 index 0000000000..07d1cda6c8 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/unittest.proto @@ -0,0 +1,48 @@ +syntax = "proto2"; +option optimize_for = LITE_RUNTIME; +package webrtc.audioproc; + +message Test { + optional int32 num_reverse_channels = 1; + optional int32 num_input_channels = 2; + optional int32 num_output_channels = 3; + optional int32 sample_rate = 4; + + message Frame { + } + + repeated Frame frame = 5; + + optional int32 analog_level_average = 6; + optional int32 max_output_average = 7; + optional int32 has_voice_count = 9; + optional int32 is_saturated_count = 10; + + message EchoMetrics { + optional float echo_return_loss = 1; + optional float echo_return_loss_enhancement = 2; + optional float divergent_filter_fraction = 3; + optional float residual_echo_likelihood = 4; + optional float residual_echo_likelihood_recent_max = 5; + } + + repeated EchoMetrics echo_metrics = 11; + + message DelayMetrics { + optional int32 median = 1; + optional int32 std = 2; + } + + repeated DelayMetrics delay_metrics = 12; + + optional float rms_dbfs_average = 13; + + optional float ns_speech_probability_average = 14; + + optional bool use_aec_extended_filter = 15; +} + +message OutputData { + repeated Test test = 1; +} + diff --git a/third_party/libwebrtc/modules/audio_processing/test/wav_based_simulator.cc b/third_party/libwebrtc/modules/audio_processing/test/wav_based_simulator.cc new file mode 100644 index 0000000000..ee87f9e1a8 --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/wav_based_simulator.cc @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/test/wav_based_simulator.h" + +#include <stdio.h> + +#include <iostream> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/logging/apm_data_dumper.h" +#include "modules/audio_processing/test/test_utils.h" +#include "rtc_base/checks.h" +#include "rtc_base/system/file_wrapper.h" + +namespace webrtc { +namespace test { + +std::vector<WavBasedSimulator::SimulationEventType> +WavBasedSimulator::GetCustomEventChain(absl::string_view filename) { + std::vector<WavBasedSimulator::SimulationEventType> call_chain; + FileWrapper file_wrapper = FileWrapper::OpenReadOnly(filename); + + RTC_CHECK(file_wrapper.is_open()) + << "Could not open the custom call order file, reverting " + "to using the default call order"; + + char c; + size_t num_read = file_wrapper.Read(&c, sizeof(char)); + while (num_read > 0) { + switch (c) { + case 'r': + call_chain.push_back(SimulationEventType::kProcessReverseStream); + break; + case 'c': + call_chain.push_back(SimulationEventType::kProcessStream); + break; + case '\n': + break; + default: + RTC_FATAL() << "Incorrect custom call order file"; + } + + num_read = file_wrapper.Read(&c, sizeof(char)); + } + + return call_chain; +} + +WavBasedSimulator::WavBasedSimulator( + const SimulationSettings& settings, + rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder) + : AudioProcessingSimulator(settings, + std::move(audio_processing), + std::move(ap_builder)) { + if (settings_.call_order_input_filename) { + call_chain_ = WavBasedSimulator::GetCustomEventChain( + *settings_.call_order_input_filename); + } else { + call_chain_ = WavBasedSimulator::GetDefaultEventChain(); + } +} + +WavBasedSimulator::~WavBasedSimulator() = default; + +std::vector<WavBasedSimulator::SimulationEventType> +WavBasedSimulator::GetDefaultEventChain() { + std::vector<WavBasedSimulator::SimulationEventType> call_chain(2); + call_chain[0] = SimulationEventType::kProcessStream; + call_chain[1] = SimulationEventType::kProcessReverseStream; + return call_chain; +} + +void WavBasedSimulator::PrepareProcessStreamCall() { + if (settings_.fixed_interface) { + fwd_frame_.CopyFrom(*in_buf_); + } + ap_->set_stream_key_pressed(settings_.override_key_pressed.value_or(false)); + + if (!settings_.use_stream_delay || *settings_.use_stream_delay) { + RTC_CHECK_EQ(AudioProcessing::kNoError, + ap_->set_stream_delay_ms( + settings_.stream_delay ? *settings_.stream_delay : 0)); + } +} + +void WavBasedSimulator::PrepareReverseProcessStreamCall() { + if (settings_.fixed_interface) { + rev_frame_.CopyFrom(*reverse_in_buf_); + } +} + +void WavBasedSimulator::Process() { + ConfigureAudioProcessor(); + + Initialize(); + + bool samples_left_to_process = true; + int call_chain_index = 0; + int capture_frames_since_init = 0; + constexpr int kInitIndex = 1; + while (samples_left_to_process) { + switch (call_chain_[call_chain_index]) { + case SimulationEventType::kProcessStream: + SelectivelyToggleDataDumping(kInitIndex, capture_frames_since_init); + + samples_left_to_process = HandleProcessStreamCall(); + ++capture_frames_since_init; + break; + case SimulationEventType::kProcessReverseStream: + if (settings_.reverse_input_filename) { + samples_left_to_process = HandleProcessReverseStreamCall(); + } + break; + default: + RTC_CHECK_NOTREACHED(); + } + + call_chain_index = (call_chain_index + 1) % call_chain_.size(); + } + + DetachAecDump(); +} + +void WavBasedSimulator::Analyze() { + std::cout << "Inits:" << std::endl; + std::cout << "1: -->" << std::endl; + std::cout << " Time:" << std::endl; + std::cout << " Capture: 0 s (0 frames) " << std::endl; + std::cout << " Render: 0 s (0 frames)" << std::endl; +} + +bool WavBasedSimulator::HandleProcessStreamCall() { + bool samples_left_to_process = buffer_reader_->Read(in_buf_.get()); + if (samples_left_to_process) { + PrepareProcessStreamCall(); + ProcessStream(settings_.fixed_interface); + } + return samples_left_to_process; +} + +bool WavBasedSimulator::HandleProcessReverseStreamCall() { + bool samples_left_to_process = + reverse_buffer_reader_->Read(reverse_in_buf_.get()); + if (samples_left_to_process) { + PrepareReverseProcessStreamCall(); + ProcessReverseStream(settings_.fixed_interface); + } + return samples_left_to_process; +} + +void WavBasedSimulator::Initialize() { + std::unique_ptr<WavReader> in_file( + new WavReader(settings_.input_filename->c_str())); + int input_sample_rate_hz = in_file->sample_rate(); + int input_num_channels = in_file->num_channels(); + buffer_reader_.reset(new ChannelBufferWavReader(std::move(in_file))); + + int output_sample_rate_hz = settings_.output_sample_rate_hz + ? *settings_.output_sample_rate_hz + : input_sample_rate_hz; + int output_num_channels = settings_.output_num_channels + ? *settings_.output_num_channels + : input_num_channels; + + int reverse_sample_rate_hz = 48000; + int reverse_num_channels = 1; + int reverse_output_sample_rate_hz = 48000; + int reverse_output_num_channels = 1; + if (settings_.reverse_input_filename) { + std::unique_ptr<WavReader> reverse_in_file( + new WavReader(settings_.reverse_input_filename->c_str())); + reverse_sample_rate_hz = reverse_in_file->sample_rate(); + reverse_num_channels = reverse_in_file->num_channels(); + reverse_buffer_reader_.reset( + new ChannelBufferWavReader(std::move(reverse_in_file))); + + reverse_output_sample_rate_hz = + settings_.reverse_output_sample_rate_hz + ? *settings_.reverse_output_sample_rate_hz + : reverse_sample_rate_hz; + reverse_output_num_channels = settings_.reverse_output_num_channels + ? *settings_.reverse_output_num_channels + : reverse_num_channels; + } + + SetupBuffersConfigsOutputs( + input_sample_rate_hz, output_sample_rate_hz, reverse_sample_rate_hz, + reverse_output_sample_rate_hz, input_num_channels, output_num_channels, + reverse_num_channels, reverse_output_num_channels); +} + +} // namespace test +} // namespace webrtc diff --git a/third_party/libwebrtc/modules/audio_processing/test/wav_based_simulator.h b/third_party/libwebrtc/modules/audio_processing/test/wav_based_simulator.h new file mode 100644 index 0000000000..44e9ee2b7f --- /dev/null +++ b/third_party/libwebrtc/modules/audio_processing/test/wav_based_simulator.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_TEST_WAV_BASED_SIMULATOR_H_ +#define MODULES_AUDIO_PROCESSING_TEST_WAV_BASED_SIMULATOR_H_ + +#include <vector> + +#include "absl/strings/string_view.h" +#include "modules/audio_processing/test/audio_processing_simulator.h" + +namespace webrtc { +namespace test { + +// Used to perform an audio processing simulation from wav files. +class WavBasedSimulator final : public AudioProcessingSimulator { + public: + WavBasedSimulator(const SimulationSettings& settings, + rtc::scoped_refptr<AudioProcessing> audio_processing, + std::unique_ptr<AudioProcessingBuilder> ap_builder); + + WavBasedSimulator() = delete; + WavBasedSimulator(const WavBasedSimulator&) = delete; + WavBasedSimulator& operator=(const WavBasedSimulator&) = delete; + + ~WavBasedSimulator() override; + + // Processes the WAV input. + void Process() override; + + // Only analyzes the data for the simulation, instead of perform any + // processing. + void Analyze() override; + + private: + enum SimulationEventType { + kProcessStream, + kProcessReverseStream, + }; + + void Initialize(); + bool HandleProcessStreamCall(); + bool HandleProcessReverseStreamCall(); + void PrepareProcessStreamCall(); + void PrepareReverseProcessStreamCall(); + static std::vector<SimulationEventType> GetDefaultEventChain(); + static std::vector<SimulationEventType> GetCustomEventChain( + absl::string_view filename); + + std::vector<SimulationEventType> call_chain_; +}; + +} // namespace test +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_TEST_WAV_BASED_SIMULATOR_H_ |