bwconrad
Update UI
b7b6d2a
import altair as alt
import cv2
import gradio as gr
import numpy as np
import open_clip
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms.functional import to_pil_image, to_tensor
def run(
path: str,
model_key: str,
text_search: str,
image_search: Image.Image,
thresh: float,
stride: int,
batch_size: int,
center_crop: bool,
):
assert path, "An input video should be provided"
assert (
text_search is not None or image_search is not None
), "A text or image query should be provided"
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Initialize model
name, weights = MODELS[model_key]
model, _, preprocess = open_clip.create_model_and_transforms(
name, pretrained=weights, device=device
)
model.eval()
# Remove center crop transform
if not center_crop:
del preprocess.transforms[1]
# Load video
dataset = LoadVideo(path, transforms=preprocess, vid_stride=stride)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0
)
# Get text query features
if text_search:
# Tokenize search phrase
tokenizer = open_clip.get_tokenizer(name)
text = tokenizer([text_search]).to(device)
# Encode text query
with torch.no_grad():
query_features = model.encode_text(text)
query_features /= query_features.norm(dim=-1, keepdim=True)
# Get image query features
else:
image = preprocess(image_search).unsqueeze(0).to(device)
with torch.no_grad():
query_features = model.encode_image(image)
query_features /= query_features.norm(dim=-1, keepdim=True)
# Encode each frame and compare with query features
matches = []
res = pd.DataFrame(columns=["Frame", "Timestamp", "Similarity"])
for image, orig, frame, timestamp in dataloader:
with torch.no_grad():
image = image.to(device)
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
probs = query_features.cpu().numpy() @ image_features.cpu().numpy().T
probs = probs[0]
# Save frame similarity values
df = pd.DataFrame(
{
"Frame": frame.tolist(),
"Timestamp": torch.round(timestamp / 1000, decimals=2).tolist(),
"Similarity": probs.tolist(),
}
)
res = pd.concat([res, df])
# Check if frame is over threshold
for i, p in enumerate(probs):
if p > thresh:
matches.append(to_pil_image(orig[i]))
print(f"Frames: {frame.tolist()} - Probs: {probs}")
# Create plot of similarity values
lines = (
alt.Chart(res)
.mark_line(color="firebrick")
.encode(
alt.X("Timestamp", title="Timestamp (seconds)"),
alt.Y("Similarity", scale=alt.Scale(zero=False)),
)
).properties(width=600)
rule = alt.Chart().mark_rule(strokeDash=[6, 3], size=2).encode(y=alt.datum(thresh))
return lines + rule, matches[:30] # Only return up to 30 images to not crash the UI
class LoadVideo(Dataset):
def __init__(self, path, transforms, vid_stride=1):
self.transforms = transforms
self.vid_stride = vid_stride
self.cur_frame = 0
self.cap = cv2.VideoCapture(path)
self.total_frames = int(
self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride
)
def __getitem__(self, _):
# Read video
# Skip over frames
for _ in range(self.vid_stride):
self.cap.grab()
self.cur_frame += 1
# Read frame
_, img = self.cap.retrieve()
timestamp = self.cap.get(cv2.CAP_PROP_POS_MSEC)
# Convert to PIL
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(np.uint8(img))
# Apply transforms
img_t = self.transforms(img)
return img_t, to_tensor(img), self.cur_frame, timestamp
def __len__(self):
return self.total_frames
MODELS = {
"convnext_base - laion400m_s13b_b51k": ("convnext_base", "laion400m_s13b_b51k"),
"convnext_base_w - laion2b_s13b_b82k": (
"convnext_base_w",
"laion2b_s13b_b82k",
),
"convnext_base_w - laion2b_s13b_b82k_augreg": (
"convnext_base_w",
"laion2b_s13b_b82k_augreg",
),
"convnext_base_w - laion_aesthetic_s13b_b82k": (
"convnext_base_w",
"laion_aesthetic_s13b_b82k",
),
"convnext_base_w_320 - laion_aesthetic_s13b_b82k": (
"convnext_base_w_320",
"laion_aesthetic_s13b_b82k",
),
"convnext_base_w_320 - laion_aesthetic_s13b_b82k_augreg": (
"convnext_base_w_320",
"laion_aesthetic_s13b_b82k_augreg",
),
"convnext_large_d - laion2b_s26b_b102k_augreg": (
"convnext_large_d",
"laion2b_s26b_b102k_augreg",
),
"convnext_large_d_320 - laion2b_s29b_b131k_ft": (
"convnext_large_d_320",
"laion2b_s29b_b131k_ft",
),
"convnext_large_d_320 - laion2b_s29b_b131k_ft_soup": (
"convnext_large_d_320",
"laion2b_s29b_b131k_ft_soup",
),
"convnext_xxlarge - laion2b_s34b_b82k_augreg": (
"convnext_xxlarge",
"laion2b_s34b_b82k_augreg",
),
"convnext_xxlarge - laion2b_s34b_b82k_augreg_rewind": (
"convnext_xxlarge",
"laion2b_s34b_b82k_augreg_rewind",
),
"convnext_xxlarge - laion2b_s34b_b82k_augreg_soup": (
"convnext_xxlarge",
"laion2b_s34b_b82k_augreg_soup",
),
}
if __name__ == "__main__":
desc_text = """
Search the content's of a video with a text description. This application utilizes ConvNext CLIP models from [OpenCLIP](https://github.com/mlfoundations/open_clip) to compare video frames with the feature representation of a user text or image query. Code can be found at [this repo](https://github.com/bwconrad/video-content-search).
__Note__: Long videos (over a few minutes) may cause UI performance issues.
"""
text_app = gr.Interface(
description=desc_text,
fn=run,
inputs=[
gr.Video(label="Video"),
gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value="convnext_base_w - laion2b_s13b_b82k",
),
gr.Textbox(label="Text Search Query"),
gr.Image(label="Image Search Query", visible=False),
gr.Slider(label="Threshold", maximum=1.0, value=0.3),
gr.Slider(label="Frame-rate Stride", value=4, step=1),
gr.Slider(label="Batch Size", value=4, step=1),
gr.Checkbox(label="Center Crop"),
],
outputs=[
gr.Plot(label="Similarity Plot"),
gr.Gallery(label="Matched Frames").style(
columns=2, object_fit="contain", height="auto"
),
],
allow_flagging="never",
)
desc_image = """
Search the content's of a video with an image query. This application utilizes ConvNext CLIP models from [OpenCLIP](https://github.com/mlfoundations/open_clip) to compare video frames with the feature representation of a user text or image query. Code can be found at [this repo](https://github.com/bwconrad/video-content-search).
__Note__: Long videos (over a few minutes) may cause UI performance issues.
"""
image_app = gr.Interface(
description=desc_image,
fn=run,
inputs=[
gr.Video(label="Video"),
gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value="convnext_base_w - laion2b_s13b_b82k",
),
gr.Textbox(label="Text Search Query", visible=False),
gr.Image(label="Image Search Query", type="pil"),
gr.Slider(label="Threshold", maximum=1.0, value=0.3),
gr.Slider(label="Frame-rate Stride", value=4, step=1),
gr.Slider(label="Batch Size", value=4, step=1),
gr.Checkbox(label="Center Crop"),
],
outputs=[
gr.Plot(label="Similarity Plot"),
gr.Gallery(label="Matched Frames").style(
columns=2, object_fit="contain", height="auto"
),
],
allow_flagging="never",
)
app = gr.TabbedInterface(
interface_list=[text_app, image_app],
tab_names=["Text Query Search", "Image Query Search"],
title="CLIP Video Content Search",
)
app.launch()