phamngoctukts commited on
Commit
39a832e
·
verified ·
1 Parent(s): 0e97050

Upload 6 files

Browse files
Chatweb.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import speech_recognition as sr
2
+ import ollama
3
+ from gtts import gTTS
4
+ import gradio as gr
5
+ from io import BytesIO
6
+ import numpy as np
7
+ from dataclasses import dataclass, field
8
+ import time
9
+ import traceback
10
+ from pydub import AudioSegment
11
+ import librosa
12
+ from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
13
+
14
+ r = sr.Recognizer()
15
+
16
+ @dataclass
17
+ class AppState:
18
+ stream: np.ndarray | None = None
19
+ sampling_rate: int = 0
20
+ pause_detected: bool = False
21
+ started_talking: bool = False
22
+ stopped: bool = False
23
+ conversation: list = field(default_factory=list)
24
+
25
+ def run_vad(ori_audio, sr):
26
+ _st = time.time()
27
+ try:
28
+ audio = ori_audio
29
+ audio = audio.astype(np.float32) / 32768.0
30
+ sampling_rate = 16000
31
+ if sr != sampling_rate:
32
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
33
+ vad_parameters = {}
34
+ vad_parameters = VadOptions(**vad_parameters)
35
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
36
+ audio = collect_chunks(audio, speech_chunks)
37
+ duration_after_vad = audio.shape[0] / sampling_rate
38
+ if sr != sampling_rate:
39
+ # resample to original sampling rate
40
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
41
+ else:
42
+ vad_audio = audio
43
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
44
+ vad_audio_bytes = vad_audio.tobytes()
45
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
46
+ except Exception as e:
47
+ msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
48
+ print(msg)
49
+ return -1, ori_audio, round(time.time() - _st, 4)
50
+
51
+ def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
52
+ """Take in the stream, determine if a pause happened"""
53
+ temp_audio = audio
54
+ dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
55
+ duration = len(audio) / sampling_rate
56
+ if dur_vad > 0.5 and not state.started_talking:
57
+ print("started talking")
58
+ state.started_talking = True
59
+ return False
60
+ print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
61
+ return (duration - dur_vad) > 1
62
+
63
+ def process_audio(audio:tuple, state:AppState):
64
+ if state.stream is None:
65
+ state.stream = audio[1]
66
+ state.sampling_rate = audio[0]
67
+ else:
68
+ state.stream = np.concatenate((state.stream, audio[1]))
69
+ pause_detected = determine_pause(state.stream, state.sampling_rate, state)
70
+ state.pause_detected = pause_detected
71
+ if state.pause_detected and state.started_talking:
72
+ return gr.Audio(recording=False), state
73
+ return None, state
74
+
75
+ def response(state:AppState):
76
+ if not state.pause_detected and not state.started_talking:
77
+ return None, AppState()
78
+ audio_buffer = BytesIO()
79
+ segment = AudioSegment(
80
+ state.stream.tobytes(),
81
+ frame_rate=state.sampling_rate,
82
+ sample_width=state.stream.dtype.itemsize,
83
+ channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
84
+ )
85
+ segment.export(audio_buffer, format="wav")
86
+ textin = ""
87
+ with sr.AudioFile(audio_buffer) as source:
88
+ audio_data=r.record(source)
89
+ try:
90
+ textin=r.recognize_google(audio_data,language='vi')
91
+ except:
92
+ textin = ""
93
+ state.conversation.append({"role": "user", "content": "Bạn: " + textin})
94
+ if textin != "":
95
+ print("Đang nghĩ...")
96
+ response = ollama.chat(model='llama3.2', messages=[
97
+ {
98
+ 'role': 'user',
99
+ 'content': textin,
100
+ },
101
+ ])
102
+ textout=response['message']['content']
103
+ textout = textout.replace('*','')
104
+ state.conversation.append({"role": "user", "content": "Trợ lý: " + textout})
105
+ if textout != "":
106
+ print("Đang đọc...")
107
+ mp3 = gTTS(textout,tld='com.vn',lang='vi',slow=False)
108
+ mp3_fp = BytesIO()
109
+ mp3.write_to_fp(mp3_fp)
110
+ srr=mp3_fp.getvalue()
111
+ mp3_fp.close()
112
+ #yield srr, state
113
+ yield srr, AppState(conversation=state.conversation)
114
+
115
+ def start_recording_user(state: AppState):
116
+ if not state.stopped:
117
+ return gr.Audio(recording=True)
118
+
119
+ title = "vietnamese by tuphamkts"
120
+ description = "A vietnamese text-to-speech demo."
121
+
122
+ with gr.Blocks() as demo:
123
+ with gr.Row():
124
+ with gr.Column():
125
+ input_audio = gr.Audio(label="Nói cho tôi nghe nào", sources="microphone", type="numpy")
126
+ with gr.Column():
127
+ chatbot = gr.Chatbot(label="Nội dung trò chuyện", type="messages")
128
+ output_audio = gr.Audio(label="Trợ lý", autoplay=True)
129
+ state = gr.State(value=AppState())
130
+
131
+ stream = input_audio.stream(
132
+ process_audio,
133
+ [input_audio, state],
134
+ [input_audio, state],
135
+ stream_every=0.50,
136
+ time_limit=30,
137
+ )
138
+ respond = input_audio.stop_recording(
139
+ response,
140
+ [state],
141
+ [output_audio, state],
142
+ )
143
+ respond.then(lambda s: s.conversation, [state], [chatbot])
144
+
145
+ restart = output_audio.stop(
146
+ start_recording_user,
147
+ [state],
148
+ [input_audio],
149
+ )
150
+ cancel = gr.Button("Stop Conversation", variant="stop")
151
+ cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None,
152
+ [state, input_audio], cancels=[respond, restart])
153
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug='false', share=True)
utils/__pycache__/snac_utils.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
utils/__pycache__/vad.cpython-310.pyc ADDED
Binary file (7.84 kB). View file
 
