import shutil import subprocess import torch import gradio as gr from fastapi import FastAPI import os from PIL import Image import tempfile from decord import VideoReader, cpu from transformers import TextStreamer import argparse import sys from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle, Conversation from llava.mm_utils import process_images from infer_utils import load_video_into_frames from utils import load_image, image_ext, video_ext from gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css def save_image_to_local(image): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') image = Image.open(image) image.save(filename) return filename def save_video_to_local(video_path): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') shutil.copyfile(video_path, filename) return filename def generate(video, textbox_in, first_run, state, state_, images_tensor, num_frames=50): flag = 1 if not textbox_in: if len(state_.messages) > 0: textbox_in = state_.messages[-1][1] state_.messages.pop(-1) flag = 0 else: return "Please enter instruction" video = video if video else "none" if type(state) is not Conversation: state = conv_templates[conv_mode].copy() state_ = conv_templates[conv_mode].copy() images_tensor = [] first_run = False if len(state.messages) > 0 else True text_en_in = textbox_in.replace("picture", "image") image_processor = handler.image_processor assert os.path.exists(video) if os.path.splitext(video)[-1].lower() in video_ext: # video extension video_decode_backend = 'opencv' elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder video_decode_backend = 'frames' else: raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}') frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames) tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad')) tensor = tensor.to(handler.model.device, dtype=dtype) images_tensor = tensor if handler.model.config.mm_use_im_start_end: text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in else: text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) state_.messages[-1] = (state_.roles[1], text_en_out) text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out show_images = "" if os.path.exists(video): filename = save_video_to_local(video) show_images += f'' if flag: state.append_message(state.roles[0], textbox_in + "\n" + show_images) state.append_message(state.roles[1], textbox_out) return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, \ gr.update(value=video if os.path.exists(video) else None, interactive=True)) def regenerate(state, state_): state.messages.pop(-1) state_.messages.pop(-1) if len(state.messages) > 0: return state, state_, state.to_gradio_chatbot(), False return (state, state_, state.to_gradio_chatbot(), True) def clear_history(state, state_): state = conv_templates[conv_mode].copy() state_ = conv_templates[conv_mode].copy() return (gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), \ gr.update(value=None, interactive=True), \ True, state, state_, state.to_gradio_chatbot(), []) # ==== CHANGE HERE ==== conv_mode = "llava_v0" model_path = 'SNUMPR/vlm_rlaif_video_llava_7b' cache_dir = './cache_dir' device = 'cuda' load_8bit = True load_4bit = False dtype = torch.float16 # ============= handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir) if not os.path.exists("temp"): os.makedirs("temp") app = FastAPI() textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", container=False ) with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo: gr.Markdown(title_markdown) state = gr.State() state_ = gr.State() first_run = gr.State() images_tensor = gr.State() with gr.Row(): with gr.Column(scale=3): video = gr.Video(label="Input Video") cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples( examples=[ [ f"{cur_dir}/examples/sample_demo_1.mp4", "Why is this video funny?", ], [ f"{cur_dir}/examples/sample_demo_3.mp4", "Can you identify any safety hazards in this video?" ], [ f"{cur_dir}/examples/sample_demo_9.mp4", "Describe the video.", ], [ f"{cur_dir}/examples/sample_demo_22.mp4", "Describe the activity in the video.", ], ], inputs=[video, textbox], ) with gr.Column(scale=7): chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button( value="Send", variant="primary", interactive=True ) with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=True) downvote_btn = gr.Button(value="👎 Downvote", interactive=True) flag_btn = gr.Button(value="⚠️ Flag", interactive=True) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) # gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) demo.launch()