prithivMLmods commited on
Commit
696cd59
·
verified ·
1 Parent(s): 709c732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -17
app.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
 
15
 
16
  from transformers import (
17
  AutoModelForCausalLM,
@@ -149,6 +150,28 @@ def progress_bar_html(label: str) -> str:
149
  </style>
150
  '''
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  @spaces.GPU(duration=60, enable_queue=True)
153
  def generate_image_fn(
154
  prompt: str,
@@ -213,14 +236,16 @@ def generate(
213
  Special commands:
214
  - "@tts1" or "@tts2": triggers text-to-speech.
215
  - "@image": triggers image generation using the SDXL pipeline.
 
216
  """
217
  text = input_dict["text"]
218
  files = input_dict.get("files", [])
 
219
 
220
- if text.strip().lower().startswith("@image"):
 
221
  # Remove the "@image" tag and use the rest as prompt
222
  prompt = text[len("@image"):].strip()
223
- # Show animated progress bar for image generation
224
  yield progress_bar_html("Generating Image")
225
  image_paths, used_seed = generate_image_fn(
226
  prompt=prompt,
@@ -235,10 +260,57 @@ def generate(
235
  use_resolution_binning=True,
236
  num_images=1,
237
  )
238
- # Once done, yield the generated image
239
  yield gr.Image(image_paths[0])
240
- return # Exit early
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
 
242
  tts_prefix = "@tts"
243
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
244
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -246,11 +318,9 @@ def generate(
246
  if is_tts and voice_index:
247
  voice = TTS_VOICES[voice_index - 1]
248
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
249
- # Clear previous chat history for a fresh TTS request.
250
  conversation = [{"role": "user", "content": text}]
251
  else:
252
  voice = None
253
- # Remove any stray @tts tags and build the conversation history.
254
  text = text.replace(tts_prefix, "").strip()
255
  conversation = clean_chat_history(chat_history)
256
  conversation.append({"role": "user", "content": text})
@@ -269,15 +339,13 @@ def generate(
269
  {"type": "text", "text": text},
270
  ]
271
  }]
272
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
273
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
274
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
275
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
276
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
277
  thread.start()
278
-
279
  buffer = ""
280
- # Show animated progress bar for multimodal generation
281
  yield progress_bar_html("Thinking...")
282
  for new_text in streamer:
283
  buffer += new_text
@@ -304,18 +372,13 @@ def generate(
304
  }
305
  t = Thread(target=model.generate, kwargs=generation_kwargs)
306
  t.start()
307
-
308
  outputs = []
309
- # Show animated progress bar for text generation
310
- yield progress_bar_html("Thinking...")
311
  for new_text in streamer:
312
  outputs.append(new_text)
313
  yield "".join(outputs)
314
-
315
  final_response = "".join(outputs)
316
  yield final_response
317
-
318
- # If TTS was requested, convert the final response to speech.
319
  if is_tts and voice:
320
  output_file = asyncio.run(text_to_speech(final_response, voice))
321
  yield gr.Audio(output_file, autoplay=True)
@@ -330,6 +393,7 @@ demo = gr.ChatInterface(
330
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
331
  ],
332
  examples=[
 
333
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
334
  ["Python Program for Array Rotation"],
335
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
@@ -342,7 +406,7 @@ demo = gr.ChatInterface(
342
  description=DESCRIPTION,
343
  css=css,
344
  fill_height=True,
345
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="‎ @tts1, @tts2-voices, @image-image gen, default [text, vision]"),
346
  stop_btn="Stop Generation",
347
  multimodal=True,
348
  )
 
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
15
+ import cv2
16
 
17
  from transformers import (
18
  AutoModelForCausalLM,
 
150
  </style>
151
  '''
152
 
153
+ def downsample_video(video_path):
154
+ """
155
+ Downsamples the video to 10 evenly spaced frames.
156
+ Each frame is returned as a PIL image along with its timestamp.
157
+ """
158
+ vidcap = cv2.VideoCapture(video_path)
159
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
160
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
161
+ frames = []
162
+ # Sample 10 evenly spaced frames.
163
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
164
+ for i in frame_indices:
165
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
166
+ success, image = vidcap.read()
167
+ if success:
168
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
169
+ pil_image = Image.fromarray(image)
170
+ timestamp = round(i / fps, 2)
171
+ frames.append((pil_image, timestamp))
172
+ vidcap.release()
173
+ return frames
174
+
175
  @spaces.GPU(duration=60, enable_queue=True)
176
  def generate_image_fn(
177
  prompt: str,
 
236
  Special commands:
237
  - "@tts1" or "@tts2": triggers text-to-speech.
238
  - "@image": triggers image generation using the SDXL pipeline.
239
+ - "@qwen2vl-video": triggers video processing using Qwen2VL.
240
  """
241
  text = input_dict["text"]
242
  files = input_dict.get("files", [])
243
+ lower_text = text.strip().lower()
244
 
245
+ # Branch for image generation.
246
+ if lower_text.startswith("@image"):
247
  # Remove the "@image" tag and use the rest as prompt
248
  prompt = text[len("@image"):].strip()
 
249
  yield progress_bar_html("Generating Image")
250
  image_paths, used_seed = generate_image_fn(
251
  prompt=prompt,
 
260
  use_resolution_binning=True,
261
  num_images=1,
262
  )
 
263
  yield gr.Image(image_paths[0])
264
+ return
265
+
266
+ # New branch for video processing with Qwen2VL.
267
+ if lower_text.startswith("@qwen2vl-video"):
268
+ prompt = text[len("@qwen2vl-video"):].strip()
269
+ if files:
270
+ # Assume the first file is a video.
271
+ video_path = files[0]
272
+ frames = downsample_video(video_path)
273
+ messages = [
274
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
275
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
276
+ ]
277
+ # Append each frame with its timestamp.
278
+ for frame in frames:
279
+ image, timestamp = frame
280
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
281
+ image.save(image_path)
282
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
283
+ messages[1]["content"].append({"type": "image", "url": image_path})
284
+ else:
285
+ messages = [
286
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
287
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
288
+ ]
289
+ inputs = processor.apply_chat_template(
290
+ messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
291
+ ).to("cuda")
292
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
293
+ generation_kwargs = {
294
+ **inputs,
295
+ "streamer": streamer,
296
+ "max_new_tokens": max_new_tokens,
297
+ "do_sample": True,
298
+ "temperature": temperature,
299
+ "top_p": top_p,
300
+ "top_k": top_k,
301
+ "repetition_penalty": repetition_penalty,
302
+ }
303
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
304
+ thread.start()
305
+ buffer = ""
306
+ yield progress_bar_html("Processing video with Qwen2VL")
307
+ for new_text in streamer:
308
+ buffer += new_text
309
+ time.sleep(0.01)
310
+ yield buffer
311
+ return
312
 
313
+ # Determine if TTS is requested.
314
  tts_prefix = "@tts"
315
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
316
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
318
  if is_tts and voice_index:
319
  voice = TTS_VOICES[voice_index - 1]
320
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
321
  conversation = [{"role": "user", "content": text}]
322
  else:
323
  voice = None
 
324
  text = text.replace(tts_prefix, "").strip()
325
  conversation = clean_chat_history(chat_history)
326
  conversation.append({"role": "user", "content": text})
 
339
  {"type": "text", "text": text},
340
  ]
341
  }]
342
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
343
+ inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
344
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
345
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
346
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
347
  thread.start()
 
348
  buffer = ""
 
349
  yield progress_bar_html("Thinking...")
350
  for new_text in streamer:
351
  buffer += new_text
 
372
  }
373
  t = Thread(target=model.generate, kwargs=generation_kwargs)
374
  t.start()
 
375
  outputs = []
376
+ yield progress_bar_html("Processing with Qwen2VL Ocr")
 
377
  for new_text in streamer:
378
  outputs.append(new_text)
379
  yield "".join(outputs)
 
380
  final_response = "".join(outputs)
381
  yield final_response
 
 
382
  if is_tts and voice:
383
  output_file = asyncio.run(text_to_speech(final_response, voice))
384
  yield gr.Audio(output_file, autoplay=True)
 
393
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
394
  ],
395
  examples=[
396
+ [{"text": "@gemma3-4b-video Summarize the events in this video", "files": ["examples/sky.mp4"]}],
397
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
398
  ["Python Program for Array Rotation"],
399
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
 
406
  description=DESCRIPTION,
407
  css=css,
408
  fill_height=True,
409
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="‎ @tts1, @tts2-voices, @image for image gen, @qwen2vl-video for video, default [text, vision]"),
410
  stop_btn="Stop Generation",
411
  multimodal=True,
412
  )