Spaces:
Paused
Paused
makaveli10
commited on
Commit
β’
7cc82ad
1
Parent(s):
201054b
add multiprocess communication
Browse files- services/llm_service.py β llm_service.py +67 -58
- main.py +35 -0
- whisper_live/trt_server.py +68 -18
- whisper_live/trt_transcriber.py +2 -2
- whisper_live/vad.py +5 -1
services/llm_service.py β llm_service.py
RENAMED
@@ -103,6 +103,7 @@ class MistralTensorRTLLM:
|
|
103 |
rank=self.runtime_rank,
|
104 |
debug_mode=False,
|
105 |
lora_ckpt_source='hf')
|
|
|
106 |
|
107 |
def parse_input(
|
108 |
self,
|
@@ -152,75 +153,83 @@ class MistralTensorRTLLM:
|
|
152 |
output.append(output_text)
|
153 |
return output
|
154 |
|
155 |
-
def
|
156 |
-
self,
|
157 |
-
|
158 |
-
|
|
|
|
|
159 |
max_attention_window_size=4096,
|
160 |
num_beams=1,
|
161 |
streaming=True,
|
162 |
streaming_interval=4,
|
|
|
163 |
):
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
input_text=input_text,
|
168 |
-
add_special_tokens=True,
|
169 |
-
max_input_length=923,
|
170 |
-
pad_id=None,
|
171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
output = self.decode_tokens(
|
204 |
output_ids,
|
205 |
input_lengths,
|
206 |
-
sequence_lengths
|
207 |
)
|
208 |
-
|
209 |
-
print(input_text[0] + " " + output[0])
|
210 |
-
else:
|
211 |
-
output_ids = outputs['output_ids']
|
212 |
-
sequence_lengths = outputs['sequence_lengths']
|
213 |
-
context_logits = None
|
214 |
-
generation_logits = None
|
215 |
-
if runner.gather_all_token_logits:
|
216 |
-
context_logits = outputs['context_logits']
|
217 |
-
generation_logits = outputs['generation_logits']
|
218 |
-
output = self.decode_tokens(
|
219 |
-
output_ids,
|
220 |
-
input_lengths,
|
221 |
-
sequence_lengths,
|
222 |
-
)
|
223 |
-
return output
|
224 |
|
225 |
|
226 |
if __name__=="__main__":
|
|
|
103 |
rank=self.runtime_rank,
|
104 |
debug_mode=False,
|
105 |
lora_ckpt_source='hf')
|
106 |
+
self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
|
107 |
|
108 |
def parse_input(
|
109 |
self,
|
|
|
153 |
output.append(output_text)
|
154 |
return output
|
155 |
|
156 |
+
def run(
|
157 |
+
self,
|
158 |
+
transcription_queue=None,
|
159 |
+
llm_queue=None,
|
160 |
+
input_text=None,
|
161 |
+
max_output_len=20,
|
162 |
max_attention_window_size=4096,
|
163 |
num_beams=1,
|
164 |
streaming=True,
|
165 |
streaming_interval=4,
|
166 |
+
debug=False,
|
167 |
):
|
168 |
+
self.initialize_model(
|
169 |
+
"/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
|
170 |
+
"teknium/OpenHermes-2.5-Mistral-7B",
|
|
|
|
|
|
|
|
|
171 |
)
|
172 |
+
print("Loaded LLM...")
|
173 |
+
while True:
|
174 |
+
|
175 |
+
# while transcription
|
176 |
+
transcription_output = transcription_queue.get()
|
177 |
+
input_text=transcription_output['prompt'].strip()
|
178 |
+
|
179 |
+
print("Whisper: ", input_text)
|
180 |
+
batch_input_ids = self.parse_input(
|
181 |
+
input_text=input_text,
|
182 |
+
add_special_tokens=True,
|
183 |
+
max_input_length=923,
|
184 |
+
pad_id=None,
|
185 |
+
)
|
186 |
|
187 |
+
input_lengths = [x.size(1) for x in batch_input_ids]
|
188 |
+
with torch.no_grad():
|
189 |
+
outputs = self.runner.generate(
|
190 |
+
batch_input_ids,
|
191 |
+
max_new_tokens=max_output_len,
|
192 |
+
max_attention_window_size=max_attention_window_size,
|
193 |
+
end_id=self.end_id,
|
194 |
+
pad_id=self.pad_id,
|
195 |
+
temperature=1.0,
|
196 |
+
top_k=1,
|
197 |
+
top_p=0.0,
|
198 |
+
num_beams=num_beams,
|
199 |
+
length_penalty=1.0,
|
200 |
+
repetition_penalty=1.0,
|
201 |
+
stop_words_list=None,
|
202 |
+
bad_words_list=None,
|
203 |
+
lora_uids=None,
|
204 |
+
prompt_table_path=None,
|
205 |
+
prompt_tasks=None,
|
206 |
+
streaming=streaming,
|
207 |
+
output_sequence_lengths=True,
|
208 |
+
return_dict=True)
|
209 |
+
torch.cuda.synchronize()
|
210 |
+
if streaming:
|
211 |
+
for curr_outputs in throttle_generator(outputs, streaming_interval):
|
212 |
+
output_ids = curr_outputs['output_ids']
|
213 |
+
sequence_lengths = curr_outputs['sequence_lengths']
|
214 |
+
output = self.decode_tokens(
|
215 |
+
output_ids,
|
216 |
+
input_lengths,
|
217 |
+
sequence_lengths
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
output_ids = outputs['output_ids']
|
221 |
+
sequence_lengths = outputs['sequence_lengths']
|
222 |
+
context_logits = None
|
223 |
+
generation_logits = None
|
224 |
+
if runner.gather_all_token_logits:
|
225 |
+
context_logits = outputs['context_logits']
|
226 |
+
generation_logits = outputs['generation_logits']
|
227 |
output = self.decode_tokens(
|
228 |
output_ids,
|
229 |
input_lengths,
|
230 |
+
sequence_lengths,
|
231 |
)
|
232 |
+
llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
|
235 |
if __name__=="__main__":
|
main.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from whisper_live.trt_server import TranscriptionServer
|
2 |
+
from llm_service import MistralTensorRTLLM
|
3 |
+
import multiprocessing
|
4 |
+
import threading
|
5 |
+
import ssl
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import functools
|
9 |
+
|
10 |
+
from multiprocessing import Process, Manager, Value, Queue
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
# multiprocessing.set_start_method('spawn')
|
15 |
+
|
16 |
+
lock = multiprocessing.Lock()
|
17 |
+
|
18 |
+
manager = Manager()
|
19 |
+
shared_output = manager.list()
|
20 |
+
|
21 |
+
transcription_queue = Queue()
|
22 |
+
llm_queue = Queue()
|
23 |
+
|
24 |
+
|
25 |
+
whisper_server = TranscriptionServer()
|
26 |
+
whisper_process = multiprocessing.Process(target=whisper_server.run, args=("0.0.0.0", 6006, transcription_queue, llm_queue))
|
27 |
+
whisper_process.start()
|
28 |
+
|
29 |
+
# llm_provider = MistralTensorRTLLM()
|
30 |
+
# # llm_provider = MistralTensorRTLLMProvider()
|
31 |
+
# llm_process = multiprocessing.Process(target=llm_provider.run, args=(transcription_queue, llm_queue))
|
32 |
+
# llm_process.start()
|
33 |
+
|
34 |
+
# llm_process.join()
|
35 |
+
whisper_process.join()
|
whisper_live/trt_server.py
CHANGED
@@ -12,6 +12,8 @@ from websockets.sync.server import serve
|
|
12 |
import torch
|
13 |
import numpy as np
|
14 |
import time
|
|
|
|
|
15 |
from whisper_live.vad import VoiceActivityDetection
|
16 |
from whisper_live.trt_transcriber import WhisperTRTLLM
|
17 |
|
@@ -47,13 +49,13 @@ class TranscriptionServer:
|
|
47 |
|
48 |
def __init__(self):
|
49 |
# voice activity detection model
|
50 |
-
|
51 |
-
self.vad_threshold = 0.5
|
52 |
self.clients = {}
|
53 |
self.websockets = {}
|
54 |
self.clients_start_time = {}
|
55 |
self.max_clients = 4
|
56 |
self.max_connection_time = 600
|
|
|
57 |
|
58 |
def get_wait_time(self):
|
59 |
"""
|
@@ -72,7 +74,7 @@ class TranscriptionServer:
|
|
72 |
|
73 |
return wait_time / 60
|
74 |
|
75 |
-
def recv_audio(self, websocket):
|
76 |
"""
|
77 |
Receive audio chunks from a client in an infinite loop.
|
78 |
|
@@ -93,6 +95,9 @@ class TranscriptionServer:
|
|
93 |
Raises:
|
94 |
Exception: If there is an error during the audio frame processing.
|
95 |
"""
|
|
|
|
|
|
|
96 |
logging.info("New client connected")
|
97 |
options = websocket.recv()
|
98 |
options = json.loads(options)
|
@@ -115,22 +120,26 @@ class TranscriptionServer:
|
|
115 |
multilingual=options["multilingual"],
|
116 |
language=options["language"],
|
117 |
task=options["task"],
|
118 |
-
client_uid=options["uid"]
|
|
|
|
|
119 |
)
|
120 |
|
121 |
self.clients[websocket] = client
|
122 |
self.clients_start_time[websocket] = time.time()
|
123 |
no_voice_activity_chunks = 0
|
|
|
124 |
while True:
|
125 |
try:
|
126 |
frame_data = websocket.recv()
|
127 |
frame_np = np.frombuffer(frame_data, dtype=np.float32)
|
|
|
128 |
# VAD
|
129 |
try:
|
130 |
speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
|
131 |
if speech_prob < self.vad_threshold:
|
132 |
no_voice_activity_chunks += 1
|
133 |
-
print("No speech", no_voice_activity_chunks)
|
134 |
if no_voice_activity_chunks > 2:
|
135 |
if not self.clients[websocket].eos:
|
136 |
self.clients[websocket].set_eos(True)
|
@@ -164,7 +173,7 @@ class TranscriptionServer:
|
|
164 |
del websocket
|
165 |
break
|
166 |
|
167 |
-
def run(self, host, port=9090):
|
168 |
"""
|
169 |
Run the transcription server.
|
170 |
|
@@ -172,7 +181,15 @@ class TranscriptionServer:
|
|
172 |
host (str): The host address to bind the server.
|
173 |
port (int): The port number to bind the server.
|
174 |
"""
|
175 |
-
with serve(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
server.serve_forever()
|
177 |
|
178 |
|
@@ -209,7 +226,17 @@ class ServeClient:
|
|
209 |
SERVER_READY = "SERVER_READY"
|
210 |
DISCONNECT = "DISCONNECT"
|
211 |
|
212 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
"""
|
214 |
Initialize a ServeClient instance.
|
215 |
The Whisper model is initialized based on the client's language and device availability.
|
@@ -226,6 +253,8 @@ class ServeClient:
|
|
226 |
|
227 |
"""
|
228 |
self.client_uid = client_uid
|
|
|
|
|
229 |
self.data = b""
|
230 |
self.frames = b""
|
231 |
self.language = language if multilingual else "en"
|
@@ -246,6 +275,7 @@ class ServeClient:
|
|
246 |
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
247 |
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
248 |
self.transcript = []
|
|
|
249 |
self.send_last_n_segments = 10
|
250 |
|
251 |
# text formatting
|
@@ -318,6 +348,14 @@ class ServeClient:
|
|
318 |
|
319 |
"""
|
320 |
while True:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
if self.exit:
|
322 |
logging.info("Exiting speech to text thread")
|
323 |
break
|
@@ -334,28 +372,39 @@ class ServeClient:
|
|
334 |
samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
|
335 |
input_bytes = self.frames_np[int(samples_take):].copy()
|
336 |
duration = input_bytes.shape[0] / self.RATE
|
337 |
-
if duration<1.0
|
338 |
continue
|
339 |
|
340 |
try:
|
341 |
input_sample = input_bytes.copy()
|
342 |
-
save_wav(input_sample)
|
343 |
# whisper transcribe with prompt
|
344 |
mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
352 |
try:
|
|
|
353 |
self.websocket.send(
|
354 |
json.dumps({
|
355 |
"uid": self.client_uid,
|
356 |
-
"segments": segments
|
|
|
357 |
})
|
358 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
except Exception as e:
|
360 |
logging.error(f"[ERROR]: {e}")
|
361 |
|
@@ -364,6 +413,7 @@ class ServeClient:
|
|
364 |
time.sleep(0.01)
|
365 |
|
366 |
def append_segment(self, result):
|
|
|
367 |
if not len(self.transcript):
|
368 |
self.transcript.append({"text": result + " "})
|
369 |
else:
|
|
|
12 |
import torch
|
13 |
import numpy as np
|
14 |
import time
|
15 |
+
import queue
|
16 |
+
|
17 |
from whisper_live.vad import VoiceActivityDetection
|
18 |
from whisper_live.trt_transcriber import WhisperTRTLLM
|
19 |
|
|
|
49 |
|
50 |
def __init__(self):
|
51 |
# voice activity detection model
|
52 |
+
|
|
|
53 |
self.clients = {}
|
54 |
self.websockets = {}
|
55 |
self.clients_start_time = {}
|
56 |
self.max_clients = 4
|
57 |
self.max_connection_time = 600
|
58 |
+
print("done loading")
|
59 |
|
60 |
def get_wait_time(self):
|
61 |
"""
|
|
|
74 |
|
75 |
return wait_time / 60
|
76 |
|
77 |
+
def recv_audio(self, websocket, transcription_queue=None, llm_queue=None):
|
78 |
"""
|
79 |
Receive audio chunks from a client in an infinite loop.
|
80 |
|
|
|
95 |
Raises:
|
96 |
Exception: If there is an error during the audio frame processing.
|
97 |
"""
|
98 |
+
self.vad_model = VoiceActivityDetection()
|
99 |
+
self.vad_threshold = 0.5
|
100 |
+
|
101 |
logging.info("New client connected")
|
102 |
options = websocket.recv()
|
103 |
options = json.loads(options)
|
|
|
120 |
multilingual=options["multilingual"],
|
121 |
language=options["language"],
|
122 |
task=options["task"],
|
123 |
+
client_uid=options["uid"],
|
124 |
+
transcription_queue=transcription_queue,
|
125 |
+
llm_queue=llm_queue
|
126 |
)
|
127 |
|
128 |
self.clients[websocket] = client
|
129 |
self.clients_start_time[websocket] = time.time()
|
130 |
no_voice_activity_chunks = 0
|
131 |
+
print()
|
132 |
while True:
|
133 |
try:
|
134 |
frame_data = websocket.recv()
|
135 |
frame_np = np.frombuffer(frame_data, dtype=np.float32)
|
136 |
+
print(frame_np.shape)
|
137 |
# VAD
|
138 |
try:
|
139 |
speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
|
140 |
if speech_prob < self.vad_threshold:
|
141 |
no_voice_activity_chunks += 1
|
142 |
+
print("No speech", no_voice_activity_chunks, self.clients[websocket].eos)
|
143 |
if no_voice_activity_chunks > 2:
|
144 |
if not self.clients[websocket].eos:
|
145 |
self.clients[websocket].set_eos(True)
|
|
|
173 |
del websocket
|
174 |
break
|
175 |
|
176 |
+
def run(self, host, port=9090, transcription_queue=None, llm_queue=None):
|
177 |
"""
|
178 |
Run the transcription server.
|
179 |
|
|
|
181 |
host (str): The host address to bind the server.
|
182 |
port (int): The port number to bind the server.
|
183 |
"""
|
184 |
+
with serve(
|
185 |
+
functools.partial(
|
186 |
+
self.recv_audio,
|
187 |
+
transcription_queue=transcription_queue,
|
188 |
+
llm_queue=llm_queue,
|
189 |
+
),
|
190 |
+
host,
|
191 |
+
port
|
192 |
+
) as server:
|
193 |
server.serve_forever()
|
194 |
|
195 |
|
|
|
226 |
SERVER_READY = "SERVER_READY"
|
227 |
DISCONNECT = "DISCONNECT"
|
228 |
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
websocket,
|
232 |
+
task="transcribe",
|
233 |
+
device=None,
|
234 |
+
multilingual=False,
|
235 |
+
language=None,
|
236 |
+
client_uid=None,
|
237 |
+
transcription_queue=None,
|
238 |
+
llm_queue=None,
|
239 |
+
):
|
240 |
"""
|
241 |
Initialize a ServeClient instance.
|
242 |
The Whisper model is initialized based on the client's language and device availability.
|
|
|
253 |
|
254 |
"""
|
255 |
self.client_uid = client_uid
|
256 |
+
self.transcription_queue = transcription_queue
|
257 |
+
self.llm_queue = llm_queue
|
258 |
self.data = b""
|
259 |
self.frames = b""
|
260 |
self.language = language if multilingual else "en"
|
|
|
275 |
self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
|
276 |
self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
|
277 |
self.transcript = []
|
278 |
+
self.prompt = None
|
279 |
self.send_last_n_segments = 10
|
280 |
|
281 |
# text formatting
|
|
|
348 |
|
349 |
"""
|
350 |
while True:
|
351 |
+
try:
|
352 |
+
if self.llm_queue is not None:
|
353 |
+
llm_output = self.llm_queue.get_nowait()
|
354 |
+
if llm_output:
|
355 |
+
self.websocket.send(json.dumps(llm_output))
|
356 |
+
except queue.Empty:
|
357 |
+
pass
|
358 |
+
|
359 |
if self.exit:
|
360 |
logging.info("Exiting speech to text thread")
|
361 |
break
|
|
|
372 |
samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
|
373 |
input_bytes = self.frames_np[int(samples_take):].copy()
|
374 |
duration = input_bytes.shape[0] / self.RATE
|
375 |
+
if duration<1.0:
|
376 |
continue
|
377 |
|
378 |
try:
|
379 |
input_sample = input_bytes.copy()
|
380 |
+
# save_wav(input_sample)
|
381 |
# whisper transcribe with prompt
|
382 |
mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
|
383 |
+
last_segment = self.transcriber.transcribe(mel)
|
384 |
+
|
385 |
+
if len(last_segment):
|
386 |
+
if len(self.transcript) < self.send_last_n_segments:
|
387 |
+
segments = self.transcript
|
388 |
+
else:
|
389 |
+
segments = self.transcript[-self.send_last_n_segments:]
|
390 |
+
segments.append({"text": last_segment})
|
391 |
try:
|
392 |
+
print(f"Sending... {segments}")
|
393 |
self.websocket.send(
|
394 |
json.dumps({
|
395 |
"uid": self.client_uid,
|
396 |
+
"segments": segments,
|
397 |
+
"eos": self.eos
|
398 |
})
|
399 |
)
|
400 |
+
if self.eos:
|
401 |
+
self.append_segment(last_segment)
|
402 |
+
self.timestamp_offset += duration
|
403 |
+
self.prompt = ' '.join(segment['text'] for segment in segments)
|
404 |
+
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
405 |
+
self.transcript = []
|
406 |
+
self.set_eos(False)
|
407 |
+
|
408 |
except Exception as e:
|
409 |
logging.error(f"[ERROR]: {e}")
|
410 |
|
|
|
413 |
time.sleep(0.01)
|
414 |
|
415 |
def append_segment(self, result):
|
416 |
+
print("adding to trasncript: ", result)
|
417 |
if not len(self.transcript):
|
418 |
self.transcript.append({"text": result + " "})
|
419 |
else:
|
whisper_live/trt_transcriber.py
CHANGED
@@ -339,9 +339,9 @@ def decode_wav_file(
|
|
339 |
|
340 |
if __name__=="__main__":
|
341 |
tensorrt_llm.logger.set_level("error")
|
342 |
-
model = WhisperTRTLLM("
|
343 |
mel, total_duration = model.log_mel_spectrogram(
|
344 |
-
"/
|
345 |
)
|
346 |
results = model.transcribe(mel)
|
347 |
print(results, total_duration)
|
|
|
339 |
|
340 |
if __name__=="__main__":
|
341 |
tensorrt_llm.logger.set_level("error")
|
342 |
+
model = WhisperTRTLLM("/root/TensorRT-LLM/examples/whisper/whisper_small_en", False, "../assets", device="cuda")
|
343 |
mel, total_duration = model.log_mel_spectrogram(
|
344 |
+
"../assets/1221-135766-0002.wav",
|
345 |
)
|
346 |
results = model.transcribe(mel)
|
347 |
print(results, total_duration)
|
whisper_live/vad.py
CHANGED
@@ -10,19 +10,23 @@ import onnxruntime
|
|
10 |
class VoiceActivityDetection():
|
11 |
|
12 |
def __init__(self, force_onnx_cpu=True):
|
|
|
13 |
path = self.download()
|
|
|
|
|
14 |
opts = onnxruntime.SessionOptions()
|
15 |
opts.log_severity_level = 3
|
16 |
|
17 |
opts.inter_op_num_threads = 1
|
18 |
opts.intra_op_num_threads = 1
|
19 |
|
|
|
20 |
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
21 |
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
22 |
else:
|
23 |
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
|
24 |
|
25 |
-
|
26 |
self.reset_states()
|
27 |
self.sample_rates = [8000, 16000]
|
28 |
|
|
|
10 |
class VoiceActivityDetection():
|
11 |
|
12 |
def __init__(self, force_onnx_cpu=True):
|
13 |
+
print("downloading ONNX model...")
|
14 |
path = self.download()
|
15 |
+
print("loading session")
|
16 |
+
|
17 |
opts = onnxruntime.SessionOptions()
|
18 |
opts.log_severity_level = 3
|
19 |
|
20 |
opts.inter_op_num_threads = 1
|
21 |
opts.intra_op_num_threads = 1
|
22 |
|
23 |
+
print("loading onnx model")
|
24 |
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
25 |
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
26 |
else:
|
27 |
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
|
28 |
|
29 |
+
print("reset states")
|
30 |
self.reset_states()
|
31 |
self.sample_rates = [8000, 16000]
|
32 |
|