makaveli10 commited on
Commit
e3f7cd8
1 Parent(s): 7cfc46e

add history of conversation to prompt

Browse files
Files changed (2) hide show
  1. llm_service.py +36 -8
  2. main.py +2 -2
llm_service.py CHANGED
@@ -84,7 +84,7 @@ def load_tokenizer(tokenizer_dir: Optional[str] = None,
84
  return tokenizer, pad_id, end_id
85
 
86
 
87
- class MistralTensorRTLLM:
88
  def __init__(self):
89
  pass
90
 
@@ -165,11 +165,17 @@ class MistralTensorRTLLM:
165
  output.append(output_text)
166
  return output
167
 
168
- def format_prompt_qa(self, prompt):
169
- return f"Instruct: {prompt}\nOutput:"
 
 
 
170
 
171
- def format_prompt_chat(self, prompt):
172
- return f"Alice: {prompt}\nBob:"
 
 
 
173
 
174
  def run(
175
  self,
@@ -192,6 +198,9 @@ class MistralTensorRTLLM:
192
  )
193
 
194
  logging.info("[LLM] loaded: True")
 
 
 
195
  while True:
196
 
197
  # Get the last transcription output from the queue
@@ -199,17 +208,26 @@ class MistralTensorRTLLM:
199
  if transcription_queue.qsize() != 0:
200
  logging.info("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
201
  continue
 
 
 
202
 
203
  prompt = transcription_output['prompt'].strip()
204
- input_text=[self.format_prompt_qa(prompt)]
205
-
206
  # if prompt is same but EOS is True, we need that to send outputs to websockets
207
  if self.last_prompt == prompt:
208
  if self.last_output is not None and transcription_output["eos"]:
209
  self.eos = transcription_output["eos"]
210
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": self.last_output, "eos": self.eos})
211
  audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
 
 
 
 
212
  continue
 
 
 
213
 
214
  self.eos = transcription_output["eos"]
215
 
@@ -257,6 +275,10 @@ class MistralTensorRTLLM:
257
 
258
  if output is None:
259
  break
 
 
 
 
260
  # Interrupted by transcription queue
261
  if output is None:
262
  logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
@@ -276,6 +298,9 @@ class MistralTensorRTLLM:
276
  sequence_lengths,
277
  transcription_queue
278
  )
 
 
 
279
 
280
  # if self.eos:
281
  if output is not None:
@@ -285,12 +310,15 @@ class MistralTensorRTLLM:
285
  audio_queue.put({"llm_output": output, "eos": self.eos})
286
 
287
  if self.eos:
 
 
 
288
  self.last_prompt = None
289
  self.last_output = None
290
 
291
 
292
  if __name__=="__main__":
293
- llm = MistralTensorRTLLM()
294
  llm.initialize_model(
295
  "/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
296
  "teknium/OpenHermes-2.5-Mistral-7B",
 
84
  return tokenizer, pad_id, end_id
85
 
86
 
87
+ class TensorRTLLMEngine:
88
  def __init__(self):
89
  pass
90
 
 
165
  output.append(output_text)
166
  return output
167
 
168
+ def format_prompt_qa(self, prompt, conversation_history):
169
+ formatted_prompt = ""
170
+ for user_prompt, llm_response in conversation_history:
171
+ formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n"
172
+ return f"{formatted_prompt}Instruct: {prompt}\nOutput:"
173
 
174
+ def format_prompt_chat(self, prompt, conversation_history):
175
+ formatted_prompt = ""
176
+ for user_prompt, llm_response in conversation_history:
177
+ formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n"
178
+ return f"{formatted_prompt}Alice: {prompt}\nBob:"
179
 
180
  def run(
181
  self,
 
198
  )
199
 
200
  logging.info("[LLM] loaded: True")
201
+
202
+ conversation_history = {}
203
+
204
  while True:
205
 
206
  # Get the last transcription output from the queue
 
208
  if transcription_queue.qsize() != 0:
209
  logging.info("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
210
  continue
211
+
212
+ if transcription_output["uid"] not in conversation_history:
213
+ conversation_history[transcription_output["uid"]] = []
214
 
215
  prompt = transcription_output['prompt'].strip()
216
+
 
217
  # if prompt is same but EOS is True, we need that to send outputs to websockets
218
  if self.last_prompt == prompt:
219
  if self.last_output is not None and transcription_output["eos"]:
220
  self.eos = transcription_output["eos"]
221
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": self.last_output, "eos": self.eos})
222
  audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
223
+ conversation_history[transcription_output["uid"]].append(
224
+ (transcription_output['prompt'].strip(), self.last_output.strip())
225
+ )
226
+ print(f"History: {conversation_history}")
227
  continue
228
+
229
+ input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])]
230
+ # print(f"Formatted prompt with history...:\n{input_text}")
231
 
232
  self.eos = transcription_output["eos"]
233
 
 
275
 
276
  if output is None:
277
  break
278
+
279
+ if output is not None:
280
+ if "Instruct" in output[0]:
281
+ output[0] = output[0].split("Instruct")[0]
282
  # Interrupted by transcription queue
283
  if output is None:
284
  logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
 
298
  sequence_lengths,
299
  transcription_queue
300
  )
301
+ if output is not None:
302
+ if "Instruct" in output[0]:
303
+ output[0] = output[0].split("Instruct")[0]
304
 
305
  # if self.eos:
306
  if output is not None:
 
310
  audio_queue.put({"llm_output": output, "eos": self.eos})
311
 
312
  if self.eos:
313
+ conversation_history[transcription_output["uid"]].append(
314
+ (transcription_output['prompt'].strip(), output[0].strip())
315
+ )
316
  self.last_prompt = None
317
  self.last_output = None
318
 
319
 
320
  if __name__=="__main__":
321
+ llm = TensorRTLLMEngine()
322
  llm.initialize_model(
323
  "/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
324
  "teknium/OpenHermes-2.5-Mistral-7B",
main.py CHANGED
@@ -9,7 +9,7 @@ import functools
9
  from multiprocessing import Process, Manager, Value, Queue
10
 
11
  from whisper_live.trt_server import TranscriptionServer
12
- from llm_service import MistralTensorRTLLM
13
  from tts_service import WhisperSpeechTTS
14
 
15
 
@@ -88,7 +88,7 @@ if __name__ == "__main__":
88
  )
89
  whisper_process.start()
90
 
91
- llm_provider = MistralTensorRTLLM()
92
  # llm_provider = MistralTensorRTLLMProvider()
93
  llm_process = multiprocessing.Process(
94
  target=llm_provider.run,
 
9
  from multiprocessing import Process, Manager, Value, Queue
10
 
11
  from whisper_live.trt_server import TranscriptionServer
12
+ from llm_service import TensorRTLLMEngine
13
  from tts_service import WhisperSpeechTTS
14
 
15
 
 
88
  )
89
  whisper_process.start()
90
 
91
+ llm_provider = TensorRTLLMEngine()
92
  # llm_provider = MistralTensorRTLLMProvider()
93
  llm_process = multiprocessing.Process(
94
  target=llm_provider.run,