yeliudev commited on
Commit
bc120ce
·
1 Parent(s): 084c942
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Temporary data
7
+ /checkpoints
8
+ /flagged
9
+ .DS_Store
10
+ ._*
README.md CHANGED
@@ -1,13 +1,11 @@
1
  ---
2
  title: R2 Tuning
3
- emoji:
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-3-clause
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: R2 Tuning
3
+ emoji: 🌀
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-3-clause
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
+
3
+ from functools import partial
4
+
5
+ import clip
6
+ import decord
7
+ import nncore
8
+ import torch
9
+ import gradio as gr
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torchvision.transforms.functional as F
13
+ from decord import VideoReader
14
+ from nncore.engine import load_checkpoint
15
+ from nncore.nn import build_model
16
+
17
+ TITLE = '🌀R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding' # noqa
18
+ DESCRIPTION = 'R2-Tuning is a parameter- and memory efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.\n\nUser Guide:\n1. Upload or record a video using web camera.\n2. Input a text query. A good practice is to use a sentence with 5~10 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.' # noqa
19
+
20
+ CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
21
+ WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' # noqa
22
+
23
+
24
+ def convert_time(seconds):
25
+ minutes, seconds = divmod(round(seconds), 60)
26
+ return f'{minutes:02d}:{seconds:02d}'
27
+
28
+
29
+ def load_video(video_path, cfg):
30
+ decord.bridge.set_bridge('torch')
31
+
32
+ vr = VideoReader(video_path)
33
+ stride = vr.get_avg_fps() / cfg.data.val.fps
34
+ fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
35
+ video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
36
+
37
+ size = 336 if '336px' in cfg.model.arch else 224
38
+ h, w = video.size(-2), video.size(-1)
39
+ s = min(h, w)
40
+ x, y = round((h - s) / 2), round((w - s) / 2)
41
+ video = video[..., x:x + s, y:y + s]
42
+ video = F.resize(video, size=(size, size))
43
+ video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
44
+ video = video.reshape(video.size(0), -1).unsqueeze(0)
45
+
46
+ return video
47
+
48
+
49
+ def init_model(config, checkpoint):
50
+ cfg = nncore.Config.from_file(config)
51
+ cfg.model.init = True
52
+
53
+ if checkpoint.startswith('http'):
54
+ checkpoint = nncore.download(checkpoint, out_dir='checkpoints')
55
+
56
+ model = build_model(cfg.model, dist=False).eval()
57
+ model = load_checkpoint(model, checkpoint, warning=False)
58
+
59
+ return model, cfg
60
+
61
+
62
+ def main(video, query, model, cfg):
63
+ if len(query) == 0:
64
+ raise gr.Error('Text query can not be empty.')
65
+
66
+ try:
67
+ video = load_video(video, cfg)
68
+ except Exception:
69
+ raise gr.Error('Failed to load the video.')
70
+
71
+ query = clip.tokenize(query, truncate=True)
72
+
73
+ device = next(model.parameters()).device
74
+ data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
75
+
76
+ with torch.inference_mode():
77
+ pred = model(data)
78
+
79
+ mr = pred['_out']['boundary'][:5].cpu().tolist()
80
+ mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
81
+
82
+ hd = pred['_out']['saliency'].cpu()
83
+ hd = ((hd - hd.min()) / (hd.max() - hd.min())).tolist()
84
+
85
+ fig, ax = plt.subplots(figsize=(10, 5.5))
86
+ ax.plot(range(0, len(hd) * 2, 2), hd)
87
+
88
+ ax.set_xlabel('Time (s)', fontsize=15)
89
+ ax.set_ylabel('Saliency Score', fontsize=15)
90
+
91
+ ax.tick_params(labelsize=14)
92
+ plt.tight_layout(rect=(0.02, 0.02, 0.95, 0.885))
93
+
94
+ return mr, fig
95
+
96
+
97
+ model, cfg = init_model(CONFIG, WEIGHT)
98
+ main = partial(main, model=model, cfg=cfg)
99
+
100
+ demo = gr.Interface(
101
+ fn=main,
102
+ inputs=[gr.Video(label='Video'),
103
+ gr.Textbox(label='Text Query')],
104
+ outputs=[
105
+ gr.Dataframe(
106
+ headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval'),
107
+ gr.Plot(label='Highlight Detection')
108
+ ],
109
+ allow_flagging='auto',
110
+ title=TITLE,
111
+ description=DESCRIPTION)
112
+ demo.launch()
configs/_base_/datasets/qvhighlights.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset settings
2
+ data_type = 'Grounding'
3
+ data_root = 'data/qvhighlights/'
4
+ data = dict(
5
+ train=dict(
6
+ type='RepeatDataset',
7
+ times=4,
8
+ dataset=dict(
9
+ type=data_type,
10
+ label_path=data_root + 'qvhighlights_train.jsonl',
11
+ video_path=data_root + 'frames_224_0.5fps',
12
+ cache_path=data_root + 'clip_b32_vid_k4',
13
+ query_path=data_root + 'clip_b32_txt_k4',
14
+ use_cache=True,
15
+ min_video_len=5,
16
+ fps=0.5,
17
+ unit=2),
18
+ loader=dict(batch_size=128, num_workers=4, pin_memory=True, shuffle=True)),
19
+ val=dict(
20
+ type=data_type,
21
+ label_path=data_root + 'qvhighlights_val.jsonl',
22
+ video_path=data_root + 'frames_224_0.5fps',
23
+ cache_path=data_root + 'clip_b32_vid_k4',
24
+ query_path=data_root + 'clip_b32_txt_k4',
25
+ use_cache=True,
26
+ fps=0.5,
27
+ unit=2,
28
+ loader=dict(batch_size=1, num_workers=4, pin_memory=True, shuffle=False)),
29
+ test=dict(
30
+ type=data_type,
31
+ label_path=data_root + 'qvhighlights_test.jsonl',
32
+ video_path=data_root + 'frames_224_0.5fps',
33
+ cache_path=data_root + 'clip_b32_vid_k4',
34
+ query_path=data_root + 'clip_b32_txt_k4',
35
+ use_cache=True,
36
+ fps=0.5,
37
+ unit=2,
38
+ loader=dict(batch_size=1, num_workers=4, pin_memory=True, shuffle=False)))
configs/_base_/models/model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['models']
2
+ # model settings
3
+ model = dict(
4
+ type='R2Tuning',
5
+ arch='ViT-B/32',
6
+ init=False,
7
+ dims=256,
8
+ strides=(1, 2, 4, 8),
9
+ buffer_size=1024,
10
+ max_num_moment=50,
11
+ adapter_cfg=dict(
12
+ type='R2Block',
13
+ k=4,
14
+ dropout=0.5,
15
+ use_tef=True,
16
+ pos_cfg=dict(type='PositionalEncoding', normalize=True, max_len=1024),
17
+ tem_cfg=dict(
18
+ type='TransformerDecoderLayer',
19
+ heads=8,
20
+ ratio=4,
21
+ att_dropout=0.0,
22
+ ffn_dropout=0.0,
23
+ att_out_dropout=0.0,
24
+ ffn_out_dropout=0.0,
25
+ droppath=0.1,
26
+ pre_norm=False,
27
+ bias=True,
28
+ norm_cfg=dict(type='LN'),
29
+ act_cfg=dict(type='ReLU', inplace=True),
30
+ order=('cross_att', 'self_att', 'ffn'),
31
+ att_init_cfg=dict(type='xavier', distribution='uniform'),
32
+ ffn_init_cfg=dict(type='kaiming'))),
33
+ pyramid_cfg=dict(type='ConvPyramid'),
34
+ pooling_cfg=dict(type='AdaPooling'),
35
+ class_head_cfg=dict(type='ConvHead', kernal_size=3),
36
+ coord_head_cfg=dict(type='ConvHead', kernal_size=3),
37
+ loss_cfg=dict(
38
+ type='BundleLoss',
39
+ sample_radius=1.5,
40
+ loss_cls=dict(type='FocalLoss', loss_weight=1.0),
41
+ loss_reg=dict(type='L1Loss', loss_weight=0.2),
42
+ loss_sal=dict(type='SampledNCELoss', loss_weight=0.1),
43
+ loss_video_cal=dict(type='InfoNCELoss', loss_weight=0.1),
44
+ loss_layer_cal=dict(type='InfoNCELoss', loss_weight=0.1)))
configs/qvhighlights/r2_tuning_qvhighlights.py ADDED
@@ -0,0 +1 @@
 
 
1
+ _base_ = ['../_base_/models/model.py', '../_base_/datasets/qvhighlights.py']
models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .adapter import R2Block
2
+ from .blocks import AdaPooling, ConvHead, ConvPyramid
3
+ from .loss import BundleLoss
4
+ from .model import R2Tuning
5
+
6
+ __all__ = ['R2Block', 'AdaPooling', 'ConvHead', 'ConvPyramid', 'BundleLoss', 'R2Tuning']
models/adapter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from nncore.nn import MODELS, build_model
6
+
7
+
8
+ @MODELS.register()
9
+ class R2Block(nn.Module):
10
+
11
+ def __init__(self,
12
+ dims,
13
+ in_dims,
14
+ k=4,
15
+ dropout=0.5,
16
+ use_tef=True,
17
+ pos_cfg=None,
18
+ tem_cfg=None):
19
+ super(R2Block, self).__init__()
20
+
21
+ # yapf:disable
22
+ self.video_map = nn.Sequential(
23
+ nn.LayerNorm((in_dims[0] + 2) if use_tef else in_dims[0]),
24
+ nn.Dropout(dropout),
25
+ nn.Linear((in_dims[0] + 2) if use_tef else in_dims[0], dims),
26
+ nn.ReLU(inplace=True),
27
+ nn.LayerNorm(dims),
28
+ nn.Dropout(dropout),
29
+ nn.Linear(dims, dims))
30
+
31
+ self.query_map = nn.Sequential(
32
+ nn.LayerNorm(in_dims[1]),
33
+ nn.Dropout(dropout),
34
+ nn.Linear(in_dims[1], dims),
35
+ nn.ReLU(inplace=True),
36
+ nn.LayerNorm(dims),
37
+ nn.Dropout(dropout),
38
+ nn.Linear(dims, dims))
39
+ # yapf:enable
40
+
41
+ if k > 1:
42
+ self.gate = nn.Parameter(torch.zeros([k - 1]))
43
+
44
+ self.v_map = nn.Linear(dims, dims)
45
+ self.q_map = nn.Linear(dims, dims)
46
+ self.scale = nn.Parameter(torch.zeros([k]))
47
+
48
+ self.pos = build_model(pos_cfg, dims=dims)
49
+ self.tem = build_model(tem_cfg, dims=dims)
50
+
51
+ self.dims = dims
52
+ self.in_dims = in_dims
53
+ self.k = k
54
+ self.dropout = dropout
55
+ self.use_tef = use_tef
56
+
57
+ def forward(self, video_emb, query_emb, video_msk, query_msk):
58
+ video_emb = video_emb[-self.k:]
59
+ query_emb = query_emb[-self.k:]
60
+
61
+ _, b, t, p, _ = video_emb.size()
62
+
63
+ if self.use_tef:
64
+ tef_s = torch.arange(0, 1, 1 / t, device=video_emb.device)
65
+ tef_e = tef_s + 1.0 / t
66
+ tef = torch.stack((tef_s, tef_e), dim=1)
67
+ tef = tef.unsqueeze(1).unsqueeze(0).unsqueeze(0).repeat(self.k, b, 1, p, 1)
68
+ video_emb = torch.cat((video_emb, tef[:, :, :video_emb.size(2)]), dim=-1)
69
+
70
+ coll_v, coll_q, last = [], [], None
71
+ for i in range(self.k - 1, -1, -1):
72
+ v_emb = self.video_map(video_emb[i]) # B * T * P * C
73
+ q_emb = self.query_map(query_emb[i]) # B * L * C
74
+
75
+ coll_v.append(v_emb[:, :, 0])
76
+ coll_q.append(q_emb)
77
+
78
+ v_pool = v_emb.view(b * t, -1, self.dims) # BT * P * C
79
+ q_pool = q_emb.repeat_interleave(t, dim=0) # BT * L * C
80
+
81
+ v_pool_map = self.v_map(v_pool) # BT * P * C
82
+ q_pool_map = self.q_map(q_pool) # BT * L * C
83
+
84
+ att = torch.bmm(q_pool_map, v_pool_map.transpose(1, 2)) / self.dims**0.5
85
+ att = att.softmax(-1) # BT * L * P
86
+
87
+ o_pool = torch.bmm(att, v_pool) + q_pool # BT * L * C
88
+ o_pool = o_pool.amax(dim=1, keepdim=True) # BT * 1 * C
89
+ v_emb = v_pool[:, 0, None] + o_pool * self.scale[i].tanh()
90
+ v_emb = v_emb.view(b, t, self.dims) # B * T * C
91
+
92
+ if i < self.k - 1:
93
+ gate = self.gate[i].sigmoid()
94
+ v_emb = gate * v_emb + (1 - gate) * last
95
+
96
+ v_pe = self.pos(v_emb)
97
+ last = self.tem(v_emb, q_emb, q_pe=v_pe, q_mask=video_msk, k_mask=query_msk)
98
+
99
+ return last, q_emb, coll_v, coll_q
models/blocks.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from nncore.nn import MODELS
9
+
10
+
11
+ class Permute(nn.Module):
12
+
13
+ def __init__(self):
14
+ super(Permute, self).__init__()
15
+
16
+ def forward(self, x):
17
+ return x.transpose(-1, -2)
18
+
19
+
20
+ @MODELS.register()
21
+ class ConvPyramid(nn.Module):
22
+
23
+ def __init__(self, dims, strides):
24
+ super(ConvPyramid, self).__init__()
25
+
26
+ self.blocks = nn.ModuleList()
27
+ for s in strides:
28
+ p = int(math.log2(s))
29
+ if p == 0:
30
+ layers = nn.ReLU(inplace=True)
31
+ else:
32
+ layers = nn.Sequential()
33
+ conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d
34
+ for _ in range(abs(p)):
35
+ layers.extend([
36
+ Permute(),
37
+ conv_cls(dims, dims, 2, stride=2),
38
+ Permute(),
39
+ nn.LayerNorm(dims),
40
+ nn.ReLU(inplace=True)
41
+ ])
42
+ self.blocks.append(layers)
43
+
44
+ self.strides = strides
45
+
46
+ def forward(self, x, mask, return_mask=False):
47
+ pymid, pymid_msk = [], []
48
+
49
+ for s, blk in zip(self.strides, self.blocks):
50
+ if x.size(1) < s:
51
+ continue
52
+
53
+ pymid.append(blk(x))
54
+
55
+ if return_mask:
56
+ if s > 1:
57
+ msk = F.max_pool1d(mask.float(), s, stride=s).long()
58
+ elif s < 1:
59
+ msk = mask.repeat_interleave(int(1 / s), dim=1)
60
+ else:
61
+ msk = mask
62
+ pymid_msk.append(msk)
63
+
64
+ return pymid, pymid_msk
65
+
66
+
67
+ @MODELS.register()
68
+ class AdaPooling(nn.Module):
69
+
70
+ def __init__(self, dims):
71
+ super(AdaPooling, self).__init__()
72
+ self.att = nn.Linear(dims, 1, bias=False)
73
+
74
+ def forward(self, x, mask):
75
+ a = self.att(x) + torch.where(mask.unsqueeze(2) == 1, .0, float('-inf'))
76
+ a = a.softmax(dim=1)
77
+ x = torch.matmul(x.transpose(1, 2), a)
78
+ x = x.squeeze(2).unsqueeze(1)
79
+ return x
80
+
81
+
82
+ @MODELS.register()
83
+ class ConvHead(nn.Module):
84
+
85
+ def __init__(self, dims, out_dims, kernal_size=3):
86
+ super(ConvHead, self).__init__()
87
+
88
+ # yapf:disable
89
+ self.module = nn.Sequential(
90
+ Permute(),
91
+ nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2),
92
+ nn.ReLU(inplace=True),
93
+ nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2),
94
+ Permute())
95
+ # yapf:enable
96
+
97
+ def forward(self, x):
98
+ return self.module(x)
models/generator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BufferList(nn.Module):
8
+
9
+ def __init__(self, buffers):
10
+ super(BufferList, self).__init__()
11
+ for i, buffer in enumerate(buffers):
12
+ self.register_buffer(str(i), buffer, persistent=False)
13
+
14
+ def __len__(self):
15
+ return len(self._buffers)
16
+
17
+ def __iter__(self):
18
+ return iter(self._buffers.values())
19
+
20
+
21
+ class PointGenerator(nn.Module):
22
+
23
+ def __init__(self, strides, buffer_size, offset=False):
24
+ super(PointGenerator, self).__init__()
25
+
26
+ reg_range, last = [], 0
27
+ for stride in strides[1:]:
28
+ reg_range.append((last, stride))
29
+ last = stride
30
+ reg_range.append((last, float('inf')))
31
+
32
+ self.strides = strides
33
+ self.reg_range = reg_range
34
+ self.buffer_size = buffer_size
35
+ self.offset = offset
36
+
37
+ self.buffer = self._cache_points()
38
+
39
+ def _cache_points(self):
40
+ buffer_list = []
41
+ for stride, reg_range in zip(self.strides, self.reg_range):
42
+ reg_range = torch.Tensor([reg_range])
43
+ lv_stride = torch.Tensor([stride])
44
+ points = torch.arange(0, self.buffer_size, stride)[:, None]
45
+ if self.offset:
46
+ points += 0.5 * stride
47
+ reg_range = reg_range.repeat(points.size(0), 1)
48
+ lv_stride = lv_stride.repeat(points.size(0), 1)
49
+ buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1))
50
+ buffer = BufferList(buffer_list)
51
+ return buffer
52
+
53
+ def forward(self, pymid):
54
+ points = []
55
+ sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid))
56
+ for size, buffer in zip(sizes, self.buffer):
57
+ if size == 0:
58
+ continue
59
+ assert size <= buffer.size(0), 'reached max buffer size'
60
+ points.append(buffer[:size, :])
61
+ points = torch.cat(points)
62
+ return points
models/loss.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from nncore.nn import LOSSES, Parameter, build_loss
9
+
10
+
11
+ @LOSSES.register()
12
+ class SampledNCELoss(nn.Module):
13
+
14
+ def __init__(self,
15
+ temperature=0.07,
16
+ max_scale=100,
17
+ learnable=False,
18
+ direction=('row', 'col'),
19
+ loss_weight=1.0):
20
+ super(SampledNCELoss, self).__init__()
21
+
22
+ scale = torch.Tensor([math.log(1 / temperature)])
23
+
24
+ if learnable:
25
+ self.scale = Parameter(scale)
26
+ else:
27
+ self.register_buffer('scale', scale)
28
+
29
+ self.temperature = temperature
30
+ self.max_scale = max_scale
31
+ self.learnable = learnable
32
+ self.direction = (direction, ) if isinstance(direction, str) else direction
33
+ self.loss_weight = loss_weight
34
+
35
+ def extra_repr(self):
36
+ return ('temperature={}, max_scale={}, learnable={}, direction={}, loss_weight={}'
37
+ .format(self.temperature, self.max_scale, self.learnable, self.direction,
38
+ self.loss_weight))
39
+
40
+ def forward(self, video_emb, query_emb, video_msk, saliency, pos_clip):
41
+ batch_inds = torch.arange(video_emb.size(0), device=video_emb.device)
42
+
43
+ pos_scores = saliency[batch_inds, pos_clip].unsqueeze(-1)
44
+ loss_msk = (saliency <= pos_scores) * video_msk
45
+
46
+ scale = self.scale.exp().clamp(max=self.max_scale)
47
+ i_sim = F.cosine_similarity(video_emb, query_emb, dim=-1) * scale
48
+ i_sim = i_sim + torch.where(loss_msk > 0, .0, float('-inf'))
49
+
50
+ loss = 0
51
+
52
+ if 'row' in self.direction:
53
+ i_met = F.log_softmax(i_sim, dim=1)[batch_inds, pos_clip]
54
+ loss = loss - i_met.sum() / i_met.size(0)
55
+
56
+ if 'col' in self.direction:
57
+ j_sim = i_sim.t()
58
+ j_met = F.log_softmax(j_sim, dim=1)[pos_clip, batch_inds]
59
+ loss = loss - j_met.sum() / j_met.size(0)
60
+
61
+ loss = loss * self.loss_weight
62
+ return loss
63
+
64
+
65
+ @LOSSES.register()
66
+ class BundleLoss(nn.Module):
67
+
68
+ def __init__(self,
69
+ sample_radius=1.5,
70
+ loss_cls=None,
71
+ loss_reg=None,
72
+ loss_sal=None,
73
+ loss_video_cal=None,
74
+ loss_layer_cal=None):
75
+ super(BundleLoss, self).__init__()
76
+
77
+ self._loss_cls = build_loss(loss_cls)
78
+ self._loss_reg = build_loss(loss_reg)
79
+ self._loss_sal = build_loss(loss_sal)
80
+ self._loss_video_cal = build_loss(loss_video_cal)
81
+ self._loss_layer_cal = build_loss(loss_layer_cal)
82
+
83
+ self.sample_radius = sample_radius
84
+
85
+ def get_target_single(self, point, gt_bnd, gt_cls):
86
+ num_pts, num_gts = point.size(0), gt_bnd.size(0)
87
+
88
+ lens = gt_bnd[:, 1] - gt_bnd[:, 0]
89
+ lens = lens[None, :].repeat(num_pts, 1)
90
+
91
+ gt_seg = gt_bnd[None].expand(num_pts, num_gts, 2)
92
+ s = point[:, 0, None] - gt_seg[:, :, 0]
93
+ e = gt_seg[:, :, 1] - point[:, 0, None]
94
+ r_tgt = torch.stack((s, e), dim=-1)
95
+
96
+ if self.sample_radius > 0:
97
+ center = (gt_seg[:, :, 0] + gt_seg[:, :, 1]) / 2
98
+ t_mins = center - point[:, 3, None] * self.sample_radius
99
+ t_maxs = center + point[:, 3, None] * self.sample_radius
100
+ dist_s = point[:, 0, None] - torch.maximum(t_mins, gt_seg[:, :, 0])
101
+ dist_e = torch.minimum(t_maxs, gt_seg[:, :, 1]) - point[:, 0, None]
102
+ center = torch.stack((dist_s, dist_e), dim=-1)
103
+ cls_msk = center.min(-1)[0] >= 0
104
+ else:
105
+ cls_msk = r_tgt.min(-1)[0] >= 0
106
+
107
+ reg_dist = r_tgt.max(-1)[0]
108
+ reg_msk = torch.logical_and((reg_dist >= point[:, 1, None]),
109
+ (reg_dist <= point[:, 2, None]))
110
+
111
+ lens.masked_fill_(cls_msk == 0, float('inf'))
112
+ lens.masked_fill_(reg_msk == 0, float('inf'))
113
+ min_len, min_len_inds = lens.min(dim=1)
114
+
115
+ min_len_mask = torch.logical_and((lens <= (min_len[:, None] + 1e-3)),
116
+ (lens < float('inf'))).to(r_tgt.dtype)
117
+
118
+ label = F.one_hot(gt_cls[:, 0], 2).to(r_tgt.dtype)
119
+ c_tgt = torch.matmul(min_len_mask, label).clamp(min=0.0, max=1.0)[:, 1]
120
+ r_tgt = r_tgt[range(num_pts), min_len_inds] / point[:, 3, None]
121
+
122
+ return c_tgt, r_tgt
123
+
124
+ def get_target(self, data):
125
+ cls_tgt, reg_tgt = [], []
126
+
127
+ for i in range(data['boundary'].size(0)):
128
+ gt_bnd = data['boundary'][i] * data['fps'][i]
129
+ gt_cls = gt_bnd.new_ones(gt_bnd.size(0), 1).long()
130
+
131
+ c_tgt, r_tgt = self.get_target_single(data['point'], gt_bnd, gt_cls)
132
+
133
+ cls_tgt.append(c_tgt)
134
+ reg_tgt.append(r_tgt)
135
+
136
+ cls_tgt = torch.stack(cls_tgt)
137
+ reg_tgt = torch.stack(reg_tgt)
138
+
139
+ return cls_tgt, reg_tgt
140
+
141
+ def loss_cls(self, data, output, cls_tgt):
142
+ src = data['out_class'].squeeze(-1)
143
+ msk = torch.cat(data['pymid_msk'], dim=1)
144
+
145
+ loss_cls = self._loss_cls(src, cls_tgt, weight=msk, avg_factor=msk.sum())
146
+
147
+ output['loss_cls'] = loss_cls
148
+ return output
149
+
150
+ def loss_reg(self, data, output, cls_tgt, reg_tgt):
151
+ src = data['out_coord']
152
+ msk = cls_tgt.unsqueeze(2).repeat(1, 1, 2).bool()
153
+
154
+ loss_reg = self._loss_reg(src, reg_tgt, weight=msk, avg_factor=msk.sum())
155
+
156
+ output['loss_reg'] = loss_reg
157
+ return output
158
+
159
+ def loss_sal(self, data, output):
160
+ video_emb = data['video_emb']
161
+ query_emb = data['query_emb']
162
+ video_msk = data['video_msk']
163
+
164
+ saliency = data['saliency']
165
+ pos_clip = data['pos_clip'][:, 0]
166
+
167
+ output['loss_sal'] = self._loss_sal(video_emb, query_emb, video_msk, saliency,
168
+ pos_clip)
169
+ return output
170
+
171
+ def loss_cal(self, data, output):
172
+ pos_clip = data['pos_clip'][:, 0]
173
+
174
+ batch_inds = torch.arange(pos_clip.size(0), device=pos_clip.device)
175
+
176
+ coll_v_emb, coll_q_emb = [], []
177
+ for v_emb, q_emb in zip(data['coll_v'], data['coll_q']):
178
+ v_emb_pos = v_emb[batch_inds, pos_clip]
179
+ q_emb_pos = q_emb[:, 0]
180
+
181
+ coll_v_emb.append(v_emb_pos)
182
+ coll_q_emb.append(q_emb_pos)
183
+
184
+ v_emb = torch.stack(coll_v_emb)
185
+ q_emb = torch.stack(coll_q_emb)
186
+ output['loss_video_cal'] = self._loss_video_cal(v_emb, q_emb)
187
+
188
+ v_emb = torch.stack(coll_v_emb, dim=1)
189
+ q_emb = torch.stack(coll_q_emb, dim=1)
190
+ output['loss_layer_cal'] = self._loss_layer_cal(v_emb, q_emb)
191
+
192
+ return output
193
+
194
+ def forward(self, data, output):
195
+ if self._loss_reg is not None:
196
+ cls_tgt, reg_tgt = self.get_target(data)
197
+ output = self.loss_reg(data, output, cls_tgt, reg_tgt)
198
+ else:
199
+ cls_tgt = data['saliency']
200
+
201
+ if self._loss_cls is not None:
202
+ output = self.loss_cls(data, output, cls_tgt)
203
+
204
+ if self._loss_sal is not None:
205
+ output = self.loss_sal(data, output)
206
+
207
+ if self._loss_video_cal is not None or self._loss_layer_cal is not None:
208
+ output = self.loss_cal(data, output)
209
+
210
+ return output
models/model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
+
3
+ import math
4
+
5
+ import clip
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from nncore.nn import MODELS, build_loss, build_model
10
+
11
+ from .generator import PointGenerator
12
+
13
+ _CLIP_ARCHS = {
14
+ 'ViT-B/32': (768, 512, 50),
15
+ 'ViT-B/16': (768, 512, 197),
16
+ 'ViT-L/14': (1024, 768, 50),
17
+ 'ViT-L/14-336px': (1024, 768, 577)
18
+ }
19
+
20
+
21
+ @MODELS.register()
22
+ class R2Tuning(nn.Module):
23
+
24
+ def __init__(self,
25
+ arch='ViT-B/32',
26
+ init=True,
27
+ dims=256,
28
+ strides=(1, 2, 4, 8),
29
+ buffer_size=1024,
30
+ max_num_moment=50,
31
+ merge_cls_sal=True,
32
+ adapter_cfg=None,
33
+ pyramid_cfg=None,
34
+ pooling_cfg=None,
35
+ class_head_cfg=None,
36
+ coord_head_cfg=None,
37
+ loss_cfg=None):
38
+ super(R2Tuning, self).__init__()
39
+
40
+ if init:
41
+ self.clip, _ = clip.load(arch, device='cpu')
42
+ for param in self.clip.parameters():
43
+ param.requires_grad = False
44
+
45
+ self.cfg = _CLIP_ARCHS[arch]
46
+ self.adapter = build_model(adapter_cfg, dims, self.cfg[:2])
47
+ self.pyramid = build_model(pyramid_cfg, dims, strides)
48
+ self.pooling = build_model(pooling_cfg, dims)
49
+
50
+ self.class_head = build_model(class_head_cfg, dims, 1)
51
+ self.coord_head = build_model(coord_head_cfg, dims, 2)
52
+
53
+ self.generator = PointGenerator(strides, buffer_size)
54
+
55
+ self.coef = nn.Parameter(torch.ones(len(strides)))
56
+ self.loss = build_loss(loss_cfg)
57
+
58
+ self.max_num_moment = max_num_moment
59
+ self.merge_cls_sal = merge_cls_sal
60
+
61
+ def train(self, mode=True):
62
+ super(R2Tuning, self).train(mode=mode)
63
+ if hasattr(self, 'clip'):
64
+ self.clip.eval()
65
+
66
+ @torch.no_grad
67
+ def clip_video_tower(self, video):
68
+ video = video.type(self.clip.dtype)
69
+ video = self.clip.visual.conv1(video)
70
+ video = video.reshape(video.size(0), video.size(1), -1).permute(0, 2, 1)
71
+ c_emb = video.new_zeros(video.size(0), 1, video.size(-1))
72
+ c_emb = self.clip.visual.class_embedding.to(video.dtype) + c_emb
73
+ video = torch.cat((c_emb, video), dim=1)
74
+ video = video + self.clip.visual.positional_embedding.to(video.dtype)
75
+ video = self.clip.visual.ln_pre(video).permute(1, 0, 2)
76
+ emb = [video]
77
+ for blk in self.clip.visual.transformer.resblocks:
78
+ emb.append(blk(emb[-1]))
79
+ video = torch.stack([e.permute(1, 0, 2) for e in emb])
80
+ return video
81
+
82
+ @torch.no_grad
83
+ def clip_query_tower(self, query):
84
+ query = self.clip.token_embedding(query).type(self.clip.dtype)
85
+ query = query + self.clip.positional_embedding.type(self.clip.dtype)
86
+ query = query.permute(1, 0, 2)
87
+ emb = [query]
88
+ for blk in self.clip.transformer.resblocks:
89
+ emb.append(blk(emb[-1]))
90
+ query = torch.stack([e.permute(1, 0, 2) for e in emb])
91
+ return query
92
+
93
+ def forward(self, data, mode='test'):
94
+ video, query = data['video'], data['query']
95
+
96
+ if hasattr(self, 'clip'):
97
+ video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0)
98
+ query_msk = torch.where(query == 0, 0, 1)
99
+
100
+ video[~video.isfinite()] = 0
101
+
102
+ (b, t), d = video.size()[:2], int(math.sqrt(video.size(2) / 3))
103
+ video = video.view(b * t, 3, d, d)
104
+
105
+ video_emb = self.clip_video_tower(video)
106
+ query_emb = self.clip_query_tower(query)
107
+
108
+ n, _, p, c = video_emb.size()
109
+ video_emb = video_emb.view(n, b, t, p, c)
110
+ else:
111
+ video_msk = torch.where(video[:, :, 0].isfinite(), 1, 0)
112
+ query_msk = torch.where(query[:, :, 0].isfinite(), 1, 0)
113
+
114
+ video[~video.isfinite()] = 0
115
+ query[~query.isfinite()] = 0
116
+
117
+ (b, t), l = video.size()[:2], query.size(1)
118
+ video = video.view(b, t, -1, self.cfg[2], self.cfg[0]).permute(2, 0, 1, 3, 4)
119
+ query = query.view(b, l, -1, self.cfg[1]).permute(2, 0, 1, 3)
120
+
121
+ video_emb = video.float()
122
+ query_emb = query.float()
123
+
124
+ # video_emb: N * B * T * P * C
125
+ # query_emb: N * B * L * C
126
+
127
+ video_emb, query_emb, coll_v, coll_q = self.adapter(video_emb, query_emb,
128
+ video_msk, query_msk)
129
+
130
+ pymid, pymid_msk = self.pyramid(video_emb, video_msk, return_mask=mode != 'test')
131
+ point = self.generator(pymid)
132
+
133
+ with torch.autocast('cuda', enabled=False):
134
+ video_emb = video_emb.float()
135
+ query_emb = self.pooling(query_emb.float(), query_msk)
136
+
137
+ out_class = [self.class_head(e.float()) for e in pymid]
138
+ out_class = torch.cat(out_class, dim=1)
139
+
140
+ if self.coord_head is not None:
141
+ out_coord = [
142
+ self.coord_head(e.float()).exp() * self.coef[i]
143
+ for i, e in enumerate(pymid)
144
+ ]
145
+ out_coord = torch.cat(out_coord, dim=1)
146
+ else:
147
+ out_coord = None
148
+
149
+ output = dict(_avg_factor=b)
150
+
151
+ if mode != 'test':
152
+ data['coll_v'] = [e.float() for e in coll_v]
153
+ data['coll_q'] = [self.pooling(e.float(), query_msk) for e in coll_q]
154
+
155
+ data['point'] = point
156
+ data['video_emb'] = video_emb
157
+ data['query_emb'] = query_emb
158
+ data['video_msk'] = video_msk
159
+ data['pymid_msk'] = pymid_msk
160
+ data['out_class'] = out_class
161
+ data['out_coord'] = out_coord
162
+
163
+ output = self.loss(data, output)
164
+
165
+ if mode != 'train':
166
+ assert b == 1, 'batch size larger than 1 is not supported for inference'
167
+ out_class = out_class.sigmoid()
168
+ out_score = F.cosine_similarity(video_emb, query_emb, dim=-1)
169
+
170
+ output['_out'] = dict(label=data.get('label', [None])[0])
171
+
172
+ pyd_shape = [e.size(1) for e in pymid]
173
+ pyd_class = out_class[0, :, 0].split(pyd_shape)
174
+
175
+ saliency = []
176
+ for shape, score in zip(pyd_shape, pyd_class):
177
+ if t >= shape:
178
+ score = score.repeat_interleave(int(t / shape))
179
+ postfix = score[-1:].repeat(t - score.size(0))
180
+ score = torch.cat((score, postfix))
181
+ else:
182
+ scale = int(shape / t)
183
+ score = F.max_pool1d(score.unsqueeze(0), scale, stride=scale)[0]
184
+ saliency.append(score)
185
+
186
+ saliency = torch.stack(saliency).amax(dim=0)
187
+
188
+ if self.merge_cls_sal:
189
+ saliency *= out_score[0]
190
+
191
+ output['_out']['saliency'] = saliency
192
+
193
+ if self.coord_head is not None:
194
+ boundary = out_coord[0]
195
+ boundary[:, 0] *= -1
196
+ boundary *= point[:, 3, None].repeat(1, 2)
197
+ boundary += point[:, 0, None].repeat(1, 2)
198
+ boundary /= data['fps'][0]
199
+ boundary = torch.cat((boundary, out_class[0]), dim=-1)
200
+
201
+ _, inds = out_class[0, :, 0].sort(descending=True)
202
+ boundary = boundary[inds[:self.max_num_moment]]
203
+
204
+ output['_out']['boundary'] = boundary
205
+
206
+ return output
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/CLIP.git@a1d0717
2
+ decord==0.6.0
3
+ matplotlib==3.9.0
4
+ nncore==0.4.3
5
+ torch==2.2.1
6
+ torchvision==0.17.1
setup.cfg ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [yapf]
2
+ column_limit = 90
3
+ based_on_style = pep8
4
+ blank_line_before_nested_class_or_def = true
5
+ split_before_expression_after_opening_paren = true
6
+
7
+ [isort]
8
+ line_length = 90
9
+ multi_line_output = 0
10
+ known_third_party = clip,decord,gradio,nncore,numpy,torch,torchvision
11
+ no_lines_before = STDLIB,LOCALFOLDER
12
+ default_section = FIRSTPARTY
13
+
14
+ [flake8]
15
+ max-line-length = 90