qgyd2021 commited on
Commit
40f83cf
·
1 Parent(s): 8343d8d
Files changed (3) hide show
  1. examples/silerovad/vad.py +129 -0
  2. main.py +9 -9
  3. requirements.txt +2 -0
examples/silerovad/vad.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://pytorch.org/hub/snakers4_silero-vad_vad/
5
+ https://github.com/snakers4/silero-vad
6
+ """
7
+ import argparse
8
+
9
+ from scipy.io import wavfile
10
+ import torch
11
+
12
+ from project_settings import project_path
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "--wav_file",
19
+ default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(),
20
+ type=str,
21
+ )
22
+ parser.add_argument(
23
+ "--model_name",
24
+ default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
25
+ type=str,
26
+ )
27
+ parser.add_argument("--threshold", default=0.5, type=float)
28
+ parser.add_argument("--min_speech_duration_ms", default=250, type=int)
29
+ parser.add_argument("--speech_pad_ms", default=30, type=int)
30
+ parser.add_argument("--max_speech_duration_s", default=float("inf"), type=float)
31
+ parser.add_argument("--window_size_samples", default=512, type=int)
32
+ parser.add_argument("--min_silence_duration_ms", default=100, type=int)
33
+
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def main():
39
+ args = get_args()
40
+
41
+ with open(args.model_name, "rb") as f:
42
+ model = torch.jit.load(f, map_location="cpu")
43
+ model.reset_states()
44
+
45
+ sample_rate, signal = wavfile.read(args.wav_file)
46
+ signal = signal / 32768
47
+ signal = torch.tensor(signal, dtype=torch.float32)
48
+ print(signal)
49
+
50
+ min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000
51
+ speech_pad_samples = sample_rate * args.speech_pad_ms / 1000
52
+ max_speech_samples = sample_rate * args.max_speech_duration_s - args.window_size_samples - 2 * speech_pad_samples
53
+ min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000
54
+ min_silence_samples_at_max_speech = sample_rate * 98 / 1000
55
+
56
+ # probs
57
+ speech_probs = []
58
+ for start in range(0, len(signal), args.window_size_samples):
59
+ chunk = signal[start: start + args.window_size_samples]
60
+ if len(chunk) < args.window_size_samples:
61
+ chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk))))
62
+
63
+ speech_prob = model(chunk, sample_rate).item()
64
+ speech_probs.append(speech_prob)
65
+
66
+ print(speech_probs)
67
+
68
+ # segments
69
+ triggered = False
70
+ speeches = list()
71
+ current_speech = dict()
72
+ neg_threshold = args.threshold - 0.15
73
+ temp_end = 0
74
+ prev_end = next_start = 0
75
+
76
+ for i, speech_prob in enumerate(speech_probs):
77
+ if (speech_prob >= args.threshold) and temp_end:
78
+ temp_end = 0
79
+ if next_start < prev_end:
80
+ next_start = args.window_size_samples * i
81
+
82
+ if (speech_prob >= args.threshold) and not triggered:
83
+ triggered = True
84
+ current_speech["start"] = args.window_size_samples * i
85
+ continue
86
+
87
+ if triggered and (args.window_size_samples * i) - current_speech["start"] > max_speech_samples:
88
+ if prev_end:
89
+ current_speech["end"] = prev_end
90
+ speeches.append(current_speech)
91
+ current_speech = {}
92
+ if next_start < prev_end:
93
+ triggered = False
94
+ else:
95
+ current_speech["start"] = next_start
96
+ prev_end = next_start = temp_end = 0
97
+ else:
98
+ current_speech["end"] = args.window_size_samples * i
99
+ speeches.append(current_speech)
100
+ current_speech = {}
101
+ prev_end = next_start = temp_end = 0
102
+ triggered = False
103
+ continue
104
+
105
+ if speech_prob < neg_threshold and triggered:
106
+ if not temp_end:
107
+ temp_end = args.window_size_samples * i
108
+ if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech:
109
+ prev_end = temp_end
110
+ if (args.window_size_samples * i) - temp_end < min_silence_samples:
111
+ continue
112
+ else:
113
+ current_speech["end"] = temp_end
114
+ if (current_speech["end"] - current_speech["start"]) > min_speech_samples:
115
+ speeches.append(current_speech)
116
+ current_speech = {}
117
+ prev_end = next_start = temp_end = 0
118
+ triggered = False
119
+ continue
120
+
121
+ if current_speech and (args.audio_length_samples - current_speech["start"]) > min_speech_samples:
122
+ current_speech["end"] = args.audio_length_samples
123
+ speeches.append(current_speech)
124
+
125
+ return
126
+
127
+
128
+ if __name__ == '__main__':
129
+ main()
main.py CHANGED
@@ -105,15 +105,15 @@ def main():
105
  webrtcvad_image = gr.Image(label="image", height=300, width=720, show_label=False)
106
  webrtcvad_end_points = gr.TextArea(label="end_points", max_lines=35)
107
 
108
- gr.Examples(
109
- examples=webrtcvad_examples,
110
- inputs=[
111
- webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms,
112
- webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold
113
- ],
114
- outputs=[webrtcvad_image, webrtcvad_end_points],
115
- fn=click_webrtcvad_button
116
- )
117
 
118
  # click event
119
  webrtcvad_button.click(
 
105
  webrtcvad_image = gr.Image(label="image", height=300, width=720, show_label=False)
106
  webrtcvad_end_points = gr.TextArea(label="end_points", max_lines=35)
107
 
108
+ # gr.Examples(
109
+ # examples=webrtcvad_examples,
110
+ # inputs=[
111
+ # webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms,
112
+ # webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold
113
+ # ],
114
+ # outputs=[webrtcvad_image, webrtcvad_end_points],
115
+ # fn=click_webrtcvad_button
116
+ # )
117
 
118
  # click event
119
  webrtcvad_button.click(
requirements.txt CHANGED
@@ -4,3 +4,5 @@ wave==0.0.2
4
  matplotlib==3.7.4
5
  scipy==1.10.1
6
  pillow==10.2.0
 
 
 
4
  matplotlib==3.7.4
5
  scipy==1.10.1
6
  pillow==10.2.0
7
+ torch==2.1.2
8
+ torchaudio==2.1.2