prithivMLmods commited on
Commit
d418457
·
verified ·
1 Parent(s): 9b873c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -104
app.py CHANGED
@@ -1,11 +1,5 @@
1
  import gradio as gr
2
- from transformers import (
3
- AutoProcessor,
4
- Qwen2_5_VLForConditionalGeneration,
5
- TextIteratorStreamer,
6
- AutoModelForCausalLM,
7
- AutoTokenizer,
8
- )
9
  from transformers.image_utils import load_image
10
  from threading import Thread
11
  import time
@@ -15,9 +9,6 @@ import cv2
15
  import numpy as np
16
  from PIL import Image
17
 
18
- # -----------------------
19
- # Progress Bar Helper
20
- # -----------------------
21
  def progress_bar_html(label: str) -> str:
22
  """
23
  Returns an HTML snippet for a thin progress bar with a label.
@@ -38,9 +29,6 @@ def progress_bar_html(label: str) -> str:
38
  </style>
39
  '''
40
 
41
- # -----------------------
42
- # Video Processing Helper
43
- # -----------------------
44
  def downsample_video(video_path):
45
  """
46
  Downsamples the video to 10 evenly spaced frames.
@@ -66,60 +54,45 @@ def downsample_video(video_path):
66
  vidcap.release()
67
  return frames
68
 
69
- # -----------------------
70
- # Qwen2.5-VL Model (Multimodal)
71
- # -----------------------
72
- MODEL_ID_VL = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
73
- processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
74
- vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
75
- MODEL_ID_VL,
76
  trust_remote_code=True,
77
  torch_dtype=torch.bfloat16
78
  ).to("cuda").eval()
79
 
80
- # -----------------------
81
- # Text Generation Setup (DeepHermes)
82
- # -----------------------
83
- TG_MODEL_ID = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
84
- tg_tokenizer = AutoTokenizer.from_pretrained(TG_MODEL_ID)
85
- tg_model = AutoModelForCausalLM.from_pretrained(
86
- TG_MODEL_ID,
87
- device_map="auto",
88
- torch_dtype=torch.bfloat16,
89
- )
90
- tg_model.eval()
91
-
92
- # -----------------------
93
- # Main Inference Function
94
- # -----------------------
95
  @spaces.GPU
96
  def model_inference(input_dict, history):
97
  text = input_dict["text"]
98
  files = input_dict["files"]
99
 
100
- # Video inference branch
101
  if text.strip().lower().startswith("@video-infer"):
 
102
  text = text[len("@video-infer"):].strip()
103
  if not files:
104
- yield gr.Error("Please upload a video file along with your @video-infer query.")
105
  return
 
106
  video_path = files[0]
107
  frames = downsample_video(video_path)
108
  if not frames:
109
- yield gr.Error("Could not process video.")
110
  return
111
- # Build messages starting with the text prompt and then add each frame with its timestamp.
112
  messages = [
113
  {
114
  "role": "user",
115
  "content": [{"type": "text", "text": text}]
116
  }
117
  ]
 
118
  for image, timestamp in frames:
119
  messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
120
  messages[0]["content"].append({"type": "image", "image": image})
121
- # Collect images from the frames.
122
  video_images = [image for image, _ in frames]
 
123
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
124
  inputs = processor(
125
  text=[prompt],
@@ -127,9 +100,10 @@ def model_inference(input_dict, history):
127
  return_tensors="pt",
128
  padding=True,
129
  ).to("cuda")
 
130
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
131
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
132
- thread = Thread(target=vl_model.generate, kwargs=generation_kwargs)
133
  thread.start()
134
  buffer = ""
135
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
@@ -139,82 +113,52 @@ def model_inference(input_dict, history):
139
  yield buffer
140
  return
141
 
142
- # Multimodal branch if images are provided (non-video)
143
- if files:
144
- # If more than one file is provided, load them as images.
145
- if len(files) > 1:
146
- images = [load_image(image) for image in files]
147
- elif len(files) == 1:
148
- images = [load_image(files[0])]
149
- else:
150
- images = []
151
 
152
- if text == "":
153
- yield gr.Error("Please input a text query along with the image(s).")
154
- return
155
-
156
- messages = [
157
- {
158
- "role": "user",
159
- "content": [
160
- *[{"type": "image", "image": image} for image in images],
161
- {"type": "text", "text": text},
162
- ],
163
- }
164
- ]
165
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
166
- inputs = processor(
167
- text=[prompt],
168
- images=images,
169
- return_tensors="pt",
170
- padding=True,
171
- ).to("cuda")
172
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
173
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
174
- thread = Thread(target=vl_model.generate, kwargs=generation_kwargs)
175
- thread.start()
176
- buffer = ""
177
- yield progress_bar_html("Processing with Qwen2.5VL Model")
178
- for new_text in streamer:
179
- buffer += new_text
180
- time.sleep(0.01)
181
- yield buffer
182
  return
