makaveli10 commited on
Commit
16388cf
1 Parent(s): 81cb63c

integrate whisperspeech

Browse files
Files changed (4) hide show
  1. llm_service.py +32 -19
  2. main.py +4 -4
  3. tts_service.py +33 -11
  4. whisper_live/trt_server.py +6 -13
llm_service.py CHANGED
@@ -1,6 +1,9 @@
1
  import json
2
  from pathlib import Path
3
  from typing import Optional
 
 
 
4
  import numpy as np
5
  import torch
6
  from transformers import AutoTokenizer
@@ -104,6 +107,8 @@ class MistralTensorRTLLM:
104
  debug_mode=False,
105
  lora_ckpt_source='hf')
106
  self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
 
 
107
 
108
  def parse_input(
109
  self,
@@ -156,7 +161,7 @@ class MistralTensorRTLLM:
156
  outputs = output_ids[batch_idx][beam][
157
  output_begin:output_end].tolist()
158
  output_text = self.tokenizer.decode(outputs)
159
- print("[LLM] output:", output_text)
160
  output.append(output_text)
161
  return output
162
 
@@ -177,7 +182,7 @@ class MistralTensorRTLLM:
177
  max_output_len=40,
178
  max_attention_window_size=4096,
179
  num_beams=1,
180
- streaming=True,
181
  streaming_interval=4,
182
  debug=False,
183
  ):
@@ -186,27 +191,26 @@ class MistralTensorRTLLM:
186
  tokenizer_path,
187
  )
188
 
189
- print("[LLM] loaded: True")
190
  while True:
191
 
192
  # Get the last transcription output from the queue
193
  transcription_output = transcription_queue.get()
194
  if transcription_queue.qsize() != 0:
195
- print("[LLM] transcription queue size:", transcription_queue.qsize())
196
  continue
197
- # while True:
198
- # try:
199
- # transcription_output = transcription_queue.get_nowait()
200
- # except Exception as e:
201
- # print("[Queue] exception", e)
202
- # break
203
-
204
- # transcription_output = transcription_queue.get()
205
 
206
  prompt = transcription_output['prompt'].strip()
207
  input_text=[self.format_prompt_qa(prompt)]
 
208
 
209
- print("[Whisper] prompt:", prompt)
 
 
 
 
 
 
210
  batch_input_ids = self.parse_input(
211
  input_text=input_text,
212
  add_special_tokens=True,
@@ -252,15 +256,16 @@ class MistralTensorRTLLM:
252
  break
253
  # Interrupted by transcription queue
254
  if output is None:
255
- print("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!", transcription_queue.qsize())
256
  continue
257
  else:
258
  output_ids = outputs['output_ids']
259
  sequence_lengths = outputs['sequence_lengths']
260
  context_logits = None
261
  generation_logits = None
262
- if runner.gather_all_token_logits:
263
  context_logits = outputs['context_logits']
 
264
  generation_logits = outputs['generation_logits']
265
  output = self.decode_tokens(
266
  output_ids,
@@ -268,8 +273,16 @@ class MistralTensorRTLLM:
268
  sequence_lengths,
269
  transcription_queue
270
  )
271
- llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
272
- audio_queue.put(output)
 
 
 
 
 
 
 
 
273
 
274
 
275
  if __name__=="__main__":
@@ -278,11 +291,11 @@ if __name__=="__main__":
278
  "/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
279
  "teknium/OpenHermes-2.5-Mistral-7B",
280
  )
281
- print("intialized")
282
  for i in range(1):
283
  output = llm(
284
  ["Born in north-east France, Soyer trained as a"], streaming=True
285
  )
286
- print(output)
287
 
288
 
 
1
  import json
2
  from pathlib import Path
3
  from typing import Optional
4
+ import logging
5
+ logging.basicConfig(level = logging.INFO)
6
+
7
  import numpy as np
8
  import torch
9
  from transformers import AutoTokenizer
 
107
  debug_mode=False,
108
  lora_ckpt_source='hf')
109
  self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
110
+ self.last_prompt = None
111
+ self.last_output = None
112
 
113
  def parse_input(
114
  self,
 
161
  outputs = output_ids[batch_idx][beam][
162
  output_begin:output_end].tolist()
163
  output_text = self.tokenizer.decode(outputs)
164
+ logging.info(f"[LLM] output: {output_text}")
165
  output.append(output_text)
166
  return output
167
 
 
182
  max_output_len=40,
183
  max_attention_window_size=4096,
184
  num_beams=1,
185
+ streaming=False,
186
  streaming_interval=4,
187
  debug=False,
188
  ):
 
