Spaces:
Paused
Paused
makaveli
commited on
add prompt formating
Browse files- llm_service.py +12 -4
llm_service.py
CHANGED
@@ -153,14 +153,21 @@ class MistralTensorRTLLM:
|
|
153 |
output.append(output_text)
|
154 |
return output
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def run(
|
157 |
self,
|
158 |
model_path,
|
159 |
tokenizer_path,
|
160 |
transcription_queue=None,
|
161 |
llm_queue=None,
|
|
|
162 |
input_text=None,
|
163 |
-
max_output_len=
|
164 |
max_attention_window_size=4096,
|
165 |
num_beams=1,
|
166 |
streaming=True,
|
@@ -177,10 +184,10 @@ class MistralTensorRTLLM:
|
|
177 |
|
178 |
# while transcription
|
179 |
transcription_output = transcription_queue.get()
|
180 |
-
|
181 |
-
|
182 |
|
183 |
-
print("Whisper: ",
|
184 |
batch_input_ids = self.parse_input(
|
185 |
input_text=input_text,
|
186 |
add_special_tokens=True,
|
@@ -234,6 +241,7 @@ class MistralTensorRTLLM:
|
|
234 |
sequence_lengths,
|
235 |
)
|
236 |
llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
|
|
|
237 |
|
238 |
|
239 |
if __name__=="__main__":
|
|
|
153 |
output.append(output_text)
|
154 |
return output
|
155 |
|
156 |
+
def format_prompt_qa(self, prompt):
|
157 |
+
return f"Instruct: {prompt}\nOutput:"
|
158 |
+
|
159 |
+
def format_prompt_chat(self, prompt):
|
160 |
+
return f"Alice: {prompt}\nBob:"
|
161 |
+
|
162 |
def run(
|
163 |
self,
|
164 |
model_path,
|
165 |
tokenizer_path,
|
166 |
transcription_queue=None,
|
167 |
llm_queue=None,
|
168 |
+
audio_queue=None,
|
169 |
input_text=None,
|
170 |
+
max_output_len=40,
|
171 |
max_attention_window_size=4096,
|
172 |
num_beams=1,
|
173 |
streaming=True,
|
|
|
184 |
|
185 |
# while transcription
|
186 |
transcription_output = transcription_queue.get()
|
187 |
+
prompt = transcription_output['prompt'].strip()
|
188 |
+
input_text=[self.format_prompt_qa(prompt)]
|
189 |
|
190 |
+
print("Whisper: ", prompt)
|
191 |
batch_input_ids = self.parse_input(
|
192 |
input_text=input_text,
|
193 |
add_special_tokens=True,
|
|
|
241 |
sequence_lengths,
|
242 |
)
|
243 |
llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
|
244 |
+
audio_queue.put(output)
|
245 |
|
246 |
|
247 |
if __name__=="__main__":
|