File size: 6,704 Bytes
40f83cf
 
 
 
 
 
 
 
7e17176
 
40f83cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e17176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f83cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e17176
 
40f83cf
 
7e17176
40f83cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e17176
40f83cf
 
 
 
 
 
 
 
 
 
 
7e17176
 
40f83cf
 
7e17176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f83cf
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://pytorch.org/hub/snakers4_silero-vad_vad/
https://github.com/snakers4/silero-vad
"""
import argparse

import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import torch

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--wav_file",
        default=(project_path / "data/early_media/3300999628164249998.wav").as_posix(),
        type=str,
    )
    parser.add_argument(
        "--model_name",
        default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
        type=str,
    )
    parser.add_argument("--threshold", default=0.5, type=float)
    parser.add_argument("--min_speech_duration_ms", default=250, type=int)
    parser.add_argument("--speech_pad_ms", default=30, type=int)
    parser.add_argument("--max_speech_duration_s", default=float("inf"), type=float)
    parser.add_argument("--window_size_samples", default=512, type=int)
    parser.add_argument("--min_silence_duration_ms", default=100, type=int)

    args = parser.parse_args()
    return args


def make_visualization(probs, step):
    import pandas as pd
    pd.DataFrame({'probs': probs},
                 index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
                                                                   kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
                                                                   xlabel='seconds',
                                                                   ylabel='speech probability',
                                                                   colormap='tab20')


def plot(signal, sample_rate, speeches):
    time = np.arange(0, len(signal)) / sample_rate

    plt.figure(figsize=(12, 5))

    plt.plot(time, signal / 32768, color="b")

    for speech in speeches:
        start = speech["start"]
        end = speech["end"]
        plt.axvline(x=start, ymin=0.25, ymax=0.75, color="g", linestyle="--")
        plt.axvline(x=end, ymin=0.25, ymax=0.75, color="r", linestyle="--")

    plt.show()
    return


def main():
    args = get_args()

    with open(args.model_name, "rb") as f:
        model = torch.jit.load(f, map_location="cpu")
    model.reset_states()

    sample_rate, signal = wavfile.read(args.wav_file)
    signal = signal / 32768
    signal = torch.tensor(signal, dtype=torch.float32)

    min_speech_samples = sample_rate * args.min_speech_duration_ms / 1000
    speech_pad_samples = sample_rate * args.speech_pad_ms / 1000
    max_speech_samples = sample_rate * args.max_speech_duration_s - args.window_size_samples - 2 * speech_pad_samples
    min_silence_samples = sample_rate * args.min_silence_duration_ms / 1000
    min_silence_samples_at_max_speech = sample_rate * 98 / 1000

    audio_length_samples = len(signal)

    # probs
    speech_probs = []
    for start in range(0, audio_length_samples, args.window_size_samples):
        chunk = signal[start: start + args.window_size_samples]
        if len(chunk) < args.window_size_samples:
            chunk = torch.nn.functional.pad(chunk, (0, int(args.window_size_samples - len(chunk))))

        speech_prob = model(chunk, sample_rate).item()
        speech_probs.append(speech_prob)

    # segments
    triggered = False
    speeches = list()
    current_speech = dict()
    neg_threshold = args.threshold - 0.15
    temp_end = 0
    prev_end = next_start = 0

    for i, speech_prob in enumerate(speech_probs):
        if (speech_prob >= args.threshold) and temp_end:
            temp_end = 0
            if next_start < prev_end:
                next_start = args.window_size_samples * i

        if (speech_prob >= args.threshold) and not triggered:
            triggered = True
            current_speech["start"] = args.window_size_samples * i
            continue

        if triggered and (args.window_size_samples * i) - current_speech["start"] > max_speech_samples:
            if prev_end:
                current_speech["end"] = prev_end
                speeches.append(current_speech)
                current_speech = {}
                if next_start < prev_end:
                    triggered = False
                else:
                    current_speech["start"] = next_start
                prev_end = next_start = temp_end = 0
            else:
                current_speech["end"] = args.window_size_samples * i
                speeches.append(current_speech)
                current_speech = {}
                prev_end = next_start = temp_end = 0
                triggered = False
                continue

        if speech_prob < neg_threshold and triggered:
            if not temp_end:
                temp_end = args.window_size_samples * i
            if ((args.window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech:
                prev_end = temp_end

            if (args.window_size_samples * i) - temp_end < min_silence_samples:
                continue
            else:
                current_speech["end"] = temp_end
                if (current_speech["end"] - current_speech["start"]) > min_speech_samples:
                    speeches.append(current_speech)
                current_speech = {}
                prev_end = next_start = temp_end = 0
                triggered = False
                continue

    if current_speech and (audio_length_samples - current_speech["start"]) > min_speech_samples:
        current_speech["end"] = audio_length_samples
        speeches.append(current_speech)

    for i, speech in enumerate(speeches):
        if i == 0:
            speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
        if i != len(speeches) - 1:
            silence_duration = speeches[i+1]["start"] - speech["end"]
            if silence_duration < 2 * speech_pad_samples:
                speech["end"] += int(silence_duration // 2)
                speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - silence_duration // 2))
            else:
                speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples))
                speeches[i+1]["start"] = int(max(0, speeches[i+1]["start"] - speech_pad_samples))
        else:
            speech["end"] = int(min(audio_length_samples, speech["end"] + speech_pad_samples))

    # in seconds
    for speech_dict in speeches:
        speech_dict["start"] = round(speech_dict["start"] / sample_rate, 1)
        speech_dict["end"] = round(speech_dict["end"] / sample_rate, 1)

    print(speeches)
    plot(signal, sample_rate, speeches)

    return


if __name__ == '__main__':
    main()