summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py')
-rw-r--r--third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py
new file mode 100644
index 0000000000..a7db7b4840
--- /dev/null
+++ b/third_party/libwebrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py
@@ -0,0 +1,75 @@
+# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
+#
+# Use of this source code is governed by a BSD-style license
+# that can be found in the LICENSE file in the root of the source
+# tree. An additional intellectual property rights grant can be found
+# in the file PATENTS. All contributing project authors may
+# be found in the AUTHORS file in the root of the source tree.
+
+from __future__ import division
+
+import logging
+import os
+import subprocess
+import shutil
+import sys
+import tempfile
+
+try:
+ import numpy as np
+except ImportError:
+ logging.critical('Cannot import the third-party Python package numpy')
+ sys.exit(1)
+
+from . import signal_processing
+
+
+class ExternalVad(object):
+ def __init__(self, path_to_binary, name):
+ """Args:
+ path_to_binary: path to binary that accepts '-i <wav>', '-o
+ <float probabilities>'. There must be one float value per
+ 10ms audio
+ name: a name to identify the external VAD. Used for saving
+ the output as extvad_output-<name>.
+ """
+ self._path_to_binary = path_to_binary
+ self.name = name
+ assert os.path.exists(self._path_to_binary), (self._path_to_binary)
+ self._vad_output = None
+
+ def Run(self, wav_file_path):
+ _signal = signal_processing.SignalProcessingUtils.LoadWav(
+ wav_file_path)
+ if _signal.channels != 1:
+ raise NotImplementedError('Multiple-channel'
+ ' annotations not implemented')
+ if _signal.frame_rate != 48000:
+ raise NotImplementedError('Frame rates '
+ 'other than 48000 not implemented')
+
+ tmp_path = tempfile.mkdtemp()
+ try:
+ output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
+ subprocess.call([
+ self._path_to_binary, '-i', wav_file_path, '-o',
+ output_file_path
+ ])
+ self._vad_output = np.fromfile(output_file_path, np.float32)
+ except Exception as e:
+ logging.error('Error while running the ' + self.name + ' VAD (' +
+ e.message + ')')
+ finally:
+ if os.path.exists(tmp_path):
+ shutil.rmtree(tmp_path)
+
+ def GetVadOutput(self):
+ assert self._vad_output is not None
+ return self._vad_output
+
+ @classmethod
+ def ConstructVadDict(cls, vad_paths, vad_names):
+ external_vads = {}
+ for path, name in zip(vad_paths, vad_names):
+ external_vads[name] = ExternalVad(path, name)
+ return external_vads