qgyd2021 commited on
Commit
7e17176
1 Parent(s): e1eca0e
examples/silerovad/vad.py CHANGED
@@ -6,6 +6,8 @@ https://github.com/snakers4/silero-vad
6
  """
7
  import argparse
8
 
 
 
9
  from scipy.io import wavfile
10
  import torch
11
 
@@ -35,6 +37,33 @@ def get_args():
35
  return args
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def main():
39
  args = get_args()
40
 
@@ -45,7 +74,6 @@ def main():
45
  sample_rate, signal = wavfile.read(args.wav_file)
46
  signal = signal / 32768
47
  signal = torch.tensor(signal, dtype=torch.float32)
48
- print(signal)
49
 
50
  min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000
51
  speech_pad_samples = sample_rate * args.speech_pad_ms / 1000
@@ -53,9 +81,11 @@ def main():
53
  min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000
54
  min_silence_samples_at_max_speech = sample_rate * 98 / 1000
55
 
 
 
56
  # probs
57
  speech_probs = []
58
- for start in range(0, len(signal), args.window_size_samples):
59
  chunk = signal[start: start + args.window_size_samples]
60
  if len(chunk) < args.window_size_samples:
61
  chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk))))
@@ -63,8 +93,6 @@ def main():
63
  speech_prob = model(chunk, sample_rate).item()
64
  speech_probs.append(speech_prob)
65
 
66
- print(speech_probs)
67
-
68
  # segments
69
  triggered = False
70
  speeches = list()
@@ -107,6 +135,7 @@ def main():
107
  temp_end = args.window_size_samples * i
108
  if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech:
109
  prev_end = temp_end
 
110
  if (args.window_size_samples * i) - temp_end < min_silence_samples:
111
  continue
112
  else:
@@ -118,10 +147,32 @@ def main():
118
  triggered = False
119
  continue
120
 
121
- if current_speech and (args.audio_length_samples - current_speech["start"]) > min_speech_samples:
122
- current_speech["end"] = args.audio_length_samples
123
  speeches.append(current_speech)
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return
126
 
127
 
 
6
  """
7
  import argparse
8
 
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
  from scipy.io import wavfile
12
  import torch
13
 
 
37
  return args
38
 
39
 
40
+ def make_visualization(probs, step):
41
+ import pandas as pd
42
+ pd.DataFrame({'probs': probs},
43
+ index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
44
+ kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
45
+ xlabel='seconds',
46
+ ylabel='speech probability',
47
+ colormap='tab20')
48
+
49
+
50
+ def plot(signal, sample_rate, speeches):
51
+ time = np.arange(0, len(signal)) / sample_rate
52
+
53
+ plt.figure(figsize=(12, 5))
54
+
55
+ plt.plot(time, signal / 32768, color="b")
56
+
57
+ for speech in speeches:
58
+ start = speech["start"]
59
+ end = speech["end"]
60
+ plt.axvline(x=start, ymin=0.25, ymax=0.75, color="g", linestyle="--")
61
+ plt.axvline(x=end, ymin=0.25, ymax=0.75, color="r", linestyle="--")
62
+
63
+ plt.show()
64
+ return
65
+
66
+
67
  def main():
68
  args = get_args()
69
 
 
74
  sample_rate, signal = wavfile.read(args.wav_file)
75
  signal = signal / 32768
76
  signal = torch.tensor(signal, dtype=torch.float32)
 
77
 
78
  min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000
79
  speech_pad_samples = sample_rate * args.speech_pad_ms / 1000
 
81
  min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000
82
  min_silence_samples_at_max_speech = sample_rate * 98 / 1000
83
 
84
+ audio_length_samples = len(signal)
85
+
86
  # probs
87
  speech_probs = []
88
+ for start in range(0, audio_length_samples, args.window_size_samples):
89
  chunk = signal[start: start + args.window_size_samples]
90
  if len(chunk) < args.window_size_samples:
91
  chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk))))
 
93
  speech_prob = model(chunk, sample_rate).item()
94
  speech_probs.append(speech_prob)
95
 
 
 
96
  # segments
97
  triggered = False
98
  speeches = list()
 
135
  temp_end = args.window_size_samples * i
136
  if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech:
137
  prev_end = temp_end
138
+
139
  if (args.window_size_samples * i) - temp_end < min_silence_samples:
140
  continue
141
  else:
 
