metek7 commited on
Commit
f0272e1
·
verified ·
1 Parent(s): 1178e6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -51
app.py CHANGED
@@ -1,63 +1,95 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
 
 
 
 
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
27
 
28
- response = ""
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from llava.model.builder import load_pretrained_model
4
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
5
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
6
+ from llava.conversation import conv_templates
7
+ import copy
8
+ from decord import VideoReader, cpu
9
+ import numpy as np
10
 
11
+ title = "# 🎥 Instagram Short Video Analyzer with LLaVA-Video"
12
+ description = """
13
+ This application uses the LLaVA-Video-7B-Qwen2 model to analyze Instagram short videos.
14
+ Upload your Instagram short video and ask questions about its content!
15
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def load_video(video_path, max_frames_num=64, fps=1):
18
+ vr = VideoReader(video_path, ctx=cpu(0))
19
+ total_frame_num = len(vr)
20
+ video_time = total_frame_num / vr.get_avg_fps()
21
+ fps = round(vr.get_avg_fps()/fps)
22
+ frame_idx = list(range(0, len(vr), fps))
23
+ if len(frame_idx) > max_frames_num:
24
+ frame_idx = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int).tolist()
25
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
26
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
27
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
28
+ return spare_frames, frame_time, video_time
29
 
30
+ # Load the model
31
+ pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
32
+ model_name = "llava_qwen"
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ device_map = "auto"
35
 
36
+ print("Loading model...")
37
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
38
+ model.eval()
39
+ print("Model loaded successfully!")
40
 
41
+ def process_instagram_short(video_path, question):
42
+ max_frames_num = 64
43
+ video, frame_time, video_time = load_video(video_path, max_frames_num)
44
+ video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
45
+ video = [video]
 
 
 
46
 
47
+ time_instruction = f"This is an Instagram short video lasting {video_time:.2f} seconds. {len(video[0])} frames were sampled at {frame_time}. Analyze this short video and answer the following question:"
48
+
49
+ full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
50
+
51
+ conv = copy.deepcopy(conv_templates["qwen_1_5"])
52
+ conv.append_message(conv.roles[0], full_question)
53
+ conv.append_message(conv.roles[1], None)
54
+ prompt_question = conv.get_prompt()
55
+
56
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
57
+
58
+ with torch.no_grad():
59
+ output = model.generate(
60
+ input_ids,
61
+ images=video,
62
+ modalities=["video"],
63
+ do_sample=False,
64
+ temperature=0,
65
+ max_new_tokens=4096,
66
+ )
67
+
68
+ response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
69
+ return response
70
 
71
+ def gradio_interface(video_file, question):
72
+ if video_file is None:
73
+ return "Please upload an Instagram short video."
74
+ response = process_instagram_short(video_file, question)
75
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown(title)
79
+ gr.Markdown(description)
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ video_input = gr.Video(label="Upload Instagram Short Video")
84
+ question_input = gr.Textbox(label="Ask a question about the video", placeholder="What's happening in this Instagram short?")
85
+ submit_button = gr.Button("Analyze Short Video")
86
+ output = gr.Textbox(label="Analysis Result")
87
+
88
+ submit_button.click(
89
+ fn=gradio_interface,
90
+ inputs=[video_input, question_input],
91
+ outputs=output
92
+ )
93
 
94
  if __name__ == "__main__":
95
+ demo.launch(show_error=True)