|
|
|
|
|
from functools import partial |
|
|
|
import clip |
|
import decord |
|
import nncore |
|
import torch |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torchvision.transforms.functional as F |
|
from decord import VideoReader |
|
from nncore.engine import load_checkpoint |
|
from nncore.nn import build_model |
|
|
|
TITLE = 'πR2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding' |
|
DESCRIPTION = 'R2-Tuning is a parameter- and memory efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.\n\nUser Guide:\n1. Upload or record a video using web camera.\n2. Input a text query. A good practice is to use a sentence with 5~10 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.' |
|
|
|
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' |
|
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' |
|
|
|
|
|
def convert_time(seconds): |
|
minutes, seconds = divmod(round(seconds), 60) |
|
return f'{minutes:02d}:{seconds:02d}' |
|
|
|
|
|
def load_video(video_path, cfg): |
|
decord.bridge.set_bridge('torch') |
|
|
|
vr = VideoReader(video_path) |
|
stride = vr.get_avg_fps() / cfg.data.val.fps |
|
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] |
|
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 |
|
|
|
size = 336 if '336px' in cfg.model.arch else 224 |
|
h, w = video.size(-2), video.size(-1) |
|
s = min(h, w) |
|
x, y = round((h - s) / 2), round((w - s) / 2) |
|
video = video[..., x:x + s, y:y + s] |
|
video = F.resize(video, size=(size, size)) |
|
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) |
|
video = video.reshape(video.size(0), -1).unsqueeze(0) |
|
|
|
return video |
|
|
|
|
|
def init_model(config, checkpoint): |
|
cfg = nncore.Config.from_file(config) |
|
cfg.model.init = True |
|
|
|
if checkpoint.startswith('http'): |
|
checkpoint = nncore.download(checkpoint, out_dir='checkpoints') |
|
|
|
model = build_model(cfg.model, dist=False).eval() |
|
model = load_checkpoint(model, checkpoint, warning=False) |
|
|
|
return model, cfg |
|
|
|
|
|
def main(video, query, model, cfg): |
|
if len(query) == 0: |
|
raise gr.Error('Text query can not be empty.') |
|
|
|
try: |
|
video = load_video(video, cfg) |
|
except Exception: |
|
raise gr.Error('Failed to load the video.') |
|
|
|
query = clip.tokenize(query, truncate=True) |
|
|
|
device = next(model.parameters()).device |
|
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) |
|
|
|
with torch.inference_mode(): |
|
pred = model(data) |
|
|
|
mr = pred['_out']['boundary'][:5].cpu().tolist() |
|
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr] |
|
|
|
hd = pred['_out']['saliency'].cpu() |
|
hd = ((hd - hd.min()) / (hd.max() - hd.min())).tolist() |
|
|
|
fig, ax = plt.subplots(figsize=(10, 5.5)) |
|
ax.plot(range(0, len(hd) * 2, 2), hd) |
|
|
|
ax.set_xlabel('Time (s)', fontsize=15) |
|
ax.set_ylabel('Saliency Score', fontsize=15) |
|
|
|
ax.tick_params(labelsize=14) |
|
plt.tight_layout(rect=(0.02, 0.02, 0.95, 0.885)) |
|
|
|
return mr, fig |
|
|
|
|
|
model, cfg = init_model(CONFIG, WEIGHT) |
|
main = partial(main, model=model, cfg=cfg) |
|
|
|
demo = gr.Interface( |
|
fn=main, |
|
inputs=[gr.Video(label='Video'), |
|
gr.Textbox(label='Text Query')], |
|
outputs=[ |
|
gr.Dataframe( |
|
headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval'), |
|
gr.Plot(label='Highlight Detection') |
|
], |
|
allow_flagging='auto', |
|
title=TITLE, |
|
description=DESCRIPTION) |
|
demo.launch() |
|
|