vlm-rlaif-demo / gradio_web_server.py
dcahn12
Edit LICENSE
e07ba8d
raw
history blame
7.11 kB
import shutil
import subprocess
import torch
import gradio as gr
from fastapi import FastAPI
import os
from PIL import Image
import tempfile
from decord import VideoReader, cpu
from transformers import TextStreamer
import argparse
import sys
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle, Conversation
from llava.mm_utils import process_images
from infer_utils import load_video_into_frames
from utils import load_image, image_ext, video_ext
from gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
def save_image_to_local(image):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
image = Image.open(image)
image.save(filename)
return filename
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
video = video if video else "none"
if type(state) is not Conversation:
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
images_tensor = []
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
image_processor = handler.image_processor
assert os.path.exists(video)
if os.path.splitext(video)[-1].lower() in video_ext: # video extension
video_decode_backend = 'opencv'
elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
video_decode_backend = 'frames'
else:
raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
tensor = tensor.to(handler.model.device, dtype=dtype)
images_tensor = tensor
if handler.model.config.mm_use_im_start_end:
text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
else:
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if os.path.exists(video):
filename = save_video_to_local(video)
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, \
gr.update(value=video if os.path.exists(video) else None, interactive=True))
def regenerate(state, state_):
state.messages.pop(-1)
state_.messages.pop(-1)
if len(state.messages) > 0:
return state, state_, state.to_gradio_chatbot(), False
return (state, state_, state.to_gradio_chatbot(), True)
def clear_history(state, state_):
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
return (gr.update(value=None, interactive=True),
gr.update(value=None, interactive=True), \
gr.update(value=None, interactive=True), \
True, state, state_, state.to_gradio_chatbot(), [])
# ==== CHANGE HERE ====
conv_mode = "llava_v0"
model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
cache_dir = './cache_dir'
device = 'cuda'
load_8bit = True
load_4bit = False
dtype = torch.float16
# =============
handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
if not os.path.exists("temp"):
os.makedirs("temp")
app = FastAPI()
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
[
f"{cur_dir}/examples/sample_demo_1.mp4",
"Why is this video funny?",
],
[
f"{cur_dir}/examples/sample_demo_3.mp4",
"Can you identify any safety hazards in this video?"
],
[
f"{cur_dir}/examples/sample_demo_9.mp4",
"Describe the video.",
],
[
f"{cur_dir}/examples/sample_demo_22.mp4",
"Describe the activity in the video.",
],
],
inputs=[video, textbox],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(
value="Send", variant="primary", interactive=True
)
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True)
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True)
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
# gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
[state, state_, chatbot, first_run, textbox, images_tensor, video])
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
demo.launch()