File size: 5,033 Bytes
bc120ce e9fc911 bc120ce e9fc911 bc120ce e9fc911 bc120ce e9fc911 bc120ce e9fc911 bc120ce e96a3aa bc120ce 655f4ae bc120ce d2fb314 655f4ae bc120ce 655f4ae bc120ce e96a3aa e9fc911 655f4ae bc120ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
import random
from functools import partial
import clip
import decord
import gradio as gr
import nncore
import numpy as np
import torch
import torchvision.transforms.functional as F
from decord import VideoReader
from nncore.engine import load_checkpoint
from nncore.nn import build_model
import pandas as pd
TITLE = '๐R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
TITLE_MD = '<h1 align="center">๐R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.'
GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.'
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
# yapf:disable
EXAMPLES = [
('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
('data/CkWOpyrAXdw_210.0_360.0.mp4', 'Indian girl cleaning her kitchen before cooking.'),
('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
]
# yapf:enable
def convert_time(seconds):
minutes, seconds = divmod(round(max(seconds, 0)), 60)
return f'{minutes:02d}:{seconds:02d}'
def load_video(video_path, cfg):
decord.bridge.set_bridge('torch')
vr = VideoReader(video_path)
stride = vr.get_avg_fps() / cfg.data.val.fps
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
size = 336 if '336px' in cfg.model.arch else 224
h, w = video.size(-2), video.size(-1)
s = min(h, w)
x, y = round((h - s) / 2), round((w - s) / 2)
video = video[..., x:x + s, y:y + s]
video = F.resize(video, size=(size, size))
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
video = video.reshape(video.size(0), -1).unsqueeze(0)
return video
def init_model(config, checkpoint):
cfg = nncore.Config.from_file(config)
cfg.model.init = True
if checkpoint.startswith('http'):
checkpoint = nncore.download(checkpoint, out_dir='checkpoints')
model = build_model(cfg.model, dist=False).eval()
model = load_checkpoint(model, checkpoint, warning=False)
return model, cfg
def main(video, query, model, cfg):
if len(query) == 0:
raise gr.Error('Text query can not be empty.')
try:
video = load_video(video, cfg)
except Exception:
raise gr.Error('Failed to load the video.')
query = clip.tokenize(query, truncate=True)
device = next(model.parameters()).device
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
with torch.inference_mode():
pred = model(data)
mr = pred['_out']['boundary'][:5].cpu().tolist()
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
hd = pred['_out']['saliency'].cpu()
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
return mr, hd
model, cfg = init_model(CONFIG, WEIGHT)
fn = partial(main, model=model, cfg=cfg)
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(TITLE_MD)
gr.Markdown(DESCRIPTION_MD)
gr.Markdown(GUIDE_MD)
with gr.Row():
with gr.Column():
video = gr.Video(label='Video')
query = gr.Textbox(label='Text Query')
with gr.Row():
random_btn = gr.Button(value='๐ฎ Random')
gr.ClearButton([video, query], value='๐๏ธ Reset')
submit_btn = gr.Button(value='๐ Submit')
with gr.Column():
mr = gr.DataFrame(
headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
hd = gr.LinePlot(
x='x',
y='y',
x_title='Time (seconds)',
y_title='Saliency Score',
label='Highlight Detection')
random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
submit_btn.click(fn, [video, query], [mr, hd])
demo.launch()
|