Spaces:
Sleeping
Sleeping
import os | |
import csv | |
import json | |
import torch | |
import argparse | |
import pandas as pd | |
import torch.nn as nn | |
from tqdm import tqdm | |
from collections import defaultdict | |
from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
from torch.utils.data import DataLoader | |
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
from peft import LoraConfig, get_peft_model | |
from data_utils.xgpt3_dataset import MultiModalDataset | |
from utils import batchify | |
import gradio as gr | |
from entailment_inference import get_scores | |
from nle_inference import VideoCaptionDataset, get_nle | |
import re | |
def modify_keys(state_dict): | |
new_state_dict = defaultdict() | |
pattern = re.compile(r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj).weight') | |
for key, value in state_dict.items(): | |
if pattern.match(key): | |
key = key.split('.') | |
key.insert(-1, 'base_layer') | |
key = '.'.join(key) | |
new_state_dict[key] = value | |
return new_state_dict | |
pretrained_ckpt = "mplugowl7bvideo/" | |
trained_ckpt = "owl-con/checkpoint-5178/pytorch_model.bin" | |
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt) | |
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) | |
processor = MplugOwlProcessor(image_processor, tokenizer) | |
# Instantiate model | |
model = MplugOwlForConditionalGeneration.from_pretrained( | |
pretrained_ckpt, | |
torch_dtype=torch.bfloat16, | |
device_map={'': 'cpu'} | |
) | |
peft_config = LoraConfig( | |
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)', | |
inference_mode=True, | |
r=32, | |
lora_alpha=16, | |
lora_dropout=0.05 | |
) | |
model = get_peft_model(model, peft_config) | |
model.print_trainable_parameters() | |
with open(trained_ckpt, 'rb') as f: | |
ckpt = torch.load(f, map_location = torch.device("cpu")) | |
ckpt = modify_keys(ckpt) | |
model.load_state_dict(ckpt) | |
model = model.to("cuda:0").to(torch.bfloat16) | |
def inference(videopath, text): | |
PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
Human: <|video|> | |
Human: Does this video entail the description: "{caption}"? | |
AI: """ | |
valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential') | |
dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify) | |
score = get_scores(model, tokenizer, dataloader) | |
if score < 0.5: | |
dataset = VideoCaptionDataset(videopath, text) | |
dataloader = DataLoader(dataset) | |
nle = get_nle(model, processor, tokenizer, dataloader) | |
else: | |
nle = "None (NLE is only triggered when entailment score < 0.5)" | |
return score, nle | |
demo = gr.Interface(inference, | |
title="Owl-Con Demo", | |
description="Owl-Con Demo (Code: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)", | |
inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')], | |
outputs=[gr.Number(label='Entailment Score'), gr.Textbox(label='Natural Language Explanation')], | |
examples=[["examples/820.mp4", "We see the group making cookies."], ["examples/820.mp4", "We see the group eating cookies."], ["examples/244.mp4", "She throws a bowling ball while talking on the phone."], ["examples/244.mp4", "She throws a baseball while talking on the phone."]]) | |
if __name__ == "__main__": | |
demo.launch() |