wav2tsv / vad_utils.py
mizoru's picture
fully fledged
50a5992
raw
history blame
8 kB
# @title Define funcs
import torch
import torchaudio
from typing import Callable, List
import torch.nn.functional as F
import warnings
import pandas as pd
from matplotlib import pyplot as plt
def get_speech_probs(audio: torch.Tensor,
# model,
threshold: float = 0.5,
sampling_rate: int = 16000,
window_size_samples: int = 512,
progress_tracking_callback: Callable[[float], None] = None):
if not torch.is_tensor(audio):
try:
audio = torch.Tensor(audio)
except:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
if len(audio.shape) > 1:
for i in range(len(audio.shape)): # trying to squeeze empty dimensions
audio = audio.squeeze(0)
if len(audio.shape) > 1:
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
step = sampling_rate // 16000
sampling_rate = 16000
audio = audio[::step]
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
else:
step = 1
if sampling_rate == 8000 and window_size_samples > 768:
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
if window_size_samples not in [256, 512, 768, 1024, 1536]:
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
model.reset_states()
audio_length_samples = len(audio)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob = model(chunk, sampling_rate).item()
speech_probs.append(speech_prob)
# caculate progress and seng it to callback function
progress = current_start_sample + window_size_samples
if progress > audio_length_samples:
progress = audio_length_samples
progress_percent = (progress / audio_length_samples) * 100
if progress_tracking_callback:
progress_tracking_callback(progress_percent)
return speech_probs
def probs2speech_timestamps(speech_probs, audio_length_samples,
threshold: float = 0.5,
sampling_rate: int = 16000,
min_speech_duration_ms: int = 250,
max_speech_duration_s: float = float('inf'),
min_silence_duration_ms: int = 100,
window_size_samples: int = 512,
speech_pad_ms: int = 30,
return_seconds: bool = False,
rounding: int = 1,):
step = sampling_rate // 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
temp_end = 0 # to save potential segment end (and tolerate some silence)
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech['start'] = window_size_samples * i
continue
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
if prev_end:
current_speech['end'] = prev_end
speeches.append(current_speech)
current_speech = {}
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
triggered = False
else:
current_speech['start'] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech['end'] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech['end'] = temp_end
if (current_speech['end'] - current_speech['start']) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
current_speech['end'] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i+1]['start'] - speech['end']
if silence_duration < 2 * speech_pad_samples:
speech['end'] += int(silence_duration // 2)
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
else:
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
else:
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
if return_seconds:
for speech_dict in speeches:
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, rounding)
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, rounding)
elif step > 1:
for speech_dict in speeches:
speech_dict['start'] *= step
speech_dict['end'] *= step
return speeches
def make_visualization(probs, step):
fig, ax = plt.subplots(figsize=(16, 8),)
pd.DataFrame({'probs': probs},
index=[x * step for x in range(len(probs))]).plot(ax = ax,
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
xlabel='seconds',
ylabel='speech probability',
colormap='tab20')
return fig
torch.set_num_threads(1)
USE_ONNX = True # change this to True if you want to test onnx model
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
# force_reload=True,
onnx=USE_ONNX)
(_,
_, read_audio,
*_) = utils