Add files
Browse files- .gitignore +10 -0
- README.md +3 -5
- app.py +112 -0
- configs/_base_/datasets/qvhighlights.py +38 -0
- configs/_base_/models/model.py +44 -0
- configs/qvhighlights/r2_tuning_qvhighlights.py +1 -0
- models/__init__.py +6 -0
- models/adapter.py +99 -0
- models/blocks.py +98 -0
- models/generator.py +62 -0
- models/loss.py +210 -0
- models/model.py +206 -0
- requirements.txt +6 -0
- setup.cfg +15 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# Temporary data
|
7 |
+
/checkpoints
|
8 |
+
/flagged
|
9 |
+
.DS_Store
|
10 |
+
._*
|
README.md
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
---
|
2 |
title: R2 Tuning
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: bsd-3-clause
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: R2 Tuning
|
3 |
+
emoji: 🌀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: bsd-3-clause
|
11 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import clip
|
6 |
+
import decord
|
7 |
+
import nncore
|
8 |
+
import torch
|
9 |
+
import gradio as gr
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import torchvision.transforms.functional as F
|
13 |
+
from decord import VideoReader
|
14 |
+
from nncore.engine import load_checkpoint
|
15 |
+
from nncore.nn import build_model
|
16 |
+
|
17 |
+
TITLE = '🌀R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding' # noqa
|
18 |
+
DESCRIPTION = 'R2-Tuning is a parameter- and memory efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.\n\nUser Guide:\n1. Upload or record a video using web camera.\n2. Input a text query. A good practice is to use a sentence with 5~10 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.' # noqa
|
19 |
+
|
20 |
+
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
|
21 |
+
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' # noqa
|
22 |
+
|
23 |
+
|
24 |
+
def convert_time(seconds):
|
25 |
+
minutes, seconds = divmod(round(seconds), 60)
|
26 |
+
return f'{minutes:02d}:{seconds:02d}'
|
27 |
+
|
28 |
+
|
29 |
+
def load_video(video_path, cfg):
|
30 |
+
decord.bridge.set_bridge('torch')
|
31 |
+
|
32 |
+
vr = VideoReader(video_path)
|
33 |
+
stride = vr.get_avg_fps() / cfg.data.val.fps
|
34 |
+
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
|
35 |
+
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
|
36 |
+
|
37 |
+
size = 336 if '336px' in cfg.model.arch else 224
|
38 |
+
h, w = video.size(-2), video.size(-1)
|
39 |
+
s = min(h, w)
|
40 |
+
x, y = round((h - s) / 2), round((w - s) / 2)
|
41 |
+
video = video[..., x:x + s, y:y + s]
|
42 |
+
video = F.resize(video, size=(size, size))
|
43 |
+
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
|
44 |
+
video = video.reshape(video.size(0), -1).unsqueeze(0)
|
45 |
+
|
46 |
+
return video
|
47 |
+
|
48 |
+
|
49 |
+
def init_model(config, checkpoint):
|
50 |
+
cfg = nncore.Config.from_file(config)
|
51 |
+
cfg.model.init = True
|
52 |
+
|
53 |
+
if checkpoint.startswith('http'):
|
54 |
+
checkpoint = nncore.download(checkpoint, out_dir='checkpoints')
|
55 |
+
|
56 |
+
model = build_model(cfg.model, dist=False).eval()
|
57 |
+
model = load_checkpoint(model, checkpoint, warning=False)
|
58 |
+
|
59 |
+
return model, cfg
|
60 |
+
|
61 |
+
|
62 |
+
def main(video, query, model, cfg):
|
63 |
+
if len(query) == 0:
|
64 |
+
raise gr.Error('Text query can not be empty.')
|
65 |
+
|
66 |
+
try:
|
67 |
+
video = load_video(video, cfg)
|
68 |
+
except Exception:
|
69 |
+
raise gr.Error('Failed to load the video.')
|
70 |
+
|
71 |
+
query = clip.tokenize(query, truncate=True)
|
72 |
+
|
73 |
+
device = next(model.parameters()).device
|
74 |
+
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
|
75 |
+
|
76 |
+
with torch.inference_mode():
|
77 |
+
pred = model(data)
|
78 |
+
|
79 |
+
mr = pred['_out']['boundary'][:5].cpu().tolist()
|
80 |
+
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
|
81 |
+
|
82 |
+
hd = pred['_out']['saliency'].cpu()
|
83 |
+
hd = ((hd - hd.min()) / (hd.max() - hd.min())).tolist()
|
84 |
+
|
85 |
+
fig, ax = plt.subplots(figsize=(10, 5.5))
|
86 |
+
ax.plot(range(0, len(hd) * 2, 2), hd)
|
87 |
+
|
88 |
+
ax.set_xlabel('Time (s)', fontsize=15)
|
89 |
+
ax.set_ylabel('Saliency Score', fontsize=15)
|
90 |
+
|
91 |
+
ax.tick_params(labelsize=14)
|
92 |
+
plt.tight_layout(rect=(0.02, 0.02, 0.95, 0.885))
|
93 |
+
|
94 |
+
return mr, fig
|
95 |
+
|
96 |
+
|
97 |
+
model, cfg = init_model(CONFIG, WEIGHT)
|
98 |
+
main = partial(main, model=model, cfg=cfg)
|
99 |
+
|
100 |
+
demo = gr.Interface(
|
101 |
+
fn=main,
|
102 |
+
inputs=[gr.Video(label='Video'),
|
103 |
+
gr.Textbox(label='Text Query')],
|
104 |
+
outputs=[
|
105 |
+
gr.Dataframe(
|
106 |
+
headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval'),
|
107 |
+
gr.Plot(label='Highlight Detection')
|
108 |
+
],
|
109 |
+
allow_flagging='auto',
|
110 |
+
title=TITLE,
|
111 |
+
description=DESCRIPTION)
|
112 |
+
demo.launch()
|
configs/_base_/datasets/qvhighlights.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset settings
|
2 |
+
data_type = 'Grounding'
|
3 |
+
data_root = 'data/qvhighlights/'
|
4 |
+
data = dict(
|
5 |
+
train=dict(
|
6 |
+
type='RepeatDataset',
|
7 |
+
times=4,
|
8 |
+
dataset=dict(
|
9 |
+
type=data_type,
|
10 |
+
label_path=data_root + 'qvhighlights_train.jsonl',
|
11 |
+
video_path=data_root + 'frames_224_0.5fps',
|
12 |
+
cache_path=data_root + 'clip_b32_vid_k4',
|
13 |
+
query_path=data_root + 'clip_b32_txt_k4',
|
14 |
+
use_cache=True,
|
15 |
+
min_video_len=5,
|
16 |
+
fps=0.5,
|
17 |
+
unit=2),
|
18 |
+
loader=dict(batch_size=128, num_workers=4, pin_memory=True, shuffle=True)),
|
19 |
+
val=dict(
|
20 |
+
type=data_type,
|
21 |
+
label_path=data_root + 'qvhighlights_val.jsonl',
|
22 |
+
video_path=data_root + 'frames_224_0.5fps',
|
23 |
+
cache_path=data_root + 'clip_b32_vid_k4',
|
24 |
+
query_path=data_root + 'clip_b32_txt_k4',
|
25 |
+
use_cache=True,
|
26 |
+
fps=0.5,
|
27 |
+
unit=2,
|
28 |
+
loader=dict(batch_size=1, num_workers=4, pin_memory=True, shuffle=False)),
|
29 |
+
test=dict(
|
30 |
+
type=data_type,
|
31 |
+
label_path=data_root + 'qvhighlights_test.jsonl',
|
32 |
+
video_path=data_root + 'frames_224_0.5fps',
|
33 |
+
cache_path=data_root + 'clip_b32_vid_k4',
|
34 |
+
query_path=data_root + 'clip_b32_txt_k4',
|
35 |
+
use_cache=True,
|
36 |
+
fps=0.5,
|
37 |
+
unit=2,
|
38 |
+
loader=dict(batch_size=1, num_workers=4, pin_memory=True, shuffle=False)))
|
configs/_base_/models/model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['models']
|
2 |
+
# model settings
|
3 |
+
model = dict(
|
4 |
+
type='R2Tuning',
|
5 |
+
arch='ViT-B/32',
|
6 |
+
init=False,
|
7 |
+
dims=256,
|
8 |
+
strides=(1, 2, 4, 8),
|
9 |
+
buffer_size=1024,
|
10 |
+
max_num_moment=50,
|
11 |
+
adapter_cfg=dict(
|
12 |
+
type='R2Block',
|
13 |
+
k=4,
|
14 |
+
dropout=0.5,
|
15 |
+
use_tef=True,
|
16 |
+
pos_cfg=dict(type='PositionalEncoding', normalize=True, max_len=1024),
|
17 |
+
tem_cfg=dict(
|
18 |
+
type='TransformerDecoderLayer',
|
19 |
+
heads=8,
|
20 |
+
ratio=4,
|
21 |
+
att_dropout=0.0,
|
22 |
+
ffn_dropout=0.0,
|
23 |
+
att_out_dropout=0.0,
|
24 |
+
ffn_out_dropout=0.0,
|
25 |
+
droppath=0.1,
|
26 |
+
pre_norm=False,
|
27 |
+
bias=True,
|
28 |
+
norm_cfg=dict(type='LN'),
|
29 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
30 |
+
order=('cross_att', 'self_att', 'ffn'),
|
31 |
+
att_init_cfg=dict(type='xavier', distribution='uniform'),
|
32 |
+
ffn_init_cfg=dict(type='kaiming'))),
|
33 |
+
pyramid_cfg=dict(type='ConvPyramid'),
|
34 |
+
pooling_cfg=dict(type='AdaPooling'),
|
35 |
+
class_head_cfg=dict(type='ConvHead', kernal_size=3),
|
36 |
+
coord_head_cfg=dict(type='ConvHead', kernal_size=3),
|
37 |
+
loss_cfg=dict(
|
38 |
+
type='BundleLoss',
|
39 |
+
sample_radius=1.5,
|
40 |
+
loss_cls=dict(type='FocalLoss', loss_weight=1.0),
|
41 |
+
loss_reg=dict(type='L1Loss', loss_weight=0.2),
|
42 |
+
loss_sal=dict(type='SampledNCELoss', loss_weight=0.1),
|
43 |
+
loss_video_cal=dict(type='InfoNCELoss', loss_weight=0.1),
|
44 |
+
loss_layer_cal=dict(type='InfoNCELoss', loss_weight=0.1)))
|
configs/qvhighlights/r2_tuning_qvhighlights.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
_base_ = ['../_base_/models/model.py', '../_base_/datasets/qvhighlights.py']
|
models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .adapter import R2Block
|
2 |
+
from .blocks import AdaPooling, ConvHead, ConvPyramid
|
3 |
+
from .loss import BundleLoss
|
4 |
+
from .model import R2Tuning
|
5 |
+
|
6 |
+
__all__ = ['R2Block', 'AdaPooling', 'ConvHead', 'ConvPyramid', 'BundleLoss', 'R2Tuning']
|
models/adapter.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from nncore.nn import MODELS, build_model
|
6 |
+
|
7 |
+
|
8 |
+
@MODELS.register()
|
9 |
+
class R2Block(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
dims,
|
13 |
+
in_dims,
|
14 |
+
k=4,
|
15 |
+
dropout=0.5,
|
16 |
+
use_tef=True,
|
17 |
+
pos_cfg=None,
|
18 |
+
tem_cfg=None):
|
19 |
+
super(R2Block, self).__init__()
|
20 |
+
|
21 |
+
# yapf:disable
|
22 |
+
self.video_map = nn.Sequential(
|
23 |
+
nn.LayerNorm((in_dims[0] + 2) if use_tef else in_dims[0]),
|
24 |
+
nn.Dropout(dropout),
|
25 |
+
nn.Linear((in_dims[0] + 2) if use_tef else in_dims[0], dims),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.LayerNorm(dims),
|
28 |
+
nn.Dropout(dropout),
|
29 |
+
nn.Linear(dims, dims))
|
30 |
+
|
31 |
+
self.query_map = nn.Sequential(
|
32 |
+
nn.LayerNorm(in_dims[1]),
|
33 |
+
nn.Dropout(dropout),
|
34 |
+
nn.Linear(in_dims[1], dims),
|
35 |
+
nn.ReLU(inplace=True),
|
36 |
+
nn.LayerNorm(dims),
|
37 |
+
nn.Dropout(dropout),
|
38 |
+
nn.Linear(dims, dims))
|
39 |
+
# yapf:enable
|
40 |
+
|
41 |
+
if k > 1:
|
42 |
+
self.gate = nn.Parameter(torch.zeros([k - 1]))
|
43 |
+
|
44 |
+
self.v_map = nn.Linear(dims, dims)
|
45 |
+
self.q_map = nn.Linear(dims, dims)
|
46 |
+
self.scale = nn.Parameter(torch.zeros([k]))
|
47 |
+
|
48 |
+
self.pos = build_model(pos_cfg, dims=dims)
|
49 |
+
self.tem = build_model(tem_cfg, dims=dims)
|
50 |
+
|
51 |
+
self.dims = dims
|
52 |
+
self.in_dims = in_dims
|
53 |
+
self.k = k
|
54 |
+
self.dropout = dropout
|
55 |
+
self.use_tef = use_tef
|
56 |
+
|
57 |
+
def forward(self, video_emb, query_emb, video_msk, query_msk):
|
58 |
+
video_emb = video_emb[-self.k:]
|
59 |
+
query_emb = query_emb[-self.k:]
|
60 |
+
|
61 |
+
_, b, t, p, _ = video_emb.size()
|
62 |
+
|
63 |
+
if self.use_tef:
|
64 |
+
tef_s = torch.arange(0, 1, 1 / t, device=video_emb.device)
|
65 |
+
tef_e = tef_s + 1.0 / t
|
66 |
+
tef = torch.stack((tef_s, tef_e), dim=1)
|
67 |
+
tef = tef.unsqueeze(1).unsqueeze(0).unsqueeze(0).repeat(self.k, b, 1, p, 1)
|
68 |
+
video_emb = torch.cat((video_emb, tef[:, :, :video_emb.size(2)]), dim=-1)
|
69 |
+
|
70 |
+
coll_v, coll_q, last = [], [], None
|
71 |
+
for i in range(self.k - 1, -1, -1):
|
72 |
+
v_emb = self.video_map(video_emb[i]) # B * T * P * C
|
73 |
+
q_emb = self.query_map(query_emb[i]) # B * L * C
|
74 |
+
|
75 |
+
coll_v.append(v_emb[:, :, 0])
|
76 |
+
coll_q.append(q_emb)
|
77 |
+
|
78 |
+
v_pool = v_emb.view(b * t, -1, self.dims) # BT * P * C
|
79 |
+
q_pool = q_emb.repeat_interleave(t, dim=0) # BT * L * C
|
80 |
+
|
81 |
+
v_pool_map = self.v_map(v_pool) # BT * P * C
|
82 |
+
q_pool_map = self.q_map(q_pool) # BT * L * C
|
83 |
+
|
84 |
+
att = torch.bmm(q_pool_map, v_pool_map.transpose(1, 2)) / self.dims**0.5
|
85 |
+
att = att.softmax(-1) # BT * L * P
|
86 |
+
|
87 |
+
o_pool = torch.bmm(att, v_pool) + q_pool # BT * L * C
|
88 |
+
o_pool = o_pool.amax(dim=1, keepdim=True) # BT * 1 * C
|
89 |
+
v_emb = v_pool[:, 0, None] + o_pool * self.scale[i].tanh()
|
90 |
+
v_emb = v_emb.view(b, t, self.dims) # B * T * C
|
91 |
+
|
92 |
+
if i < self.k - 1:
|
93 |
+
gate = self.gate[i].sigmoid()
|
94 |
+
v_emb = gate * v_emb + (1 - gate) * last
|
95 |
+
|
96 |
+
v_pe = self.pos(v_emb)
|
97 |
+
last = self.tem(v_emb, q_emb, q_pe=v_pe, q_mask=video_msk, k_mask=query_msk)
|
98 |
+
|
99 |
+
return last, q_emb, coll_v, coll_q
|
models/blocks.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from nncore.nn import MODELS
|
9 |
+
|
10 |
+
|
11 |
+
class Permute(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
super(Permute, self).__init__()
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return x.transpose(-1, -2)
|
18 |
+
|
19 |
+
|
20 |
+
@MODELS.register()
|
21 |
+
class ConvPyramid(nn.Module):
|
22 |
+
|
23 |
+
def __init__(self, dims, strides):
|
24 |
+
super(ConvPyramid, self).__init__()
|
25 |
+
|
26 |
+
self.blocks = nn.ModuleList()
|
27 |
+
for s in strides:
|
28 |
+
p = int(math.log2(s))
|
29 |
+
if p == 0:
|
30 |
+
layers = nn.ReLU(inplace=True)
|
31 |
+
else:
|
32 |
+
layers = nn.Sequential()
|
33 |
+
conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d
|
34 |
+
for _ in range(abs(p)):
|
35 |
+
layers.extend([
|
36 |
+
Permute(),
|
37 |
+
conv_cls(dims, dims, 2, stride=2),
|
38 |
+
Permute(),
|
39 |
+
nn.LayerNorm(dims),
|
40 |
+
nn.ReLU(inplace=True)
|
41 |
+
])
|
42 |
+
self.blocks.append(layers)
|
43 |
+
|
44 |
+
self.strides = strides
|
45 |
+
|
46 |
+
def forward(self, x, mask, return_mask=False):
|
47 |
+
pymid, pymid_msk = [], []
|
48 |
+
|
49 |
+
for s, blk in zip(self.strides, self.blocks):
|
50 |
+
if x.size(1) < s:
|
51 |
+
continue
|
52 |
+
|
53 |
+
pymid.append(blk(x))
|
54 |
+
|
55 |
+
if return_mask:
|
56 |
+
if s > 1:
|
57 |
+
msk = F.max_pool1d(mask.float(), s, stride=s).long()
|
58 |
+
elif s < 1:
|
59 |
+
msk = mask.repeat_interleave(int(1 / s), dim=1)
|
60 |
+
else:
|
61 |
+
msk = mask
|
62 |
+
pymid_msk.append(msk)
|
63 |
+
|
64 |
+
return pymid, pymid_msk
|
65 |
+
|
66 |
+
|
67 |
+
@MODELS.register()
|
68 |
+
class AdaPooling(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, dims):
|
71 |
+
super(AdaPooling, self).__init__()
|
72 |
+
self.att = nn.Linear(dims, 1, bias=False)
|
73 |
+
|
74 |
+
def forward(self, x, mask):
|
75 |
+
a = self.att(x) + torch.where(mask.unsqueeze(2) == 1, .0, float('-inf'))
|
76 |
+
a = a.softmax(dim=1)
|
77 |
+
x = torch.matmul(x.transpose(1, 2), a)
|
78 |
+
x = x.squeeze(2).unsqueeze(1)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
@MODELS.register()
|
83 |
+
class ConvHead(nn.Module):
|
84 |
+
|
85 |
+
def __init__(self, dims, out_dims, kernal_size=3):
|
86 |
+
super(ConvHead, self).__init__()
|
87 |
+
|
88 |
+
# yapf:disable
|
89 |
+
self.module = nn.Sequential(
|
90 |
+
Permute(),
|
91 |
+
nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2),
|
92 |
+
nn.ReLU(inplace=True),
|
93 |
+
nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2),
|
94 |
+
Permute())
|
95 |
+
# yapf:enable
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
return self.module(x)
|
models/generator.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class BufferList(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, buffers):
|
10 |
+
super(BufferList, self).__init__()
|
11 |
+
for i, buffer in enumerate(buffers):
|
12 |
+
self.register_buffer(str(i), buffer, persistent=False)
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self._buffers)
|
16 |
+
|
17 |
+
def __iter__(self):
|
18 |
+
return iter(self._buffers.values())
|
19 |
+
|
20 |
+
|
21 |
+
class PointGenerator(nn.Module):
|
22 |
+
|
23 |
+
def __init__(self, strides, buffer_size, offset=False):
|
24 |
+
super(PointGenerator, self).__init__()
|
25 |
+
|
26 |
+
reg_range, last = [], 0
|
27 |
+
for stride in strides[1:]:
|
28 |
+
reg_range.append((last, stride))
|
29 |
+
last = stride
|
30 |
+
reg_range.append((last, float('inf')))
|
31 |
+
|
32 |
+
self.strides = strides
|
33 |
+
self.reg_range = reg_range
|
34 |
+
self.buffer_size = buffer_size
|
35 |
+
self.offset = offset
|
36 |
+
|
37 |
+
self.buffer = self._cache_points()
|
38 |
+
|
39 |
+
def _cache_points(self):
|
40 |
+
buffer_list = []
|
41 |
+
for stride, reg_range in zip(self.strides, self.reg_range):
|
42 |
+
reg_range = torch.Tensor([reg_range])
|
43 |
+
lv_stride = torch.Tensor([stride])
|
44 |
+
points = torch.arange(0, self.buffer_size, stride)[:, None]
|
45 |
+
if self.offset:
|
46 |
+
points += 0.5 * stride
|
47 |
+
reg_range = reg_range.repeat(points.size(0), 1)
|
48 |
+
lv_stride = lv_stride.repeat(points.size(0), 1)
|
49 |
+
buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1))
|
50 |
+
buffer = BufferList(buffer_list)
|
51 |
+
return buffer
|
52 |
+
|
53 |
+
def forward(self, pymid):
|
54 |
+
points = []
|
55 |
+
sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid))
|
56 |
+
for size, buffer in zip(sizes, self.buffer):
|
57 |
+
if size == 0:
|
58 |
+
continue
|
59 |
+
assert size <= buffer.size(0), 'reached max buffer size'
|
60 |
+
points.append(buffer[:size, :])
|
61 |
+
points = torch.cat(points)
|
62 |
+
return points
|
models/loss.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from nncore.nn import LOSSES, Parameter, build_loss
|
9 |
+
|
10 |
+
|
11 |
+
@LOSSES.register()
|
12 |
+
class SampledNCELoss(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self,
|
15 |
+
temperature=0.07,
|
16 |
+
max_scale=100,
|
17 |
+
learnable=False,
|
18 |
+
direction=('row', 'col'),
|
19 |
+
loss_weight=1.0):
|
20 |
+
super(SampledNCELoss, self).__init__()
|
21 |
+
|
22 |
+
scale = torch.Tensor([math.log(1 / temperature)])
|
23 |
+
|
24 |
+
if learnable:
|
25 |
+
self.scale = Parameter(scale)
|
26 |
+
else:
|
27 |
+
self.register_buffer('scale', scale)
|
28 |
+
|
29 |
+
self.temperature = temperature
|
30 |
+
self.max_scale = max_scale
|
31 |
+
self.learnable = learnable
|
32 |
+
self.direction = (direction, ) if isinstance(direction, str) else direction
|
33 |
+
self.loss_weight = loss_weight
|
34 |
+
|
35 |
+
def extra_repr(self):
|
36 |
+
return ('temperature={}, max_scale={}, learnable={}, direction={}, loss_weight={}'
|
37 |
+
.format(self.temperature, self.max_scale, self.learnable, self.direction,
|
38 |
+
self.loss_weight))
|
39 |
+
|
40 |
+
def forward(self, video_emb, query_emb, video_msk, saliency, pos_clip):
|
41 |
+
batch_inds = torch.arange(video_emb.size(0), device=video_emb.device)
|
42 |
+
|
43 |
+
pos_scores = saliency[batch_inds, pos_clip].unsqueeze(-1)
|
44 |
+
loss_msk = (saliency <= pos_scores) * video_msk
|
45 |
+
|
46 |
+
scale = self.scale.exp().clamp(max=self.max_scale)
|
47 |
+
i_sim = F.cosine_similarity(video_emb, query_emb, dim=-1) * scale
|
48 |
+
i_sim = i_sim + torch.where(loss_msk > 0, .0, float('-inf'))
|
49 |
+
|
50 |
+
loss = 0
|
51 |
+
|
52 |
+
if 'row' in self.direction:
|
53 |
+
i_met = F.log_softmax(i_sim, dim=1)[batch_inds, pos_clip]
|
54 |
+
loss = loss - i_met.sum() / i_met.size(0)
|
55 |
+
|
56 |
+
if 'col' in self.direction:
|
57 |
+
j_sim = i_sim.t()
|
58 |
+
j_met = F.log_softmax(j_sim, dim=1)[pos_clip, batch_inds]
|
59 |
+
loss = loss - j_met.sum() / j_met.size(0)
|
60 |
+
|
61 |
+
loss = loss * self.loss_weight
|
62 |
+
return loss
|
63 |
+
|
64 |
+
|
65 |
+
@LOSSES.register()
|
66 |
+
class BundleLoss(nn.Module):
|
67 |
+
|
68 |
+
def __init__(self,
|
69 |
+
sample_radius=1.5,
|
70 |
+
loss_cls=None,
|
71 |
+
loss_reg=None,
|
72 |
+
loss_sal=None,
|
73 |
+
loss_video_cal=None,
|
74 |
+
loss_layer_cal=None):
|
75 |
+
super(BundleLoss, self).__init__()
|
76 |
+
|
77 |
+
self._loss_cls = build_loss(loss_cls)
|
78 |
+
self._loss_reg = build_loss(loss_reg)
|
79 |
+
self._loss_sal = build_loss(loss_sal)
|
80 |
+
self._loss_video_cal = build_loss(loss_video_cal)
|
81 |
+
self._loss_layer_cal = build_loss(loss_layer_cal)
|
82 |
+
|
83 |
+
self.sample_radius = sample_radius
|
84 |
+
|
85 |
+
def get_target_single(self, point, gt_bnd, gt_cls):
|
86 |
+
num_pts, num_gts = point.size(0), gt_bnd.size(0)
|
87 |
+
|
88 |
+
lens = gt_bnd[:, 1] - gt_bnd[:, 0]
|
89 |
+
lens = lens[None, :].repeat(num_pts, 1)
|
90 |
+
|
91 |
+
gt_seg = gt_bnd[None].expand(num_pts, num_gts, 2)
|
92 |
+
s = point[:, 0, None] - gt_seg[:, :, 0]
|
93 |
+
e = gt_seg[:, :, 1] - point[:, 0, None]
|
94 |
+
r_tgt = torch.stack((s, e), dim=-1)
|
95 |
+
|
96 |
+
if self.sample_radius > 0:
|
97 |
+
center = (gt_seg[:, :, 0] + gt_seg[:, :, 1]) / 2
|
98 |
+
t_mins = center - point[:, 3, None] * self.sample_radius
|
99 |
+
t_maxs = center + point[:, 3, None] * self.sample_radius
|
100 |
+
dist_s = point[:, 0, None] - torch.maximum(t_mins, gt_seg[:, :, 0])
|
101 |
+
dist_e = torch.minimum(t_maxs, gt_seg[:, :, 1]) - point[:, 0, None]
|
102 |
+
center = torch.stack((dist_s, dist_e), dim=-1)
|
103 |
+
cls_msk = center.min(-1)[0] >= 0
|
104 |
+
else:
|
105 |
+
cls_msk = r_tgt.min(-1)[0] >= 0
|
106 |
+
|
107 |
+
reg_dist = r_tgt.max(-1)[0]
|
108 |
+
reg_msk = torch.logical_and((reg_dist >= point[:, 1, None]),
|
109 |
+
(reg_dist <= point[:, 2, None]))
|
110 |
+
|
111 |
+
lens.masked_fill_(cls_msk == 0, float('inf'))
|
112 |
+
lens.masked_fill_(reg_msk == 0, float('inf'))
|
113 |
+
min_len, min_len_inds = lens.min(dim=1)
|
114 |
+
|
115 |
+
min_len_mask = torch.logical_and((lens <= (min_len[:, None] + 1e-3)),
|
116 |
+
(lens < float('inf'))).to(r_tgt.dtype)
|
117 |
+
|
118 |
+
label = F.one_hot(gt_cls[:, 0], 2).to(r_tgt.dtype)
|
119 |
+
c_tgt = torch.matmul(min_len_mask, label).clamp(min=0.0, max=1.0)[:, 1]
|
120 |
+
r_tgt = r_tgt[range(num_pts), min_len_inds] / point[:, 3, None]
|
121 |
+
|
122 |
+
return c_tgt, r_tgt
|
123 |
+
|
124 |
+
def get_target(self, data):
|
125 |
+
cls_tgt, reg_tgt = [], []
|
126 |
+
|
127 |
+
for i in range(data['boundary'].size(0)):
|
128 |
+
gt_bnd = data['boundary'][i] * data['fps'][i]
|
129 |
+
gt_cls = gt_bnd.new_ones(gt_bnd.size(0), 1).long()
|
130 |
+
|
131 |
+
c_tgt, r_tgt = self.get_target_single(data['point'], gt_bnd, gt_cls)
|
132 |
+
|
133 |
+
cls_tgt.append(c_tgt)
|
134 |
+
reg_tgt.append(r_tgt)
|
135 |
+
|
136 |
+
cls_tgt = torch.stack(cls_tgt)
|
137 |
+
reg_tgt = torch.stack(reg_tgt)
|
138 |
+
|
139 |
+
return cls_tgt, reg_tgt
|
140 |
+
|
141 |
+
def loss_cls(self, data, output, cls_tgt):
|
142 |
+
src = data['out_class'].squeeze(-1)
|
143 |
+
msk = torch.cat(data['pymid_msk'], dim=1)
|
144 |
+
|
145 |
+
loss_cls = self._loss_cls(src, cls_tgt, weight=msk, avg_factor=msk.sum())
|
146 |
+
|
147 |
+
output['loss_cls'] = loss_cls
|
148 |
+
return output
|
149 |
+
|
150 |
+
def loss_reg(self, data, output, cls_tgt, reg_tgt):
|
151 |
+
src = data['out_coord']
|
152 |
+
msk = cls_tgt.unsqueeze(2).repeat(1, 1, 2).bool()
|
153 |
+
|
154 |
+
loss_reg = self._loss_reg(src, reg_tgt, weight=msk, avg_factor=msk.sum())
|
155 |
+
|
156 |
+
output['loss_reg'] = loss_reg
|
157 |
+
return output
|
158 |
+
|
159 |
+
def loss_sal(self, data, output):
|
160 |
+
video_emb = data['video_emb']
|
161 |
+
query_emb = data['query_emb']
|
162 |
+
video_msk = data['video_msk']
|
163 |
+
|
164 |
+
saliency = data['saliency']
|
165 |
+
pos_clip = data['pos_clip'][:, 0]
|
166 |
+
|
167 |
+
output['loss_sal'] = self._loss_sal(video_emb, query_emb, video_msk, saliency,
|
168 |
+
pos_clip)
|
169 |
+
return output
|
170 |
+
|
171 |
+
def loss_cal(self, data, output):
|
172 |
+
pos_clip = data['pos_clip'][:, 0]
|
173 |
+
|
174 |
+
batch_inds = torch.arange(pos_clip.size(0), device=pos_clip.device)
|
175 |
+
|
176 |
+
coll_v_emb, coll_q_emb = [], []
|
177 |
+
for v_emb, q_emb in zip(data['coll_v'], data['coll_q']):
|
178 |
+
v_emb_pos = v_emb[batch_inds, pos_clip]
|
179 |
+
q_emb_pos = q_emb[:, 0]
|
180 |
+
|
181 |
+
coll_v_emb.append(v_emb_pos)
|
182 |
+
coll_q_emb.append(q_emb_pos)
|
183 |
+
|
184 |
+
v_emb = torch.stack(coll_v_emb)
|
185 |
+
q_emb = torch.stack(coll_q_emb)
|
186 |
+
output['loss_video_cal'] = self._loss_video_cal(v_emb, q_emb)
|
187 |
+
|
188 |
+
v_emb = torch.stack(coll_v_emb, dim=1)
|
189 |
+
q_emb = torch.stack(coll_q_emb, dim=1)
|
190 |
+
output['loss_layer_cal'] = self._loss_layer_cal(v_emb, q_emb)
|
191 |
+
|
192 |
+
return output
|
193 |
+
|
194 |
+
def forward(self, data, output):
|
195 |
+
if self._loss_reg is not None:
|
196 |
+
cls_tgt, reg_tgt = self.get_target(data)
|
197 |
+
output = self.loss_reg(data, output, cls_tgt, reg_tgt)
|
198 |
+
else:
|
199 |
+
cls_tgt = data['saliency']
|
200 |
+
|
201 |
+
if self._loss_cls is not None:
|
202 |
+
output = self.loss_cls(data, output, cls_tgt)
|
203 |
+
|
204 |
+
if self._loss_sal is not None:
|
205 |
+
output = self.loss_sal(data, output)
|
206 |
+
|
207 |
+
if self._loss_video_cal is not None or self._loss_layer_cal is not None:
|
208 |
+
output = self.loss_cal(data, output)
|
209 |
+
|
210 |
+
return output
|
models/model.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import clip
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from nncore.nn import MODELS, build_loss, build_model
|
10 |
+
|
11 |
+
from .generator import PointGenerator
|
12 |
+
|
13 |
+
_CLIP_ARCHS = {
|
14 |
+
'ViT-B/32': (768, 512, 50),
|
15 |
+
'ViT-B/16': (768, 512, 197),
|
16 |
+
'ViT-L/14': (1024, 768, 50),
|
17 |
+
'ViT-L/14-336px': (1024, 768, 577)
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
@MODELS.register()
|
22 |
+
class R2Tuning(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
arch='ViT-B/32',
|
26 |
+
init=True,
|
27 |
+
dims=256,
|
28 |
+
strides=(1, 2, 4, 8),
|
29 |
+
buffer_size=1024,
|
30 |
+
max_num_moment=50,
|
31 |
+
merge_cls_sal=True,
|
32 |
+
adapter_cfg=None,
|
33 |
+
pyramid_cfg=None,
|
34 |
+
pooling_cfg=None,
|
35 |
+
class_head_cfg=None,
|
36 |
+
coord_head_cfg=None,
|
37 |
+
loss_cfg=None):
|
38 |
+
super(R2Tuning, self).__init__()
|
39 |
+
|
40 |
+
if init:
|
41 |
+
self.clip, _ = clip.load(arch, device='cpu')
|
42 |
+
for param in self.clip.parameters():
|
43 |
+
param.requires_grad = False
|
44 |
+
|
45 |
+
self.cfg = _CLIP_ARCHS[arch]
|
46 |
+
self.adapter = build_model(adapter_cfg, dims, self.cfg[:2])
|
47 |
+
self.pyramid = build_model(pyramid_cfg, dims, strides)
|
48 |
+
self.pooling = build_model(pooling_cfg, dims)
|
49 |
+
|
50 |
+
self.class_head = build_model(class_head_cfg, dims, 1)
|
51 |
+
self.coord_head = build_model(coord_head_cfg, dims, 2)
|
52 |
+
|
53 |
+
self.generator = PointGenerator(strides, buffer_size)
|
54 |
+
|
55 |
+
self.coef = nn.Parameter(torch.ones(len(strides)))
|
56 |
+
self.loss = build_loss(loss_cfg)
|
57 |
+
|
58 |
+
self.max_num_moment = max_num_moment
|
59 |
+
self.merge_cls_sal = merge_cls_sal
|
60 |
+
|
61 |
+
def train(self, mode=True):
|
62 |
+
super(R2Tuning, self).train(mode=mode)
|
63 |
+
if hasattr(self, 'clip'):
|
64 |
+
self.clip.eval()
|
65 |
+
|
66 |
+
@torch.no_grad
|
67 |
+
def clip_video_tower(self, video):
|
68 |
+
video = video.type(self.clip.dtype)
|
69 |
+
video = self.clip.visual.conv1(video)
|
70 |
+
video = video.reshape(video.size(0), video.size(1), -1).permute(0, 2, 1)
|
71 |
+
c_emb = video.new_zeros(video.size(0), 1, video.size(-1))
|
72 |
+
c_emb = self.clip.visual.class_embedding.to(video.dtype) + c_emb
|
73 |
+
video = torch.cat((c_emb, video), dim=1)
|
74 |
+
video = video + self.clip.visual.positional_embedding.to(video.dtype)
|
75 |
+
video = self.clip.visual.ln_pre(video).permute(1, 0, 2)
|
76 |
+
emb = [video]
|
77 |
+
for blk in self.clip.visual.transformer.resblocks:
|
78 |
+
emb.append(blk(emb[-1]))
|
79 |
+
video = torch.stack([e.permute(1, 0, 2) for e in emb])
|
80 |
+
return video
|
81 |
+
|
82 |
+
@torch.no_grad
|
83 |
+
def clip_query_tower(self, query):
|
84 |
+
query = self.clip.token_embedding(query).type(self.clip.dtype)
|
85 |
+
query = query + self.clip.positional_embedding.type(self.clip.dtype)
|
86 |
+
query = query.permute(1, 0, 2)
|
87 |
+
emb = [query]
|
88 |
+
for blk in self.clip.transformer.resblocks:
|
89 |
+
emb.append(blk(emb[-1]))
|
90 |
+
query = torch.stack([e.permute(1, 0, 2) for e in emb])
|
91 |
+
return query
|
92 |
+
|
93 |
+
def forward(self, data, mode='test'):
|
94 |
+
video, query = data['video'], data['query']
|
95 |
+
|
96 |
+
if hasattr(self, 'clip'):
|
97 |
+
video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0)
|
98 |
+
query_msk = torch.where(query == 0, 0, 1)
|
99 |
+
|
100 |
+
video[~video.isfinite()] = 0
|
101 |
+
|
102 |
+
(b, t), d = video.size()[:2], int(math.sqrt(video.size(2) / 3))
|
103 |
+
video = video.view(b * t, 3, d, d)
|
104 |
+
|
105 |
+
video_emb = self.clip_video_tower(video)
|
106 |
+
query_emb = self.clip_query_tower(query)
|
107 |
+
|
108 |
+
n, _, p, c = video_emb.size()
|
109 |
+
video_emb = video_emb.view(n, b, t, p, c)
|
110 |
+
else:
|
111 |
+
video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0)
|
112 |
+
query_msk = torch.where(query[:, :, 0].isfinite(), 1, 0)
|
113 |
+
|
114 |
+
video[~video.isfinite()] = 0
|
115 |
+
query[~query.isfinite()] = 0
|
116 |
+
|
117 |
+
(b, t), l = video.size()[:2], query.size(1)
|
118 |
+
video = video.view(b, t, -1, self.cfg[2], self.cfg[0]).permute(2, 0, 1, 3, 4)
|
119 |
+
query = query.view(b, l, -1, self.cfg[1]).permute(2, 0, 1, 3)
|
120 |
+
|
121 |
+
video_emb = video.float()
|
122 |
+
query_emb = query.float()
|
123 |
+
|
124 |
+
# video_emb: N * B * T * P * C
|
125 |
+
# query_emb: N * B * L * C
|
126 |
+
|
127 |
+
video_emb, query_emb, coll_v, coll_q = self.adapter(video_emb, query_emb,
|
128 |
+
video_msk, query_msk)
|
129 |
+
|
130 |
+
pymid, pymid_msk = self.pyramid(video_emb, video_msk, return_mask=mode != 'test')
|
131 |
+
point = self.generator(pymid)
|
132 |
+
|
133 |
+
with torch.autocast('cuda', enabled=False):
|
134 |
+
video_emb = video_emb.float()
|
135 |
+
query_emb = self.pooling(query_emb.float(), query_msk)
|
136 |
+
|
137 |
+
out_class = [self.class_head(e.float()) for e in pymid]
|
138 |
+
out_class = torch.cat(out_class, dim=1)
|
139 |
+
|
140 |
+
if self.coord_head is not None:
|
141 |
+
out_coord = [
|
142 |
+
self.coord_head(e.float()).exp() * self.coef[i]
|
143 |
+
for i, e in enumerate(pymid)
|
144 |
+
]
|
145 |
+
out_coord = torch.cat(out_coord, dim=1)
|
146 |
+
else:
|
147 |
+
out_coord = None
|
148 |
+
|
149 |
+
output = dict(_avg_factor=b)
|
150 |
+
|
151 |
+
if mode != 'test':
|
152 |
+
data['coll_v'] = [e.float() for e in coll_v]
|
153 |
+
data['coll_q'] = [self.pooling(e.float(), query_msk) for e in coll_q]
|
154 |
+
|
155 |
+
data['point'] = point
|
156 |
+
data['video_emb'] = video_emb
|
157 |
+
data['query_emb'] = query_emb
|
158 |
+
data['video_msk'] = video_msk
|
159 |
+
data['pymid_msk'] = pymid_msk
|
160 |
+
data['out_class'] = out_class
|
161 |
+
data['out_coord'] = out_coord
|
162 |
+
|
163 |
+
output = self.loss(data, output)
|
164 |
+
|
165 |
+
if mode != 'train':
|
166 |
+
assert b == 1, 'batch size larger than 1 is not supported for inference'
|
167 |
+
out_class = out_class.sigmoid()
|
168 |
+
out_score = F.cosine_similarity(video_emb, query_emb, dim=-1)
|
169 |
+
|
170 |
+
output['_out'] = dict(label=data.get('label', [None])[0])
|
171 |
+
|
172 |
+
pyd_shape = [e.size(1) for e in pymid]
|
173 |
+
pyd_class = out_class[0, :, 0].split(pyd_shape)
|
174 |
+
|
175 |
+
saliency = []
|
176 |
+
for shape, score in zip(pyd_shape, pyd_class):
|
177 |
+
if t >= shape:
|
178 |
+
score = score.repeat_interleave(int(t / shape))
|
179 |
+
postfix = score[-1:].repeat(t - score.size(0))
|
180 |
+
score = torch.cat((score, postfix))
|
181 |
+
else:
|
182 |
+
scale = int(shape / t)
|
183 |
+
score = F.max_pool1d(score.unsqueeze(0), scale, stride=scale)[0]
|
184 |
+
saliency.append(score)
|
185 |
+
|
186 |
+
saliency = torch.stack(saliency).amax(dim=0)
|
187 |
+
|
188 |
+
if self.merge_cls_sal:
|
189 |
+
saliency *= out_score[0]
|
190 |
+
|
191 |
+
output['_out']['saliency'] = saliency
|
192 |
+
|
193 |
+
if self.coord_head is not None:
|
194 |
+
boundary = out_coord[0]
|
195 |
+
boundary[:, 0] *= -1
|
196 |
+
boundary *= point[:, 3, None].repeat(1, 2)
|
197 |
+
boundary += point[:, 0, None].repeat(1, 2)
|
198 |
+
boundary /= data['fps'][0]
|
199 |
+
boundary = torch.cat((boundary, out_class[0]), dim=-1)
|
200 |
+
|
201 |
+
_, inds = out_class[0, :, 0].sort(descending=True)
|
202 |
+
boundary = boundary[inds[:self.max_num_moment]]
|
203 |
+
|
204 |
+
output['_out']['boundary'] = boundary
|
205 |
+
|
206 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/openai/CLIP.git@a1d0717
|
2 |
+
decord==0.6.0
|
3 |
+
matplotlib==3.9.0
|
4 |
+
nncore==0.4.3
|
5 |
+
torch==2.2.1
|
6 |
+
torchvision==0.17.1
|
setup.cfg
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[yapf]
|
2 |
+
column_limit = 90
|
3 |
+
based_on_style = pep8
|
4 |
+
blank_line_before_nested_class_or_def = true
|
5 |
+
split_before_expression_after_opening_paren = true
|
6 |
+
|
7 |
+
[isort]
|
8 |
+
line_length = 90
|
9 |
+
multi_line_output = 0
|
10 |
+
known_third_party = clip,decord,gradio,nncore,numpy,torch,torchvision
|
11 |
+
no_lines_before = STDLIB,LOCALFOLDER
|
12 |
+
default_section = FIRSTPARTY
|
13 |
+
|
14 |
+
[flake8]
|
15 |
+
max-line-length = 90
|