191
  tokenizer_path,
192
  )
193
 
194
+ logging.info("[LLM] loaded: True")
195
  while True:
196
 
197
  # Get the last transcription output from the queue
198
  transcription_output = transcription_queue.get()
199
  if transcription_queue.qsize() != 0:
200
+ logging.info("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
201
  continue
 
 
 
 
 
 
 
 
202
 
203
  prompt = transcription_output['prompt'].strip()
204
  input_text=[self.format_prompt_qa(prompt)]
205
+ self.eos = transcription_output["eos"]
206
 
207
+ if self.last_prompt == prompt:
208
+ if self.last_output is not None:
209
+ # logging.info(f"[LLM info:] Same prompt, adding last llm output to audio queue.")
210
+ audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
211
+ continue
212
+
213
+ logging.info(f"[LLM INFO:] WhisperLive prompt: {prompt}, eos: {self.eos}")
214
  batch_input_ids = self.parse_input(
215
  input_text=input_text,
216
  add_special_tokens=True,
 
256
  break
257
  # Interrupted by transcription queue
258
  if output is None:
259
+ logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
260
  continue
261
  else:
262
  output_ids = outputs['output_ids']
263
  sequence_lengths = outputs['sequence_lengths']
264
  context_logits = None
265
  generation_logits = None
266
+ if self.runner.gather_context_logits:
267
  context_logits = outputs['context_logits']
268
+ if self.runner.gather_generation_logits:
269
  generation_logits = outputs['generation_logits']
270
  output = self.decode_tokens(
271
  output_ids,
 
273
  sequence_lengths,
274
  transcription_queue
275
  )
276
+
277
+ # if self.eos:
278
+ if output is not None:
279
+ self.last_output = output
280
+ self.last_prompt = prompt
281
+ llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
282
+ audio_queue.put({"llm_output": output, "eos": self.eos})
283
+
284
+ if self.eos:
285
+ self.last_prompt = None
286
 
287
 
288
  if __name__=="__main__":
 
291
  "/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
292
  "teknium/OpenHermes-2.5-Mistral-7B",
293
  )
294
+ logging.info("intialized")
295
  for i in range(1):
296
  output = llm(
297
  ["Born in north-east France, Soyer trained as a"], streaming=True
298
  )
299
+ logging.info(output)
300
 
301
 
main.py CHANGED
@@ -105,10 +105,10 @@ if __name__ == "__main__":
105
  llm_process.start()
106
 
107
  # audio process
108
- # tts_runner = WhisperSpeechTTS()
109
- # tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue))
110
- # tts_process.start()
111
 
112
  llm_process.join()
113
  whisper_process.join()
114
- # tts_process.join()
 
105
  llm_process.start()
106
 
107
  # audio process
108
+ tts_runner = WhisperSpeechTTS()
109
+ tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue))
110
+ tts_process.start()
111
 
112
  llm_process.join()
113
  whisper_process.join()
114
+ tts_process.join()
tts_service.py CHANGED
@@ -1,4 +1,7 @@
1
  import functools
 
 
 
2
 
3
  from websockets.sync.server import serve
4
  from whisperspeech.pipeline import Pipeline
@@ -9,8 +12,13 @@ class WhisperSpeechTTS:
9
 
10
  def initialize_model(self):
11
  self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
 
 
 
 
 
 
12
 
