makaveli10 commited on
Commit
3a0cae9
1 Parent(s): 3d1dc04

send llm response with eos

Browse files
Files changed (2) hide show
  1. llm_service.py +9 -5
  2. whisper_live/trt_server.py +15 -12
llm_service.py CHANGED
@@ -202,13 +202,16 @@ class MistralTensorRTLLM:
202
 
203
  prompt = transcription_output['prompt'].strip()
204
  input_text=[self.format_prompt_qa(prompt)]
205
- self.eos = transcription_output["eos"]
206
-
207
  if self.last_prompt == prompt:
208
- if self.last_output is not None:
209
- # logging.info(f"[LLM info:] Same prompt, adding last llm output to audio queue.")
 
210
  audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
211
  continue
 
 
212
 
213
  logging.info(f"[LLM INFO:] WhisperLive prompt: {prompt}, eos: {self.eos}")
214
  batch_input_ids = self.parse_input(
@@ -278,11 +281,12 @@ class MistralTensorRTLLM:
278
  if output is not None:
279
  self.last_output = output
280
  self.last_prompt = prompt
281
- llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
282
  audio_queue.put({"llm_output": output, "eos": self.eos})
283
 
284
  if self.eos:
285
  self.last_prompt = None
 
286
 
287
 
288
  if __name__=="__main__":
 
202
 
203
  prompt = transcription_output['prompt'].strip()
204
  input_text=[self.format_prompt_qa(prompt)]
205
+
206
+ # if prompt is same but EOS is True, we need that to send outputs to websockets
207
  if self.last_prompt == prompt:
208
+ if self.last_output is not None and transcription_output["eos"]:
209
+ self.eos = transcription_output["eos"]
210
+ llm_queue.put({"uid": transcription_output["uid"], "llm_output": self.last_output, "eos": self.eos})
211
  audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
212
  continue
213
+
214
+ self.eos = transcription_output["eos"]
215
 
216
  logging.info(f"[LLM INFO:] WhisperLive prompt: {prompt}, eos: {self.eos}")
217
  batch_input_ids = self.parse_input(
 
281
  if output is not None:
282
  self.last_output = output
283
  self.last_prompt = prompt
284
+ llm_queue.put({"uid": transcription_output["uid"], "llm_output": output, "eos": self.eos})
285
  audio_queue.put({"llm_output": output, "eos": self.eos})
286
 
287
  if self.eos:
288
  self.last_prompt = None
289
+ self.last_output = None
290
 
291
 
292
  if __name__=="__main__":
whisper_live/trt_server.py CHANGED
@@ -350,16 +350,19 @@ class ServeClient:
350
 
351
  """
352
  while True:
353
- if self.eos:
354
- try:
355
- llm_output = None
356
- if self.llm_queue is not None:
357
- while not self.llm_queue.empty():
358
- llm_output = self.llm_queue.get_nowait()
359
- if llm_output:
360
- self.websocket.send(json.dumps(llm_output))
361
- except queue.Empty:
362
- pass
 
 
 
363
 
364
  if self.exit:
365
  logging.info("Exiting speech to text thread")
@@ -400,12 +403,12 @@ class ServeClient:
400
  "eos": self.eos
401
  })
402
  )
403
- logging.info(f"[INFO]: {segments}, eos: {self.eos}")
404
-
405
  self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt, "eos": self.eos})
406
  if self.eos:
407
  # self.append_segment(last_segment)
408
  self.timestamp_offset += duration
 
409
  logging.info(
410
  f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
411
  )
 
350
 
351
  """
352
  while True:
353
+ # send the LLM outputs
354
+ try:
355
+ llm_response = None
356
+ if self.llm_queue is not None:
357
+ while not self.llm_queue.empty():
358
+ llm_response = self.llm_queue.get()
359
+
360
+ if llm_response:
361
+ # eos = llm_response["eos"]
362
+ # if eos:
363
+ self.websocket.send(json.dumps(llm_response))
364
+ except queue.Empty:
365
+ pass
366
 
367
  if self.exit:
368
  logging.info("Exiting speech to text thread")
 
403
  "eos": self.eos
404
  })
405
  )
406
+
 
407
  self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt, "eos": self.eos})
408
  if self.eos:
409
  # self.append_segment(last_segment)
410
  self.timestamp_offset += duration
411
+ logging.info(f"[INFO]: {segments}, eos: {self.eos}")
412
  logging.info(
413
  f"[INFO:] Processed : {self.timestamp_offset} seconds / {self.frames_np.shape[0] / self.RATE} seconds"
414
  )