KingNish commited on
Commit
4936b31
1 Parent(s): 4e975f4

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +27 -15
chatbot.py CHANGED
@@ -151,7 +151,15 @@ def video_gen(prompt):
151
  image_extensions = Image.registered_extensions()
152
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
153
 
154
- def qwen_inference(user_prompt, chat_history):
 
 
 
 
 
 
 
 
155
  images = []
156
  text_input = user_prompt["text"]
157
 
@@ -194,20 +202,12 @@ def qwen_inference(user_prompt, chat_history):
194
  })
195
 
196
  return messages
197
-
198
- # Initialize inference clients for different models
199
- client_mistral = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
200
- client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
201
- client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
202
- client_mistral_nemo = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
203
-
204
- @spaces.GPU(duration=60, queue=False)
205
- def model_inference( user_prompt, chat_history):
206
  if user_prompt["files"]:
207
  messages = qwen_inference(user_prompt, chat_history)
208
  text = processor.apply_chat_template(
209
- messages, tokenize=False, add_generation_prompt=True
210
- )
211
  image_inputs, video_inputs = process_vision_info(messages)
212
  inputs = processor(
213
  text=[text],
@@ -218,8 +218,8 @@ def model_inference( user_prompt, chat_history):
218
  ).to("cuda")
219
 
220
  streamer = TextIteratorStreamer(
221
- processor, skip_prompt=True, **{"skip_special_tokens": True}
222
- )
223
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
224
 
225
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -325,7 +325,19 @@ def model_inference( user_prompt, chat_history):
325
  yield gr.Video(video)
326
 
327
  elif json_data["name"] == "image_qna":
328
- inputs = llava(user_prompt, chat_history)
 
 
 
 
 
 
 
 
 
 
 
 
329
  streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True})
330
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
331
 
 
151
  image_extensions = Image.registered_extensions()
152
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
153
 
154
+ # Initialize inference clients for different models
155
+ client_mistral = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
156
+ client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
157
+ client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
158
+ client_mistral_nemo = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
159
+
160
+ def model_inference(user_prompt, chat_history):
161
+ @spaces.GPU(duration=60, queue=False)
162
+ def qwen_inference(user_prompt, chat_history):
163
  images = []
164
  text_input = user_prompt["text"]
165
 
 
202
  })
203
 
204
  return messages
205
+
 
 
 
 
 
 
 
 
206
  if user_prompt["files"]:
207
  messages = qwen_inference(user_prompt, chat_history)
208
  text = processor.apply_chat_template(
209
+ messages, tokenize=False, add_generation_prompt=True
210
+ )
211
  image_inputs, video_inputs = process_vision_info(messages)
212
  inputs = processor(
213
  text=[text],
 
218
  ).to("cuda")
219
 
220
  streamer = TextIteratorStreamer(
221
+ processor, skip_prompt=True, **{"skip_special_tokens": True}
222
+ )
223
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
224
 
225
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
325
  yield gr.Video(video)
326
 
327
  elif json_data["name"] == "image_qna":
328
+ messages = qwen_inference(user_prompt, chat_history)
329
+ text = processor.apply_chat_template(
330
+ messages, tokenize=False, add_generation_prompt=True
331
+ )
332
+ image_inputs, video_inputs = process_vision_info(messages)
333
+ inputs = processor(
334
+ text=[text],
335
+ images=image_inputs,
336
+ videos=video_inputs,
337
+ padding=True,
338
+ return_tensors="pt",
339
+ ).to("cuda")
340
+
341
  streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True})
342
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
343