makaveli10 commited on
Commit
7cc82ad
β€’
1 Parent(s): 201054b

add multiprocess communication

Browse files
services/llm_service.py β†’ llm_service.py RENAMED
@@ -103,6 +103,7 @@ class MistralTensorRTLLM:
103
  rank=self.runtime_rank,
104
  debug_mode=False,
105
  lora_ckpt_source='hf')
 
106
 
107
  def parse_input(
108
  self,
@@ -152,75 +153,83 @@ class MistralTensorRTLLM:
152
  output.append(output_text)
153
  return output
154
 
155
- def __call__(
156
- self,
157
- input_text,
158
- max_output_len=100,
 
 
159
  max_attention_window_size=4096,
160
  num_beams=1,
161
  streaming=True,
162
  streaming_interval=4,
 
163
  ):
164
- import time
165
- start = time.time()
166
- batch_input_ids = self.parse_input(
167
- input_text=input_text,
168
- add_special_tokens=True,
169
- max_input_length=923,
170
- pad_id=None,
171
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- input_lengths = [x.size(1) for x in batch_input_ids]
174
- print(self.runner_kwargs)
175
- runner = self.runner_cls.from_dir(**self.runner_kwargs)
176
- with torch.no_grad():
177
- outputs = runner.generate(
178
- batch_input_ids,
179
- max_new_tokens=max_output_len,
180
- max_attention_window_size=max_attention_window_size,
181
- end_id=self.end_id,
182
- pad_id=self.pad_id,
183
- temperature=1.0,
184
- top_k=1,
185
- top_p=0.0,
186
- num_beams=num_beams,
187
- length_penalty=1.0,
188
- repetition_penalty=1.0,
189
- stop_words_list=None,
190
- bad_words_list=None,
191
- lora_uids=None,
192
- prompt_table_path=None,
193
- prompt_tasks=None,
194
- streaming=streaming,
195
- output_sequence_lengths=True,
196
- return_dict=True)
197
- torch.cuda.synchronize()
198
- print(outputs)
199
- if streaming:
200
- for curr_outputs in throttle_generator(outputs, streaming_interval):
201
- output_ids = curr_outputs['output_ids']
202
- sequence_lengths = curr_outputs['sequence_lengths']
 
 
 
 
 
 
 
 
 
 
203
  output = self.decode_tokens(
204
  output_ids,
205
  input_lengths,
206
- sequence_lengths
207
  )
208
- print(time.time() - start)
209
- print(input_text[0] + " " + output[0])
210
- else:
211
- output_ids = outputs['output_ids']
212
- sequence_lengths = outputs['sequence_lengths']
213
- context_logits = None
214
- generation_logits = None
215
- if runner.gather_all_token_logits:
216
- context_logits = outputs['context_logits']
217
- generation_logits = outputs['generation_logits']
218
- output = self.decode_tokens(
219
- output_ids,
220
- input_lengths,
221
- sequence_lengths,
222
- )
223
- return output
224
 
225
 
226
  if __name__=="__main__":
 
103
  rank=self.runtime_rank,
104
  debug_mode=False,
105
  lora_ckpt_source='hf')
106
+ self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
107
 
108
  def parse_input(
109
  self,
 
153
  output.append(output_text)
154
  return output
155
 
156
+ def run(
157
+ self,
158
+ transcription_queue=None,
159
+ llm_queue=None,
160
+ input_text=None,
161
+ max_output_len=20,
162
  max_attention_window_size=4096,
163
  num_beams=1,
164
  streaming=True,
165
  streaming_interval=4,
166
+ debug=False,
167
  ):
168
+ self.initialize_model(
169
+ "/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
170
+ "teknium/OpenHermes-2.5-Mistral-7B",
 
 
 
 
171
  )
172
+ print("Loaded LLM...")
173
+ while True:
174
+
175
+ # while transcription
176
+ transcription_output = transcription_queue.get()
177
+ input_text=transcription_output['prompt'].strip()
178
+
179
+ print("Whisper: ", input_text)
180
+ batch_input_ids = self.parse_input(
181
+ input_text=input_text,
182
+ add_special_tokens=True,
183
+ max_input_length=923,
184
+ pad_id=None,
185
+ )
186
 
187
+ input_lengths = [x.size(1) for x in batch_input_ids]
188
+ with torch.no_grad():
189
+ outputs = self.runner.generate(
190
+ batch_input_ids,
191
+ max_new_tokens=max_output_len,
192
+ max_attention_window_size=max_attention_window_size,
193
+ end_id=self.end_id,
194
+ pad_id=self.pad_id,
195
+ temperature=1.0,
196
+ top_k=1,
197
+ top_p=0.0,
198
+ num_beams=num_beams,
199
+ length_penalty=1.0,
200
+ repetition_penalty=1.0,
201
+ stop_words_list=None,
202
+ bad_words_list=None,
203
+ lora_uids=None,
204
+ prompt_table_path=None,
205
+ prompt_tasks=None,
206
+ streaming=streaming,
207
+ output_sequence_lengths=True,
208
+ return_dict=True)
209
+ torch.cuda.synchronize()
210
+ if streaming:
211
+ for curr_outputs in throttle_generator(outputs, streaming_interval):
212
+ output_ids = curr_outputs['output_ids']
213
+ sequence_lengths = curr_outputs['sequence_lengths']
214
+ output = self.decode_tokens(
215
+ output_ids,
216
+ input_lengths,
217
+ sequence_lengths
218
+ )
219
+ else:
220
+ output_ids = outputs['output_ids']
221
+ sequence_lengths = outputs['sequence_lengths']
222
+ context_logits = None
223
+ generation_logits = None
224
+ if runner.gather_all_token_logits:
225
+ context_logits = outputs['context_logits']
226
+ generation_logits = outputs['generation_logits']
227
  output = self.decode_tokens(
228
  output_ids,
229
  input_lengths,
230
+ sequence_lengths,
231
  )
232
+ llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
 
235
  if __name__=="__main__":
main.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from whisper_live.trt_server import TranscriptionServer
2
+ from llm_service import MistralTensorRTLLM
3
+ import multiprocessing
4
+ import threading
5
+ import ssl
6
+ import time
7
+ import sys
8
+ import functools
9
+
10
+ from multiprocessing import Process, Manager, Value, Queue
11
+
12
+
13
+ if __name__ == "__main__":
14
+ # multiprocessing.set_start_method('spawn')
15
+
16
+ lock = multiprocessing.Lock()
17
+
18
+ manager = Manager()
19
+ shared_output = manager.list()
20
+
21
+ transcription_queue = Queue()
22
+ llm_queue = Queue()
23
+
24
+
25
+ whisper_server = TranscriptionServer()
26
+ whisper_process = multiprocessing.Process(target=whisper_server.run, args=("0.0.0.0", 6006, transcription_queue, llm_queue))
27
+ whisper_process.start()
28
+
29
+ # llm_provider = MistralTensorRTLLM()
30
+ # # llm_provider = MistralTensorRTLLMProvider()
31
+ # llm_process = multiprocessing.Process(target=llm_provider.run, args=(transcription_queue, llm_queue))
32
+ # llm_process.start()
33
+
34
+ # llm_process.join()
35
+ whisper_process.join()
whisper_live/trt_server.py CHANGED
@@ -12,6 +12,8 @@ from websockets.sync.server import serve
12
  import torch
13
  import numpy as np
14
  import time
 
 
15
  from whisper_live.vad import VoiceActivityDetection
16
  from whisper_live.trt_transcriber import WhisperTRTLLM
17
 
@@ -47,13 +49,13 @@ class TranscriptionServer:
47
 
48
  def __init__(self):
49
  # voice activity detection model
50
- self.vad_model = VoiceActivityDetection()
51
- self.vad_threshold = 0.5
52
  self.clients = {}
53
  self.websockets = {}
54
  self.clients_start_time = {}
55
  self.max_clients = 4
56
  self.max_connection_time = 600
 
57
 
58
  def get_wait_time(self):
59
  """
@@ -72,7 +74,7 @@ class TranscriptionServer:
72
 
73
  return wait_time / 60
74
 
75
- def recv_audio(self, websocket):
76
  """
77
  Receive audio chunks from a client in an infinite loop.
78
 
@@ -93,6 +95,9 @@ class TranscriptionServer:
93
  Raises:
94
  Exception: If there is an error during the audio frame processing.
95
  """
 
 
 
96
  logging.info("New client connected")
97
  options = websocket.recv()
98
  options = json.loads(options)
@@ -115,22 +120,26 @@ class TranscriptionServer:
115
  multilingual=options["multilingual"],
116
  language=options["language"],
117
  task=options["task"],
118
- client_uid=options["uid"]
 
 
119
  )
120
 
121
  self.clients[websocket] = client
122
  self.clients_start_time[websocket] = time.time()
123
  no_voice_activity_chunks = 0
 
124
  while True:
125
  try:
126
  frame_data = websocket.recv()
127
  frame_np = np.frombuffer(frame_data, dtype=np.float32)
 
128
  # VAD
129
  try:
130
  speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
131
  if speech_prob < self.vad_threshold:
132
  no_voice_activity_chunks += 1
133
- print("No speech", no_voice_activity_chunks)
134
  if no_voice_activity_chunks > 2:
135
  if not self.clients[websocket].eos:
136
  self.clients[websocket].set_eos(True)
@@ -164,7 +173,7 @@ class TranscriptionServer:
164
  del websocket
165
  break
166
 
167
- def run(self, host, port=9090):
168
  """
169
  Run the transcription server.
170
 
@@ -172,7 +181,15 @@ class TranscriptionServer:
172
  host (str): The host address to bind the server.
173
  port (int): The port number to bind the server.
174
  """
175
- with serve(self.recv_audio, host, port) as server:
 
 
 
 
 
 
 
 
176
  server.serve_forever()
177
 
178
 
@@ -209,7 +226,17 @@ class ServeClient:
209
  SERVER_READY = "SERVER_READY"
210
  DISCONNECT = "DISCONNECT"
211
 
212
- def __init__(self, websocket, task="transcribe", device=None, multilingual=False, language=None, client_uid=None):
 
 
 
 
 
 
 
 
 
 
213
  """
214
  Initialize a ServeClient instance.
215
  The Whisper model is initialized based on the client's language and device availability.
@@ -226,6 +253,8 @@ class ServeClient:
226
 
227
  """
228
  self.client_uid = client_uid
 
 
229
  self.data = b""
230
  self.frames = b""
231
  self.language = language if multilingual else "en"
@@ -246,6 +275,7 @@ class ServeClient:
246
  self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
247
  self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
248
  self.transcript = []
 
249
  self.send_last_n_segments = 10
250
 
251
  # text formatting
@@ -318,6 +348,14 @@ class ServeClient:
318
 
319
  """
320
  while True:
 
 
 
 
 
 
 
 
321
  if self.exit:
322
  logging.info("Exiting speech to text thread")
323
  break
@@ -334,28 +372,39 @@ class ServeClient:
334
  samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
335
  input_bytes = self.frames_np[int(samples_take):].copy()
336
  duration = input_bytes.shape[0] / self.RATE
337
- if duration<1.0 or not self.eos:
338
  continue
339
 
340
  try:
341
  input_sample = input_bytes.copy()
342
- save_wav(input_sample)
343
  # whisper transcribe with prompt
344
  mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
345
- print(mel.shape, duration)
346
- result = self.transcriber.transcribe(mel)
347
- self.append_segment(result)
348
- self.set_eos(False)
349
- self.timestamp_offset += duration
350
- if len(result):
351
- segments = self.transcript[-self.send_last_n_segments:]
 
352
  try:
 
353
  self.websocket.send(
354
  json.dumps({
355
  "uid": self.client_uid,
356
- "segments": segments
 
357
  })
358
  )
 
 
 
 
 
 
 
 
359
  except Exception as e:
360
  logging.error(f"[ERROR]: {e}")
361
 
@@ -364,6 +413,7 @@ class ServeClient:
364
  time.sleep(0.01)
365
 
366
  def append_segment(self, result):
 
367
  if not len(self.transcript):
368
  self.transcript.append({"text": result + " "})
369
  else:
 
12
  import torch
13
  import numpy as np
14
  import time
15
+ import queue
16
+
17
  from whisper_live.vad import VoiceActivityDetection
18
  from whisper_live.trt_transcriber import WhisperTRTLLM
19
 
 
49
 
50
  def __init__(self):
51
  # voice activity detection model
52
+
 
53
  self.clients = {}
54
  self.websockets = {}
55
  self.clients_start_time = {}
56
  self.max_clients = 4
57
  self.max_connection_time = 600
58
+ print("done loading")
59
 
60
  def get_wait_time(self):
61
  """
 
74
 
75
  return wait_time / 60
76
 
77
+ def recv_audio(self, websocket, transcription_queue=None, llm_queue=None):
78
  """
79
  Receive audio chunks from a client in an infinite loop.
80
 
 
95
  Raises:
96
  Exception: If there is an error during the audio frame processing.
97
  """
98
+ self.vad_model = VoiceActivityDetection()
99
+ self.vad_threshold = 0.5
100
+
101
  logging.info("New client connected")
102
  options = websocket.recv()
103
  options = json.loads(options)
 
120
  multilingual=options["multilingual"],
121
  language=options["language"],
122
  task=options["task"],
123
+ client_uid=options["uid"],
124
+ transcription_queue=transcription_queue,
125
+ llm_queue=llm_queue
126
  )
127
 
128
  self.clients[websocket] = client
129
  self.clients_start_time[websocket] = time.time()
130
  no_voice_activity_chunks = 0
131
+ print()
132
  while True:
133
  try:
134
  frame_data = websocket.recv()
135
  frame_np = np.frombuffer(frame_data, dtype=np.float32)
136
+ print(frame_np.shape)
137
  # VAD
138
  try:
139
  speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item()
140
  if speech_prob < self.vad_threshold:
141
  no_voice_activity_chunks += 1
142
+ print("No speech", no_voice_activity_chunks, self.clients[websocket].eos)
143
  if no_voice_activity_chunks > 2:
144
  if not self.clients[websocket].eos:
145
  self.clients[websocket].set_eos(True)
 
173
  del websocket
174
  break
175
 
176
+ def run(self, host, port=9090, transcription_queue=None, llm_queue=None):
177
  """
178
  Run the transcription server.
179
 
 
181
  host (str): The host address to bind the server.
182
  port (int): The port number to bind the server.
183
  """
184
+ with serve(
185
+ functools.partial(
186
+ self.recv_audio,
187
+ transcription_queue=transcription_queue,
188
+ llm_queue=llm_queue,
189
+ ),
190
+ host,
191
+ port
192
+ ) as server:
193
  server.serve_forever()
194
 
195
 
 
226
  SERVER_READY = "SERVER_READY"
227
  DISCONNECT = "DISCONNECT"
228
 
229
+ def __init__(
230
+ self,
231
+ websocket,
232
+ task="transcribe",
233
+ device=None,
234
+ multilingual=False,
235
+ language=None,
236
+ client_uid=None,
237
+ transcription_queue=None,
238
+ llm_queue=None,
239
+ ):
240
  """
241
  Initialize a ServeClient instance.
242
  The Whisper model is initialized based on the client's language and device availability.
 
253
 
254
  """
255
  self.client_uid = client_uid
256
+ self.transcription_queue = transcription_queue
257
+ self.llm_queue = llm_queue
258
  self.data = b""
259
  self.frames = b""
260
  self.language = language if multilingual else "en"
 
275
  self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
276
  self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
277
  self.transcript = []
278
+ self.prompt = None
279
  self.send_last_n_segments = 10
280
 
281
  # text formatting
 
348
 
349
  """
350
  while True:
351
+ try:
352
+ if self.llm_queue is not None:
353
+ llm_output = self.llm_queue.get_nowait()
354
+ if llm_output:
355
+ self.websocket.send(json.dumps(llm_output))
356
+ except queue.Empty:
357
+ pass
358
+
359
  if self.exit:
360
  logging.info("Exiting speech to text thread")
361
  break
 
372
  samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE)
373
  input_bytes = self.frames_np[int(samples_take):].copy()
374
  duration = input_bytes.shape[0] / self.RATE
375
+ if duration<1.0:
376
  continue
377
 
378
  try:
379
  input_sample = input_bytes.copy()
380
+ # save_wav(input_sample)
381
  # whisper transcribe with prompt
382
  mel, duration = self.transcriber.log_mel_spectrogram(input_sample)
383
+ last_segment = self.transcriber.transcribe(mel)
384
+
385
+ if len(last_segment):
386
+ if len(self.transcript) < self.send_last_n_segments:
387
+ segments = self.transcript
388
+ else:
389
+ segments = self.transcript[-self.send_last_n_segments:]
390
+ segments.append({"text": last_segment})
391
  try:
392
+ print(f"Sending... {segments}")
393
  self.websocket.send(
394
  json.dumps({
395
  "uid": self.client_uid,
396
+ "segments": segments,
397
+ "eos": self.eos
398
  })
399
  )
400
+ if self.eos:
401
+ self.append_segment(last_segment)
402
+ self.timestamp_offset += duration
403
+ self.prompt = ' '.join(segment['text'] for segment in segments)
404
+ self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt})
405
+ self.transcript = []
406
+ self.set_eos(False)
407
+
408
  except Exception as e:
409
  logging.error(f"[ERROR]: {e}")
410
 
 
413
  time.sleep(0.01)
414
 
415
  def append_segment(self, result):
416
+ print("adding to trasncript: ", result)
417
  if not len(self.transcript):
418
  self.transcript.append({"text": result + " "})
419
  else:
whisper_live/trt_transcriber.py CHANGED
@@ -339,9 +339,9 @@ def decode_wav_file(
339
 
340
  if __name__=="__main__":
341
  tensorrt_llm.logger.set_level("error")
342
- model = WhisperTRTLLM("../whisper_small_en", False, "../assets", device="cuda")
343
  mel, total_duration = model.log_mel_spectrogram(
344
- "/root/Code/outputs/output3.wav",
345
  )
346
  results = model.transcribe(mel)
347
  print(results, total_duration)
 
339
 
340
  if __name__=="__main__":
341
  tensorrt_llm.logger.set_level("error")
342
+ model = WhisperTRTLLM("/root/TensorRT-LLM/examples/whisper/whisper_small_en", False, "../assets", device="cuda")
343
  mel, total_duration = model.log_mel_spectrogram(
344
+ "../assets/1221-135766-0002.wav",
345
  )
346
  results = model.transcribe(mel)
347
  print(results, total_duration)
whisper_live/vad.py CHANGED
@@ -10,19 +10,23 @@ import onnxruntime
10
  class VoiceActivityDetection():
11
 
12
  def __init__(self, force_onnx_cpu=True):
 
13
  path = self.download()
 
 
14
  opts = onnxruntime.SessionOptions()
15
  opts.log_severity_level = 3
16
 
17
  opts.inter_op_num_threads = 1
18
  opts.intra_op_num_threads = 1
19
 
 
20
  if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
21
  self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
22
  else:
23
  self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
24
 
25
-
26
  self.reset_states()
27
  self.sample_rates = [8000, 16000]
28
 
 
10
  class VoiceActivityDetection():
11
 
12
  def __init__(self, force_onnx_cpu=True):
13
+ print("downloading ONNX model...")
14
  path = self.download()
15
+ print("loading session")
16
+
17
  opts = onnxruntime.SessionOptions()
18
  opts.log_severity_level = 3
19
 
20
  opts.inter_op_num_threads = 1
21
  opts.intra_op_num_threads = 1
22
 
23
+ print("loading onnx model")
24
  if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
25
  self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
26
  else:
27
  self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
28
 
29
+ print("reset states")
30
  self.reset_states()
31
  self.sample_rates = [8000, 16000]
32