147
  triggered = False
148
  continue
149
 
150
+ if current_speech and (audio_length_samples - current_speech["start"]) > min_speech_samples:
151
+ current_speech["end"] = audio_length_samples
152
  speeches.append(current_speech)
153
 
154
+ for i, speech in enumerate(speeches):
155
+ if i == 0:
156
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
157
+ if i != len(speeches) - 1:
158
+ silence_duration = speeches[i+1]["start"] - speech["end"]
159
+ if silence_duration < 2 * speech_pad_samples:
160
+ speech["end"] += int(silence_duration // 2)
161
+ speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - silence_duration // 2))
162
+ else:
163
+ speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples))
164
+ speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - speech_pad_samples))
165
+ else:
166
+ speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples))
167
+
168
+ # in seconds
169
+ for speech_dict in speeches:
170
+ speech_dict["start"] = round(speech_dict["start"] / sample_rate, 1)
171
+ speech_dict["end"] = round(speech_dict["end"] / sample_rate, 1)
172
+
173
+ print(speeches)
174
+ plot(signal, sample_rate, speeches)
175
+
176
  return
177
 
178
 
main.py CHANGED
@@ -15,44 +15,65 @@ from PIL import Image
15
 
16
  from project_settings import project_path, temp_directory
17
  from toolbox.webrtcvad.vad import WebRTCVad
 
18
 
19
 
20
  def get_args():
21
  parser = argparse.ArgumentParser()
22
  parser.add_argument(
23
- "--webrtcvad_examples_file",
24
- default=(project_path / "webrtcvad_examples.json").as_posix(),
25
  type=str
26
  )
27
  args = parser.parse_args()
28
  return args
29
 
30
 
31
- webrtcvad: WebRTCVad = None
32
 
33
 
34
- def click_webrtcvad_button(audio: Tuple[int, np.ndarray],
35
- agg: int = 3,
36
- frame_duration_ms: int = 30,
37
- padding_duration_ms: int = 300,
38
- silence_duration_threshold: float = 0.3,
39
- ):
40
- global webrtcvad
 
 
 
41
 
 
 
42
  sample_rate, signal = audio
43
 
44
- webrtcvad = WebRTCVad(agg=int(agg),
45
- frame_duration_ms=frame_duration_ms,
46
- padding_duration_ms=padding_duration_ms,
47
- silence_duration_threshold=silence_duration_threshold,
48
- sample_rate=sample_rate,
49
- )
50
-
51
- vad_segments = list()
52
- segments = webrtcvad.vad(signal)
53
- vad_segments += segments
54
- segments = webrtcvad.last_vad_segments()
55
- vad_segments += segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  time = np.arange(0, len(signal)) / sample_rate
58
  plt.figure(figsize=(12, 5))
