Spaces:
Paused
Paused
Commit
·
f1e930a
1
Parent(s):
e5e84e9
add whisper output to queue for llm process
Browse files- llm_service.py +5 -2
- main.py +51 -4
- whisper_live/trt_server.py +25 -131
llm_service.py
CHANGED
@@ -155,6 +155,8 @@ class MistralTensorRTLLM:
|
|
155 |
|
156 |
def run(
|
157 |
self,
|
|
|
|
|
158 |
transcription_queue=None,
|
159 |
llm_queue=None,
|
160 |
input_text=None,
|
@@ -166,9 +168,10 @@ class MistralTensorRTLLM:
|
|
166 |
debug=False,
|
167 |
):
|
168 |
self.initialize_model(
|
169 |
-
|
170 |
-
|
171 |
)
|
|
|
172 |
print("Loaded LLM...")
|
173 |
while True:
|
174 |
|
|
|
155 |
|
156 |
def run(
|
157 |
self,
|
158 |
+
model_path,
|
159 |
+
tokenizer_path,
|
160 |
transcription_queue=None,
|
161 |
llm_queue=None,
|
162 |
input_text=None,
|
|
|
168 |
debug=False,
|
169 |
):
|
170 |
self.initialize_model(
|
171 |
+
model_path,
|
172 |
+
tokenizer_path,
|
173 |
)
|
174 |
+
|
175 |
print("Loaded LLM...")
|
176 |
while True:
|
177 |
|
main.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
from whisper_live.trt_server import TranscriptionServer
|
2 |
-
from llm_service import MistralTensorRTLLM
|
3 |
import multiprocessing
|
|
|
4 |
import threading
|
5 |
import ssl
|
6 |
import time
|
@@ -9,8 +8,39 @@ import functools
|
|
9 |
|
10 |
from multiprocessing import Process, Manager, Value, Queue
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
multiprocessing.set_start_method('spawn')
|
15 |
|
16 |
lock = multiprocessing.Lock()
|
@@ -23,12 +53,29 @@ if __name__ == "__main__":
|
|
23 |
|
24 |
|
25 |
whisper_server = TranscriptionServer()
|
26 |
-
whisper_process = multiprocessing.Process(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
whisper_process.start()
|
28 |
|
29 |
llm_provider = MistralTensorRTLLM()
|
30 |
# llm_provider = MistralTensorRTLLMProvider()
|
31 |
-
llm_process = multiprocessing.Process(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
llm_process.start()
|
33 |
|
34 |
llm_process.join()
|
|
|
|
|
|
|
1 |
import multiprocessing
|
2 |
+
import argparse
|
3 |
import threading
|
4 |
import ssl
|
5 |
import time
|
|
|
8 |
|
9 |
from multiprocessing import Process, Manager, Value, Queue
|
10 |
|
11 |
+
from whisper_live.trt_server import TranscriptionServer
|
12 |
+
from llm_service import MistralTensorRTLLM
|
13 |
+
|
14 |
+
|
15 |
+
def parse_arguments():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--whisper_tensorrt_path',
|
18 |
+
type=str,
|
19 |
+
default=None,
|
20 |
+
help='Whisper TensorRT model path')
|
21 |
+
parser.add_argument('--mistral_tensorrt_path',
|
22 |
+
type=str,
|
23 |
+
default=None,
|
24 |
+
help='Mistral TensorRT model path')
|
25 |
+
parser.add_argument('--mistral_tokenizer_path',
|
26 |
+
type=str,
|
27 |
+
default="teknium/OpenHermes-2.5-Mistral-7B",
|
28 |
+
help='Mistral TensorRT model path')
|
29 |
+
return parser.parse_args()
|
30 |
+
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
+
args = parse_arguments()
|
34 |
+
if not args.whisper_tensorrt_path:
|
35 |
+
raise ValueError("Please provide whisper_tensorrt_path to run the pipeline.")
|
36 |
+
import sys
|
37 |
+
sys.exit(0)
|
38 |
+
|
39 |
+
if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
|
40 |
+
raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
|
41 |
+
import sys
|
42 |
+
sys.exit(0)
|
43 |
+
|
44 |
multiprocessing.set_start_method('spawn')
|
45 |
|
46 |
lock = multiprocessing.Lock()
|
|
|
53 |
|
54 |
|
55 |
whisper_server = TranscriptionServer()
|
56 |
+
whisper_process = multiprocessing.Process(
|
57 |
+
target=whisper_server.run,
|
58 |
+
args=(
|
59 |
+
"0.0.0.0",
|
60 |
+
6006,
|
61 |
+
transcription_queue,
|
62 |
+
llm_queue,
|
63 |
+
args.whisper_tensorrt_path
|
64 |
+
)
|
65 |
+
)
|
66 |
whisper_process.start()
|
67 |
|
68 |
llm_provider = MistralTensorRTLLM()
|
69 |
# llm_provider = MistralTensorRTLLMProvider()
|
70 |
+
llm_process = multiprocessing.Process(
|
71 |
+
target=llm_provider.run,
|
72 |
+
args=(
|
73 |
+
args.mistral_tensorrt_path,
|
74 |
+
args.mistral_tokenizer_path,
|
75 |
+
transcription_queue,
|
76 |
+
llm_queue,
|
77 |
+
)
|
78 |
+
)
|
79 |
llm_process.start()
|
80 |
|
81 |
llm_process.join()
|
whisper_live/trt_server.py
CHANGED
@@ -73,7 +73,7 @@ class TranscriptionServer:
|
|
73 |
|
74 |
return wait_time / 60
|
75 |
|
76 |
-
def recv_audio(self, websocket, transcription_queue=None, llm_queue=None):
|
77 |
"""
|
78 |
Receive audio chunks from a client in an infinite loop.
|
79 |
|
@@ -121,7 +121,8 @@ class TranscriptionServer:
|
|
121 |
task=options["task"],
|
122 |
client_uid=options["uid"],
|
123 |
transcription_queue=transcription_queue,
|
124 |
-
llm_queue=llm_queue
|
|
|
125 |
)
|
126 |
|
127 |
self.clients[websocket] = client
|
@@ -132,16 +133,16 @@ class TranscriptionServer:
|
|
132 |
try:
|
133 |
frame_data = websocket.recv()
|
134 |
frame_np = np.frombuffer(frame_data, dtype=np.float32)
|
135 |
-
|
136 |
# VAD
|
137 |
try:
|
138 |
speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
|
139 |
if speech_prob < self.vad_threshold:
|
140 |
no_voice_activity_chunks += 1
|
141 |
-
# print("No speech", no_voice_activity_chunks, self.clients[websocket].eos)
|
142 |
if no_voice_activity_chunks > 2:
|
143 |
if not self.clients[websocket].eos:
|
144 |
self.clients[websocket].set_eos(True)
|
|
|
145 |
continue
|
146 |
no_voice_activity_chunks = 0
|
147 |
self.clients[websocket].set_eos(False)
|
@@ -172,7 +173,7 @@ class TranscriptionServer:
|
|
172 |
del websocket
|
173 |
break
|
174 |
|
175 |
-
def run(self, host, port=9090, transcription_queue=None, llm_queue=None):
|
176 |
"""
|
177 |
Run the transcription server.
|
178 |
|
@@ -181,7 +182,12 @@ class TranscriptionServer:
|
|
181 |
port (int): The port number to bind the server.
|
182 |
"""
|
183 |
with serve(
|
184 |
-
functools.partial(
|
|
|
|
|
|
|
|
|
|
|
185 |
host,
|
186 |
port
|
187 |
) as server:
|
@@ -231,6 +237,7 @@ class ServeClient:
|
|
231 |
client_uid=None,
|
232 |
transcription_queue=None,
|
233 |
llm_queue=None,
|
|
|
234 |
):
|
235 |
"""
|
236 |
Initialize a ServeClient instance.
|
@@ -254,9 +261,7 @@ class ServeClient:
|
|
254 |
self.frames = b""
|
255 |
self.language = language if multilingual else "en"
|
256 |
self.task = task
|
257 |
-
|
258 |
-
self.transcriber = WhisperTRTLLM(
|
259 |
-
"whisper_small_en", False, "assets", device="cuda")
|
260 |
|
261 |
self.timestamp_offset = 0.0
|
262 |
self.frames_np = None
|
@@ -295,8 +300,6 @@ class ServeClient:
|
|
295 |
|
296 |
def set_eos(self, eos):
|
297 |
self.lock.acquire()
|
298 |
-
# if self.eos != eos:
|
299 |
-
# logging.info(f"[WhisperLive:] setting eos: {eos}")
|
300 |
self.eos = eos
|
301 |
self.lock.release()
|
302 |
|
@@ -345,13 +348,10 @@ class ServeClient:
|
|
345 |
"""
|
346 |
while True:
|
347 |
try:
|
348 |
-
start = time.time()
|
349 |
if self.llm_queue is not None:
|
350 |
llm_output = self.llm_queue.get_nowait()
|
351 |
if llm_output:
|
352 |
self.websocket.send(json.dumps(llm_output))
|
353 |
-
end = time.time()
|
354 |
-
# print(f"Time to check LLM output {end - start}")
|
355 |
except queue.Empty:
|
356 |
pass
|
357 |
|
@@ -360,8 +360,7 @@ class ServeClient:
|
|
360 |
break
|
361 |
|
362 |
if self.frames_np is None:
|
363 |
-
#
|
364 |
-
time.sleep(0.05)
|
365 |
continue
|
366 |
|
367 |
# clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
@@ -373,24 +372,19 @@ class ServeClient:
|
|
373 |
samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
|
374 |
input_bytes = self.frames_np[int(samples_take):].copy()
|
375 |
duration = input_bytes.shape[0] / self.RATE
|
376 |
-
if duration<
|
377 |
continue
|
378 |
|
379 |
try:
|
380 |
input_sample = input_bytes.copy()
|
381 |
-
|
382 |
-
# whisper transcribe with prompt
|
383 |
mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
|
384 |
last_segment = self.transcriber.transcribe(mel)
|
385 |
-
|
386 |
if len(last_segment):
|
387 |
-
if len(self.transcript) < self.send_last_n_segments:
|
388 |
-
segments = self.transcript
|
389 |
-
else:
|
390 |
-
segments = self.transcript[-self.send_last_n_segments:]
|
391 |
segments.append({"text": last_segment})
|
392 |
try:
|
393 |
-
|
394 |
self.websocket.send(
|
395 |
json.dumps({
|
396 |
"uid": self.client_uid,
|
@@ -399,115 +393,22 @@ class ServeClient:
|
|
399 |
})
|
400 |
)
|
401 |
if self.eos:
|
402 |
-
self.append_segment(last_segment)
|
403 |
self.timestamp_offset += duration
|
404 |
self.prompt = ' '.join(segment['text'] for segment in segments)
|
405 |
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
406 |
-
self.transcript = []
|
407 |
self.set_eos(False)
|
408 |
|
|
|
|
|
|
|
|
|
|
|
409 |
except Exception as e:
|
410 |
logging.error(f"[ERROR]: {e}")
|
411 |
|
412 |
except Exception as e:
|
413 |
logging.error(f"[ERROR]: {e}")
|
414 |
-
time.sleep(0.01)
|
415 |
-
|
416 |
-
def append_segment(self, result):
|
417 |
-
# print("adding to trasncript: ", result)
|
418 |
-
if not len(self.transcript):
|
419 |
-
self.transcript.append({"text": result + " "})
|
420 |
-
else:
|
421 |
-
if self.transcript[-1]["text"].strip()[-1] == ".":
|
422 |
-
if result[0] >= "a" and result[0] <= "z":
|
423 |
-
self.transcript[-1]["text"] = replace_last_occurrence(
|
424 |
-
self.transcript[-1]["text"], ".", ","
|
425 |
-
)
|
426 |
-
elif self.transcript[-1]["text"].strip()[-1] == "?":
|
427 |
-
if result[0] >= "a" and result[0] <= "z":
|
428 |
-
self.transcript[-1]["text"] = replace_last_occurrence(
|
429 |
-
self.transcript[-1]["text"], "?", ","
|
430 |
-
)
|
431 |
-
|
432 |
-
self.transcript.append({"text": result + " "})
|
433 |
-
|
434 |
-
|
435 |
-
def update_segments(self, segments, duration):
|
436 |
-
"""
|
437 |
-
Processes the segments from whisper. Appends all the segments to the list
|
438 |
-
except for the last segment assuming that it is incomplete.
|
439 |
-
|
440 |
-
Updates the ongoing transcript with transcribed segments, including their start and end times.
|
441 |
-
Complete segments are appended to the transcript in chronological order. Incomplete segments
|
442 |
-
(assumed to be the last one) are processed to identify repeated content. If the same incomplete
|
443 |
-
segment is seen multiple times, it updates the offset and appends the segment to the transcript.
|
444 |
-
A threshold is used to detect repeated content and ensure it is only included once in the transcript.
|
445 |
-
The timestamp offset is updated based on the duration of processed segments. The method returns the
|
446 |
-
last processed segment, allowing it to be sent to the client for real-time updates.
|
447 |
-
|
448 |
-
Args:
|
449 |
-
segments(dict) : dictionary of segments as returned by whisper
|
450 |
-
duration(float): duration of the current chunk
|
451 |
-
|
452 |
-
Returns:
|
453 |
-
dict or None: The last processed segment with its start time, end time, and transcribed text.
|
454 |
-
Returns None if there are no valid segments to process.
|
455 |
-
"""
|
456 |
-
offset = None
|
457 |
-
self.current_out = ''
|
458 |
-
last_segment = None
|
459 |
-
# process complete segments
|
460 |
-
if len(segments) > 1:
|
461 |
-
for i, s in enumerate(segments[:-1]):
|
462 |
-
text_ = s.text
|
463 |
-
self.text.append(text_)
|
464 |
-
start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end)
|
465 |
-
self.transcript.append(
|
466 |
-
{
|
467 |
-
'start': start,
|
468 |
-
'end': end,
|
469 |
-
'text': text_
|
470 |
-
}
|
471 |
-
)
|
472 |
-
|
473 |
-
offset = min(duration, s.end)
|
474 |
-
|
475 |
-
self.current_out += segments[-1].text
|
476 |
-
last_segment = {
|
477 |
-
'start': self.timestamp_offset + segments[-1].start,
|
478 |
-
'end': self.timestamp_offset + min(duration, segments[-1].end),
|
479 |
-
'text': self.current_out
|
480 |
-
}
|
481 |
-
|
482 |
-
# if same incomplete segment is seen multiple times then update the offset
|
483 |
-
# and append the segment to the list
|
484 |
-
if self.current_out.strip() == self.prev_out.strip() and self.current_out != '':
|
485 |
-
self.same_output_threshold += 1
|
486 |
-
else:
|
487 |
-
self.same_output_threshold = 0
|
488 |
-
|
489 |
-
if self.same_output_threshold > 5:
|
490 |
-
if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower():
|
491 |
-
self.text.append(self.current_out)
|
492 |
-
self.transcript.append(
|
493 |
-
{
|
494 |
-
'start': self.timestamp_offset,
|
495 |
-
'end': self.timestamp_offset + duration,
|
496 |
-
'text': self.current_out
|
497 |
-
}
|
498 |
-
)
|
499 |
-
self.current_out = ''
|
500 |
-
offset = duration
|
501 |
-
self.same_output_threshold = 0
|
502 |
-
last_segment = None
|
503 |
-
else:
|
504 |
-
self.prev_out = self.current_out
|
505 |
-
|
506 |
-
# update offset
|
507 |
-
if offset is not None:
|
508 |
-
self.timestamp_offset += offset
|
509 |
-
|
510 |
-
return last_segment
|
511 |
|
512 |
def disconnect(self):
|
513 |
"""
|
@@ -538,10 +439,3 @@ class ServeClient:
|
|
538 |
logging.info("Cleaning up.")
|
539 |
self.exit = True
|
540 |
self.transcriber.destroy()
|
541 |
-
|
542 |
-
def replace_last_occurrence(input_str, old_char, new_char):
|
543 |
-
parts = input_str.rsplit(old_char, 1)
|
544 |
-
if len(parts) == 2:
|
545 |
-
return parts[0] + new_char + parts[1]
|
546 |
-
else:
|
547 |
-
return input_str
|
|
|
73 |
|
74 |
return wait_time / 60
|
75 |
|
76 |
+
def recv_audio(self, websocket, transcription_queue=None, llm_queue=None, whisper_tensorrt_path=None):
|
77 |
"""
|
78 |
Receive audio chunks from a client in an infinite loop.
|
79 |
|
|
|
121 |
task=options["task"],
|
122 |
client_uid=options["uid"],
|
123 |
transcription_queue=transcription_queue,
|
124 |
+
llm_queue=llm_queue,
|
125 |
+
model_path=whisper_tensorrt_path
|
126 |
)
|
127 |
|
128 |
self.clients[websocket] = client
|
|
|
133 |
try:
|
134 |
frame_data = websocket.recv()
|
135 |
frame_np = np.frombuffer(frame_data, dtype=np.float32)
|
136 |
+
|
137 |
# VAD
|
138 |
try:
|
139 |
speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
|
140 |
if speech_prob < self.vad_threshold:
|
141 |
no_voice_activity_chunks += 1
|
|
|
142 |
if no_voice_activity_chunks > 2:
|
143 |
if not self.clients[websocket].eos:
|
144 |
self.clients[websocket].set_eos(True)
|
145 |
+
time.sleep(0.25) # EOS stop receiving frames for a 250ms(to send output to LLM.)
|
146 |
continue
|
147 |
no_voice_activity_chunks = 0
|
148 |
self.clients[websocket].set_eos(False)
|
|
|
173 |
del websocket
|
174 |
break
|
175 |
|
176 |
+
def run(self, host, port=9090, transcription_queue=None, llm_queue=None, whisper_tensorrt_path=None):
|
177 |
"""
|
178 |
Run the transcription server.
|
179 |
|
|
|
182 |
port (int): The port number to bind the server.
|
183 |
"""
|
184 |
with serve(
|
185 |
+
functools.partial(
|
186 |
+
self.recv_audio,
|
187 |
+
transcription_queue=transcription_queue,
|
188 |
+
llm_queue=llm_queue,
|
189 |
+
whisper_tensorrt_path=whisper_tensorrt_path
|
190 |
+
),
|
191 |
host,
|
192 |
port
|
193 |
) as server:
|
|
|
237 |
client_uid=None,
|
238 |
transcription_queue=None,
|
239 |
llm_queue=None,
|
240 |
+
model_path=None
|
241 |
):
|
242 |
"""
|
243 |
Initialize a ServeClient instance.
|
|
|
261 |
self.frames = b""
|
262 |
self.language = language if multilingual else "en"
|
263 |
self.task = task
|
264 |
+
self.transcriber = WhisperTRTLLM(model_path, False, "assets", device="cuda")
|
|
|
|
|
265 |
|
266 |
self.timestamp_offset = 0.0
|
267 |
self.frames_np = None
|
|
|
300 |
|
301 |
def set_eos(self, eos):
|
302 |
self.lock.acquire()
|
|
|
|
|
303 |
self.eos = eos
|
304 |
self.lock.release()
|
305 |
|
|
|
348 |
"""
|
349 |
while True:
|
350 |
try:
|
|
|
351 |
if self.llm_queue is not None:
|
352 |
llm_output = self.llm_queue.get_nowait()
|
353 |
if llm_output:
|
354 |
self.websocket.send(json.dumps(llm_output))
|
|
|
|
|
355 |
except queue.Empty:
|
356 |
pass
|
357 |
|
|
|
360 |
break
|
361 |
|
362 |
if self.frames_np is None:
|
363 |
+
time.sleep(0.01) # wait for any audio to arrive
|
|
|
364 |
continue
|
365 |
|
366 |
# clip audio if the current chunk exceeds 30 seconds, this basically implies that
|
|
|
372 |
samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
|
373 |
input_bytes = self.frames_np[int(samples_take):].copy()
|
374 |
duration = input_bytes.shape[0] / self.RATE
|
375 |
+
if duration<0.4:
|
376 |
continue
|
377 |
|
378 |
try:
|
379 |
input_sample = input_bytes.copy()
|
380 |
+
|
|
|
381 |
mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
|
382 |
last_segment = self.transcriber.transcribe(mel)
|
383 |
+
segments = []
|
384 |
if len(last_segment):
|
|
|
|
|
|
|
|
|
385 |
segments.append({"text": last_segment})
|
386 |
try:
|
387 |
+
print(f"Sending... {segments}")
|
388 |
self.websocket.send(
|
389 |
json.dumps({
|
390 |
"uid": self.client_uid,
|
|
|
393 |
})
|
394 |
)
|
395 |
if self.eos:
|
396 |
+
# self.append_segment(last_segment)
|
397 |
self.timestamp_offset += duration
|
398 |
self.prompt = ' '.join(segment['text'] for segment in segments)
|
399 |
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
|
|
400 |
self.set_eos(False)
|
401 |
|
402 |
+
logging.info(
|
403 |
+
f"[INFO:] \
|
404 |
+
Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
|
405 |
+
)
|
406 |
+
|
407 |
except Exception as e:
|
408 |
logging.error(f"[ERROR]: {e}")
|
409 |
|
410 |
except Exception as e:
|
411 |
logging.error(f"[ERROR]: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
def disconnect(self):
|
414 |
"""
|
|
|
439 |
logging.info("Cleaning up.")
|
440 |
self.exit = True
|
441 |
self.transcriber.destroy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|