update
Browse files- data/call_monitor/id-ID/noise/000ad44a-fbad-4a22-ba5a-c6dc855779b2_id-ID_1672040947119.wav +3 -0
- data/call_monitor/id-ID/noise/000da369-6652-4601-b241-33ffbd52a224_id-ID_1676000326981.wav +3 -0
- data/call_monitor/id-ID/noise/00a0a2a3-14ff-4a84-8aee-b18b2fb65355_id-ID_1680237229413.wav +3 -0
- data/call_monitor/id-ID/noise_mute/000d7fba-80ce-4bd7-84fe-e9c43de30f4a_id-ID_1678495379262.wav +3 -0
- data/call_monitor/id-ID/voice/000a3f9a-b2bf-46fd-9c69-477fc62cda51_id-ID_1671935534167 - 副本.wav +3 -0
- data/call_monitor/id-ID/voice/000a3f9a-b2bf-46fd-9c69-477fc62cda51_id-ID_1671935534167.wav +3 -0
- data/call_monitor/id-ID/voice/000cb369-a0ee-44aa-a213-18b036f1baf7_id-ID_1678762306513.wav +3 -0
- data/call_monitor/id-ID/voicemail/000b03b3-172e-4784-8510-24cf37e205ba_id-ID_1672193551438.wav +3 -0
- data/call_monitor/id-ID/voicemail/00a20d31-e1cb-4c70-821b-6fd151b260ae_id-ID_1671762897272.wav +3 -0
- data/early_media/62/33009996287818451333.wav +3 -0
- data/early_media/62/3300999628999191096.wav +3 -0
- main.py +38 -23
- ring_vad_examples.json +24 -4
- toolbox/torch/__init__.py +6 -0
- toolbox/torch/utils/__init__.py +6 -0
- toolbox/torch/utils/data/__init__.py +6 -0
- toolbox/torch/utils/data/vocabulary.py +211 -0
- toolbox/torch/utils/utils.py +26 -0
- toolbox/vad/vad.py +114 -61
- trained_models/cnn_voicemail_common_20231130/cnn_voicemail.pth +3 -0
- trained_models/cnn_voicemail_common_20231130/labels.json +10 -0
data/call_monitor/id-ID/noise/000ad44a-fbad-4a22-ba5a-c6dc855779b2_id-ID_1672040947119.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7deca6895788f2fe7f7d2324dffabc39581ee6edfa4c6619d458790a2ca79b65
|
3 |
+
size 32044
|
data/call_monitor/id-ID/noise/000da369-6652-4601-b241-33ffbd52a224_id-ID_1676000326981.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0678b0d7a759bdefe33725b2f224661beecce0e9fda52998d3535acab9e1c6e8
|
3 |
+
size 32044
|
data/call_monitor/id-ID/noise/00a0a2a3-14ff-4a84-8aee-b18b2fb65355_id-ID_1680237229413.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cae813da4e2628586537cb41f0db6dc18f2021725a0a4827e1f7794dad727381
|
3 |
+
size 32044
|
data/call_monitor/id-ID/noise_mute/000d7fba-80ce-4bd7-84fe-e9c43de30f4a_id-ID_1678495379262.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc09aeec1b37f75c65df3e83065ff54d5c2390b69847d174f65d3cb69f95da52
|
3 |
+
size 32044
|
data/call_monitor/id-ID/voice/000a3f9a-b2bf-46fd-9c69-477fc62cda51_id-ID_1671935534167 - 副本.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf9e6ef0ee87be308c8a59a1459836dc9229c83be37c5e7204586c385d8d7a84
|
3 |
+
size 32044
|
data/call_monitor/id-ID/voice/000a3f9a-b2bf-46fd-9c69-477fc62cda51_id-ID_1671935534167.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf9e6ef0ee87be308c8a59a1459836dc9229c83be37c5e7204586c385d8d7a84
|
3 |
+
size 32044
|
data/call_monitor/id-ID/voice/000cb369-a0ee-44aa-a213-18b036f1baf7_id-ID_1678762306513.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f18b06b287ed16faf1bb231b5758b127562e90633e17f4ca931c48a0373b6b5
|
3 |
+
size 32044
|
data/call_monitor/id-ID/voicemail/000b03b3-172e-4784-8510-24cf37e205ba_id-ID_1672193551438.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50f33d1c4b76ebeb028041d465a6f75965f0c6a584f19c38da4bfb104d0b3e26
|
3 |
+
size 32044
|
data/call_monitor/id-ID/voicemail/00a20d31-e1cb-4c70-821b-6fd151b260ae_id-ID_1671762897272.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64c8496fa1fc98d40b145d5f4a1e07d2b1bf742348549d2518cb52a36130be05
|
3 |
+
size 32044
|
data/early_media/62/33009996287818451333.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a68356976cde2101182663b90c2272be5730b73f341f1ef7aa76f2716dae7637
|
3 |
+
size 155884
|
data/early_media/62/3300999628999191096.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12a24a8927ac75d5bf0549c7ac4f0fe9339b73b4a035b3953a375676761c71c3
|
3 |
+
size 186604
|
main.py
CHANGED
@@ -15,7 +15,7 @@ from PIL import Image
|
|
15 |
|
16 |
from project_settings import project_path, temp_directory
|
17 |
from toolbox.webrtcvad.vad import WebRTCVad
|
18 |
-
from toolbox.vad.vad import Vad, WebRTCVoiceClassifier, SileroVoiceClassifier
|
19 |
|
20 |
|
21 |
def get_args():
|
@@ -35,9 +35,10 @@ vad: Vad = None
|
|
35 |
def click_ring_vad_button(audio: Tuple[int, np.ndarray],
|
36 |
model_name: str,
|
37 |
agg: int = 3,
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
start_ring_rate: float = 0.9,
|
42 |
end_ring_rate: float = 0.1,
|
43 |
):
|
@@ -47,22 +48,24 @@ def click_ring_vad_button(audio: Tuple[int, np.ndarray],
|
|
47 |
return None, "please upload audio."
|
48 |
sample_rate, signal = audio
|
49 |
|
50 |
-
if model_name == "webrtcvad" and
|
51 |
return None, "only 10, 20, 30 available for `frame_duration_ms`."
|
52 |
|
53 |
if model_name == "webrtcvad":
|
54 |
model = WebRTCVoiceClassifier(agg=agg)
|
55 |
elif model_name == "silerovad":
|
56 |
-
model = SileroVoiceClassifier(
|
|
|
|
|
57 |
else:
|
58 |
return None, "`model_name` not valid."
|
59 |
|
60 |
vad = Vad(model=model,
|
61 |
start_ring_rate=start_ring_rate,
|
62 |
end_ring_rate=end_ring_rate,
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
sample_rate=sample_rate,
|
67 |
)
|
68 |
|
@@ -75,12 +78,21 @@ def click_ring_vad_button(audio: Tuple[int, np.ndarray],
|
|
75 |
except Exception as e:
|
76 |
return None, str(e)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
time = np.arange(0, len(signal)) / sample_rate
|
79 |
plt.figure(figsize=(12, 5))
|
80 |
-
plt.plot(time, signal / 32768, color=
|
|
|
|
|
81 |
for start, end in vad_segments:
|
82 |
-
plt.axvline(x=start, ymin=0.
|
83 |
-
plt.axvline(x=end, ymin=0.
|
84 |
|
85 |
temp_image_file = temp_directory / "temp.jpg"
|
86 |
plt.savefig(temp_image_file)
|
@@ -116,19 +128,20 @@ def main():
|
|
116 |
ring_wav = gr.Audio(label="wav")
|
117 |
|
118 |
with gr.Row():
|
119 |
-
ring_model_name = gr.Dropdown(choices=["webrtcvad", "silerovad"], value="webrtcvad", label="model_name")
|
|
|
120 |
|
121 |
with gr.Row():
|
122 |
-
|
123 |
-
|
124 |
|
125 |
with gr.Row():
|
126 |
-
|
127 |
-
|
128 |
|
129 |
with gr.Row():
|
130 |
-
ring_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.
|
131 |
-
ring_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.
|
132 |
|
133 |
ring_button = gr.Button("retrieval", variant="primary")
|
134 |
|
@@ -140,8 +153,9 @@ def main():
|
|
140 |
examples=ring_vad_examples,
|
141 |
inputs=[
|
142 |
ring_wav,
|
143 |
-
ring_model_name, ring_agg,
|
144 |
-
|
|
|
145 |
ring_start_ring_rate, ring_end_ring_rate
|
146 |
],
|
147 |
outputs=[ring_image, ring_end_points],
|
@@ -153,8 +167,9 @@ def main():
|
|
153 |
click_ring_vad_button,
|
154 |
inputs=[
|
155 |
ring_wav,
|
156 |
-
ring_model_name, ring_agg,
|
157 |
-
|
|
|
158 |
ring_start_ring_rate, ring_end_ring_rate
|
159 |
],
|
160 |
outputs=[ring_image, ring_end_points],
|
|
|
15 |
|
16 |
from project_settings import project_path, temp_directory
|
17 |
from toolbox.webrtcvad.vad import WebRTCVad
|
18 |
+
from toolbox.vad.vad import Vad, WebRTCVoiceClassifier, SileroVoiceClassifier, CallVoiceClassifier, process_speech_probs
|
19 |
|
20 |
|
21 |
def get_args():
|
|
|
35 |
def click_ring_vad_button(audio: Tuple[int, np.ndarray],
|
36 |
model_name: str,
|
37 |
agg: int = 3,
|
38 |
+
frame_length_ms: int = 30,
|
39 |
+
frame_step_ms: int = 30,
|
40 |
+
padding_length_ms: int = 300,
|
41 |
+
max_silence_length_ms: int = 300,
|
42 |
start_ring_rate: float = 0.9,
|
43 |
end_ring_rate: float = 0.1,
|
44 |
):
|
|
|
48 |
return None, "please upload audio."
|
49 |
sample_rate, signal = audio
|
50 |
|
51 |
+
if model_name == "webrtcvad" and frame_length_ms not in (10, 20, 30):
|
52 |
return None, "only 10, 20, 30 available for `frame_duration_ms`."
|
53 |
|
54 |
if model_name == "webrtcvad":
|
55 |
model = WebRTCVoiceClassifier(agg=agg)
|
56 |
elif model_name == "silerovad":
|
57 |
+
model = SileroVoiceClassifier(model_path=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix())
|
58 |
+
elif model_name == "call_voice":
|
59 |
+
model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
|
60 |
else:
|
61 |
return None, "`model_name` not valid."
|
62 |
|
63 |
vad = Vad(model=model,
|
64 |
start_ring_rate=start_ring_rate,
|
65 |
end_ring_rate=end_ring_rate,
|
66 |
+
frame_length_ms=frame_length_ms,
|
67 |
+
padding_length_ms=padding_length_ms,
|
68 |
+
max_silence_length_ms=max_silence_length_ms,
|
69 |
sample_rate=sample_rate,
|
70 |
)
|
71 |
|
|
|
78 |
except Exception as e:
|
79 |
return None, str(e)
|
80 |
|
81 |
+
# speech_probs
|
82 |
+
speech_probs = process_speech_probs(
|
83 |
+
signal=signal,
|
84 |
+
speech_probs=vad.speech_probs,
|
85 |
+
frame_step=vad.frame_step,
|
86 |
+
)
|
87 |
+
|
88 |
time = np.arange(0, len(signal)) / sample_rate
|
89 |
plt.figure(figsize=(12, 5))
|
90 |
+
plt.plot(time, signal / 32768, color="b")
|
91 |
+
plt.plot(time, speech_probs * 2, color="gray")
|
92 |
+
|
93 |
for start, end in vad_segments:
|
94 |
+
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点")
|
95 |
+
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点")
|
96 |
|
97 |
temp_image_file = temp_directory / "temp.jpg"
|
98 |
plt.savefig(temp_image_file)
|
|
|
128 |
ring_wav = gr.Audio(label="wav")
|
129 |
|
130 |
with gr.Row():
|
131 |
+
ring_model_name = gr.Dropdown(choices=["webrtcvad", "silerovad", "call_voice"], value="webrtcvad", label="model_name")
|
132 |
+
ring_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg")
|
133 |
|
134 |
with gr.Row():
|
135 |
+
ring_frame_length_ms = gr.Slider(minimum=0, maximum=1000, value=30, label="frame_length_ms")
|
136 |
+
ring_frame_step_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_step_ms")
|
137 |
|
138 |
with gr.Row():
|
139 |
+
ring_padding_length_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_length_ms")
|
140 |
+
ring_max_silence_length_ms = gr.Slider(minimum=0, maximum=1000, value=300, step=0.1, label="max_silence_length_ms")
|
141 |
|
142 |
with gr.Row():
|
143 |
+
ring_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.05, label="start_ring_rate")
|
144 |
+
ring_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="end_ring_rate")
|
145 |
|
146 |
ring_button = gr.Button("retrieval", variant="primary")
|
147 |
|
|
|
153 |
examples=ring_vad_examples,
|
154 |
inputs=[
|
155 |
ring_wav,
|
156 |
+
ring_model_name, ring_agg,
|
157 |
+
ring_frame_length_ms, ring_frame_step_ms,
|
158 |
+
ring_padding_length_ms, ring_max_silence_length_ms,
|
159 |
ring_start_ring_rate, ring_end_ring_rate
|
160 |
],
|
161 |
outputs=[ring_image, ring_end_points],
|
|
|
167 |
click_ring_vad_button,
|
168 |
inputs=[
|
169 |
ring_wav,
|
170 |
+
ring_model_name, ring_agg,
|
171 |
+
ring_frame_length_ms, ring_frame_step_ms,
|
172 |
+
ring_padding_length_ms, ring_max_silence_length_ms,
|
173 |
ring_start_ring_rate, ring_end_ring_rate
|
174 |
],
|
175 |
outputs=[ring_image, ring_end_points],
|
ring_vad_examples.json
CHANGED
@@ -1,18 +1,38 @@
|
|
1 |
[
|
2 |
[
|
3 |
"data/early_media/3300999628164249998.wav",
|
4 |
-
"webrtcvad", 3, 30, 300,
|
5 |
],
|
6 |
[
|
7 |
"data/early_media/3300999628164852605.wav",
|
8 |
-
"webrtcvad", 3, 30, 300,
|
9 |
],
|
10 |
[
|
11 |
"data/early_media/3300999628164249998.wav",
|
12 |
-
"silerovad", 3, 35, 350,
|
13 |
],
|
14 |
[
|
15 |
"data/early_media/3300999628164852605.wav",
|
16 |
-
"silerovad", 3, 35, 350,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
]
|
18 |
]
|
|
|
1 |
[
|
2 |
[
|
3 |
"data/early_media/3300999628164249998.wav",
|
4 |
+
"webrtcvad", 3, 30, 300, 300, 300, 0.9, 0.1
|
5 |
],
|
6 |
[
|
7 |
"data/early_media/3300999628164852605.wav",
|
8 |
+
"webrtcvad", 3, 30, 300, 300, 300, 0.9, 0.1
|
9 |
],
|
10 |
[
|
11 |
"data/early_media/3300999628164249998.wav",
|
12 |
+
"silerovad", 3, 35, 350, 350, 350, 0.7, 0.3
|
13 |
],
|
14 |
[
|
15 |
"data/early_media/3300999628164852605.wav",
|
16 |
+
"silerovad", 3, 35, 350, 350, 350, 0.5, 0.5
|
17 |
+
],
|
18 |
+
[
|
19 |
+
"data/early_media/3300999628164852605.wav",
|
20 |
+
"call_voice", 3, 300, 30, 300, 300, 0.2, 0.1
|
21 |
+
],
|
22 |
+
[
|
23 |
+
"data/early_media/62/3300999628999191096.wav",
|
24 |
+
"call_voice", 3, 300, 30, 300, 300, 0.2, 0.1
|
25 |
+
],
|
26 |
+
[
|
27 |
+
"data/early_media/62/33009996287818451333.wav",
|
28 |
+
"call_voice", 3, 300, 30, 300, 300, 0.2, 0.1
|
29 |
+
],
|
30 |
+
[
|
31 |
+
"data/call_monitor/id-ID/noise_mute/000d7fba-80ce-4bd7-84fe-e9c43de30f4a_id-ID_1678495379262.wav",
|
32 |
+
"silerovad", 3, 35, 350, 350, 350, 0.7, 0.3
|
33 |
+
],
|
34 |
+
[
|
35 |
+
"data/call_monitor/id-ID/noise/00a0a2a3-14ff-4a84-8aee-b18b2fb65355_id-ID_1680237229413.wav",
|
36 |
+
"silerovad", 3, 35, 350, 350, 350, 0.7, 0.3
|
37 |
]
|
38 |
]
|
toolbox/torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/utils/data/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torch/utils/data/vocabulary.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from collections import defaultdict, OrderedDict
|
4 |
+
import os
|
5 |
+
from typing import Any, Callable, Dict, Iterable, List, Set
|
6 |
+
|
7 |
+
|
8 |
+
def namespace_match(pattern: str, namespace: str):
|
9 |
+
"""
|
10 |
+
Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
|
11 |
+
``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
|
12 |
+
``stemmed_tokens``.
|
13 |
+
"""
|
14 |
+
if pattern[0] == '*' and namespace.endswith(pattern[1:]):
|
15 |
+
return True
|
16 |
+
elif pattern == namespace:
|
17 |
+
return True
|
18 |
+
return False
|
19 |
+
|
20 |
+
|
21 |
+
class _NamespaceDependentDefaultDict(defaultdict):
|
22 |
+
def __init__(self,
|
23 |
+
non_padded_namespaces: Set[str],
|
24 |
+
padded_function: Callable[[], Any],
|
25 |
+
non_padded_function: Callable[[], Any]) -> None:
|
26 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
27 |
+
self._padded_function = padded_function
|
28 |
+
self._non_padded_function = non_padded_function
|
29 |
+
super(_NamespaceDependentDefaultDict, self).__init__()
|
30 |
+
|
31 |
+
def __missing__(self, key: str):
|
32 |
+
if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
|
33 |
+
value = self._non_padded_function()
|
34 |
+
else:
|
35 |
+
value = self._padded_function()
|
36 |
+
dict.__setitem__(self, key, value)
|
37 |
+
return value
|
38 |
+
|
39 |
+
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
|
40 |
+
# add non_padded_namespaces which weren't already present
|
41 |
+
self._non_padded_namespaces.update(non_padded_namespaces)
|
42 |
+
|
43 |
+
|
44 |
+
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
|
45 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
46 |
+
super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
|
47 |
+
lambda: {padding_token: 0, oov_token: 1},
|
48 |
+
lambda: {})
|
49 |
+
|
50 |
+
|
51 |
+
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
|
52 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
53 |
+
super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
|
54 |
+
lambda: {0: padding_token, 1: oov_token},
|
55 |
+
lambda: {})
|
56 |
+
|
57 |
+
|
58 |
+
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
|
59 |
+
DEFAULT_PADDING_TOKEN = '[PAD]'
|
60 |
+
DEFAULT_OOV_TOKEN = '[UNK]'
|
61 |
+
NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
|
62 |
+
|
63 |
+
|
64 |
+
class Vocabulary(object):
|
65 |
+
def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
|
66 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
67 |
+
self._padding_token = DEFAULT_PADDING_TOKEN
|
68 |
+
self._oov_token = DEFAULT_OOV_TOKEN
|
69 |
+
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
|
70 |
+
self._padding_token,
|
71 |
+
self._oov_token)
|
72 |
+
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
|
73 |
+
self._padding_token,
|
74 |
+
self._oov_token)
|
75 |
+
|
76 |
+
def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
|
77 |
+
if token not in self._token_to_index[namespace]:
|
78 |
+
index = len(self._token_to_index[namespace])
|
79 |
+
self._token_to_index[namespace][token] = index
|
80 |
+
self._index_to_token[namespace][index] = token
|
81 |
+
return index
|
82 |
+
else:
|
83 |
+
return self._token_to_index[namespace][token]
|
84 |
+
|
85 |
+
def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
|
86 |
+
return self._index_to_token[namespace]
|
87 |
+
|
88 |
+
def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
|
89 |
+
return self._token_to_index[namespace]
|
90 |
+
|
91 |
+
def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
|
92 |
+
if token in self._token_to_index[namespace]:
|
93 |
+
return self._token_to_index[namespace][token]
|
94 |
+
else:
|
95 |
+
return self._token_to_index[namespace][self._oov_token]
|
96 |
+
|
97 |
+
def get_token_from_index(self, index: int, namespace: str = 'tokens'):
|
98 |
+
return self._index_to_token[namespace][index]
|
99 |
+
|
100 |
+
def get_vocab_size(self, namespace: str = 'tokens') -> int:
|
101 |
+
return len(self._token_to_index[namespace])
|
102 |
+
|
103 |
+
def save_to_files(self, directory: str):
|
104 |
+
os.makedirs(directory, exist_ok=True)
|
105 |
+
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
|
106 |
+
for namespace_str in self._non_padded_namespaces:
|
107 |
+
f.write('{}\n'.format(namespace_str))
|
108 |
+
|
109 |
+
for namespace, token_to_index in self._token_to_index.items():
|
110 |
+
filename = os.path.join(directory, '{}.txt'.format(namespace))
|
111 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
112 |
+
for token, _ in token_to_index.items():
|
113 |
+
f.write('{}\n'.format(token))
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_files(cls, directory: str) -> 'Vocabulary':
|
117 |
+
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
|
118 |
+
non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
|
119 |
+
|
120 |
+
vocab = cls(non_padded_namespaces=non_padded_namespaces)
|
121 |
+
|
122 |
+
for namespace_filename in os.listdir(directory):
|
123 |
+
if namespace_filename == NAMESPACE_PADDING_FILE:
|
124 |
+
continue
|
125 |
+
if namespace_filename.startswith("."):
|
126 |
+
continue
|
127 |
+
namespace = namespace_filename.replace('.txt', '')
|
128 |
+
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
|
129 |
+
is_padded = False
|
130 |
+
else:
|
131 |
+
is_padded = True
|
132 |
+
filename = os.path.join(directory, namespace_filename)
|
133 |
+
vocab.set_from_file(filename, is_padded, namespace=namespace)
|
134 |
+
|
135 |
+
return vocab
|
136 |
+
|
137 |
+
def set_from_file(self,
|
138 |
+
filename: str,
|
139 |
+
is_padded: bool = True,
|
140 |
+
oov_token: str = DEFAULT_OOV_TOKEN,
|
141 |
+
namespace: str = "tokens"
|
142 |
+
):
|
143 |
+
if is_padded:
|
144 |
+
self._token_to_index[namespace] = {self._padding_token: 0}
|
145 |
+
self._index_to_token[namespace] = {0: self._padding_token}
|
146 |
+
else:
|
147 |
+
self._token_to_index[namespace] = {}
|
148 |
+
self._index_to_token[namespace] = {}
|
149 |
+
|
150 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
151 |
+
index = 1 if is_padded else 0
|
152 |
+
for row in f:
|
153 |
+
token = str(row).strip()
|
154 |
+
if token == oov_token:
|
155 |
+
token = self._oov_token
|
156 |
+
self._token_to_index[namespace][token] = index
|
157 |
+
self._index_to_token[namespace][index] = token
|
158 |
+
index += 1
|
159 |
+
|
160 |
+
def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
|
161 |
+
result = list()
|
162 |
+
for token in tokens:
|
163 |
+
idx = self._token_to_index[namespace].get(token)
|
164 |
+
if idx is None:
|
165 |
+
idx = self._token_to_index[namespace][self._oov_token]
|
166 |
+
result.append(idx)
|
167 |
+
return result
|
168 |
+
|
169 |
+
def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
|
170 |
+
result = list()
|
171 |
+
for idx in ids:
|
172 |
+
idx = self._index_to_token[namespace][idx]
|
173 |
+
result.append(idx)
|
174 |
+
return result
|
175 |
+
|
176 |
+
def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
|
177 |
+
pad_idx = self._token_to_index[namespace][self._padding_token]
|
178 |
+
|
179 |
+
length = len(ids)
|
180 |
+
if length > max_length:
|
181 |
+
result = ids[:max_length]
|
182 |
+
else:
|
183 |
+
result = ids + [pad_idx] * (max_length - length)
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
def demo1():
|
188 |
+
import jieba
|
189 |
+
|
190 |
+
vocabulary = Vocabulary()
|
191 |
+
vocabulary.add_token_to_namespace('白天', 'tokens')
|
192 |
+
vocabulary.add_token_to_namespace('晚上', 'tokens')
|
193 |
+
|
194 |
+
text = '不是在白天, 就是在晚上'
|
195 |
+
tokens = jieba.lcut(text)
|
196 |
+
|
197 |
+
print(tokens)
|
198 |
+
|
199 |
+
ids = vocabulary.convert_tokens_to_ids(tokens)
|
200 |
+
print(ids)
|
201 |
+
|
202 |
+
padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
|
203 |
+
print(padded_idx)
|
204 |
+
|
205 |
+
tokens = vocabulary.convert_ids_to_tokens(padded_idx)
|
206 |
+
print(tokens)
|
207 |
+
return
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
demo1()
|
toolbox/torch/utils/utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def get_text_field_mask(text_field_tensors: torch.Tensor,
|
9 |
+
num_wrapping_dims: int = 0) -> torch.LongTensor:
|
10 |
+
|
11 |
+
tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()]
|
12 |
+
tensor_dims.sort(key=lambda x: x[0])
|
13 |
+
|
14 |
+
smallest_dim = tensor_dims[0][0] - num_wrapping_dims
|
15 |
+
if smallest_dim == 2:
|
16 |
+
token_tensor = tensor_dims[0][1]
|
17 |
+
return (token_tensor != 0).long()
|
18 |
+
elif smallest_dim == 3:
|
19 |
+
character_tensor = tensor_dims[0][1]
|
20 |
+
return ((character_tensor > 0).long().sum(dim=-1) > 0).long()
|
21 |
+
else:
|
22 |
+
raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim))
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
pass
|
toolbox/vad/vad.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import argparse
|
4 |
import collections
|
|
|
5 |
from typing import List
|
6 |
|
7 |
import matplotlib.pyplot as plt
|
@@ -11,6 +12,7 @@ import torch
|
|
11 |
import webrtcvad
|
12 |
|
13 |
from project_settings import project_path
|
|
|
14 |
|
15 |
|
16 |
class FrameVoiceClassifier(object):
|
@@ -39,12 +41,12 @@ class WebRTCVoiceClassifier(FrameVoiceClassifier):
|
|
39 |
|
40 |
class SileroVoiceClassifier(FrameVoiceClassifier):
|
41 |
def __init__(self,
|
42 |
-
|
43 |
sample_rate: int = 8000):
|
44 |
-
self.
|
45 |
self.sample_rate = sample_rate
|
46 |
|
47 |
-
with open(self.
|
48 |
model = torch.jit.load(f, map_location="cpu")
|
49 |
self.model = model
|
50 |
self.model.reset_states()
|
@@ -61,11 +63,39 @@ class SileroVoiceClassifier(FrameVoiceClassifier):
|
|
61 |
return float(speech_prob)
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
class Frame(object):
|
65 |
-
def __init__(self, signal: np.ndarray,
|
66 |
self.signal = signal
|
67 |
-
self.
|
68 |
-
self.duration = duration
|
69 |
|
70 |
|
71 |
class Vad(object):
|
@@ -73,26 +103,28 @@ class Vad(object):
|
|
73 |
model: FrameVoiceClassifier,
|
74 |
start_ring_rate: float = 0.5,
|
75 |
end_ring_rate: float = 0.5,
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
79 |
sample_rate: int = 8000
|
80 |
):
|
81 |
self.model = model
|
82 |
self.start_ring_rate = start_ring_rate
|
83 |
self.end_ring_rate = end_ring_rate
|
84 |
-
self.
|
85 |
-
self.
|
86 |
-
self.
|
87 |
self.sample_rate = sample_rate
|
88 |
|
89 |
# frames
|
90 |
-
self.frame_length = int(sample_rate * (
|
91 |
-
self.
|
92 |
-
self.
|
|
|
93 |
|
94 |
# segments
|
95 |
-
self.num_padding_frames = int(
|
96 |
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
97 |
self.triggered = False
|
98 |
self.voiced_frames: List[Frame] = list()
|
@@ -100,21 +132,23 @@ class Vad(object):
|
|
100 |
|
101 |
# vad segments
|
102 |
self.is_first_segment = True
|
103 |
-
self.
|
104 |
-
self.
|
|
|
|
|
|
|
105 |
|
106 |
def signal_to_frames(self, signal: np.ndarray):
|
107 |
frames = list()
|
108 |
|
109 |
l = len(signal)
|
110 |
|
111 |
-
|
112 |
|
113 |
-
for offset in range(0, l, self.
|
114 |
sub_signal = signal[offset:offset+self.frame_length]
|
115 |
-
|
116 |
-
|
117 |
-
self.frame_timestamp += duration
|
118 |
|
119 |
frames.append(frame)
|
120 |
return frames
|
@@ -124,7 +158,8 @@ class Vad(object):
|
|
124 |
if self.signal_cache is not None:
|
125 |
signal = np.concatenate([self.signal_cache, signal])
|
126 |
|
127 |
-
rest
|
|
|
128 |
|
129 |
if rest == 0:
|
130 |
self.signal_cache = None
|
@@ -138,6 +173,7 @@ class Vad(object):
|
|
138 |
|
139 |
for frame in frames:
|
140 |
speech_prob = self.model.predict(frame.signal)
|
|
|
141 |
|
142 |
if not self.triggered:
|
143 |
self.ring_buffer.append((frame, speech_prob))
|
@@ -158,8 +194,8 @@ class Vad(object):
|
|
158 |
self.triggered = False
|
159 |
segment = [
|
160 |
np.concatenate([f.signal for f in self.voiced_frames]),
|
161 |
-
self.voiced_frames[0].
|
162 |
-
self.voiced_frames[-1].
|
163 |
]
|
164 |
yield segment
|
165 |
self.ring_buffer.clear()
|
@@ -173,21 +209,21 @@ class Vad(object):
|
|
173 |
end = round(segment[2], 4)
|
174 |
|
175 |
if self.is_first_segment:
|
176 |
-
self.
|
177 |
-
self.
|
178 |
self.is_first_segment = False
|
179 |
continue
|
180 |
|
181 |
-
if self.
|
182 |
-
|
183 |
-
if
|
184 |
-
vad_segment = [self.
|
185 |
yield vad_segment
|
186 |
|
187 |
-
self.
|
188 |
-
self.
|
189 |
else:
|
190 |
-
self.
|
191 |
|
192 |
def vad(self, signal: np.ndarray) -> List[list]:
|
193 |
segments = self.segments_generator(signal)
|
@@ -202,8 +238,8 @@ class Vad(object):
|
|
202 |
else:
|
203 |
segment = [
|
204 |
np.concatenate([f.signal for f in self.voiced_frames]),
|
205 |
-
self.voiced_frames[0].
|
206 |
-
self.voiced_frames[-1].
|
207 |
]
|
208 |
segments = [segment]
|
209 |
|
@@ -211,17 +247,33 @@ class Vad(object):
|
|
211 |
vad_segments = self.vad_segments_generator(segments)
|
212 |
vad_segments = list(vad_segments)
|
213 |
|
214 |
-
|
|
|
215 |
return vad_segments
|
216 |
|
217 |
|
218 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
time = np.arange(0, len(signal)) / sample_rate
|
220 |
plt.figure(figsize=(12, 5))
|
221 |
plt.plot(time, signal / 32768, color='b')
|
|
|
222 |
for start, end in vad_segments:
|
223 |
-
plt.axvline(x=start, ymin=0.
|
224 |
-
plt.axvline(x=end, ymin=0.
|
225 |
|
226 |
plt.show()
|
227 |
return
|
@@ -231,25 +283,14 @@ def get_args():
|
|
231 |
parser = argparse.ArgumentParser()
|
232 |
parser.add_argument(
|
233 |
"--wav_file",
|
234 |
-
default=(project_path / "data/early_media/
|
235 |
type=str,
|
236 |
)
|
237 |
parser.add_argument(
|
238 |
-
"--
|
239 |
default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
|
240 |
type=str,
|
241 |
)
|
242 |
-
parser.add_argument(
|
243 |
-
"--frame_duration_ms",
|
244 |
-
default=30,
|
245 |
-
type=int,
|
246 |
-
)
|
247 |
-
parser.add_argument(
|
248 |
-
"--silence_duration_threshold",
|
249 |
-
default=0.3,
|
250 |
-
type=float,
|
251 |
-
help="minimum silence duration, in seconds."
|
252 |
-
)
|
253 |
args = parser.parse_args()
|
254 |
return args
|
255 |
|
@@ -264,15 +305,17 @@ def main():
|
|
264 |
if SAMPLE_RATE != sample_rate:
|
265 |
raise AssertionError
|
266 |
|
267 |
-
# model = SileroVoiceClassifier(
|
268 |
-
model = WebRTCVoiceClassifier(agg=1, sample_rate=SAMPLE_RATE)
|
|
|
269 |
|
270 |
vad = Vad(model=model,
|
271 |
-
start_ring_rate=0.
|
272 |
end_ring_rate=0.1,
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
276 |
sample_rate=SAMPLE_RATE,
|
277 |
)
|
278 |
print(vad)
|
@@ -290,8 +333,18 @@ def main():
|
|
290 |
for segment in segments:
|
291 |
print(segment)
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
# plot
|
294 |
-
make_visualization(signal, SAMPLE_RATE, vad_segments)
|
295 |
return
|
296 |
|
297 |
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import argparse
|
4 |
import collections
|
5 |
+
import os
|
6 |
from typing import List
|
7 |
|
8 |
import matplotlib.pyplot as plt
|
|
|
12 |
import webrtcvad
|
13 |
|
14 |
from project_settings import project_path
|
15 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
16 |
|
17 |
|
18 |
class FrameVoiceClassifier(object):
|
|
|
41 |
|
42 |
class SileroVoiceClassifier(FrameVoiceClassifier):
|
43 |
def __init__(self,
|
44 |
+
model_path: str,
|
45 |
sample_rate: int = 8000):
|
46 |
+
self.model_path = model_path
|
47 |
self.sample_rate = sample_rate
|
48 |
|
49 |
+
with open(self.model_path, "rb") as f:
|
50 |
model = torch.jit.load(f, map_location="cpu")
|
51 |
self.model = model
|
52 |
self.model.reset_states()
|
|
|
63 |
return float(speech_prob)
|
64 |
|
65 |
|
66 |
+
class CallVoiceClassifier(FrameVoiceClassifier):
|
67 |
+
def __init__(self,
|
68 |
+
model_path: str,
|
69 |
+
sample_rate: int = 8000):
|
70 |
+
self.model_path = model_path
|
71 |
+
self.sample_rate = sample_rate
|
72 |
+
|
73 |
+
self.model = torch.jit.load(os.path.join(model_path, "cnn_voicemail.pth"))
|
74 |
+
|
75 |
+
def predict(self, chunk: np.ndarray) -> float:
|
76 |
+
if chunk.dtype != np.int16:
|
77 |
+
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
|
78 |
+
|
79 |
+
chunk = chunk / 32768
|
80 |
+
|
81 |
+
inputs = torch.tensor(chunk, dtype=torch.float32)
|
82 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
83 |
+
|
84 |
+
try:
|
85 |
+
outputs = self.model(inputs)
|
86 |
+
except RuntimeError as e:
|
87 |
+
print(inputs.shape)
|
88 |
+
raise e
|
89 |
+
|
90 |
+
probs = outputs["probs"]
|
91 |
+
voice_prob = probs[0][2]
|
92 |
+
return float(voice_prob)
|
93 |
+
|
94 |
+
|
95 |
class Frame(object):
|
96 |
+
def __init__(self, signal: np.ndarray, timestamp_s: float):
|
97 |
self.signal = signal
|
98 |
+
self.timestamp_s = timestamp_s
|
|
|
99 |
|
100 |
|
101 |
class Vad(object):
|
|
|
103 |
model: FrameVoiceClassifier,
|
104 |
start_ring_rate: float = 0.5,
|
105 |
end_ring_rate: float = 0.5,
|
106 |
+
frame_length_ms: int = 30,
|
107 |
+
frame_step_ms: int = 30,
|
108 |
+
padding_length_ms: int = 300,
|
109 |
+
max_silence_length_ms: int = 300,
|
110 |
sample_rate: int = 8000
|
111 |
):
|
112 |
self.model = model
|
113 |
self.start_ring_rate = start_ring_rate
|
114 |
self.end_ring_rate = end_ring_rate
|
115 |
+
self.frame_length_ms = frame_length_ms
|
116 |
+
self.padding_length_ms = padding_length_ms
|
117 |
+
self.max_silence_length_ms = max_silence_length_ms
|
118 |
self.sample_rate = sample_rate
|
119 |
|
120 |
# frames
|
121 |
+
self.frame_length = int(sample_rate * (frame_length_ms / 1000.0))
|
122 |
+
self.frame_step = int(sample_rate * (frame_step_ms / 1000.0))
|
123 |
+
self.frame_timestamp_s = 0.0
|
124 |
+
self.signal_cache = np.zeros(shape=(self.frame_length,), dtype=np.int16)
|
125 |
|
126 |
# segments
|
127 |
+
self.num_padding_frames = int(padding_length_ms / frame_step_ms)
|
128 |
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
129 |
self.triggered = False
|
130 |
self.voiced_frames: List[Frame] = list()
|
|
|
132 |
|
133 |
# vad segments
|
134 |
self.is_first_segment = True
|
135 |
+
self.timestamp_start_s = 0.0
|
136 |
+
self.timestamp_end_s = 0.0
|
137 |
+
|
138 |
+
# speech probs
|
139 |
+
self.speech_probs: List[float] = list()
|
140 |
|
141 |
def signal_to_frames(self, signal: np.ndarray):
|
142 |
frames = list()
|
143 |
|
144 |
l = len(signal)
|
145 |
|
146 |
+
duration_s = float(self.frame_step) / self.sample_rate
|
147 |
|
148 |
+
for offset in range(0, l - self.frame_length + 1, self.frame_step):
|
149 |
sub_signal = signal[offset:offset+self.frame_length]
|
150 |
+
frame = Frame(sub_signal, self.frame_timestamp_s)
|
151 |
+
self.frame_timestamp_s += duration_s
|
|
|
152 |
|
153 |
frames.append(frame)
|
154 |
return frames
|
|
|
158 |
if self.signal_cache is not None:
|
159 |
signal = np.concatenate([self.signal_cache, signal])
|
160 |
|
161 |
+
# rest
|
162 |
+
rest = (len(signal) - self.frame_length) % self.frame_step
|
163 |
|
164 |
if rest == 0:
|
165 |
self.signal_cache = None
|
|
|
173 |
|
174 |
for frame in frames:
|
175 |
speech_prob = self.model.predict(frame.signal)
|
176 |
+
self.speech_probs.append(speech_prob)
|
177 |
|
178 |
if not self.triggered:
|
179 |
self.ring_buffer.append((frame, speech_prob))
|
|
|
194 |
self.triggered = False
|
195 |
segment = [
|
196 |
np.concatenate([f.signal for f in self.voiced_frames]),
|
197 |
+
self.voiced_frames[0].timestamp_s,
|
198 |
+
self.voiced_frames[-1].timestamp_s,
|
199 |
]
|
200 |
yield segment
|
201 |
self.ring_buffer.clear()
|
|
|
209 |
end = round(segment[2], 4)
|
210 |
|
211 |
if self.is_first_segment:
|
212 |
+
self.timestamp_start_s = start
|
213 |
+
self.timestamp_end_s = end
|
214 |
self.is_first_segment = False
|
215 |
continue
|
216 |
|
217 |
+
if self.timestamp_start_s:
|
218 |
+
silence_length_s = (start - self.timestamp_end_s) * 1000
|
219 |
+
if silence_length_s > self.max_silence_length_ms:
|
220 |
+
vad_segment = [self.timestamp_start_s, self.timestamp_end_s]
|
221 |
yield vad_segment
|
222 |
|
223 |
+
self.timestamp_start_s = start
|
224 |
+
self.timestamp_end_s = end
|
225 |
else:
|
226 |
+
self.timestamp_end_s = end
|
227 |
|
228 |
def vad(self, signal: np.ndarray) -> List[list]:
|
229 |
segments = self.segments_generator(signal)
|
|
|
238 |
else:
|
239 |
segment = [
|
240 |
np.concatenate([f.signal for f in self.voiced_frames]),
|
241 |
+
self.voiced_frames[0].timestamp_s,
|
242 |
+
self.voiced_frames[-1].timestamp_s
|
243 |
]
|
244 |
segments = [segment]
|
245 |
|
|
|
247 |
vad_segments = self.vad_segments_generator(segments)
|
248 |
vad_segments = list(vad_segments)
|
249 |
|
250 |
+
if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5:
|
251 |
+
vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]]
|
252 |
return vad_segments
|
253 |
|
254 |
|
255 |
+
def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
|
256 |
+
speech_probs_ = list()
|
257 |
+
for p in speech_probs[1:]:
|
258 |
+
speech_probs_.extend([p] * frame_step)
|
259 |
+
|
260 |
+
pad = (signal.shape[0] - len(speech_probs_))
|
261 |
+
speech_probs_ = speech_probs_ + [0.0] * pad
|
262 |
+
speech_probs_ = np.array(speech_probs_, dtype=np.float32)
|
263 |
+
|
264 |
+
if len(speech_probs_) != len(signal):
|
265 |
+
raise AssertionError
|
266 |
+
return speech_probs_
|
267 |
+
|
268 |
+
|
269 |
+
def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list):
|
270 |
time = np.arange(0, len(signal)) / sample_rate
|
271 |
plt.figure(figsize=(12, 5))
|
272 |
plt.plot(time, signal / 32768, color='b')
|
273 |
+
plt.plot(time, speech_probs, color='gray')
|
274 |
for start, end in vad_segments:
|
275 |
+
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点")
|
276 |
+
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点")
|
277 |
|
278 |
plt.show()
|
279 |
return
|
|
|
283 |
parser = argparse.ArgumentParser()
|
284 |
parser.add_argument(
|
285 |
"--wav_file",
|
286 |
+
default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
|
287 |
type=str,
|
288 |
)
|
289 |
parser.add_argument(
|
290 |
+
"--model_path",
|
291 |
default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
|
292 |
type=str,
|
293 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
args = parser.parse_args()
|
295 |
return args
|
296 |
|
|
|
305 |
if SAMPLE_RATE != sample_rate:
|
306 |
raise AssertionError
|
307 |
|
308 |
+
# model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE)
|
309 |
+
# model = WebRTCVoiceClassifier(agg=1, sample_rate=SAMPLE_RATE)
|
310 |
+
model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
|
311 |
|
312 |
vad = Vad(model=model,
|
313 |
+
start_ring_rate=0.2,
|
314 |
end_ring_rate=0.1,
|
315 |
+
frame_length_ms=300,
|
316 |
+
frame_step_ms=30,
|
317 |
+
padding_length_ms=300,
|
318 |
+
max_silence_length_ms=300,
|
319 |
sample_rate=SAMPLE_RATE,
|
320 |
)
|
321 |
print(vad)
|
|
|
333 |
for segment in segments:
|
334 |
print(segment)
|
335 |
|
336 |
+
print(vad.speech_probs)
|
337 |
+
print(len(vad.speech_probs))
|
338 |
+
|
339 |
+
# speech_probs
|
340 |
+
speech_probs = process_speech_probs(
|
341 |
+
signal=signal,
|
342 |
+
speech_probs=vad.speech_probs,
|
343 |
+
frame_step=vad.frame_step,
|
344 |
+
)
|
345 |
+
|
346 |
# plot
|
347 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
|
348 |
return
|
349 |
|
350 |
|
trained_models/cnn_voicemail_common_20231130/cnn_voicemail.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f27b715f5c240b56c60bc80c9325bbe0ee1a80311b2e51a8f6e531985f8d8e61
|
3 |
+
size 155558
|
trained_models/cnn_voicemail_common_20231130/labels.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"white_noise",
|
3 |
+
"voicemail",
|
4 |
+
"voice",
|
5 |
+
"noise",
|
6 |
+
"bell",
|
7 |
+
"mute",
|
8 |
+
"noise_mute",
|
9 |
+
"music"
|
10 |
+
]
|