|
|
|
|
|
""" |
|
https://pytorch.org/hub/snakers4_silero-vad_vad/ |
|
https://github.com/snakers4/silero-vad |
|
""" |
|
import argparse |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import torch |
|
|
|
from project_settings import project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--wav_file", |
|
default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument("--threshold", default=0.5, type=float) |
|
parser.add_argument("--min_speech_duration_ms", default=250, type=int) |
|
parser.add_argument("--speech_pad_ms", default=30, type=int) |
|
parser.add_argument("--max_speech_duration_s", default=float("inf"), type=float) |
|
parser.add_argument("--window_size_samples", default=512, type=int) |
|
parser.add_argument("--min_silence_duration_ms", default=100, type=int) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def make_visualization(probs, step): |
|
import pandas as pd |
|
pd.DataFrame({'probs': probs}, |
|
index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8), |
|
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], |
|
xlabel='seconds', |
|
ylabel='speech probability', |
|
colormap='tab20') |
|
|
|
|
|
def plot(signal, sample_rate, speeches): |
|
time = np.arange(0, len(signal)) / sample_rate |
|
|
|
plt.figure(figsize=(12, 5)) |
|
|
|
plt.plot(time, signal / 32768, color="b") |
|
|
|
for speech in speeches: |
|
start = speech["start"] |
|
end = speech["end"] |
|
plt.axvline(x=start, ymin=0.25, ymax=0.75, color="g", linestyle="--") |
|
plt.axvline(x=end, ymin=0.25, ymax=0.75, color="r", linestyle="--") |
|
|
|
plt.show() |
|
return |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
with open(args.model_name, "rb") as f: |
|
model = torch.jit.load(f, map_location="cpu") |
|
model.reset_states() |
|
|
|
sample_rate, signal = wavfile.read(args.wav_file) |
|
signal = signal / 32768 |
|
signal = torch.tensor(signal, dtype=torch.float32) |
|
|
|
min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000 |
|
speech_pad_samples = sample_rate * args.speech_pad_ms / 1000 |
|
max_speech_samples = sample_rate * args.max_speech_duration_s - args.window_size_samples - 2 * speech_pad_samples |
|
min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000 |
|
min_silence_samples_at_max_speech = sample_rate * 98 / 1000 |
|
|
|
audio_length_samples = len(signal) |
|
|
|
|
|
speech_probs = [] |
|
for start in range(0, audio_length_samples, args.window_size_samples): |
|
chunk = signal[start: start + args.window_size_samples] |
|
if len(chunk) < args.window_size_samples: |
|
chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk)))) |
|
|
|
speech_prob = model(chunk, sample_rate).item() |
|
speech_probs.append(speech_prob) |
|
|
|
|
|
triggered = False |
|
speeches = list() |
|
current_speech = dict() |
|
neg_threshold = args.threshold - 0.15 |
|
temp_end = 0 |
|
prev_end = next_start = 0 |
|
|
|
for i, speech_prob in enumerate(speech_probs): |
|
if (speech_prob >= args.threshold) and temp_end: |
|
temp_end = 0 |
|
if next_start < prev_end: |
|
next_start = args.window_size_samples * i |
|
|
|
if (speech_prob >= args.threshold) and not triggered: |
|
triggered = True |
|
current_speech["start"] = args.window_size_samples * i |
|
continue |
|
|
|
if triggered and (args.window_size_samples * i) - current_speech["start"] > max_speech_samples: |
|
if prev_end: |
|
current_speech["end"] = prev_end |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
if next_start < prev_end: |
|
triggered = False |
|
else: |
|
current_speech["start"] = next_start |
|
prev_end = next_start = temp_end = 0 |
|
else: |
|
current_speech["end"] = args.window_size_samples * i |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
prev_end = next_start = temp_end = 0 |
|
triggered = False |
|
continue |
|
|
|
if speech_prob < neg_threshold and triggered: |
|
if not temp_end: |
|
temp_end = args.window_size_samples * i |
|
if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: |
|
prev_end = temp_end |
|
|
|
if (args.window_size_samples * i) - temp_end < min_silence_samples: |
|
continue |
|
else: |
|
current_speech["end"] = temp_end |
|
if (current_speech["end"] - current_speech["start"]) > min_speech_samples: |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
prev_end = next_start = temp_end = 0 |
|
triggered = False |
|
continue |
|
|
|
if current_speech and (audio_length_samples - current_speech["start"]) > min_speech_samples: |
|
current_speech["end"] = audio_length_samples |
|
speeches.append(current_speech) |
|
|
|
for i, speech in enumerate(speeches): |
|
if i == 0: |
|
speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) |
|
if i != len(speeches) - 1: |
|
silence_duration = speeches[i+1]["start"] - speech["end"] |
|
if silence_duration < 2 * speech_pad_samples: |
|
speech["end"] += int(silence_duration // 2) |
|
speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - silence_duration // 2)) |
|
else: |
|
speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples)) |
|
speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - speech_pad_samples)) |
|
else: |
|
speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples)) |
|
|
|
|
|
for speech_dict in speeches: |
|
speech_dict["start"] = round(speech_dict["start"] / sample_rate, 1) |
|
speech_dict["end"] = round(speech_dict["end"] / sample_rate, 1) |
|
|
|
print(speeches) |
|
plot(signal, sample_rate, speeches) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|