13
- def run(self, host, port=6080, audio_queue=None):
14
  with serve(
15
  functools.partial(self.start_whisperspeech_tts, audio_queue=audio_queue),
16
  host, port
@@ -18,19 +26,33 @@ class WhisperSpeechTTS:
18
  server.serve_forever()
19
 
20
  def start_whisperspeech_tts(self, websocket, audio_queue=None):
21
- self.initialize_model()
 
22
 
23
  while True:
24
  if audio_queue.empty(): continue
25
-
26
- llm_output = audio_queue.get()[0]
27
- audio = self.pipe.vocoder.decode(self.pipe.generate_atoks(llm_output.strip()))
28
- audio = audio.cpu().numpy()
29
- audio = audio * 32768.0
30
-
31
- # send audio to client on another websocket
32
  try:
33
- websocket.send(audio.astype('int16').tobytes())
34
  except Exception as e:
35
- print("Audio error:", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
1
  import functools
2
+ import time
3
+ import logging
4
+ logging.basicConfig(level = logging.INFO)
5
 
6
  from websockets.sync.server import serve
7
  from whisperspeech.pipeline import Pipeline
 
12
 
13
  def initialize_model(self):
14
  self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model')
15
+ self.last_llm_response = None
16
+
17
+ def run(self, host, port, audio_queue=None):
18
+ # initialize and warmup model
19
+ self.initialize_model()
20
+ for i in range(3): self.pipe.vocoder.decode(self.pipe.generate_atoks("Hello, I am warming up."))
21
 
 
22
  with serve(
23
  functools.partial(self.start_whisperspeech_tts, audio_queue=audio_queue),
24
  host, port
 
26
  server.serve_forever()
27
 
28
  def start_whisperspeech_tts(self, websocket, audio_queue=None):
29
+ self.eos = False
30
+ self.output_audio = None
31
 
32
  while True:
33
  if audio_queue.empty(): continue
34
+
35
+ # check if this websocket exists
 
 
 
 
 
36
  try:
37
+ websocket.ping()
38
  except Exception as e:
39
+ del websocket
40
+ break
41
+
42
+ llm_response = audio_queue.get()
43
+ llm_output = llm_response["llm_output"][0]
44
+ self.eos = llm_response["eos"]
45
+
46
+ # only process if the output updated
47
+ if self.last_llm_response != llm_output.strip():
48
+ logging.INFO("[WhisperSpeech INFO:] Tunning TTS inference ...")
49
+ audio = self.pipe.vocoder.decode(self.pipe.generate_atoks(llm_output.strip()))
50
+ self.output_audio = audio.cpu().numpy()
51
+ self.last_llm_response = llm_output.strip()
52
+
53
+ if self.eos and self.output_audio is not None:
54
+ try:
55
+ websocket.send(self.output_audio.tobytes())
56
+ except Exception as e:
57
+ logging.error("[WhisperSpeech INFO:] Audio error:", e)
58
 
whisper_live/trt_server.py CHANGED
@@ -150,6 +150,7 @@ class TranscriptionServer:
150
  except Exception as e:
151
  logging.error(e)
152
  return
 
153
  self.clients[websocket].add_frames(frame_np)
154
 
155
  elapsed_time = time.time() - self.clients_start_time[websocket]
@@ -379,6 +380,7 @@ class ServeClient:
379
  input_bytes = self.frames_np[int(samples_take):].copy()
380
  duration = input_bytes.shape[0] / self.RATE
381
  if duration<0.4:
 
382
  continue
383
 
384
  try:
@@ -401,23 +403,14 @@ class ServeClient:
401
  )
402
  logging.info(f"[INFO]: {segments}, eos: {self.eos}")
403
 
 
404
  if self.eos:
405
  # self.append_segment(last_segment)
406
  self.timestamp_offset += duration
407
-
408
- if self.last_prompt != self.prompt:
409
- self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
410
-
411
- self.last_prompt = None
412
- # self.set_eos(False)
413
- logging.info(f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
414
  )
415
- else:
416
- if self.last_prompt != self.prompt:
417
- self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
418
-
419
- self.last_prompt = self.prompt
420
-
421
 
422
 
423
  except Exception as e:
 
150
  except Exception as e:
151
  logging.error(e)
152
  return
153
+ print("[WhisperLive INFO:] adding frames ...")
154
  self.clients[websocket].add_frames(frame_np)
155
 
156
  elapsed_time = time.time() - self.clients_start_time[websocket]
 
380
  input_bytes = self.frames_np[int(samples_take):].copy()
381
  duration = input_bytes.shape[0] / self.RATE
382
  if duration<0.4:
383
+ time.sleep(0.01) # 5ms sleep to wait for some voice active audio to arrive
384
  continue
385
 
386
  try:
 
403
  )
404
  logging.info(f"[INFO]: {segments}, eos: {self.eos}")
405
 
406
+ self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt, "eos": self.eos})
407
  if self.eos:
408
  # self.append_segment(last_segment)
409
  self.timestamp_offset += duration
410
+ logging.info(
411
+ f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
 
 
 
 
 
412
  )
413
+
 
 
 
 
 
414
 
415
 
416
  except Exception as e: