Spaces:
Paused
Paused
Commit
·
e3f7cd8
1
Parent(s):
7cfc46e
add history of conversation to prompt
Browse files- llm_service.py +36 -8
- main.py +2 -2
llm_service.py
CHANGED
@@ -84,7 +84,7 @@ def load_tokenizer(tokenizer_dir: Optional[str] = None,
|
|
84 |
return tokenizer, pad_id, end_id
|
85 |
|
86 |
|
87 |
-
class
|
88 |
def __init__(self):
|
89 |
pass
|
90 |
|
@@ -165,11 +165,17 @@ class MistralTensorRTLLM:
|
|
165 |
output.append(output_text)
|
166 |
return output
|
167 |
|
168 |
-
def format_prompt_qa(self, prompt):
|
169 |
-
|
|
|
|
|
|
|
170 |
|
171 |
-
def format_prompt_chat(self, prompt):
|
172 |
-
|
|
|
|
|
|
|
173 |
|
174 |
def run(
|
175 |
self,
|
@@ -192,6 +198,9 @@ class MistralTensorRTLLM:
|
|
192 |
)
|
193 |
|
194 |
logging.info("[LLM] loaded: True")
|
|
|
|
|
|
|
195 |
while True:
|
196 |
|
197 |
# Get the last transcription output from the queue
|
@@ -199,17 +208,26 @@ class MistralTensorRTLLM:
|
|
199 |
if transcription_queue.qsize() != 0:
|
200 |
logging.info("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
|
201 |
continue
|
|
|
|
|
|
|
202 |
|
203 |
prompt = transcription_output['prompt'].strip()
|
204 |
-
|
205 |
-
|
206 |
# if prompt is same but EOS is True, we need that to send outputs to websockets
|
207 |
if self.last_prompt == prompt:
|
208 |
if self.last_output is not None and transcription_output["eos"]:
|
209 |
self.eos = transcription_output["eos"]
|
210 |
llm_queue.put({"uid": transcription_output["uid"], "llm_output": self.last_output, "eos": self.eos})
|
211 |
audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
|
|
|
|
|
|
|
|
|
212 |
continue
|
|
|
|
|
|
|
213 |
|
214 |
self.eos = transcription_output["eos"]
|
215 |
|
@@ -257,6 +275,10 @@ class MistralTensorRTLLM:
|
|
257 |
|
258 |
if output is None:
|
259 |
break
|
|
|
|
|
|
|
|
|
260 |
# Interrupted by transcription queue
|
261 |
if output is None:
|
262 |
logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
|
@@ -276,6 +298,9 @@ class MistralTensorRTLLM:
|
|
276 |
sequence_lengths,
|
277 |
transcription_queue
|
278 |
)
|
|
|
|
|
|
|
279 |
|
280 |
# if self.eos:
|
281 |
if output is not None:
|
@@ -285,12 +310,15 @@ class MistralTensorRTLLM:
|
|
285 |
audio_queue.put({"llm_output": output, "eos": self.eos})
|
286 |
|
287 |
if self.eos:
|
|
|
|
|
|
|
288 |
self.last_prompt = None
|
289 |
self.last_output = None
|
290 |
|
291 |
|
292 |
if __name__=="__main__":
|
293 |
-
llm =
|
294 |
llm.initialize_model(
|
295 |
"/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
|
296 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
|
|
84 |
return tokenizer, pad_id, end_id
|
85 |
|
86 |
|
87 |
+
class TensorRTLLMEngine:
|
88 |
def __init__(self):
|
89 |
pass
|
90 |
|
|
|
165 |
output.append(output_text)
|
166 |
return output
|
167 |
|
168 |
+
def format_prompt_qa(self, prompt, conversation_history):
|
169 |
+
formatted_prompt = ""
|
170 |
+
for user_prompt, llm_response in conversation_history:
|
171 |
+
formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n"
|
172 |
+
return f"{formatted_prompt}Instruct: {prompt}\nOutput:"
|
173 |
|
174 |
+
def format_prompt_chat(self, prompt, conversation_history):
|
175 |
+
formatted_prompt = ""
|
176 |
+
for user_prompt, llm_response in conversation_history:
|
177 |
+
formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n"
|
178 |
+
return f"{formatted_prompt}Alice: {prompt}\nBob:"
|
179 |
|
180 |
def run(
|
181 |
self,
|
|
|
198 |
)
|
199 |
|
200 |
logging.info("[LLM] loaded: True")
|
201 |
+
|
202 |
+
conversation_history = {}
|
203 |
+
|
204 |
while True:
|
205 |
|
206 |
# Get the last transcription output from the queue
|
|
|
208 |
if transcription_queue.qsize() != 0:
|
209 |
logging.info("[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
|
210 |
continue
|
211 |
+
|
212 |
+
if transcription_output["uid"] not in conversation_history:
|
213 |
+
conversation_history[transcription_output["uid"]] = []
|
214 |
|
215 |
prompt = transcription_output['prompt'].strip()
|
216 |
+
|
|
|
217 |
# if prompt is same but EOS is True, we need that to send outputs to websockets
|
218 |
if self.last_prompt == prompt:
|
219 |
if self.last_output is not None and transcription_output["eos"]:
|
220 |
self.eos = transcription_output["eos"]
|
221 |
llm_queue.put({"uid": transcription_output["uid"], "llm_output": self.last_output, "eos": self.eos})
|
222 |
audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
|
223 |
+
conversation_history[transcription_output["uid"]].append(
|
224 |
+
(transcription_output['prompt'].strip(), self.last_output.strip())
|
225 |
+
)
|
226 |
+
print(f"History: {conversation_history}")
|
227 |
continue
|
228 |
+
|
229 |
+
input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])]
|
230 |
+
# print(f"Formatted prompt with history...:\n{input_text}")
|
231 |
|
232 |
self.eos = transcription_output["eos"]
|
233 |
|
|
|
275 |
|
276 |
if output is None:
|
277 |
break
|
278 |
+
|
279 |
+
if output is not None:
|
280 |
+
if "Instruct" in output[0]:
|
281 |
+
output[0] = output[0].split("Instruct")[0]
|
282 |
# Interrupted by transcription queue
|
283 |
if output is None:
|
284 |
logging.info(f"[LLM] interrupted by transcription queue!!!!!!!!!!!!!!!!!!!!!!!!")
|
|
|
298 |
sequence_lengths,
|
299 |
transcription_queue
|
300 |
)
|
301 |
+
if output is not None:
|
302 |
+
if "Instruct" in output[0]:
|
303 |
+
output[0] = output[0].split("Instruct")[0]
|
304 |
|
305 |
# if self.eos:
|
306 |
if output is not None:
|
|
|
310 |
audio_queue.put({"llm_output": output, "eos": self.eos})
|
311 |
|
312 |
if self.eos:
|
313 |
+
conversation_history[transcription_output["uid"]].append(
|
314 |
+
(transcription_output['prompt'].strip(), output[0].strip())
|
315 |
+
)
|
316 |
self.last_prompt = None
|
317 |
self.last_output = None
|
318 |
|
319 |
|
320 |
if __name__=="__main__":
|
321 |
+
llm = TensorRTLLMEngine()
|
322 |
llm.initialize_model(
|
323 |
"/root/TensorRT-LLM/examples/llama/tmp/mistral/7B/trt_engines/fp16/1-gpu",
|
324 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
main.py
CHANGED
@@ -9,7 +9,7 @@ import functools
|
|
9 |
from multiprocessing import Process, Manager, Value, Queue
|
10 |
|
11 |
from whisper_live.trt_server import TranscriptionServer
|
12 |
-
from llm_service import
|
13 |
from tts_service import WhisperSpeechTTS
|
14 |
|
15 |
|
@@ -88,7 +88,7 @@ if __name__ == "__main__":
|
|
88 |
)
|
89 |
whisper_process.start()
|
90 |
|
91 |
-
llm_provider =
|
92 |
# llm_provider = MistralTensorRTLLMProvider()
|
93 |
llm_process = multiprocessing.Process(
|
94 |
target=llm_provider.run,
|
|
|
9 |
from multiprocessing import Process, Manager, Value, Queue
|
10 |
|
11 |
from whisper_live.trt_server import TranscriptionServer
|
12 |
+
from llm_service import TensorRTLLMEngine
|
13 |
from tts_service import WhisperSpeechTTS
|
14 |
|
15 |
|
|
|
88 |
)
|
89 |
whisper_process.start()
|
90 |
|
91 |
+
llm_provider = TensorRTLLMEngine()
|
92 |
# llm_provider = MistralTensorRTLLMProvider()
|
93 |
llm_process = multiprocessing.Process(
|
94 |
target=llm_provider.run,
|