@@ -77,8 +98,8 @@ def main():
77
  """
78
 
79
  # examples
80
- with open(args.webrtcvad_examples_file, "r", encoding="utf-8") as f:
81
- webrtcvad_examples = json.load(f)
82
 
83
  # ui
84
  with gr.Blocks() as blocks:
@@ -87,50 +108,62 @@ def main():
87
  with gr.Row():
88
  with gr.Column(scale=5):
89
  with gr.Tabs():
90
- with gr.TabItem("webrtcvad"):
91
  gr.Markdown(value="")
92
 
93
  with gr.Row():
94
  with gr.Column(scale=1):
95
- webrtcvad_wav = gr.Audio(label="wav")
 
 
 
 
 
 
 
96
 
97
  with gr.Row():
98
- webrtcvad_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg")
99
- webrtcvad_frame_duration_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_duration_ms")
100
 
101
  with gr.Row():
102
- webrtcvad_padding_duration_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_duration_ms")
103
- webrtcvad_silence_duration_threshold = gr.Slider(minimum=0, maximum=1.0, value=0.3, step=0.1, label="silence_duration_threshold")
104
 
105
- webrtcvad_button = gr.Button("retrieval", variant="primary")
106
 
107
  with gr.Column(scale=1):
108
- webrtcvad_image = gr.Image(label="image", height=300, width=720, show_label=False)
109
- webrtcvad_end_points = gr.TextArea(label="end_points", max_lines=35)
110
 
111
  gr.Examples(
112
- examples=webrtcvad_examples,
113
  inputs=[
114
- webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms,
115
- webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold
 
 
116
  ],
117
- outputs=[webrtcvad_image, webrtcvad_end_points],
118
- fn=click_webrtcvad_button
119
  )
120
 
121
  # click event
122
- webrtcvad_button.click(
123
- click_webrtcvad_button,
124
  inputs=[
125
- webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms,
126
- webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold
 
 
127
  ],
128
- outputs=[webrtcvad_image, webrtcvad_end_points],
129
  )
130
 
131
  blocks.queue().launch(
132
  share=False if platform.system() == "Windows" else False,
133
- server_name="0.0.0.0", server_port=7860
 
134
  )
135
  return
136
 
 
15
 
16
  from project_settings import project_path, temp_directory
17
  from toolbox.webrtcvad.vad import WebRTCVad
18
+ from toolbox.vad.vad import Vad, WebRTCVoiceClassifier, SileroVoiceClassifier
19
 
20
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
23
  parser.add_argument(
24
+ "--ring_vad_examples_file",
25
+ default=(project_path / "ring_vad_examples.json").as_posix(),
26
  type=str
27
  )
28
  args = parser.parse_args()
29
  return args
30
 
31
 
32
+ vad: Vad = None
33
 
34
 
35
+ def click_ring_vad_button(audio: Tuple[int, np.ndarray],
36
+ model_name: str,
37
+ agg: int = 3,
38
+ frame_duration_ms: int = 30,
39
+ padding_duration_ms: int = 300,
40
+ silence_duration_threshold: float = 0.3,
41
+ start_ring_rate: float = 0.9,
42
+ end_ring_rate: float = 0.1,
43
+ ):
44
+ global vad
45
 
46
+ if audio is None:
47
+ return None, "please upload audio."
48
  sample_rate, signal = audio
49
 
50
+ if model_name == "webrtcvad" and frame_duration_ms not in (10, 20, 30):
51
+ return None, "only 10, 20, 30 available for `frame_duration_ms`."
52
+
53
+ if model_name == "webrtcvad":
54
+ model = WebRTCVoiceClassifier(agg=agg)
55
+ elif model_name == "silerovad":
56
+ model = SileroVoiceClassifier(model_name=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix())
57
+ else:
58
+ return None, "`model_name` not valid."
59
+
60
+ vad = Vad(model=model,
61
+ start_ring_rate=start_ring_rate,
62
+ end_ring_rate=end_ring_rate,
63
+ frame_duration_ms=frame_duration_ms,
64
+ padding_duration_ms=padding_duration_ms,
65
+ silence_duration_threshold=silence_duration_threshold,
66
+ sample_rate=sample_rate,
67
+ )
68
+
69
+ try:
70
+ vad_segments = list()
71
+ segments = vad.vad(signal)
72
+ vad_segments += segments
73
+ segments = vad.last_vad_segments()
74
+ vad_segments += segments
75
+ except Exception as e:
76
+ return None, str(e)
77
 
78
  time = np.arange(0, len(signal)) / sample_rate
79
  plt.figure(figsize=(12, 5))
 
98
  """
99
 
100
  # examples
101
+ with open(args.ring_vad_examples_file, "r", encoding="utf-8") as f:
102
+ ring_vad_examples = json.load(f)
103
 
104
  # ui
105
  with gr.Blocks() as blocks:
 
108
  with gr.Row():
109
  with gr.Column(scale=5):
110
  with gr.Tabs():
111
+ with gr.TabItem("ring_vad"):
112
  gr.Markdown(value="")
113
 
114
  with gr.Row():
115
  with gr.Column(scale=1):
116
+ ring_wav = gr.Audio(label="wav")
117
+
118
+ with gr.Row():
119
+ ring_model_name = gr.Dropdown(choices=["webrtcvad", "silerovad"], value="webrtcvad", label="model_name")
120
+
121
+ with gr.Row():
122
+ ring_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg")
123
+ ring_frame_duration_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_duration_ms")
124
 
125
  with gr.Row():
126
+ ring_padding_duration_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_duration_ms")
127
+ ring_silence_duration_threshold = gr.Slider(minimum=0, maximum=1.0, value=0.3, step=0.1, label="silence_duration_threshold")
128
 
129
  with gr.Row():
130
+ ring_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="start_ring_rate")
131
+ ring_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="end_ring_rate")
132
 
133
+ ring_button = gr.Button("retrieval", variant="primary")
134
 
135
  with gr.Column(scale=1):
