JustinLin610's picture
fix video bug
9ea02f2
raw
history blame
5.11 kB
import data
import torch
import gradio as gr
from models import imagebind_model
from models.imagebind_model import ModalityType
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
def image_text_zeroshot(image, text_list):
image_paths = [image]
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(labels, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = (
torch.softmax(
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1
)
.squeeze(0)
.tolist()
)
score_dict = {label: score for label, score in zip(labels, scores)}
return score_dict
def audio_text_zeroshot(audio, text_list):
audio_paths = [audio]
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(labels, device),
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = (
torch.softmax(
embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1
)
.squeeze(0)
.tolist()
)
score_dict = {label: score for label, score in zip(labels, scores)}
return score_dict
def video_text_zeroshot(video, text_list):
video_paths = [video]
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(labels, device),
ModalityType.VISION: data.load_and_transform_video_data(video_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = (
torch.softmax(
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1
)
.squeeze(0)
.tolist()
)
score_dict = {label: score for label, score in zip(labels, scores)}
return score_dict
def inference(
task,
text_list=None,
image=None,
audio=None,
video=None,
):
if task == "image-text":
result = image_text_zeroshot(image, text_list)
elif task == "audio-text":
result = audio_text_zeroshot(audio, text_list)
elif task == "video-text":
result = video_text_zeroshot(video, text_list)
else:
raise NotImplementedError
return result
def main():
inputs = [
gr.inputs.Radio(
choices=[
"image-text",
"audio-text",
"video-text",
],
type="value",
default="image-text",
label="Task",
),
gr.inputs.Textbox(lines=1, label="Candidate texts"),
gr.inputs.Image(type="filepath", label="Input image"),
gr.inputs.Audio(type="filepath", label="Input audio"),
gr.inputs.Video(type=None, label="Input video"),
]
iface = gr.Interface(
inference,
inputs,
"label",
examples=[
["image-text", "A dog|A car|A bird", "assets/dog_image.jpg", None, None],
["image-text", "A dog|A car|A bird", "assets/car_image.jpg", None, None],
["audio-text", "A dog|A car|A bird", None, "assets/bird_audio.wav", None],
["video-text", "A dog|A car|A bird", None, None, "assets/dog_video.mp4"],
],
description="""<p>This is a simple demo of ImageBind for zero-shot cross-modal understanding (now including image classification, audio classification, and video classification). Please refer to the original <a href='https://arxiv.org/abs/2305.05665' target='_blank'>paper</a> and <a href='https://github.com/facebookresearch/ImageBind' target='_blank'>repo</a> for more details.<br>
To test your own cases, you can upload an image, an audio or a video, and provide the candidate texts separated by "|".<br>
You can duplicate this space and run it privately: <a href='https://huggingface.co/spaces/OFA-Sys/chinese-clip-zero-shot-image-classification?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p>""",
title="ImageBind: Zero-shot Cross-modal Understanding",
)
iface.launch()
if __name__ == "__main__":
main()