utils/assets/silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
3
+ size 1807524
utils/snac_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output, device=None):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ def count_elements_between_hashes(lst):
65
+ try:
66
+ # Find the index of the first '#'
67
+ first_index = lst.index("#")
68
+ # Find the index of the second '#' after the first
69
+ second_index = lst.index("#", first_index + 1)
70
+ # Count the elements between the two indices
71
+ return second_index - first_index - 1
72
+ except ValueError:
73
+ # Handle the case where there aren't enough '#' symbols
74
+ return "List does not contain two '#' symbols"
75
+
76
+ def remove_elements_before_hash(flattened_list):
77
+ try:
78
+ # Find the index of the first '#'
79
+ first_hash_index = flattened_list.index("#")
80
+ # Return the list starting from the first '#'
81
+ return flattened_list[first_hash_index:]
82
+ except ValueError:
83
+ # Handle the case where there is no '#'
84
+ return "List does not contain the symbol '#'"
85
+
86
+ def list_to_torch_tensor(tensor1):
87
+ # Convert the list to a torch tensor
88
+ tensor = torch.tensor(tensor1)
89
+ # Reshape the tensor to have size (1, n)
90
+ tensor = tensor.unsqueeze(0)
91
+ return tensor
92
+
93
+ flattened_output = remove_elements_before_hash(flattened_output)
94
+ codes = []
95
+ tensor1 = []
96
+ tensor2 = []
97
+ tensor3 = []
98
+ tensor4 = []
99
+
100
+ n_tensors = count_elements_between_hashes(flattened_output)
101
+ if n_tensors == 7:
102
+ for i in range(0, len(flattened_output), 8):
103
+
104
+ tensor1.append(flattened_output[i + 1])
105
+ tensor2.append(flattened_output[i + 2])
106
+ tensor3.append(flattened_output[i + 3])
107
+ tensor3.append(flattened_output[i + 4])
108
+
109
+ tensor2.append(flattened_output[i + 5])
110
+ tensor3.append(flattened_output[i + 6])
111
+ tensor3.append(flattened_output[i + 7])
112
+ codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
+ ]
117
+
118
+ if n_tensors == 15:
119
+ for i in range(0, len(flattened_output), 16):
120
+
121
+ tensor1.append(flattened_output[i + 1])
122
+ tensor2.append(flattened_output[i + 2])
123
+ tensor3.append(flattened_output[i + 3])
124
+ tensor4.append(flattened_output[i + 4])
125
+ tensor4.append(flattened_output[i + 5])
126
+ tensor3.append(flattened_output[i + 6])
127
+ tensor4.append(flattened_output[i + 7])
128
+ tensor4.append(flattened_output[i + 8])
129
+
130
+ tensor2.append(flattened_output[i + 9])
131
+ tensor3.append(flattened_output[i + 10])
132
+ tensor4.append(flattened_output[i + 11])
133
+ tensor4.append(flattened_output[i + 12])
134
+ tensor3.append(flattened_output[i + 13])
135
+ tensor4.append(flattened_output[i + 14])
136
+ tensor4.append(flattened_output[i + 15])
137
+
138
+ codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
+ ]
144
+
145
+ return codes
146
+
utils/vad.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+
11
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
12
+ class VadOptions(NamedTuple):
13
+ """VAD options.
14
+
15
+ Attributes:
16
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
17
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
18
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
19
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
20
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
21
+ than max_speech_duration_s will be split at the timestamp of the last silence that
22
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
23
+ split aggressively just before max_speech_duration_s.
24
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
25
+ before separating it
26
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
27
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
28
+ Values other than these may affect model performance!!
29
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
30
+ """
31
+
32
+ threshold: float = 0.5
33
+ min_speech_duration_ms: int = 250
34
+ max_speech_duration_s: float = float("inf")
35
+ min_silence_duration_ms: int = 2000
36
+ window_size_samples: int = 1024
37
+ speech_pad_ms: int = 400
38
+
39
+
40
+ def get_speech_timestamps(
41
+ audio: np.ndarray,
42
+ vad_options: Optional[VadOptions] = None,
43
+ **kwargs,
44
+ ) -> List[dict]:
45
+ """This method is used for splitting long audios into speech chunks using silero VAD.
46
+
47
+ Args:
48
+ audio: One dimensional float array.
49
+ vad_options: Options for VAD processing.
50
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
51
+
52
+ Returns:
53
+ List of dicts containing begin and end samples of each speech chunk.
54
+ """
55
+ if vad_options is None:
56
+ vad_options = VadOptions(**kwargs)
57
+
58
+ threshold = vad_options.threshold
59
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
60
+ max_speech_duration_s = vad_options.max_speech_duration_s
61
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
62
+ window_size_samples = vad_options.window_size_samples
63
+ speech_pad_ms = vad_options.speech_pad_ms
64
+
65
+ if window_size_samples not in [512, 1024, 1536]:
66
+ warnings.warn(
67
+ "Unusual window_size_samples! Supported window_size_samples:\n"
68
+ " - [512, 1024, 1536] for 16000 sampling_rate"
69
+ )
70
+
71
+ sampling_rate = 16000
72
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
73
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
74
+ max_speech_samples = (
75
+ sampling_rate * max_speech_duration_s
76
+ - window_size_samples
77
+ - 2 * speech_pad_samples
78
+ )
79
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
80
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
81
+
82
+ audio_length_samples = len(audio)
83
+
84
+ model = get_vad_model()
85
+ state = model.get_initial_state(batch_size=1)
86
+
87
+ speech_probs = []
88
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
89
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
90
+ if len(chunk) < window_size_samples:
91
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
92
+ speech_prob, state = model(chunk, state, sampling_rate)
93
+ speech_probs.append(speech_prob)
94
+
95
+ triggered = False
96
+ speeches = []
97
+ current_speech = {}
98
+ neg_threshold = threshold - 0.15
99
+
100
+ # to save potential segment end (and tolerate some silence)
101
+ temp_end = 0
102
+ # to save potential segment limits in case of maximum segment size reached
103
+ prev_end = next_start = 0
104
+
105
+ for i, speech_prob in enumerate(speech_probs):
106
+ if (speech_prob >= threshold) and temp_end:
107
+ temp_end = 0
108
+ if next_start < prev_end:
109
+ next_start = window_size_samples * i
110
+
111
+ if (speech_prob >= threshold) and not triggered:
112
+ triggered = True
113
+ current_speech["start"] = window_size_samples * i
114
+ continue
115
+
116
+ if (
117
+ triggered
118
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
119
+ ):
120
+ if prev_end:
121
+ current_speech["end"] = prev_end
122
+ speeches.append(current_speech)
123
+ current_speech = {}
124
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
125
+ if next_start < prev_end:
126
+ triggered = False
127
+ else:
128
+ current_speech["start"] = next_start
129
+ prev_end = next_start = temp_end = 0
130
+ else:
131
+ current_speech["end"] = window_size_samples * i
132
+ speeches.append(current_speech)
133
+ current_speech = {}
134
+ prev_end = next_start = temp_end = 0
135
+ triggered = False
136
+ continue
137
+
138
+ if (speech_prob < neg_threshold) and triggered:
139
+ if not temp_end:
140
+ temp_end = window_size_samples * i
141
+ # condition to avoid cutting in very short silence
142
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
143
+ prev_end = temp_end
144
+ if (window_size_samples * i) - temp_end < min_silence_samples:
145
+ continue
146
+ else:
147
+ current_speech["end"] = temp_end
148
+ if (
149
+ current_speech["end"] - current_speech["start"]
150
+ ) > min_speech_samples:
151
+ speeches.append(current_speech)
152
+ current_speech = {}
153
+ prev_end = next_start = temp_end = 0
154
+ triggered = False
155
+ continue
156
+
157
+ if (
158
+ current_speech
159
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
160
+ ):
161
+ current_speech["end"] = audio_length_samples
162
+ speeches.append(current_speech)
163
+
164
+ for i, speech in enumerate(speeches):
165
+ if i == 0:
166
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
167
+ if i != len(speeches) - 1:
168
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
169
+ if silence_duration < 2 * speech_pad_samples:
170
+ speech["end"] += int(silence_duration // 2)
171
+ speeches[i + 1]["start"] = int(
172
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
173
+ )
174
+ else:
175
+ speech["end"] = int(
176
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
177
+ )
178
+ speeches[i + 1]["start"] = int(
179
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
180
+ )
181
+ else:
182
+ speech["end"] = int(
183
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
184
+ )
185
+
186
+ return speeches
187
+
188
+
189
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
190
+ """Collects and concatenates audio chunks."""
191
+ if not chunks:
192
+ return np.array([], dtype=np.float32)
193
+
194
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
195
+
196
+
197
+ class SpeechTimestampsMap:
198
+ """Helper class to restore original speech timestamps."""
199
+
200
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
201
+ self.sampling_rate = sampling_rate
202
+ self.time_precision = time_precision
203
+ self.chunk_end_sample = []
204
+ self.total_silence_before = []
205
+
206
+ previous_end = 0
207
+ silent_samples = 0
208
+
209
+ for chunk in chunks:
210
+ silent_samples += chunk["start"] - previous_end
211
+ previous_end = chunk["end"]
212
+
213
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
214
+ self.total_silence_before.append(silent_samples / sampling_rate)
215
+
216
+ def get_original_time(
217
+ self,
218
+ time: float,
219
+ chunk_index: Optional[int] = None,
220
+ ) -> float:
221
+ if chunk_index is None:
222
+ chunk_index = self.get_chunk_index(time)
223
+
224
+ total_silence_before = self.total_silence_before[chunk_index]
225
+ return round(total_silence_before + time, self.time_precision)
226
+
227
+ def get_chunk_index(self, time: float) -> int:
228
+ sample = int(time * self.sampling_rate)
229
+ return min(
230
+ bisect.bisect(self.chunk_end_sample, sample),
231
+ len(self.chunk_end_sample) - 1,
232
+ )
233
+
234
+
235
+ @functools.lru_cache
236
+ def get_vad_model():
237
+ """Returns the VAD model instance."""
238
+ asset_dir = os.path.join(os.path.dirname(__file__), "assets")
239
+ path = os.path.join(asset_dir, "silero_vad.onnx")
240
+ return SileroVADModel(path)
241
+
242
+
243
+ class SileroVADModel:
244
+ def __init__(self, path):
245
+ try:
246
+ import onnxruntime
247
+ except ImportError as e:
248
+ raise RuntimeError(
249
+ "Applying the VAD filter requires the onnxruntime package"
250
+ ) from e
251
+
252
+ opts = onnxruntime.SessionOptions()
253
+ opts.inter_op_num_threads = 1
254
+ opts.intra_op_num_threads = 1
255
+ opts.log_severity_level = 4
256
+
257
+ self.session = onnxruntime.InferenceSession(
258
+ path,
259
+ providers=["CPUExecutionProvider"],
260
+ sess_options=opts,
261
+ )
262
+
263
+ def get_initial_state(self, batch_size: int):
264
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
265
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
266
+ return h, c
267
+
268
+ def __call__(self, x, state, sr: int):
269
+ if len(x.shape) == 1:
270
+ x = np.expand_dims(x, 0)
271
+ if len(x.shape) > 2:
272
+ raise ValueError(
273
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
274
+ )
275
+ if sr / x.shape[1] > 31.25:
276
+ raise ValueError("Input audio chunk is too short")
277
+
278
+ h, c = state
279
+
280
+ ort_inputs = {
281
+ "input": x,
282
+ "h": h,
283
+ "c": c,
284
+ "sr": np.array(sr, dtype="int64"),
285
+ }
286
+
287
+ out, h, c = self.session.run(None, ort_inputs)
288
+ state = (h, c)
289
+
290
+ return out, state