|
--- |
|
license: mit |
|
--- |
|
|
|
<p align="center" width="100%"> |
|
<img src="https://i.postimg.cc/MKmyP9wH/new-banner.png" width="80%" height="80%"> |
|
</p> |
|
|
|
|
|
<div> |
|
<div align="center"> |
|
<a href='https://brianboli.com/' target='_blank'>Bo Li*<sup>1</sup></a>  |
|
<a href='https://zhangyuanhan-ai.github.io/' target='_blank'>Yuanhan Zhang*<sup>,1</sup></a>  |
|
<a href='https://cliangyu.com/' target='_blank'>Liangyu Chen*<sup>,1</sup></a>  |
|
<a href='https://king159.github.io/' target='_blank'>Jinghao Wang*<sup>,1</sup></a>  |
|
<a href='https://pufanyi.github.io/' target='_blank'>Fanyi Pu*<sup>,1</sup></a>  |
|
</br> |
|
<a href='https://jingkang50.github.io/' target='_blank'>Jingkang Yang<sup>1</sup></a>  |
|
<a href='https://chunyuan.li/' target='_blank'>Chunyuan Li<sup>2</sup></a>  |
|
<a href='https://liuziwei7.github.io/' target='_blank'>Ziwei Liu<sup>1</sup></a> |
|
</div> |
|
<div> |
|
<div align="center"> |
|
<sup>1</sup>S-Lab, Nanyang Technological University  |
|
<sup>2</sup>Microsoft Research, Redmond |
|
</div> |
|
|
|
----------------- |
|
|
|
![](https://img.shields.io/badge/otter-v0.2-darkcyan) |
|
![](https://img.shields.io/github/stars/luodian/otter?style=social) |
|
[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FLuodian%2Fotter&count_bg=%23FFA500&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=visitors&edge_flat=false)](https://hits.seeyoufarm.com) |
|
![](https://black.readthedocs.io/en/stable/_static/license.svg) |
|
![](https://img.shields.io/badge/code%20style-black-000000.svg) |
|
|
|
An example of using this model to run on your video. |
|
Please first clone [Otter](https://github.com/Luodian/Otter) to your local disk. |
|
Place following script inside the `Otter` folder to make sure it has the access to `otter/modeling_otter.py`. |
|
|
|
```python |
|
import mimetypes |
|
import os |
|
from typing import Union |
|
import cv2 |
|
import requests |
|
import torch |
|
import transformers |
|
from PIL import Image |
|
import sys |
|
|
|
# make sure you can properly access the otter folder |
|
from otter.modeling_otter import OtterForConditionalGeneration |
|
|
|
# Disable warnings |
|
requests.packages.urllib3.disable_warnings() |
|
|
|
# ------------------- Utility Functions ------------------- |
|
|
|
|
|
def get_content_type(file_path): |
|
content_type, _ = mimetypes.guess_type(file_path) |
|
return content_type |
|
|
|
|
|
# ------------------- Image and Video Handling Functions ------------------- |
|
|
|
|
|
def extract_frames(video_path, num_frames=16): |
|
video = cv2.VideoCapture(video_path) |
|
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
frame_step = total_frames // num_frames |
|
frames = [] |
|
|
|
for i in range(num_frames): |
|
video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step) |
|
ret, frame = video.read() |
|
if ret: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = Image.fromarray(frame).convert("RGB") |
|
frames.append(frame) |
|
|
|
video.release() |
|
return frames |
|
|
|
|
|
def get_image(url: str) -> Union[Image.Image, list]: |
|
if "://" not in url: # Local file |
|
content_type = get_content_type(url) |
|
else: # Remote URL |
|
content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type") |
|
|
|
if "image" in content_type: |
|
if "://" not in url: # Local file |
|
return Image.open(url) |
|
else: # Remote URL |
|
return Image.open(requests.get(url, stream=True, verify=False).raw) |
|
elif "video" in content_type: |
|
video_path = "temp_video.mp4" |
|
if "://" not in url: # Local file |
|
video_path = url |
|
else: # Remote URL |
|
with open(video_path, "wb") as f: |
|
f.write(requests.get(url, stream=True, verify=False).content) |
|
frames = extract_frames(video_path) |
|
if "://" in url: # Only remove the temporary video file if it was downloaded |
|
os.remove(video_path) |
|
return frames |
|
else: |
|
raise ValueError("Invalid content type. Expected image or video.") |
|
|
|
|
|
# ------------------- OTTER Prompt and Response Functions ------------------- |
|
|
|
|
|
def get_formatted_prompt(prompt: str) -> str: |
|
return f"<image>User: {prompt} GPT:<answer>" |
|
|
|
|
|
def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str: |
|
if isinstance(input_data, Image.Image): |
|
vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0) |
|
elif isinstance(input_data, list): # list of video frames |
|
vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(0).unsqueeze(0) |
|
else: |
|
raise ValueError("Invalid input data. Expected PIL Image or list of video frames.") |
|
|
|
lang_x = model.text_tokenizer( |
|
[ |
|
get_formatted_prompt(prompt), |
|
], |
|
return_tensors="pt", |
|
) |
|
|
|
bad_words_id = model.text_tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids |
|
generated_text = model.generate( |
|
vision_x=vision_x.to(model.device, dtype=tensor_dtype), |
|
lang_x=lang_x["input_ids"].to(model.device), |
|
attention_mask=lang_x["attention_mask"].to(model.device), |
|
max_new_tokens=512, |
|
num_beams=3, |
|
no_repeat_ngram_size=3, |
|
bad_words_ids=bad_words_id, |
|
) |
|
parsed_output = ( |
|
model.text_tokenizer.decode(generated_text[0]) |
|
.split("<answer>")[-1] |
|
.lstrip() |
|
.rstrip() |
|
.split("<|endofchunk|>")[0] |
|
.lstrip() |
|
.rstrip() |
|
.lstrip('"') |
|
.rstrip('"') |
|
) |
|
return parsed_output |
|
|
|
|
|
# ------------------- Main Function ------------------- |
|
load_bit = "fp32" |
|
if load_bit == "fp16": |
|
precision = {"torch_dtype": torch.float16} |
|
elif load_bit == "bf16": |
|
precision = {"torch_dtype": torch.bfloat16} |
|
elif load_bit == "fp32": |
|
precision = {"torch_dtype": torch.float32} |
|
|
|
# This model version is trained on MIMIC-IT DC dataset. |
|
model = OtterForConditionalGeneration.from_pretrained("luodian/OTTER-9B-DenseCaption", device_map="auto", **precision) |
|
tensor_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[load_bit] |
|
|
|
model.text_tokenizer.padding_side = "left" |
|
tokenizer = model.text_tokenizer |
|
image_processor = transformers.CLIPImageProcessor() |
|
model.eval() |
|
|
|
while True: |
|
video_url = input("Enter video path: ") # Replace with the path to your video file, could be any common format. |
|
|
|
frames_list = get_image(video_url) |
|
|
|
while True: |
|
prompts_input = input("Enter prompts: ") |
|
|
|
if prompts_input.lower() == "quit": |
|
break |
|
|
|
print(f"\nPrompt: {prompts_input}") |
|
response = get_response(frames_list, prompts_input, model, image_processor, tensor_dtype) |
|
print(f"Response: {response}") |
|
|
|
``` |