#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://pytorch.org/hub/snakers4_silero-vad_vad/ https://github.com/snakers4/silero-vad """ import argparse import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile import torch from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--wav_file", default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(), type=str, ) parser.add_argument( "--model_name", default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(), type=str, ) parser.add_argument("--threshold", default=0.5, type=float) parser.add_argument("--min_speech_duration_ms", default=250, type=int) parser.add_argument("--speech_pad_ms", default=30, type=int) parser.add_argument("--max_speech_duration_s", default=float("inf"), type=float) parser.add_argument("--window_size_samples", default=512, type=int) parser.add_argument("--min_silence_duration_ms", default=100, type=int) args = parser.parse_args() return args def make_visualization(probs, step): import pandas as pd pd.DataFrame({'probs': probs}, index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8), kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], xlabel='seconds', ylabel='speech probability', colormap='tab20') def plot(signal, sample_rate, speeches): time = np.arange(0, len(signal)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(time, signal / 32768, color="b") for speech in speeches: start = speech["start"] end = speech["end"] plt.axvline(x=start, ymin=0.25, ymax=0.75, color="g", linestyle="--") plt.axvline(x=end, ymin=0.25, ymax=0.75, color="r", linestyle="--") plt.show() return def main(): args = get_args() with open(args.model_name, "rb") as f: model = torch.jit.load(f, map_location="cpu") model.reset_states() sample_rate, signal = wavfile.read(args.wav_file) signal = signal / 32768 signal = torch.tensor(signal, dtype=torch.float32) min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000 speech_pad_samples = sample_rate * args.speech_pad_ms / 1000 max_speech_samples = sample_rate * args.max_speech_duration_s - args.window_size_samples - 2 * speech_pad_samples min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000 min_silence_samples_at_max_speech = sample_rate * 98 / 1000 audio_length_samples = len(signal) # probs speech_probs = [] for start in range(0, audio_length_samples, args.window_size_samples): chunk = signal[start: start + args.window_size_samples] if len(chunk) < args.window_size_samples: chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk)))) speech_prob = model(chunk, sample_rate).item() speech_probs.append(speech_prob) # segments triggered = False speeches = list() current_speech = dict() neg_threshold = args.threshold - 0.15 temp_end = 0 prev_end = next_start = 0 for i, speech_prob in enumerate(speech_probs): if (speech_prob >= args.threshold) and temp_end: temp_end = 0 if next_start < prev_end: next_start = args.window_size_samples * i if (speech_prob >= args.threshold) and not triggered: triggered = True current_speech["start"] = args.window_size_samples * i continue if triggered and (args.window_size_samples * i) - current_speech["start"] > max_speech_samples: if prev_end: current_speech["end"] = prev_end speeches.append(current_speech) current_speech = {} if next_start < prev_end: triggered = False else: current_speech["start"] = next_start prev_end = next_start = temp_end = 0 else: current_speech["end"] = args.window_size_samples * i speeches.append(current_speech) current_speech = {} prev_end = next_start = temp_end = 0 triggered = False continue if speech_prob < neg_threshold and triggered: if not temp_end: temp_end = args.window_size_samples * i if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: prev_end = temp_end if (args.window_size_samples * i) - temp_end < min_silence_samples: continue else: current_speech["end"] = temp_end if (current_speech["end"] - current_speech["start"]) > min_speech_samples: speeches.append(current_speech) current_speech = {} prev_end = next_start = temp_end = 0 triggered = False continue if current_speech and (audio_length_samples - current_speech["start"]) > min_speech_samples: current_speech["end"] = audio_length_samples speeches.append(current_speech) for i, speech in enumerate(speeches): if i == 0: speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) if i != len(speeches) - 1: silence_duration = speeches[i+1]["start"] - speech["end"] if silence_duration < 2 * speech_pad_samples: speech["end"] += int(silence_duration // 2) speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - silence_duration // 2)) else: speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples)) speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - speech_pad_samples)) else: speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples)) # in seconds for speech_dict in speeches: speech_dict["start"] = round(speech_dict["start"] / sample_rate, 1) speech_dict["end"] = round(speech_dict["end"] / sample_rate, 1) print(speeches) plot(signal, sample_rate, speeches) return if __name__ == '__main__': main()