183
-
184
- # Text-only branch using DeepHermes text generation.
185
- if text.strip() == "":
186
- yield gr.Error("Please input a query.")
187
  return
188
 
189
- input_ids = tg_tokenizer(text, return_tensors="pt").to(tg_model.device)
190
- streamer = TextIteratorStreamer(tg_tokenizer, skip_prompt=True, skip_special_tokens=True)
191
- generation_kwargs = {
192
- "input_ids": input_ids,
193
- "streamer": streamer,
194
- "max_new_tokens": 2048,
195
- "do_sample": True,
196
- "top_p": 0.9,
197
- "top_k": 50,
198
- "temperature": 0.6,
199
- "repetition_penalty": 1.2,
200
- }
201
- thread = Thread(target=tg_model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
202
  thread.start()
203
  buffer = ""
204
- yield progress_bar_html("Processing text with DeepHermes Model")
205
  for new_text in streamer:
206
  buffer += new_text
207
  time.sleep(0.01)
208
  yield buffer
209
 
210
- # -----------------------
211
- # Gradio Chat Interface
212
- # -----------------------
213
  examples = [
214
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
215
- [{"text": "Tell me a story about a brave knight."}],
216
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
217
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
 
218
  ]
219
 
220
  demo = gr.ChatInterface(
@@ -228,5 +172,4 @@ demo = gr.ChatInterface(
228
  cache_examples=False,
229
  )
230
 
231
- if __name__ == "__main__":
232
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
 
 
 
 
 
 
3
  from transformers.image_utils import load_image
4
  from threading import Thread
5
  import time
 
9
  import numpy as np
10
  from PIL import Image
11
 
 
 
 
12
  def progress_bar_html(label: str) -> str:
13
  """
14
  Returns an HTML snippet for a thin progress bar with a label.
 
29
  </style>
30
  '''
31
 
 
 
 
32
  def downsample_video(video_path):
33
  """
34
  Downsamples the video to 10 evenly spaced frames.
 
54
  vidcap.release()
55
  return frames
56
 
57
+ MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
58
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
+ MODEL_ID,
 
 
 
61
  trust_remote_code=True,
62
  torch_dtype=torch.bfloat16
63
  ).to("cuda").eval()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @spaces.GPU
66
  def model_inference(input_dict, history):
67
  text = input_dict["text"]
68
  files = input_dict["files"]
69
 
 
70
  if text.strip().lower().startswith("@video-infer"):
71
+ # Remove the tag from the query.
72
  text = text[len("@video-infer"):].strip()
73
  if not files:
74
+ gr.Error("Please upload a video file along with your @video-infer query.")
75
  return
76
+ # Assume the first file is a video.
77
  video_path = files[0]
78
  frames = downsample_video(video_path)
79
  if not frames:
80
+ gr.Error("Could not process video.")
81
  return
82
+ # Build messages: start with the text prompt.
83
  messages = [
84
  {
85
  "role": "user",
86
  "content": [{"type": "text", "text": text}]
87
  }
88
  ]
89
+ # Append each frame with a timestamp label.
90
  for image, timestamp in frames:
91
  messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
92
  messages[0]["content"].append({"type": "image", "image": image})
93
+ # Collect only the images from the frames.
94
  video_images = [image for image, _ in frames]
95
+ # Prepare the prompt.
96
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
97
  inputs = processor(
98
  text=[prompt],
 
100
  return_tensors="pt",
101
  padding=True,
102
  ).to("cuda")
103
+ # Set up streaming generation.
104
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
105
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
106
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
  thread.start()
108
  buffer = ""
109
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
 
113
  yield buffer
114
  return
115
 
116
+ if len(files) > 1:
117
+ images = [load_image(image) for image in files]
118
+ elif len(files) == 1:
119
+ images = [load_image(files[0])]
120
+ else:
121
+ images = []
 
 
 
122
 
123
+ if text == "" and not images:
124
+ gr.Error("Please input a query and optionally image(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return
126
+ if text == "" and images:
127
+ gr.Error("Please input a text query along with the image(s).")
 
 
128
  return
129
 
130
+ messages = [
131
+ {
132
+ "role": "user",
133
+ "content": [
134
+ *[{"type": "image", "image": image} for image in images],
135
+ {"type": "text", "text": text},
136
+ ],
137
+ }
138
+ ]
139
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
140
+ inputs = processor(
141
+ text=[prompt],
142
+ images=images if images else None,
143
+ return_tensors="pt",
144
+ padding=True,
145
+ ).to("cuda")
146
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
147
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
148
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
  buffer = ""
151
+ yield progress_bar_html("Processing with Qwen2.5VL Model")
152
  for new_text in streamer:
153
  buffer += new_text
154
  time.sleep(0.01)
155
  yield buffer
156
 
 
 
 
157
  examples = [
158
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
 
159
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
160
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
161
+ [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
162
  ]
163
 
164
  demo = gr.ChatInterface(
 
172
  cache_examples=False,
173
  )
174
 
175
+ demo.launch(debug=True)