|
import torch |
|
import webvtt |
|
import os |
|
import cv2 |
|
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, eval_bleu,eval_cider,chat_gpt_eval |
|
from minigpt4.conversation.conversation import CONV_VISION |
|
from torchvision import transforms |
|
import json |
|
from tqdm import tqdm |
|
import soundfile as sf |
|
import argparse |
|
import moviepy.editor as mp |
|
import gradio as gr |
|
from pytubefix import YouTube |
|
import shutil |
|
from PIL import Image |
|
from moviepy.editor import VideoFileClip |
|
from theme import minigptlv_style, custom_css,text_css |
|
|
|
from huggingface_hub import login, hf_hub_download |
|
hf_token = os.environ.get('HF_TKN') |
|
login(token=hf_token) |
|
|
|
hf_hub_download( |
|
repo_id='Vision-CAIR/MiniGPT4-Video', |
|
filename='video_llama_checkpoint_last.pth', |
|
local_dir='checkpoints', |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
import spaces |
|
|
|
def create_video_grid(images, rows, cols,save_path): |
|
image_width, image_height = images[0].size |
|
grid_width = cols * image_width |
|
grid_height = rows * image_height |
|
|
|
new_image = Image.new("RGB", (grid_width, grid_height)) |
|
|
|
for i in range(rows): |
|
for j in range(cols): |
|
index = i * cols + j |
|
if index < len(images): |
|
image = images[index] |
|
x_offset = j * image_width |
|
y_offset = i * image_height |
|
new_image.paste(image, (x_offset, y_offset)) |
|
|
|
return new_image |
|
|
|
def prepare_input(vis_processor,video_path,subtitle_path,instruction): |
|
cap = cv2.VideoCapture(video_path) |
|
if subtitle_path is not None: |
|
|
|
vtt_file = webvtt.read(subtitle_path) |
|
print("subtitle loaded successfully") |
|
clip = VideoFileClip(video_path) |
|
total_num_frames = int(clip.duration * clip.fps) |
|
|
|
clip.close() |
|
else : |
|
|
|
total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
if "mistral" in args.ckpt : |
|
max_images_length=90 |
|
max_sub_len = 800 |
|
else: |
|
max_images_length = 45 |
|
max_sub_len = 400 |
|
images = [] |
|
frame_count = 0 |
|
sampling_interval = int(total_num_frames / max_images_length) |
|
if sampling_interval == 0: |
|
sampling_interval = 1 |
|
img_placeholder = "" |
|
subtitle_text_in_interval = "" |
|
history_subtitles = {} |
|
|
|
number_of_words=0 |
|
transform=transforms.Compose([ |
|
transforms.ToPILImage(), |
|
]) |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
if subtitle_path is not None: |
|
for subtitle in vtt_file: |
|
sub=subtitle.text.replace('\n',' ') |
|
if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval: |
|
if not history_subtitles.get(sub,False): |
|
subtitle_text_in_interval+=sub+" " |
|
history_subtitles[sub]=True |
|
break |
|
if frame_count % sampling_interval == 0: |
|
|
|
frame = transform(frame[:,:,::-1]) |
|
frame = vis_processor(frame) |
|
images.append(frame) |
|
img_placeholder += '<Img><ImageHere>' |
|
if subtitle_path is not None and subtitle_text_in_interval != "" and number_of_words< max_sub_len: |
|
img_placeholder+=f'<Cap>{subtitle_text_in_interval}' |
|
number_of_words+=len(subtitle_text_in_interval.split(' ')) |
|
subtitle_text_in_interval = "" |
|
frame_count += 1 |
|
|
|
if len(images) >= max_images_length: |
|
break |
|
cap.release() |
|
cv2.destroyAllWindows() |
|
if len(images) == 0: |
|
|
|
return None,None |
|
|
|
images = torch.stack(images) |
|
instruction = img_placeholder + '\n' + instruction |
|
return images,instruction |
|
def extract_audio(video_path, audio_path): |
|
video_clip = mp.VideoFileClip(video_path) |
|
audio_clip = video_clip.audio |
|
audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k") |
|
|
|
def generate_subtitles(video_path): |
|
video_id=video_path.split('/')[-1].split('.')[0] |
|
audio_path = f"workspace/inference_subtitles/mp3/{video_id}"+'.mp3' |
|
os.makedirs("workspace/inference_subtitles/mp3",exist_ok=True) |
|
if existed_subtitles.get(video_id,False): |
|
return f"workspace/inference_subtitles/{video_id}"+'.vtt' |
|
try: |
|
extract_audio(video_path,audio_path) |
|
print("successfully extracted") |
|
os.system(f"whisper {audio_path} --language English --model large --output_format vtt --output_dir workspace/inference_subtitles") |
|
|
|
os.system(f"rm {audio_path}") |
|
print("subtitle successfully generated") |
|
return f"workspace/inference_subtitles/{video_id}"+'.vtt' |
|
except Exception as e: |
|
print("error",e) |
|
print("error",video_path) |
|
return None |
|
|
|
@spaces.GPU() |
|
def run (video_path,instruction,model,vis_processor,gen_subtitles=True): |
|
if gen_subtitles: |
|
subtitle_path=generate_subtitles(video_path) |
|
else : |
|
subtitle_path=None |
|
prepared_images,prepared_instruction=prepare_input(vis_processor,video_path,subtitle_path,instruction) |
|
if prepared_images is None: |
|
return "Video cann't be open ,check the video path again" |
|
length=len(prepared_images) |
|
prepared_images=prepared_images.unsqueeze(0) |
|
conv = CONV_VISION.copy() |
|
conv.system = "" |
|
|
|
conv.append_message(conv.roles[0], prepared_instruction) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = [conv.get_prompt()] |
|
answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=2) |
|
|
|
if subtitle_path: |
|
os.system(f"rm {subtitle_path}") |
|
|
|
|
|
return answers[0] |
|
|
|
def run_single_image (image_path,instruction,model,vis_processor): |
|
image=Image.open(image_path) |
|
image = vis_processor(image) |
|
prepared_images=torch.stack([image]) |
|
prepared_instruction='<Img><ImageHere>'+instruction |
|
length=len(prepared_images) |
|
prepared_images=prepared_images.unsqueeze(0) |
|
conv = CONV_VISION.copy() |
|
conv.system = "" |
|
|
|
conv.append_message(conv.roles[0], prepared_instruction) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = [conv.get_prompt()] |
|
answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=False, lengths=[length],num_beams=1) |
|
return answers[0] |
|
|
|
def download_video(youtube_url, download_finish): |
|
video_id=youtube_url.split('v=')[-1].split('&')[0] |
|
|
|
youtube = YouTube(youtube_url) |
|
|
|
video_stream = youtube.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() |
|
|
|
|
|
print('Downloading video') |
|
video_stream.download(output_path="workspace",filename=f"{video_id}.mp4") |
|
print('Video downloaded successfully') |
|
processed_video_path= f"workspace/{video_id}.mp4" |
|
download_finish = gr.State(value=True) |
|
return processed_video_path, download_finish |
|
|
|
def get_video_url(url,has_subtitles): |
|
|
|
video_id=url.split('v=')[-1].split('&')[0] |
|
|
|
youtube = YouTube(url) |
|
|
|
video_stream = youtube.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() |
|
|
|
|
|
print('Downloading video') |
|
video_stream.download(output_path="workspace",filename=f"{video_id}.mp4") |
|
print('Video downloaded successfully') |
|
return f"workspace/{video_id}.mp4" |
|
|
|
|
|
|
|
|
|
def get_arguments(): |
|
parser = argparse.ArgumentParser(description="Inference parameters") |
|
parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml") |
|
parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint") |
|
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens") |
|
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") |
|
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") |
|
parser.add_argument( |
|
"--options", |
|
nargs="+", |
|
help="override some settings in the used config, the key-value pair " |
|
"in xxx=yyy format will be merged into config file (deprecate), " |
|
"change to --cfg-options instead.", |
|
) |
|
return parser.parse_args() |
|
args=get_arguments() |
|
model, vis_processor = init_model(args) |
|
conv = CONV_VISION.copy() |
|
conv.system = "" |
|
inference_subtitles_folder="workspace/inference_subtitles" |
|
os.makedirs(inference_subtitles_folder,exist_ok=True) |
|
existed_subtitles={} |
|
for sub in os.listdir(inference_subtitles_folder): |
|
existed_subtitles[sub.split('.')[0]]=True |
|
|
|
def gradio_demo_local(video_path,has_sub,instruction): |
|
pred=run(video_path,instruction,model,vis_processor,gen_subtitles=has_sub) |
|
return pred |
|
|
|
def gradio_demo_youtube(youtube_url,has_sub,instruction): |
|
video_path=get_video_url(youtube_url,has_sub) |
|
pred=run(video_path,instruction,model,vis_processor,gen_subtitles=has_sub) |
|
return pred |
|
|
|
def use_example(url,has_sub_1,q): |
|
|
|
youtube_link.value=url |
|
has_subtitles.value=has_sub_1 |
|
question.value=q |
|
|
|
|
|
title = """<h1 align="center">MiniGPT4-video ๐๏ธ๐ฟ</h1>""" |
|
description = """<h5>This is the demo of MiniGPT4-video Model.</h5>""" |
|
project_page = """<p><a href='https://vision-cair.github.io/MiniGPT4-video/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>""" |
|
code_link="""<p><a href='https://github.com/Vision-CAIR/MiniGPT4-video'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p>""" |
|
paper_link="""<p><a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>""" |
|
|
|
with gr.Blocks(title="MiniGPT4-video ๐๏ธ๐ฟ",css=text_css ) as demo : |
|
|
|
|
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Local videos"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_player_local = gr.Video(sources=["upload"]) |
|
question_local = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?") |
|
has_subtitles_local = gr.Checkbox(label="Use subtitles", value=True) |
|
process_button_local = gr.Button("Answer the Question (QA)") |
|
|
|
with gr.Column(): |
|
answer_local=gr.Text("Answer will be here",label="MiniGPT4-video Answer") |
|
|
|
process_button_local.click(fn=gradio_demo_local, inputs=[video_player_local, has_subtitles_local, question_local], outputs=[answer_local]) |
|
|
|
with gr.Tab("Youtube videos"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
youtube_link = gr.Textbox(label="Enter the youtube link", placeholder="Paste YouTube URL here") |
|
video_player = gr.Video(autoplay=False) |
|
download_finish = gr.State(value=False) |
|
youtube_link.change( |
|
fn=download_video, |
|
inputs=[youtube_link, download_finish], |
|
outputs=[video_player, download_finish] |
|
) |
|
question = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?") |
|
has_subtitles = gr.Checkbox(label="Use subtitles", value=True) |
|
process_button = gr.Button("Answer the Question (QA)") |
|
|
|
with gr.Column(): |
|
answer=gr.Text("Answer will be here",label="MiniGPT4-video Answer") |
|
|
|
process_button.click(fn=gradio_demo_youtube, inputs=[youtube_link, has_subtitles, question], outputs=[answer]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=10).launch(share=False,show_error=True, show_api=False) |
|
|
|
|
|
|
|
|