paralym commited on
Commit
15c4e1f
·
verified ·
1 Parent(s): 52183c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -13
app.py CHANGED
@@ -201,7 +201,7 @@ def is_valid_image_filename(name):
201
  return False
202
 
203
 
204
- def sample_frames_old(video_file, num_frames):
205
  video = cv2.VideoCapture(video_file)
206
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
207
  interval = total_frames // num_frames
@@ -216,7 +216,7 @@ def sample_frames_old(video_file, num_frames):
216
  video.release()
217
  return frames
218
 
219
- def sample_frames(video_path, frame_count=32):
220
  video_frames = []
221
  vr = VideoReader(video_path, ctx=cpu(0))
222
  total_frames = len(vr)
@@ -240,6 +240,22 @@ def sample_frames(video_path, frame_count=32):
240
 
241
  return video_frames
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  def load_image(image_file):
245
  if image_file.startswith("http") or image_file.startswith("https"):
@@ -319,6 +335,7 @@ def bot(history, temperature, top_p, max_output_tokens):
319
  images_this_term = []
320
  text_this_term = ""
321
 
 
322
  num_new_images = 0
323
  # previous_image = False
324
  for i, message in enumerate(history[:-1]):
@@ -332,7 +349,9 @@ def bot(history, temperature, top_p, max_output_tokens):
332
  if is_valid_video_filename(message[0][0]):
333
  # raise ValueError("Video is not supported")
334
  # num_new_images += our_chatbot.num_frames
335
- num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
 
 
336
  elif is_valid_image_filename(message[0][0]):
337
  print("#### Load image from local file",message[0][0])
338
  num_new_images += 1
@@ -343,6 +362,7 @@ def bot(history, temperature, top_p, max_output_tokens):
343
  num_new_images = 0
344
  # previous_image = False
345
 
 
346
  image_list = []
347
  for f in images_this_term:
348
  if is_valid_video_filename(f):
@@ -388,19 +408,21 @@ def bot(history, temperature, top_p, max_output_tokens):
388
  with open(file_path, "rb") as src, open(filename, "wb") as dst:
389
  dst.write(src.read())
390
 
391
-
392
- image_tensor = [
393
- our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
394
- 0
 
 
 
 
395
  ]
396
- .half()
397
- .to(our_chatbot.model.device)
398
- for f in image_list
399
- ]
400
 
401
 
402
- image_tensor = torch.stack(image_tensor)
403
- image_token = DEFAULT_IMAGE_TOKEN * num_new_images
404
 
405
  inp = text
406
  inp = image_token + "\n" + inp
@@ -440,6 +462,7 @@ def bot(history, temperature, top_p, max_output_tokens):
440
  max_new_tokens=max_output_tokens,
441
  use_cache=False,
442
  stopping_criteria=[stopping_criteria],
 
443
  )
444
 
445
  t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
 
201
  return False
202
 
203
 
204
+ def sample_frames_v1(video_file, num_frames):
205
  video = cv2.VideoCapture(video_file)
206
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
207
  interval = total_frames // num_frames
 
216
  video.release()
217
  return frames
218
 
219
+ def sample_frames_v2(video_path, frame_count=32):
220
  video_frames = []
221
  vr = VideoReader(video_path, ctx=cpu(0))
222
  total_frames = len(vr)
 
240
 
241
  return video_frames
242
 
243
+ def sample_frames(video_path, num_frames=8):
244
+ cap = cv2.VideoCapture(video_path)
245
+ frames = []
246
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
247
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
248
+
249
+ for i in indices:
250
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
251
+ ret, frame = cap.read()
252
+ if ret:
253
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
254
+ frames.append(Image.fromarray(frame))
255
+
256
+ cap.release()
257
+ return frames
258
+
259
 
260
  def load_image(image_file):
261
  if image_file.startswith("http") or image_file.startswith("https"):
 
335
  images_this_term = []
336
  text_this_term = ""
337
 
338
+ is_video = False
339
  num_new_images = 0
340
  # previous_image = False
341
  for i, message in enumerate(history[:-1]):
 
349
  if is_valid_video_filename(message[0][0]):
350
  # raise ValueError("Video is not supported")
351
  # num_new_images += our_chatbot.num_frames
352
+ # num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames))
353
+ num_new_images += 1
354
+ is_video = True
355
  elif is_valid_image_filename(message[0][0]):
356
  print("#### Load image from local file",message[0][0])
357
  num_new_images += 1
 
362
  num_new_images = 0
363
  # previous_image = False
364
 
365
+
366
  image_list = []
367
  for f in images_this_term:
368
  if is_valid_video_filename(f):
 
408
  with open(file_path, "rb") as src, open(filename, "wb") as dst:
409
  dst.write(src.read())
410
 
411
+ if not is_video:
412
+ image_tensor = [
413
+ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
414
+ 0
415
+ ]
416
+ .half()
417
+ .to(our_chatbot.model.device)
418
+ for f in image_list
419
  ]
420
+ image_tensor = torch.stack(image_tensor)
421
+ else:
422
+ image_tensor = our_chatbot.image_processor.preprocess(image_list, return_tensors="pt")["pixel_values"].half().to(our_chatbot.model.device)
 
423
 
424
 
425
+ image_token = DEFAULT_IMAGE_TOKEN * num_new_images if not is_video else DEFAULT_IMAGE_TOKEN * num_new_images
 
426
 
427
  inp = text
428
  inp = image_token + "\n" + inp
 
462
  max_new_tokens=max_output_tokens,
463
  use_cache=False,
464
  stopping_criteria=[stopping_criteria],
465
+ modalities=["video"] if is_video else ["image"]
466
  )
467
 
468
  t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)