fcakyon's picture
update socail links
2a4a324
raw
history blame
2.94 kB
import os
import torch
import gradio as gr
from transformers import XCLIPProcessor, XCLIPModel
from utils import convert_frames_to_gif, download_youtube_video, sample_frames_from_video_file
model_name = "microsoft/xclip-base-patch16-zero-shot"
processor = XCLIPProcessor.from_pretrained(model_name)
model = XCLIPModel.from_pretrained(model_name)
examples = [
["https://www.youtu.be/l1dBM8ZECao", "sleeping dog,cat fight club,birds of prey"],
["https://youtu.be/VMj-3S1tku0", "programming course,eating spaghetti,playing football"],
["https://www.youtu.be/x8UAUAuKNcU", "game of thrones,the lord of the rings,vikings"]
]
def predict(youtube_url, labels_text):
labels = labels_text.split(",")
video_path = download_youtube_video(youtube_url)
frames = sample_frames_from_video_file(video_path, num_frames=32)
os.remove(video_path)
gif_path = convert_frames_to_gif(frames)
inputs = processor(
text=labels,
videos=list(frames),
return_tensors="pt",
padding=True
)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
label_to_prob = {}
for ind, label in enumerate(labels):
label_to_prob[label] = float(probs[ind])
return label_to_prob, gif_path
app = gr.Blocks()
with app:
gr.Markdown("# **<p align='center'>Zero-shot Video Classification with X-CLIP</p>**")
gr.Markdown(
"""
<p style='text-align: center'>
Follow me for more!
<br> <a href='https://twitter.com/fcakyon' target='_blank'>twitter</a> | <a href='https://github.com/fcakyon' target='_blank'>github</a> | <a href='https://www.linkedin.com/in/fcakyon/' target='_blank'>linkedin</a> | <a href='https://fcakyon.medium.com/' target='_blank'>medium</a>
</p>
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("Provide a Youtube video URL and a list of labels separated by commas")
youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
labels_text = gr.Textbox(label="Labels Text:", show_label=True)
predict_btn = gr.Button(value="Predict")
with gr.Column():
video_gif = gr.Image(label="Input Clip", show_label=True,)
with gr.Column():
predictions = gr.Label(label='Predictions:', show_label=True)
gr.Markdown("**Examples:**")
gr.Examples(examples, [youtube_url, labels_text], [predictions, video_gif], fn=predict, cache_examples=True)
predict_btn.click(predict, inputs=[youtube_url, labels_text], outputs=[predictions, video_gif])
gr.Markdown(
"""
\n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>
<br> Based on this <a href=\"https://huggingface.co/microsoft/xclip-base-patch16-zero-shot\">HuggingFace model</a>
"""
)
app.launch()