Tonic commited on
Commit
17d4493
•
1 Parent(s): e1f0679
Files changed (2) hide show
  1. app.py +101 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from llava.model.builder import load_pretrained_model
4
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
5
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
6
+ from llava.conversation import conv_templates, SeparatorStyle
7
+ import copy
8
+ import warnings
9
+ from decord import VideoReader, cpu
10
+ import numpy as np
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+ def load_video(video_path, max_frames_num, fps=1, force_sample=False):
15
+ if max_frames_num == 0:
16
+ return np.zeros((1, 336, 336, 3))
17
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
18
+ total_frame_num = len(vr)
19
+ video_time = total_frame_num / vr.get_avg_fps()
20
+ fps = round(vr.get_avg_fps()/fps)
21
+ frame_idx = [i for i in range(0, len(vr), fps)]
22
+ frame_time = [i/fps for i in frame_idx]
23
+ if len(frame_idx) > max_frames_num or force_sample:
24
+ sample_fps = max_frames_num
25
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
26
+ frame_idx = uniform_sampled_frames.tolist()
27
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
28
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
29
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
30
+ return spare_frames, frame_time, video_time
31
+
32
+ # Load the model
33
+ pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
34
+ model_name = "llava_qwen"
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ device_map = "auto"
37
+
38
+ print("Loading model...")
39
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
40
+ model.eval()
41
+ print("Model loaded successfully!")
42
+
43
+ def process_video(video_path, question):
44
+ max_frames_num = 64
45
+ video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
46
+ video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
47
+ video = [video]
48
+
49
+ conv_template = "qwen_1_5"
50
+ time_instruction = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
51
+
52
+ full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
53
+
54
+ conv = copy.deepcopy(conv_templates[conv_template])
55
+ conv.append_message(conv.roles[0], full_question)
56
+ conv.append_message(conv.roles[1], None)
57
+ prompt_question = conv.get_prompt()
58
+
59
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
60
+
61
+ with torch.no_grad():
62
+ output = model.generate(
63
+ input_ids,
64
+ images=video,
65
+ modalities=["video"],
66
+ do_sample=False,
67
+ temperature=0,
68
+ max_new_tokens=4096,
69
+ )
70
+
71
+ response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
72
+ return response
73
+
74
+ # Gradio interface
75
+ def gradio_interface(video_file, question):
76
+ if video_file is None:
77
+ return "Please upload a video file."
78
+ response = process_video(video_file.name, question)
79
+ return response
80
+
81
+ # Create Gradio app
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown("# LLaVA-Video-7B-Qwen2 Demo")
84
+ gr.Markdown("Upload a video and ask a question about it.")
85
+
86
+ with gr.Row():
87
+ video_input = gr.Video()
88
+ question_input = gr.Textbox(label="Question", placeholder="Ask a question about the video...")
89
+
90
+ submit_button = gr.Button("Submit")
91
+ output = gr.Textbox(label="Response")
92
+
93
+ submit_button.click(
94
+ fn=gradio_interface,
95
+ inputs=[video_input, question_input],
96
+ outputs=output
97
+ )
98
+
99
+ # Launch the app
100
+ if __name__ == "__main__":
101
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pillow
2
+ numpy
3
+ transformers
4
+ torch
5
+ torchvision
6
+ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git