clifs / app.py
ncoop57
Reorganize interface and code to be more modular and add necessary debian packages to install
fd2744e
raw
history blame
2.65 kB
import ffmpeg
import torch
import youtube_dl
import numpy as np
import streamlit as st
from sentence_transformers import SentenceTransformer, util, models
from clip import CLIPModel
from PIL import Image
@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
clip = CLIPModel()
model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
return model
def get_embedding(model, query, video):
text_emb = model.encode(query, device='cpu')
# Encode an image:
images = []
for img in video:
images.append(Image.fromarray(img))
img_embs = model.encode(images, device='cpu')
return text_emb, img_embs
def my_hook(d, model, desc, top_k, text):
if d['status'] == 'finished':
text.text("Processing video...")
probe = ffmpeg.probe(d["filename"])
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
out, _ = (
ffmpeg
.input(d["filename"])
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True)
)
video = (
np
.frombuffer(out, np.uint8)
.reshape([-1, height, width, 3])
)[::10][:200]
txt_embd, img_embds = get_embedding(model, desc, video)
cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
ids = np.argsort(cos_scores)[0][-top_k:]
imgs = [Image.fromarray(video[i]) for i in ids]
text.empty()
st.image(imgs)
def run():
st.set_page_config(page_title="Youtube CLIFS")
# main body
model = get_model()
st.sidebar.markdown("### Controls:")
top_k = st.sidebar.slider(
"Top K",
min_value=1,
max_value=5,
step=1,
)
desc = st.sidebar.text_input(
"Search Description",
value="Two white puppies",
help="Text description of what you want to find in the video",
)
url = st.sidebar.text_input(
"Youtube Video URL",
value='https://youtu.be/I3AaW9ZevIU',
help="Youtube video you'd like to search through",
)
submit_button = st.sidebar.button("Search")
if submit_button:
text = st.text("Downloading video...")
hook = lambda d: my_hook(d, model, desc, top_k, text)
ydl_opts = {"format": "mp4[height=360]", "progress_hooks": [hook], }
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
if __name__ == "__main__":
run()