import torch import gradio as gr from flash_vstream.serve.demo import Chat, title_markdown, block_css from flash_vstream.constants import * from flash_vstream.conversation import conv_templates, Conversation import os from PIL import Image import tempfile import imageio import shutil model_path = "IVGSZ/Flash-VStream-7b" load_8bit = False load_4bit = False def save_image_to_local(image): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') image = Image.open(image) image.save(filename) return filename def save_video_to_local(video_path): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') shutil.copyfile(video_path, filename) return filename def generate(video, textbox_in, first_run, state, state_, images_tensor): flag = 1 if not textbox_in: if len(state_.messages) > 0: textbox_in = state_.messages[-1][1] state_.messages.pop(-1) flag = 0 else: return "Please enter instruction" video = video if video else "none" if type(state) is not Conversation: state = conv_templates[conv_mode].copy() state_ = conv_templates[conv_mode].copy() images_tensor = [] first_run = False if len(state.messages) > 0 else True text_en_in = textbox_in.replace("picture", "image") image_processor = handler.image_processor if os.path.exists(video): video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH) images_tensor = image_processor(video_tensor, return_tensors='pt')['pixel_values'].to(handler.model.device, dtype=torch.float16) print("video_tensor", video_tensor.shape) if os.path.exists(video): text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) state_.messages[-1] = (state_.roles[1], text_en_out) text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out show_images = "" if os.path.exists(video): filename = save_video_to_local(video) show_images += f'' if flag: state.append_message(state.roles[0], textbox_in + "\n" + show_images) state.append_message(state.roles[1], textbox_out) return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True)) def regenerate(state, state_): state.messages.pop(-1) state_.messages.pop(-1) if len(state.messages) > 0: return state, state_, state.to_gradio_chatbot(), False return (state, state_, state.to_gradio_chatbot(), True) def clear_history(state, state_): state = conv_templates[conv_mode].copy() state_ = conv_templates[conv_mode].copy() return (gr.update(value=None, interactive=True), \ gr.update(value=None, interactive=True),\ True, state, state_, state.to_gradio_chatbot(), []) conv_mode = "vicuna_v1" handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit) if not os.path.exists("temp"): os.makedirs("temp") print(torch.cuda.memory_allocated()) print(torch.cuda.max_memory_allocated()) with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo: gr.Markdown(title_markdown) state = gr.State() state_ = gr.State() first_run = gr.State() images_tensor = gr.State() with gr.Row(): with gr.Column(scale=3): video = gr.Video(label="Input Video") with gr.Column(scale=7): chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700) with gr.Row(): with gr.Column(scale=8): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press Send", container=False) with gr.Column(scale=2, min_width=50): submit_btn = gr.Button(value="Send", variant="primary", interactive=True) with gr.Row(visible=True) as button_row: regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) cur_dir = os.path.dirname(os.path.abspath(__file__)) with gr.Row(): gr.Examples( examples=[ [ f"{cur_dir}/examples/video2.mp4", "Describe the video briefly.", ] ], inputs=[video, textbox], ) gr.Examples( examples=[ [ f"{cur_dir}/examples/video4.mp4", "What is the boy doing?", ] ], inputs=[video, textbox], ) gr.Examples( examples=[ [ f"{cur_dir}/examples/video5.mp4", "Why is this video funny?", ] ], inputs=[video, textbox], ) submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) clear_btn.click(clear_history, [state, state_], [video, textbox, first_run, state, state_, chatbot, images_tensor]) # app = gr.mount_gradio_app(app, demo, path="/") demo.launch()