136
+ ring_image = gr.Image(label="image", height=300, width=720, show_label=False)
137
+ ring_end_points = gr.TextArea(label="end_points", max_lines=35)
138
 
139
  gr.Examples(
140
+ examples=ring_vad_examples,
141
  inputs=[
142
+ ring_wav,
143
+ ring_model_name, ring_agg, ring_frame_duration_ms,
144
+ ring_padding_duration_ms, ring_silence_duration_threshold,
145
+ ring_start_ring_rate, ring_end_ring_rate
146
  ],
147
+ outputs=[ring_image, ring_end_points],
148
+ fn=click_ring_vad_button
149
  )
150
 
151
  # click event
152
+ ring_button.click(
153
+ click_ring_vad_button,
154
  inputs=[
155
+ ring_wav,
156
+ ring_model_name, ring_agg, ring_frame_duration_ms,
157
+ ring_padding_duration_ms, ring_silence_duration_threshold,
158
+ ring_start_ring_rate, ring_end_ring_rate
159
  ],
160
+ outputs=[ring_image, ring_end_points],
161
  )
162
 
163
  blocks.queue().launch(
164
  share=False if platform.system() == "Windows" else False,
165
+ server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
166
+ server_port=7860
167
  )
168
  return
169
 
ring_vad_examples.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ "data/early_media/3300999628164249998.wav",
4
+ "webrtcvad", 3, 30, 300, 0.3, 0.9, 0.1
5
+ ],
6
+ [
7
+ "data/early_media/3300999628164852605.wav",
8
+ "webrtcvad", 3, 30, 300, 0.3, 0.9, 0.1
9
+ ],
10
+ [
11
+ "data/early_media/3300999628164249998.wav",
12
+ "silerovad", 3, 35, 350, 0.35, 0.5, 0.5
13
+ ],
14
+ [
15
+ "data/early_media/3300999628164852605.wav",
16
+ "silerovad", 3, 35, 350, 0.35, 0.5, 0.5
17
+ ]
18
+ ]
toolbox/vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/vad/vad.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import collections
5
+ from typing import List
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from scipy.io import wavfile
10
+ import torch
11
+ import webrtcvad
12
+
13
+ from project_settings import project_path
14
+
15
+
16
+ class FrameVoiceClassifier(object):
17
+ def predict(self, chunk: np.ndarray) -> float:
18
+ raise NotImplementedError
19
+
20
+
21
+ class WebRTCVoiceClassifier(FrameVoiceClassifier):
22
+ def __init__(self,
23
+ agg: int = 3,
24
+ sample_rate: int = 8000
25
+ ):
26
+ self.agg = agg
27
+ self.sample_rate = sample_rate
28
+
29
+ self.model = webrtcvad.Vad(mode=agg)
30
+
31
+ def predict(self, chunk: np.ndarray) -> float:
32
+ if chunk.dtype != np.int16:
33
+ raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
34
+
35
+ audio_bytes = bytes(chunk)
36
+ is_speech = self.model.is_speech(audio_bytes, self.sample_rate)
37
+ return 1.0 if is_speech else 0.0
38
+
39
+
40
+ class SileroVoiceClassifier(FrameVoiceClassifier):
41
+ def __init__(self,
42
+ model_name: str,
43
+ sample_rate: int = 8000):
44
+ self.model_name = model_name
45
+ self.sample_rate = sample_rate
46
+
47
+ with open(self.model_name, "rb") as f:
48
+ model = torch.jit.load(f, map_location="cpu")
49
+ self.model = model
50
+ self.model.reset_states()
51
+
52
+ def predict(self, chunk: np.ndarray) -> float:
53
+ if self.sample_rate / len(chunk) > 31.25:
54
+ raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25))
55
+ if chunk.dtype != np.int16:
56
+ raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
57
+
58
+ chunk = chunk / 32768
59
+ chunk = torch.tensor(chunk, dtype=torch.float32)
60
+ speech_prob = self.model(chunk, self.sample_rate).item()
61
+ return float(speech_prob)
62
+
63
+
64
+ class Frame(object):
65
+ def __init__(self, signal: np.ndarray, timestamp, duration):
66
+ self.signal = signal
67
+ self.timestamp = timestamp
68
+ self.duration = duration
69
+
70
+
71
+ class Vad(object):
72
+ def __init__(self,
73
+ model: FrameVoiceClassifier,
74
+ start_ring_rate: float = 0.5,
75
+ end_ring_rate: float = 0.5,
76
+ frame_duration_ms: int = 30,
77
+ padding_duration_ms: int = 300,
78
+ silence_duration_threshold: float = 0.3,
79
+ sample_rate: int = 8000
80
+ ):
81
+ self.model = model
82
+ self.start_ring_rate = start_ring_rate
83
+ self.end_ring_rate = end_ring_rate
84
+ self.frame_duration_ms = frame_duration_ms
85
+ self.padding_duration_ms = padding_duration_ms
86
+ self.silence_duration_threshold = silence_duration_threshold
87
+ self.sample_rate = sample_rate
88
+
89
+ # frames
90
+ self.frame_length = int(sample_rate * (frame_duration_ms / 1000.0))
91
+ self.frame_timestamp = 0.0
92
+ self.signal_cache = None
93
+
94
+ # segments
95
+ self.num_padding_frames = int(padding_duration_ms / frame_duration_ms)
96
+ self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
97
+ self.triggered = False
98
+ self.voiced_frames: List[Frame] = list()
99
+ self.segments = list()
100
+
101
+ # vad segments
102
+ self.is_first_segment = True
103
+ self.timestamp_start = 0.0
104
+ self.timestamp_end = 0.0
105
+
106
+ def signal_to_frames(self, signal: np.ndarray):
107
+ frames = list()
108
+
109
+ l = len(signal)
110
+
111
+ duration = float(self.frame_length) / self.sample_rate
112
+
113
+ for offset in range(0, l, self.frame_length):
114
+ sub_signal = signal[offset:offset+self.frame_length]
115
+
116
+ frame = Frame(sub_signal, self.frame_timestamp, duration)
117
+ self.frame_timestamp += duration
118
+
119
+ frames.append(frame)
120
+ return frames
121
+
122
+ def segments_generator(self, signal: np.ndarray):
123
+ # signal rounding
124
+ if self.signal_cache is not None:
125
+ signal = np.concatenate([self.signal_cache, signal])
126
+
127
+ rest = len(signal) % self.frame_length
128
+
129
+ if rest == 0:
130
+ self.signal_cache = None
131
+ signal_ = signal
132
+ else:
133
+ self.signal_cache = signal[-rest:]
134
+ signal_ = signal[:-rest]
135
+
136
+ # frames
137
+ frames = self.signal_to_frames(signal_)
138
+
139
+ for frame in frames:
140
+ speech_prob = self.model.predict(frame.signal)
141
+
142
+ if not self.triggered:
143
+ self.ring_buffer.append((frame, speech_prob))
144
+ num_voiced = sum([p for _, p in self.ring_buffer])
145
+
146
+ if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
147
+ self.triggered = True
148
+
149
+ for f, _ in self.ring_buffer:
150
+ self.voiced_frames.append(f)
151
+ self.ring_buffer.clear()
152
+ else:
153
+ self.voiced_frames.append(frame)
154
+ self.ring_buffer.append((frame, speech_prob))
155
+ num_voiced = sum([p for _, p in self.ring_buffer])
156
+
157
+ if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
158
+ self.triggered = False
159
+ segment = [
160
+ np.concatenate([f.signal for f in self.voiced_frames]),
161
+ self.voiced_frames[0].timestamp,
162
+ self.voiced_frames[-1].timestamp,
163
+ ]
164
+ yield segment
165
+ self.ring_buffer.clear()
166
+ self.voiced_frames = []
167
+
168
+ def vad_segments_generator(self, segments_generator):
169
+ segments = list(segments_generator)
170
+
171
+ for i, segment in enumerate(segments):
172
+ start = round(segment[1], 4)
173
+ end = round(segment[2], 4)
174
+
175
+ if self.is_first_segment:
176
+ self.timestamp_start = start
177
+ self.timestamp_end = end
178
+ self.is_first_segment = False
179
+ continue
180
+
181
+ if self.timestamp_start:
182
+ sil_duration = start - self.timestamp_end
183
+ if sil_duration > self.silence_duration_threshold:
184
+ vad_segment = [self.timestamp_start, self.timestamp_end]
185
+ yield vad_segment
186
+
187
+ self.timestamp_start = start
188
+ self.timestamp_end = end
189
+ else:
190
+ self.timestamp_end = end
191
+
192
+ def vad(self, signal: np.ndarray) -> List[list]:
193
+ segments = self.segments_generator(signal)
194
+ vad_segments = self.vad_segments_generator(segments)
195
+ vad_segments = list(vad_segments)
196
+ return vad_segments
197
+
198
+ def last_vad_segments(self) -> List[list]:
199
+ # last segments
200
+ if len(self.voiced_frames) == 0:
201
+ segments = []
202
+ else:
203
+ segment = [
204
+ np.concatenate([f.signal for f in self.voiced_frames]),
205
+ self.voiced_frames[0].timestamp,
206
+ self.voiced_frames[-1].timestamp
207
+ ]
208
+ segments = [segment]
209
+
210
+ # last vad segments
211
+ vad_segments = self.vad_segments_generator(segments)
212
+ vad_segments = list(vad_segments)
213
+
214
+ vad_segments = vad_segments + [[self.timestamp_start, self.timestamp_end]]
215
+ return vad_segments
216
+
217
+
218
+ def make_visualization(signal: np.ndarray, sample_rate: int, vad_segments: list):
219
+ time = np.arange(0, len(signal)) / sample_rate
220
+ plt.figure(figsize=(12, 5))
221
+ plt.plot(time, signal / 32768, color='b')
222
+ for start, end in vad_segments:
223
+ plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
224
+ plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
225
+
226
+ plt.show()
227
+ return
228
+
229
+
230
+ def get_args():
231
+ parser = argparse.ArgumentParser()
232
+ parser.add_argument(
233
+ "--wav_file",
234
+ default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(),
235
+ type=str,
236
+ )
237
+ parser.add_argument(
238
+ "--model_name",
239
+ default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
240
+ type=str,
241
+ )
242
+ parser.add_argument(
243
+ "--frame_duration_ms",
244
+ default=30,
245
+ type=int,
246
+ )
247
+ parser.add_argument(
248
+ "--silence_duration_threshold",
249
+ default=0.3,
250
+ type=float,
251
+ help="minimum silence duration, in seconds."
252
+ )
253
+ args = parser.parse_args()
254
+ return args
255
+
256
+
257
+ SAMPLE_RATE = 8000
258
+
259
+
260
+ def main():
261
+ args = get_args()
262
+
263
+ sample_rate, signal = wavfile.read(args.wav_file)
264
+ if SAMPLE_RATE != sample_rate:
265
+ raise AssertionError
266
+
267
+ # model = SileroVoiceClassifier(model_name=args.model_name, sample_rate=SAMPLE_RATE)
268
+ model = WebRTCVoiceClassifier(agg=1, sample_rate=SAMPLE_RATE)
269
+
270
+ vad = Vad(model=model,
271
+ start_ring_rate=0.9,
272
+ end_ring_rate=0.1,
273
+ frame_duration_ms=30,
274
+ padding_duration_ms=300,
275
+ silence_duration_threshold=0.30,
276
+ sample_rate=SAMPLE_RATE,
277
+ )
278
+ print(vad)
279
+
280
+ vad_segments = list()
281
+
282
+ segments = vad.vad(signal)
283
+ vad_segments += segments
284
+ for segment in segments:
285
+ print(segment)
286
+
287
+ # last vad segment
288
+ segments = vad.last_vad_segments()
289
+ vad_segments += segments
290
+ for segment in segments:
291
+ print(segment)
292
+
293
+ # plot
294
+ make_visualization(signal, SAMPLE_RATE, vad_segments)
295
+ return
296
+
297
+
298
+ if __name__ == '__main__':
299
+ main()
toolbox/webrtcvad/vad.py CHANGED
@@ -168,7 +168,7 @@ def get_args():
168
  parser = argparse.ArgumentParser()
169
  parser.add_argument(
170
  "--wav_file",
171
- default=(project_path / "data/3300999628164249998.wav").as_posix(),
172
  type=str,
173
  )
174
  parser.add_argument(
 
168
  parser = argparse.ArgumentParser()
169
  parser.add_argument(
170
  "--wav_file",
171
+ default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(),
172
  type=str,
173
  )
174
  parser.add_argument(
webrtcvad_examples.json DELETED
@@ -1,8 +0,0 @@
1
- [
2
- [
3
- "data/early_media/3300999628164249998.wav"
4
- ],
5
- [
6
- "data/early_media/3300999628164852605.wav"
7
- ]
8
- ]