qgyd2021's picture
update
7e17176
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://pytorch.org/hub/snakers4_silero-vad_vad/
https://github.com/snakers4/silero-vad
"""
import argparse
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import torch
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(),
type=str,
)
parser.add_argument(
"--model_name",
default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
type=str,
)
parser.add_argument("--threshold", default=0.5, type=float)
parser.add_argument("--min_speech_duration_ms", default=250, type=int)
parser.add_argument("--speech_pad_ms", default=30, type=int)
parser.add_argument("--max_speech_duration_s", default=float("inf"), type=float)
parser.add_argument("--window_size_samples", default=512, type=int)
parser.add_argument("--min_silence_duration_ms", default=100, type=int)
args = parser.parse_args()
return args
def make_visualization(probs, step):
import pandas as pd
pd.DataFrame({'probs': probs},
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
xlabel='seconds',
ylabel='speech probability',
colormap='tab20')
def plot(signal, sample_rate, speeches):
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color="b")
for speech in speeches:
start = speech["start"]
end = speech["end"]
plt.axvline(x=start, ymin=0.25, ymax=0.75, color="g", linestyle="--")
plt.axvline(x=end, ymin=0.25, ymax=0.75, color="r", linestyle="--")
plt.show()
return
def main():
args = get_args()
with open(args.model_name, "rb") as f:
model = torch.jit.load(f, map_location="cpu")
model.reset_states()
sample_rate, signal = wavfile.read(args.wav_file)
signal = signal / 32768
signal = torch.tensor(signal, dtype=torch.float32)
min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000
speech_pad_samples = sample_rate * args.speech_pad_ms / 1000
max_speech_samples = sample_rate * args.max_speech_duration_s - args.window_size_samples - 2 * speech_pad_samples
min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sample_rate * 98 / 1000
audio_length_samples = len(signal)
# probs
speech_probs = []
for start in range(0, audio_length_samples, args.window_size_samples):
chunk = signal[start: start + args.window_size_samples]
if len(chunk) < args.window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk))))
speech_prob = model(chunk, sample_rate).item()
speech_probs.append(speech_prob)
# segments
triggered = False
speeches = list()
current_speech = dict()
neg_threshold = args.threshold - 0.15
temp_end = 0
prev_end = next_start = 0
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= args.threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = args.window_size_samples * i
if (speech_prob >= args.threshold) and not triggered:
triggered = True
current_speech["start"] = args.window_size_samples * i
continue
if triggered and (args.window_size_samples * i) - current_speech["start"] > max_speech_samples:
if prev_end:
current_speech["end"] = prev_end
speeches.append(current_speech)
current_speech = {}
if next_start < prev_end:
triggered = False
else:
current_speech["start"] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech["end"] = args.window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if speech_prob < neg_threshold and triggered:
if not temp_end:
temp_end = args.window_size_samples * i
if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech:
prev_end = temp_end
if (args.window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech["end"] = temp_end
if (current_speech["end"] - current_speech["start"]) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if current_speech and (audio_length_samples - current_speech["start"]) > min_speech_samples:
current_speech["end"] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i+1]["start"] - speech["end"]
if silence_duration < 2 * speech_pad_samples:
speech["end"] += int(silence_duration // 2)
speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - silence_duration // 2))
else:
speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples))
speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - speech_pad_samples))
else:
speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples))
# in seconds
for speech_dict in speeches:
speech_dict["start"] = round(speech_dict["start"] / sample_rate, 1)
speech_dict["end"] = round(speech_dict["end"] / sample_rate, 1)
print(speeches)
plot(signal, sample_rate, speeches)
return
if __name__ == '__main__':
main()