File size: 5,033 Bytes
bc120ce
 
e9fc911
bc120ce
 
 
 
 
e9fc911
bc120ce
e9fc911
bc120ce
 
 
 
 
e9fc911
 
 
 
 
 
 
bc120ce
 
e9fc911
 
 
 
 
 
 
 
 
 
 
bc120ce
 
 
e96a3aa
bc120ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655f4ae
bc120ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2fb314
655f4ae
bc120ce
655f4ae
bc120ce
 
 
e96a3aa
e9fc911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655f4ae
bc120ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.

import random
from functools import partial

import clip
import decord
import gradio as gr
import nncore
import numpy as np
import torch
import torchvision.transforms.functional as F
from decord import VideoReader
from nncore.engine import load_checkpoint
from nncore.nn import build_model

import pandas as pd

TITLE = '๐ŸŒ€R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'

TITLE_MD = '<h1 align="center">๐ŸŒ€R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.'
GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.'

CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'

# yapf:disable
EXAMPLES = [
    ('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
    ('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
    ('data/CkWOpyrAXdw_210.0_360.0.mp4', 'Indian girl cleaning her kitchen before cooking.'),
    ('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
    ('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
]
# yapf:enable


def convert_time(seconds):
    minutes, seconds = divmod(round(max(seconds, 0)), 60)
    return f'{minutes:02d}:{seconds:02d}'


def load_video(video_path, cfg):
    decord.bridge.set_bridge('torch')

    vr = VideoReader(video_path)
    stride = vr.get_avg_fps() / cfg.data.val.fps
    fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
    video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255

    size = 336 if '336px' in cfg.model.arch else 224
    h, w = video.size(-2), video.size(-1)
    s = min(h, w)
    x, y = round((h - s) / 2), round((w - s) / 2)
    video = video[..., x:x + s, y:y + s]
    video = F.resize(video, size=(size, size))
    video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
    video = video.reshape(video.size(0), -1).unsqueeze(0)

    return video


def init_model(config, checkpoint):
    cfg = nncore.Config.from_file(config)
    cfg.model.init = True

    if checkpoint.startswith('http'):
        checkpoint = nncore.download(checkpoint, out_dir='checkpoints')

    model = build_model(cfg.model, dist=False).eval()
    model = load_checkpoint(model, checkpoint, warning=False)

    return model, cfg


def main(video, query, model, cfg):
    if len(query) == 0:
        raise gr.Error('Text query can not be empty.')

    try:
        video = load_video(video, cfg)
    except Exception:
        raise gr.Error('Failed to load the video.')

    query = clip.tokenize(query, truncate=True)

    device = next(model.parameters()).device
    data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])

    with torch.inference_mode():
        pred = model(data)

    mr = pred['_out']['boundary'][:5].cpu().tolist()
    mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]

    hd = pred['_out']['saliency'].cpu()
    hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
    hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))

    return mr, hd


model, cfg = init_model(CONFIG, WEIGHT)

fn = partial(main, model=model, cfg=cfg)

with gr.Blocks(title=TITLE) as demo:
    gr.Markdown(TITLE_MD)
    gr.Markdown(DESCRIPTION_MD)
    gr.Markdown(GUIDE_MD)

    with gr.Row():
        with gr.Column():
            video = gr.Video(label='Video')
            query = gr.Textbox(label='Text Query')

            with gr.Row():
                random_btn = gr.Button(value='๐Ÿ”ฎ Random')
                gr.ClearButton([video, query], value='๐Ÿ—‘๏ธ Reset')
                submit_btn = gr.Button(value='๐Ÿš€ Submit')

        with gr.Column():
            mr = gr.DataFrame(
                headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
            hd = gr.LinePlot(
                x='x',
                y='y',
                x_title='Time (seconds)',
                y_title='Saliency Score',
                label='Highlight Detection')

        random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
        submit_btn.click(fn, [video, query], [mr, hd])

demo.launch()