Spaces:
Paused
Paused
File size: 7,112 Bytes
598d165 de577c9 598d165 e07ba8d 598d165 411aed8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import shutil
import subprocess
import torch
import gradio as gr
from fastapi import FastAPI
import os
from PIL import Image
import tempfile
from decord import VideoReader, cpu
from transformers import TextStreamer
import argparse
import sys
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle, Conversation
from llava.mm_utils import process_images
from infer_utils import load_video_into_frames
from utils import load_image, image_ext, video_ext
from gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
def save_image_to_local(image):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
image = Image.open(image)
image.save(filename)
return filename
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
video = video if video else "none"
if type(state) is not Conversation:
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
images_tensor = []
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
image_processor = handler.image_processor
assert os.path.exists(video)
if os.path.splitext(video)[-1].lower() in video_ext: # video extension
video_decode_backend = 'opencv'
elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
video_decode_backend = 'frames'
else:
raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor = tensor
if handler.model.config.mm_use_im_start_end:
text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
else:
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if os.path.exists(video):
filename = save_video_to_local(video)
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, \
gr.update(value=video if os.path.exists(video) else None, interactive=True))
def regenerate(state, state_):
state.messages.pop(-1)
state_.messages.pop(-1)
if len(state.messages) > 0:
return state, state_, state.to_gradio_chatbot(), False
return (state, state_, state.to_gradio_chatbot(), True)
def clear_history(state, state_):
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
return (gr.update(value=None, interactive=True),
gr.update(value=None, interactive=True), \
gr.update(value=None, interactive=True), \
True, state, state_, state.to_gradio_chatbot(), [])
# ==== CHANGE HERE ====
conv_mode = "llava_v0"
model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
cache_dir = './cache_dir'
device = 'cuda'
load_8bit = True
load_4bit = False
dtype = torch.float16
# =============
handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
if not os.path.exists("temp"):
os.makedirs("temp")
app = FastAPI()
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
[
f"{cur_dir}/examples/sample_demo_1.mp4",
"Why is this video funny?",
],
[
f"{cur_dir}/examples/sample_demo_3.mp4",
"Can you identify any safety hazards in this video?"
],
[
f"{cur_dir}/examples/sample_demo_9.mp4",
"Describe the video.",
],
[
f"{cur_dir}/examples/sample_demo_22.mp4",
"Describe the activity in the video.",
],
],
inputs=[video, textbox],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(
value="Send", variant="primary", interactive=True
)
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="π Upvote", interactive=True)
downvote_btn = gr.Button(value="π Downvote", interactive=True)
flag_btn = gr.Button(value="β οΈ Flag", interactive=True)
regenerate_btn = gr.Button(value="π Regenerate", interactive=True)
# gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
[state, state_, chatbot, first_run, textbox, images_tensor, video])
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
demo.launch()
|