makaveli commited on
Commit
ea46c22
1 Parent(s): 95bcc6a

add prompt formating

Browse files
Files changed (1) hide show
  1. llm_service.py +12 -4
llm_service.py CHANGED
@@ -153,14 +153,21 @@ class MistralTensorRTLLM:
153
  output.append(output_text)
154
  return output
155
 
 
 
 
 
 
 
156
  def run(
157
  self,
158
  model_path,
159
  tokenizer_path,
160
  transcription_queue=None,
161
  llm_queue=None,
 
162
  input_text=None,
163
- max_output_len=20,
164
  max_attention_window_size=4096,
165
  num_beams=1,
166
  streaming=True,
@@ -177,10 +184,10 @@ class MistralTensorRTLLM:
177
 
178
  # while transcription
179
  transcription_output = transcription_queue.get()
180
- if not debug:
181
- input_text=[transcription_output['prompt'].strip()]
182
 
183
- print("Whisper: ", input_text)
184
  batch_input_ids = self.parse_input(
185
  input_text=input_text,
186
  add_special_tokens=True,
@@ -234,6 +241,7 @@ class MistralTensorRTLLM:
234
  sequence_lengths,
235
  )
236
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
 
237
 
238
 
239
  if __name__=="__main__":
 
153
  output.append(output_text)
154
  return output
155
 
156
+ def format_prompt_qa(self, prompt):
157
+ return f"Instruct: {prompt}\nOutput:"
158
+
159
+ def format_prompt_chat(self, prompt):
160
+ return f"Alice: {prompt}\nBob:"
161
+
162
  def run(
163
  self,
164
  model_path,
165
  tokenizer_path,
166
  transcription_queue=None,
167
  llm_queue=None,
168
+ audio_queue=None,
169
  input_text=None,
170
+ max_output_len=40,
171
  max_attention_window_size=4096,
172
  num_beams=1,
173
  streaming=True,
 
184
 
185
  # while transcription
186
  transcription_output = transcription_queue.get()
187
+ prompt = transcription_output['prompt'].strip()
188
+ input_text=[self.format_prompt_qa(prompt)]
189
 
190
+ print("Whisper: ", prompt)
191
  batch_input_ids = self.parse_input(
192
  input_text=input_text,
193
  add_special_tokens=True,
 
241
  sequence_lengths,
242
  )
243
  llm_queue.put({"uid": transcription_output["uid"], "llm_output": output})
244
+ audio_queue.put(output)
245
 
246
 
247
  if __name__=="__main__":