metek7 commited on
Commit
3b421e3
·
verified ·
1 Parent(s): f0272e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -68
app.py CHANGED
@@ -1,82 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
- from llava.model.builder import load_pretrained_model
4
- from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
5
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
6
- from llava.conversation import conv_templates
7
- import copy
8
  from decord import VideoReader, cpu
9
  import numpy as np
10
 
11
- title = "# 🎥 Instagram Short Video Analyzer with LLaVA-Video"
12
- description = """
13
- This application uses the LLaVA-Video-7B-Qwen2 model to analyze Instagram short videos.
14
- Upload your Instagram short video and ask questions about its content!
15
- """
16
-
17
- def load_video(video_path, max_frames_num=64, fps=1):
18
  vr = VideoReader(video_path, ctx=cpu(0))
19
- total_frame_num = len(vr)
20
- video_time = total_frame_num / vr.get_avg_fps()
21
- fps = round(vr.get_avg_fps()/fps)
22
- frame_idx = list(range(0, len(vr), fps))
23
- if len(frame_idx) > max_frames_num:
24
- frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
25
- frame_time = [i/vr.get_avg_fps() for i in frame_idx]
26
- frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
27
- spare_frames = vr.get_batch(frame_idx).asnumpy()
28
- return spare_frames, frame_time, video_time
29
-
30
- # Load the model
31
- pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
32
- model_name = "llava_qwen"
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
- device_map = "auto"
35
-
36
- print("Loading model...")
37
- tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
38
- model.eval()
39
- print("Model loaded successfully!")
40
 
41
- def process_instagram_short(video_path, question):
42
- max_frames_num = 64
43
- video, frame_time, video_time = load_video(video_path, max_frames_num)
44
- video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
45
- video = [video]
46
 
47
- time_instruction = f"This is an Instagram short video lasting {video_time:.2f} seconds. {len(video[0])} frames were sampled at {frame_time}. Analyze this short video and answer the following question:"
48
-
49
- full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
50
-
51
- conv = copy.deepcopy(conv_templates["qwen_1_5"])
52
- conv.append_message(conv.roles[0], full_question)
53
- conv.append_message(conv.roles[1], None)
54
- prompt_question = conv.get_prompt()
55
-
56
- input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
57
-
58
- with torch.no_grad():
59
- output = model.generate(
60
- input_ids,
61
- images=video,
62
- modalities=["video"],
63
- do_sample=False,
64
- temperature=0,
65
- max_new_tokens=4096,
66
- )
67
-
68
- response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
69
- return response
70
-
71
- def gradio_interface(video_file, question):
72
  if video_file is None:
73
  return "Please upload an Instagram short video."
74
- response = process_instagram_short(video_file, question)
 
 
75
  return response
76
 
 
77
  with gr.Blocks() as demo:
78
- gr.Markdown(title)
79
- gr.Markdown(description)
80
 
81
  with gr.Row():
82
  with gr.Column():
@@ -86,10 +69,10 @@ with gr.Blocks() as demo:
86
  output = gr.Textbox(label="Analysis Result")
87
 
88
  submit_button.click(
89
- fn=gradio_interface,
90
  inputs=[video_input, question_input],
91
  outputs=output
92
  )
93
 
94
  if __name__ == "__main__":
95
- demo.launch(show_error=True)
 
1
+ import sys
2
+ import subprocess
3
+ import pkg_resources
4
+
5
+ required_packages = {
6
+ 'torch': 'torch',
7
+ 'gradio': 'gradio',
8
+ 'transformers': 'transformers',
9
+ 'decord': 'decord',
10
+ 'numpy': 'numpy'
11
+ }
12
+
13
+ def install_packages(packages):
14
+ for package in packages:
15
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
16
+
17
+ def check_and_install_packages():
18
+ installed_packages = {pkg.key for pkg in pkg_resources.working_set}
19
+ missing_packages = [required_packages[pkg] for pkg in required_packages if pkg not in installed_packages]
20
+
21
+ if missing_packages:
22
+ print("Installing missing packages...")
23
+ install_packages(missing_packages)
24
+ print("Packages installed successfully.")
25
+ else:
26
+ print("All required packages are already installed.")
27
+
28
+ # Check and install required packages
29
+ check_and_install_packages()
30
+
31
+ # Now import the required modules
32
  import gradio as gr
33
  import torch
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
35
  from decord import VideoReader, cpu
36
  import numpy as np
37
 
38
+ # Define a simple video processing function (placeholder for LLaVA-Video)
39
+ def process_video(video_path, max_frames=64):
 
 
 
 
 
40
  vr = VideoReader(video_path, ctx=cpu(0))
41
+ total_frames = len(vr)
42
+ frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int)
43
+ frames = vr.get_batch(frame_indices).asnumpy()
44
+ return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Define a simple text generation function (placeholder for actual model)
47
+ def generate_response(video_frames, question):
48
+ # This is a placeholder. In reality, you'd use the LLaVA-Video model here.
49
+ return f"Analyzed {len(video_frames)} frames. Your question was: {question}"
 
50
 
51
+ def analyze_instagram_short(video_file, question):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  if video_file is None:
53
  return "Please upload an Instagram short video."
54
+
55
+ video_frames = process_video(video_file)
56
+ response = generate_response(video_frames, question)
57
  return response
58
 
59
+ # Create Gradio interface
60
  with gr.Blocks() as demo:
61
+ gr.Markdown("# 🎥 Instagram Short Video Analyzer")
62
+ gr.Markdown("Upload your Instagram short video and ask questions about its content!")
63
 
64
  with gr.Row():
65
  with gr.Column():
 
69
  output = gr.Textbox(label="Analysis Result")
70
 
71
  submit_button.click(
72
+ fn=analyze_instagram_short,
73
  inputs=[video_input, question_input],
74
  outputs=output
75
  )
76
 
77
  if __name__ == "__main__":
78
+ demo.launch()