zhanghaoji
fix tensor
cc32390
raw
history blame
5.9 kB
import torch
import gradio as gr
from flash_vstream.serve.demo import Chat, title_markdown, block_css
from flash_vstream.constants import *
from flash_vstream.conversation import conv_templates, Conversation
import os
from PIL import Image
import tempfile
import imageio
import shutil
model_path = "IVGSZ/Flash-VStream-7b"
load_8bit = False
load_4bit = False
def save_image_to_local(image):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
image = Image.open(image)
image.save(filename)
return filename
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(video, textbox_in, first_run, state, state_, images_tensor):
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
video = video if video else "none"
if type(state) is not Conversation:
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
images_tensor = []
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
image_processor = handler.image_processor
if os.path.exists(video):
video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
images_tensor = image_processor(video_tensor, return_tensors='pt')['pixel_values'].to(handler.model.device, dtype=torch.float16)
print("video_tensor", video_tensor.shape)
if os.path.exists(video):
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if os.path.exists(video):
filename = save_video_to_local(video)
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True))
def regenerate(state, state_):
state.messages.pop(-1)
state_.messages.pop(-1)
if len(state.messages) > 0:
return state, state_, state.to_gradio_chatbot(), False
return (state, state_, state.to_gradio_chatbot(), True)
def clear_history(state, state_):
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
return (gr.update(value=None, interactive=True), \
gr.update(value=None, interactive=True),\
True, state, state_, state.to_gradio_chatbot(), [])
conv_mode = "vicuna_v1"
handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit)
if not os.path.exists("temp"):
os.makedirs("temp")
print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())
with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700)
with gr.Row():
with gr.Column(scale=8):
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press Send",
container=False)
with gr.Column(scale=2, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
with gr.Row(visible=True) as button_row:
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
cur_dir = os.path.dirname(os.path.abspath(__file__))
with gr.Row():
gr.Examples(
examples=[
[
f"{cur_dir}/examples/video2.mp4",
"Describe the video briefly.",
]
],
inputs=[video, textbox],
)
gr.Examples(
examples=[
[
f"{cur_dir}/examples/video4.mp4",
"What is the boy doing?",
]
],
inputs=[video, textbox],
)
gr.Examples(
examples=[
[
f"{cur_dir}/examples/video5.mp4",
"Why is this video funny?",
]
],
inputs=[video, textbox],
)
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
clear_btn.click(clear_history, [state, state_],
[video, textbox, first_run, state, state_, chatbot, images_tensor])
# app = gr.mount_gradio_app(app, demo, path="/")
demo.launch()