#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import collections from typing import List import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile import webrtcvad from project_settings import project_path class Frame(object): def __init__(self, signal: np.ndarray, timestamp, duration): self.signal = signal self.timestamp = timestamp self.duration = duration class WebRTCVad(object): def __init__(self, agg: int = 3, frame_duration_ms: int = 30, padding_duration_ms: int = 300, silence_duration_threshold: float = 0.3, sample_rate: int = 8000 ): self.agg = agg self.frame_duration_ms = frame_duration_ms self.padding_duration_ms = padding_duration_ms self.silence_duration_threshold = silence_duration_threshold self.sample_rate = sample_rate self._vad = webrtcvad.Vad(mode=agg) # frames self.frame_length = int(sample_rate * (frame_duration_ms / 1000.0)) self.frame_timestamp = 0.0 self.signal_cache = None # segments self.num_padding_frames = int(padding_duration_ms / frame_duration_ms) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames: List[Frame] = list() self.segments = list() # vad segments self.is_first_segment = True self.timestamp_start = 0.0 self.timestamp_end = 0.0 def signal_to_frames(self, signal: np.ndarray): frames = list() l = len(signal) duration = (float(self.frame_length) / self.sample_rate) for offset in range(0, l, self.frame_length): sub_signal = signal[offset:offset+self.frame_length] frame = Frame(sub_signal, self.frame_timestamp, duration) self.frame_timestamp += duration frames.append(frame) return frames def segments_generator(self, signal: np.ndarray): # signal rounding if self.signal_cache is not None: signal = np.concatenate([self.signal_cache, signal]) rest = len(signal) % self.frame_length if rest == 0: self.signal_cache = None signal_ = signal else: self.signal_cache = signal[-rest:] signal_ = signal[:-rest] # frames frames = self.signal_to_frames(signal_) for frame in frames: audio_bytes = bytes(frame.signal) is_speech = self._vad.is_speech(audio_bytes, self.sample_rate) if not self.triggered: self.ring_buffer.append((frame, is_speech)) num_voiced = len([f for f, speech in self.ring_buffer if speech]) if num_voiced > 0.9 * self.ring_buffer.maxlen: self.triggered = True for f, _ in self.ring_buffer: self.voiced_frames.append(f) self.ring_buffer.clear() else: self.voiced_frames.append(frame) self.ring_buffer.append((frame, is_speech)) num_unvoiced = len([f for f, speech in self.ring_buffer if not speech]) if num_unvoiced > 0.9 * self.ring_buffer.maxlen: self.triggered = False segment = [ np.concatenate([f.signal for f in self.voiced_frames]), self.voiced_frames[0].timestamp, self.voiced_frames[-1].timestamp ] yield segment self.ring_buffer.clear() self.voiced_frames = [] def vad_segments_generator(self, segments_generator): segments = list(segments_generator) for i, segment in enumerate(segments): start = round(segment[1], 4) end = round(segment[2], 4) if self.is_first_segment: self.timestamp_start = start self.timestamp_end = end self.is_first_segment = False continue if self.timestamp_start: sil_duration = start - self.timestamp_end if sil_duration > self.silence_duration_threshold: vad_segment = [self.timestamp_start, self.timestamp_end] yield vad_segment self.timestamp_start = start self.timestamp_end = end else: self.timestamp_end = end def vad(self, signal: np.ndarray) -> List[list]: segments = self.segments_generator(signal) vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) return vad_segments def last_vad_segments(self) -> List[list]: # last segments if len(self.voiced_frames) == 0: segments = [] else: segment = [ np.concatenate([f.signal for f in self.voiced_frames]), self.voiced_frames[0].timestamp, self.voiced_frames[-1].timestamp ] segments = [segment] # last vad segments vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) vad_segments = vad_segments + [[self.timestamp_start, self.timestamp_end]] return vad_segments def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--wav_file", default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(), type=str, ) parser.add_argument( "--agg", default=3, type=int, help="The level of aggressiveness of the VAD: [0-3]'" ) parser.add_argument( "--frame_duration_ms", default=30, type=int, ) parser.add_argument( "--silence_duration_threshold", default=0.3, type=float, help="minimum silence duration, in seconds." ) args = parser.parse_args() return args SAMPLE_RATE = 8000 def main(): args = get_args() w_vad = WebRTCVad(sample_rate=SAMPLE_RATE) sample_rate, signal = wavfile.read(args.wav_file) if SAMPLE_RATE != sample_rate: raise AssertionError vad_segments = list() segments = w_vad.vad(signal) vad_segments += segments for segment in segments: print(segment) # last vad segment segments = w_vad.last_vad_segments() vad_segments += segments for segment in segments: print(segment) # plot time = np.arange(0, len(signal)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(time, signal / 32768, color='b') for start, end in vad_segments: plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点 plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点 plt.show() return if __name__ == '__main__': main()