TuringsSolutions commited on
Commit
40652ca
·
verified ·
1 Parent(s): e09ac93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -81
app.py CHANGED
@@ -1,94 +1,101 @@
 
 
1
  import gradio as gr
2
  import torch
3
- from llava.model.builder import load_pretrained_model
4
- from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
5
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
6
- from llava.conversation import conv_templates
7
- import copy
8
- from decord import VideoReader, cpu
9
- import numpy as np
10
 
11
- # Load the model
12
- pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
13
- model_name = "llava_qwen"
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- device_map = "auto"
16
 
17
  print("Loading model...")
18
- tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
19
- model.eval()
 
 
 
 
 
 
20
  print("Model loaded successfully!")
21
 
22
- def load_video(video_path, max_frames_num, fps=1, force_sample=False):
23
- if max_frames_num == 0:
24
- return np.zeros((1, 336, 336, 3))
25
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
26
- total_frame_num = len(vr)
27
- video_time = total_frame_num / vr.get_avg_fps()
28
- fps = round(vr.get_avg_fps()/fps)
29
- frame_idx = [i for i in range(0, len(vr), fps)]
30
- frame_time = [i/fps for i in frame_idx]
31
- if len(frame_idx) > max_frames_num or force_sample:
32
- sample_fps = max_frames_num
33
- uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
34
- frame_idx = uniform_sampled_frames.tolist()
35
- frame_time = [i/vr.get_avg_fps() for i in frame_idx]
36
- frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
37
- spare_frames = vr.get_batch(frame_idx).asnumpy()
38
- return spare_frames, frame_time, video_time
39
 
40
- def process_video(video_path, question):
41
- max_frames_num = 64
42
- video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
43
- video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
44
- video = [video]
45
 
46
- conv_template = "qwen_1_5"
47
- time_instruction = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
48
-
49
- full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
50
-
51
- conv = copy.deepcopy(conv_templates[conv_template])
52
- conv.append_message(conv.roles[0], full_question)
53
- conv.append_message(conv.roles[1], None)
54
- prompt_question = conv.get_prompt()
55
-
56
- input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
57
-
58
- with torch.no_grad():
59
- output = model.generate(
60
- input_ids,
61
- images=video,
62
- modalities=["video"],
63
- do_sample=False,
64
- temperature=0,
65
- max_new_tokens=4096,
66
- )
67
-
68
- response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
69
- return response
70
 
71
- def gradio_interface(video_file, question):
72
- if video_file is None:
73
- return "Please upload a video file."
74
- response = process_video(video_file, question)
75
- return response
76
 
77
- # Set up Gradio interface
78
- with gr.Blocks() as demo:
79
- gr.Markdown("# 🌋📹 LLaVA-Video Chatbot")
80
- with gr.Row():
81
- with gr.Column():
82
- video_input = gr.Video()
83
- question_input = gr.Textbox(label="User Question", placeholder="Ask a question about the video...")
84
- submit_button = gr.Button("Ask LLaVA-Video")
85
- output = gr.Textbox(label="LLaVA-Video Response")
86
-
87
- submit_button.click(
88
- fn=gradio_interface,
89
- inputs=[video_input, question_input],
90
- outputs=output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
 
93
- if __name__ == "__main__":
94
- demo.launch(show_error=True)
 
 
1
+ import time
2
+ from threading import Thread
3
  import gradio as gr
4
  import torch
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
 
 
 
 
 
7
 
8
+ # Model Configuration
9
+ model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
 
 
 
10
 
11
  print("Loading model...")
12
+ processor = AutoProcessor.from_pretrained(model_id)
13
+ model = LlavaForConditionalGeneration.from_pretrained(
14
+ model_id,
15
+ torch_dtype=torch.float16,
16
+ low_cpu_mem_usage=True
17
+ )
18
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.generation_config.eos_token_id = 128009
20
  print("Model loaded successfully!")
21
 
22
+ PLACEHOLDER = """
23
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
24
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64ccdc322e592905f922a06e/DDIW0kbWmdOQWwy4XMhwX.png"
25
+ style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
26
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-Llama-3-8B</h1>
27
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">
28
+ Llava-Llama-3-8B is fine-tuned from Meta-Llama-3-8B-Instruct and CLIP-ViT-Large-patch14-336
29
+ using ShareGPT4V-PT and InternVL-SFT by XTuner.
30
+ </p>
31
+ </div>
32
+ """
 
 
 
 
 
 
33
 
34
+ def bot_streaming(message, history):
35
+ """Handles message processing with image and text streaming."""
36
+ try:
37
+ image = None
 
38
 
39
+ # Extract image from message or history
40
+ if message["files"]:
41
+ image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
42
+ else:
43
+ for hist in history:
44
+ if isinstance(hist[0], tuple):
45
+ image = hist[0][0]
46
+
47
+ if not image:
48
+ return "Error: Please upload an image for LLaVA to work."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Prepare inputs
51
+ image = Image.open(image)
52
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|>"
53
+ inputs = processor(prompt, image, return_tensors="pt").to(device=model.device, dtype=torch.float16)
 
54
 
55
+ # Stream text generation
56
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
57
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
58
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
59
+ thread.start()
60
+
61
+ buffer = ""
62
+ time.sleep(0.5) # Allow some time for initial generation
63
+
64
+ # Stream the generated response
65
+ for new_text in streamer:
66
+ if "<|eot_id|>" in new_text:
67
+ new_text = new_text.split("<|eot_id|>")[0]
68
+ buffer += new_text
69
+ yield buffer
70
+
71
+ except Exception as e:
72
+ yield f"Error: {str(e)}"
73
+
74
+ # Define Gradio interface components
75
+ chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1)
76
+ chat_input = gr.MultimodalTextbox(
77
+ interactive=True, file_types=["image"], placeholder="Enter message or upload a file...", show_label=False
78
+ )
79
+
80
+ with gr.Blocks(fill_height=True) as demo:
81
+ gr.ChatInterface(
82
+ fn=bot_streaming,
83
+ title="LLaVA Llama-3-8B",
84
+ examples=[
85
+ {"text": "What is on the flower?", "files": ["./bee.jpg"]},
86
+ {"text": "How to make this pastry?", "files": ["./baklava.png"]}
87
+ ],
88
+ description=(
89
+ "Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). "
90
+ "Upload an image and start chatting about it, or simply try one of the examples below. "
91
+ "If you don't upload an image, you will receive an error."
92
+ ),
93
+ stop_btn="Stop Generation",
94
+ multimodal=True,
95
+ textbox=chat_input,
96
+ chatbot=chatbot,
97
  )
98
 
99
+ # Launch the Gradio app
100
+ demo.queue(api_open=False)
101
+ demo.launch(show_api=False, share=False)