Spaces:
Paused
Paused
Commit
·
16388cf
1
Parent(s):
81cb63c
integrate whisperspeech
Browse files- llm_service.py +32 -19
- main.py +4 -4
- tts_service.py +33 -11
- whisper_live/trt_server.py +6 -13
llm_service.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import json
|
2 |
from pathlib import Path
|
3 |
from typing import Optional
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from transformers import AutoTokenizer
|
@@ -104,6 +107,8 @@ class MistralTensorRTLLM:
|
|
104 |
debug_mode=False,
|
105 |
lora_ckpt_source='hf')
|
106 |
self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
|
|
|
|
|
107 |
|
108 |
def parse_input(
|
109 |
self,
|
@@ -156,7 +161,7 @@ class MistralTensorRTLLM:
|
|
156 |
outputs = output_ids[batch_idx][beam][
|
157 |
output_begin:output_end].tolist()
|
158 |
output_text = self.tokenizer.decode(outputs)
|
159 |
-
|
160 |
output.append(output_text)
|
161 |
return output
|
162 |
|
@@ -177,7 +182,7 @@ class MistralTensorRTLLM:
|
|
177 |
max_output_len=40,
|
178 |
max_attention_window_size=4096,
|
179 |
num_beams=1,
|
180 |
-
streaming=
|
181 |
streaming_interval=4,
|
182 |
debug=False,
|
183 |
):
|
@@ -186,27 +191,26 @@ class MistralTensorRTLLM:
|
|
186 |
tokenizer_path,
|
187 |
)
|
188 |
|
189 |
-
|
190 |
while True:
|
191 |
|
192 |
# Get the last transcription output from the queue
|
193 |
transcription_output = transcription_queue.get()
|
194 |
if transcription_queue.qsize() != 0:
|
195 |
-
|
196 |
continue
|
197 |
-
# while True:
|
198 |
-
# try:
|
199 |
-
# transcription_output = transcription_queue.get_nowait()
|
200 |
-
# except Exception as e:
|
201 |
-
# print("[Queue] exception", e)
|
202 |
-
# break
|
203 |
-
|
204 |
-
# transcription_output = transcription_queue.get()
|
205 |
|
206 |
prompt = transcription_output['prompt'].strip()
|
207 |
input_text=[self.format_prompt_qa(prompt)]
|
|
|
208 |
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
batch_input_ids = self.parse_input(
|
211 |
input_text=input_text,
|
212 |
add_special_tokens=True,
|
@@ -252,15 +256,16 @@ class MistralTensorRTLLM:
|
|
252 |
break
|
253 |
# Interrupted by transcription queue
|
254 |
if output is None:
|
255 |
-
|
256 |
continue
|
257 |
else:
|
258 |
output_ids = outputs['output_ids']
|
259 |
sequence_lengths = outputs['sequence_lengths']
|
260 |
context_logits = None
|
261 |
generation_logits = None
|
262 |
-
if runner.
|
263 |
context_logits = outputs['context_logits']
|
|
|
264 |
generation_logits = outputs['generation_logits']
|
265 |
output = self.decode_tokens(
|
266 |
output_ids,
|
@@ -268,8 +273,16 @@ class MistralTensorRTLLM:
|
|
268 |
sequence_lengths,
|
269 |
transcription_queue
|
270 |
)
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
|
275 |
if __name__=="__main__":
|
@@ -278,11 +291,11 @@ if __name__=="__main__":
|
|
278 |
"/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
|
279 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
280 |
)
|
281 |
-
|
282 |
for i in range(1):
|
283 |
output = llm(
|
284 |
["Born in north-east France, Soyer trained as a"], streaming=True
|
285 |
)
|
286 |
-
|
287 |
|
288 |
|
|
|
1 |
import json
|
2 |
from pathlib import Path
|
3 |
from typing import Optional
|
4 |
+
import logging
|
5 |
+
logging.basicConfig(level = logging.INFO)
|
6 |
+
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
from transformers import AutoTokenizer
|
|
|
107 |
debug_mode=False,
|
108 |
lora_ckpt_source='hf')
|
109 |
self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
|
110 |
+
self.last_prompt = None
|
111 |
+
self.last_output = None
|
112 |
|
113 |
def parse_input(
|
114 |
self,
|
|
|
161 |
outputs = output_ids[batch_idx][beam][
|
162 |
output_begin:output_end].tolist()
|
163 |
output_text = self.tokenizer.decode(outputs)
|
164 |
+
logging.info(f"[LLM] output: {output_text}")
|
165 |
output.append(output_text)
|
166 |
return output
|
167 |
|
|
|
182 |
max_output_len=40,
|
183 |
max_attention_window_size=4096,
|
184 |
num_beams=1,
|
185 |
+
streaming=False,
|
186 |
streaming_interval=4,
|
187 |
debug=False,
|
188 |
):
|
|
|
191 |
tokenizer_path,
|
192 |
)
|
193 |
|
194 |
+
logging.info("[LLM] loaded: True")
|
195 |
while True:
|
196 |
|
197 |
# Get the last transcription output from the queue
|
198 |
transcription_output = transcription_queue.get()
|
199 |
if transcription_queue.qsize() != 0:
|
200 |
+
logging.info("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
|
201 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
prompt = transcription_output['prompt'].strip()
|
204 |
input_text=[self.format_prompt_qa(prompt)]
|
205 |
+
self.eos = transcription_output["eos"]
|
206 |
|
207 |
+
if self.last_prompt == prompt:
|
208 |
+
if self.last_output is not None:
|
209 |
+
# logging.info(f"[LLM info:] Same prompt, adding last llm output to audio queue.")
|
210 |
+
audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
|
211 |
+
continue
|
212 |
+
|
213 |
+
logging.info(f"[LLM INFO:] WhisperLive prompt: {prompt}, eos: {self.eos}")
|
214 |
batch_input_ids = self.parse_input(
|
215 |
input_text=input_text,
|
216 |
add_special_tokens=True,
|
|
|
256 |
break
|
257 |
# Interrupted by transcription queue
|
258 |
if output is None:
|
259 |
+
logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
|
260 |
continue
|
261 |
else:
|
262 |
output_ids = outputs['output_ids']
|
263 |
sequence_lengths = outputs['sequence_lengths']
|
264 |
context_logits = None
|
265 |
generation_logits = None
|
266 |
+
if self.runner.gather_context_logits:
|
267 |
context_logits = outputs['context_logits']
|
268 |
+
if self.runner.gather_generation_logits:
|
269 |
generation_logits = outputs['generation_logits']
|
270 |
output = self.decode_tokens(
|
271 |
output_ids,
|
|
|
273 |
sequence_lengths,
|
274 |
transcription_queue
|
275 |
)
|
276 |
+
|
277 |
+
# if self.eos:
|
278 |
+
if output is not None:
|
279 |
+
self.last_output = output
|
280 |
+
self.last_prompt = prompt
|
281 |
+
llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
|
282 |
+
audio_queue.put({"llm_output": output, "eos": self.eos})
|
283 |
+
|
284 |
+
if self.eos:
|
285 |
+
self.last_prompt = None
|
286 |
|
287 |
|
288 |
if __name__=="__main__":
|
|
|
291 |
"/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
|
292 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
293 |
)
|
294 |
+
logging.info("intialized")
|
295 |
for i in range(1):
|
296 |
output = llm(
|
297 |
["Born in north-east France, Soyer trained as a"], streaming=True
|
298 |
)
|
299 |
+
logging.info(output)
|
300 |
|
301 |
|
main.py
CHANGED
@@ -105,10 +105,10 @@ if __name__ == "__main__":
|
|
105 |
llm_process.start()
|
106 |
|
107 |
# audio process
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
|
112 |
llm_process.join()
|
113 |
whisper_process.join()
|
114 |
-
|
|
|
105 |
llm_process.start()
|
106 |
|
107 |
# audio process
|
108 |
+
tts_runner = WhisperSpeechTTS()
|
109 |
+
tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue))
|
110 |
+
tts_process.start()
|
111 |
|
112 |
llm_process.join()
|
113 |
whisper_process.join()
|
114 |
+
tts_process.join()
|
tts_service.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
import functools
|
|
|
|
|
|
|
2 |
|
3 |
from websockets.sync.server import serve
|
4 |
from whisperspeech.pipeline import Pipeline
|
@@ -9,8 +12,13 @@ class WhisperSpeechTTS:
|
|
9 |
|
10 |
def initialize_model(self):
|
11 |
self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
def run(self, host, port=6080, audio_queue=None):
|
14 |
with serve(
|
15 |
functools.partial(self.start_whisperspeech_tts, audio_queue=audio_queue),
|
16 |
host, port
|
@@ -18,19 +26,33 @@ class WhisperSpeechTTS:
|
|
18 |
server.serve_forever()
|
19 |
|
20 |
def start_whisperspeech_tts(self, websocket, audio_queue=None):
|
21 |
-
self.
|
|
|
22 |
|
23 |
while True:
|
24 |
if audio_queue.empty(): continue
|
25 |
-
|
26 |
-
|
27 |
-
audio = self.pipe.vocoder.decode(self.pipe.generate_atoks(llm_output.strip()))
|
28 |
-
audio = audio.cpu().numpy()
|
29 |
-
audio = audio * 32768.0
|
30 |
-
|
31 |
-
# send audio to client on another websocket
|
32 |
try:
|
33 |
-
websocket.
|
34 |
except Exception as e:
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
|
|
1 |
import functools
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
logging.basicConfig(level = logging.INFO)
|
5 |
|
6 |
from websockets.sync.server import serve
|
7 |
from whisperspeech.pipeline import Pipeline
|
|
|
12 |
|
13 |
def initialize_model(self):
|
14 |
self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
|
15 |
+
self.last_llm_response = None
|
16 |
+
|
17 |
+
def run(self, host, port, audio_queue=None):
|
18 |
+
# initialize and warmup model
|
19 |
+
self.initialize_model()
|
20 |
+
for i in range(3): self.pipe.vocoder.decode(self.pipe.generate_atoks("Hello, I am warming up."))
|
21 |
|
|
|
22 |
with serve(
|
23 |
functools.partial(self.start_whisperspeech_tts, audio_queue=audio_queue),
|
24 |
host, port
|
|
|
26 |
server.serve_forever()
|
27 |
|
28 |
def start_whisperspeech_tts(self, websocket, audio_queue=None):
|
29 |
+
self.eos = False
|
30 |
+
self.output_audio = None
|
31 |
|
32 |
while True:
|
33 |
if audio_queue.empty(): continue
|
34 |
+
|
35 |
+
# check if this websocket exists
|
|
|
|
|
|
|
|
|
|
|
36 |
try:
|
37 |
+
websocket.ping()
|
38 |
except Exception as e:
|
39 |
+
del websocket
|
40 |
+
break
|
41 |
+
|
42 |
+
llm_response = audio_queue.get()
|
43 |
+
llm_output = llm_response["llm_output"][0]
|
44 |
+
self.eos = llm_response["eos"]
|
45 |
+
|
46 |
+
# only process if the output updated
|
47 |
+
if self.last_llm_response != llm_output.strip():
|
48 |
+
logging.INFO("[WhisperSpeech INFO:] Tunning TTS inference ...")
|
49 |
+
audio = self.pipe.vocoder.decode(self.pipe.generate_atoks(llm_output.strip()))
|
50 |
+
self.output_audio = audio.cpu().numpy()
|
51 |
+
self.last_llm_response = llm_output.strip()
|
52 |
+
|
53 |
+
if self.eos and self.output_audio is not None:
|
54 |
+
try:
|
55 |
+
websocket.send(self.output_audio.tobytes())
|
56 |
+
except Exception as e:
|
57 |
+
logging.error("[WhisperSpeech INFO:] Audio error:", e)
|
58 |
|
whisper_live/trt_server.py
CHANGED
@@ -150,6 +150,7 @@ class TranscriptionServer:
|
|
150 |
except Exception as e:
|
151 |
logging.error(e)
|
152 |
return
|
|
|
153 |
self.clients[websocket].add_frames(frame_np)
|
154 |
|
155 |
elapsed_time = time.time() - self.clients_start_time[websocket]
|
@@ -379,6 +380,7 @@ class ServeClient:
|
|
379 |
input_bytes = self.frames_np[int(samples_take):].copy()
|
380 |
duration = input_bytes.shape[0] / self.RATE
|
381 |
if duration<0.4:
|
|
|
382 |
continue
|
383 |
|
384 |
try:
|
@@ -401,23 +403,14 @@ class ServeClient:
|
|
401 |
)
|
402 |
logging.info(f"[INFO]: {segments}, eos: {self.eos}")
|
403 |
|
|
|
404 |
if self.eos:
|
405 |
# self.append_segment(last_segment)
|
406 |
self.timestamp_offset += duration
|
407 |
-
|
408 |
-
|
409 |
-
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
410 |
-
|
411 |
-
self.last_prompt = None
|
412 |
-
# self.set_eos(False)
|
413 |
-
logging.info(f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
|
414 |
)
|
415 |
-
|
416 |
-
if self.last_prompt != self.prompt:
|
417 |
-
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
|
418 |
-
|
419 |
-
self.last_prompt = self.prompt
|
420 |
-
|
421 |
|
422 |
|
423 |
except Exception as e:
|
|
|
150 |
except Exception as e:
|
151 |
logging.error(e)
|
152 |
return
|
153 |
+
print("[WhisperLive INFO:] adding frames ...")
|
154 |
self.clients[websocket].add_frames(frame_np)
|
155 |
|
156 |
elapsed_time = time.time() - self.clients_start_time[websocket]
|
|
|
380 |
input_bytes = self.frames_np[int(samples_take):].copy()
|
381 |
duration = input_bytes.shape[0] / self.RATE
|
382 |
if duration<0.4:
|
383 |
+
time.sleep(0.01) # 5ms sleep to wait for some voice active audio to arrive
|
384 |
continue
|
385 |
|
386 |
try:
|
|
|
403 |
)
|
404 |
logging.info(f"[INFO]: {segments}, eos: {self.eos}")
|
405 |
|
406 |
+
self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt, "eos": self.eos})
|
407 |
if self.eos:
|
408 |
# self.append_segment(last_segment)
|
409 |
self.timestamp_offset += duration
|
410 |
+
logging.info(
|
411 |
+
f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
|
|
|
|
|
|
|
|
|
|
|
412 |
)
|
413 |
+
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
|
416 |
except Exception as e:
|