#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://pytorch.org/hub/snakers4_silero-vad_vad/ https://github.com/snakers4/silero-vad """ import argparse from scipy.io import wavfile import torch from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--wav_file", default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(), type=str, ) parser.add_argument( "--model_name", default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(), type=str, ) parser.add_argument("--threshold", default=0.5, type=float) parser.add_argument("--min_speech_duration_ms", default=250, type=int) parser.add_argument("--speech_pad_ms", default=30, type=int) parser.add_argument("--max_speech_duration_s", default=float("inf"), type=float) parser.add_argument("--window_size_samples", default=512, type=int) parser.add_argument("--min_silence_duration_ms", default=100, type=int) args = parser.parse_args() return args def main(): args = get_args() with open(args.model_name, "rb") as f: model = torch.jit.load(f, map_location="cpu") model.reset_states() sample_rate, signal = wavfile.read(args.wav_file) signal = signal / 32768 signal = torch.tensor(signal, dtype=torch.float32) print(signal) min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000 speech_pad_samples = sample_rate * args.speech_pad_ms / 1000 max_speech_samples = sample_rate * args.max_speech_duration_s - args.window_size_samples - 2 * speech_pad_samples min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000 min_silence_samples_at_max_speech = sample_rate * 98 / 1000 # probs speech_probs = [] for start in range(0, len(signal), args.window_size_samples): chunk = signal[start: start + args.window_size_samples] if len(chunk) < args.window_size_samples: chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk)))) speech_prob = model(chunk, sample_rate).item() speech_probs.append(speech_prob) print(speech_probs) # segments triggered = False speeches = list() current_speech = dict() neg_threshold = args.threshold - 0.15 temp_end = 0 prev_end = next_start = 0 for i, speech_prob in enumerate(speech_probs): if (speech_prob >= args.threshold) and temp_end: temp_end = 0 if next_start < prev_end: next_start = args.window_size_samples * i if (speech_prob >= args.threshold) and not triggered: triggered = True current_speech["start"] = args.window_size_samples * i continue if triggered and (args.window_size_samples * i) - current_speech["start"] > max_speech_samples: if prev_end: current_speech["end"] = prev_end speeches.append(current_speech) current_speech = {} if next_start < prev_end: triggered = False else: current_speech["start"] = next_start prev_end = next_start = temp_end = 0 else: current_speech["end"] = args.window_size_samples * i speeches.append(current_speech) current_speech = {} prev_end = next_start = temp_end = 0 triggered = False continue if speech_prob < neg_threshold and triggered: if not temp_end: temp_end = args.window_size_samples * i if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: prev_end = temp_end if (args.window_size_samples * i) - temp_end < min_silence_samples: continue else: current_speech["end"] = temp_end if (current_speech["end"] - current_speech["start"]) > min_speech_samples: speeches.append(current_speech) current_speech = {} prev_end = next_start = temp_end = 0 triggered = False continue if current_speech and (args.audio_length_samples - current_speech["start"]) > min_speech_samples: current_speech["end"] = args.audio_length_samples speeches.append(current_speech) return if __name__ == '__main__': main()