Spaces:
Runtime error
Runtime error
switch to PyTorchModelHubMixin
Browse files- app.py +3 -13
- requirements.txt +1 -1
- uniformerv2.py +0 -510
app.py
CHANGED
@@ -8,7 +8,7 @@ import torchvision.transforms as T
|
|
8 |
from PIL import Image
|
9 |
from decord import VideoReader
|
10 |
from decord import cpu
|
11 |
-
from
|
12 |
from kinetics_class_index import kinetics_classnames
|
13 |
from transforms import (
|
14 |
GroupNormalize, GroupScale, GroupCenterCrop,
|
@@ -16,24 +16,14 @@ from transforms import (
|
|
16 |
)
|
17 |
|
18 |
import gradio as gr
|
19 |
-
from huggingface_hub import hf_hub_download
|
20 |
|
21 |
-
|
22 |
-
def __init__(self, model):
|
23 |
-
super().__init__()
|
24 |
-
self.backbone = model
|
25 |
-
|
26 |
-
def forward(self, x):
|
27 |
-
return self.backbone(x)
|
28 |
|
29 |
# Device on which to run the model
|
30 |
# Set to cuda to load on GPU
|
31 |
device = "cpu"
|
32 |
-
model_path = hf_hub_download(repo_id="Andy1621/uniformerv2", filename="k400+k710_uniformerv2_b16_8x224.pyth")
|
33 |
# Pick a pretrained model
|
34 |
-
model =
|
35 |
-
state_dict = torch.load(model_path, map_location='cpu')
|
36 |
-
model.load_state_dict(state_dict)
|
37 |
|
38 |
# Set to eval mode and move to desired device
|
39 |
model = model.to(device)
|
|
|
8 |
from PIL import Image
|
9 |
from decord import VideoReader
|
10 |
from decord import cpu
|
11 |
+
from slowfast.models.uniformerv2_model import VisionTransformer
|
12 |
from kinetics_class_index import kinetics_classnames
|
13 |
from transforms import (
|
14 |
GroupNormalize, GroupScale, GroupCenterCrop,
|
|
|
16 |
)
|
17 |
|
18 |
import gradio as gr
|
|
|
19 |
|
20 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Device on which to run the model
|
23 |
# Set to cuda to load on GPU
|
24 |
device = "cpu"
|
|
|
25 |
# Pick a pretrained model
|
26 |
+
model = VisionTransformer.from_pretrained("not-lain/uniformerv2_b16")
|
|
|
|
|
27 |
|
28 |
# Set to eval mode and move to desired device
|
29 |
model = model.to(device)
|
requirements.txt
CHANGED
@@ -3,4 +3,4 @@ torchvision
|
|
3 |
einops
|
4 |
timm
|
5 |
Pillow
|
6 |
-
|
|
|
3 |
einops
|
4 |
timm
|
5 |
Pillow
|
6 |
+
git+https://github.com/not-lain/UniFormerV2.git@integrate-with-huggingface
|
uniformerv2.py
DELETED
@@ -1,510 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
import os
|
3 |
-
from collections import OrderedDict
|
4 |
-
|
5 |
-
from timm.models.layers import DropPath
|
6 |
-
import torch
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import MultiheadAttention
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torch.utils.checkpoint as checkpoint
|
11 |
-
|
12 |
-
|
13 |
-
MODEL_PATH = './'
|
14 |
-
_MODELS = {
|
15 |
-
"ViT-B/16": os.path.join(MODEL_PATH, "vit_b16.pth"),
|
16 |
-
"ViT-L/14": os.path.join(MODEL_PATH, "vit_l14.pth"),
|
17 |
-
"ViT-L/14_336": os.path.join(MODEL_PATH, "vit_l14_336.pth"),
|
18 |
-
}
|
19 |
-
|
20 |
-
|
21 |
-
class LayerNorm(nn.LayerNorm):
|
22 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
23 |
-
|
24 |
-
def forward(self, x):
|
25 |
-
orig_type = x.dtype
|
26 |
-
ret = super().forward(x.type(torch.float32))
|
27 |
-
return ret.type(orig_type)
|
28 |
-
|
29 |
-
|
30 |
-
class QuickGELU(nn.Module):
|
31 |
-
def forward(self, x):
|
32 |
-
return x * torch.sigmoid(1.702 * x)
|
33 |
-
|
34 |
-
|
35 |
-
class Local_MHRA(nn.Module):
|
36 |
-
def __init__(self, d_model, dw_reduction=1.5, pos_kernel_size=3):
|
37 |
-
super().__init__()
|
38 |
-
|
39 |
-
padding = pos_kernel_size // 2
|
40 |
-
re_d_model = int(d_model // dw_reduction)
|
41 |
-
self.pos_embed = nn.Sequential(
|
42 |
-
nn.BatchNorm3d(d_model),
|
43 |
-
nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0),
|
44 |
-
nn.Conv3d(re_d_model, re_d_model, kernel_size=(pos_kernel_size, 1, 1), stride=(1, 1, 1), padding=(padding, 0, 0), groups=re_d_model),
|
45 |
-
nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0),
|
46 |
-
)
|
47 |
-
|
48 |
-
# init zero
|
49 |
-
print('Init zero for Conv in pos_emb')
|
50 |
-
nn.init.constant_(self.pos_embed[3].weight, 0)
|
51 |
-
nn.init.constant_(self.pos_embed[3].bias, 0)
|
52 |
-
|
53 |
-
def forward(self, x):
|
54 |
-
return self.pos_embed(x)
|
55 |
-
|
56 |
-
|
57 |
-
class ResidualAttentionBlock(nn.Module):
|
58 |
-
def __init__(
|
59 |
-
self, d_model, n_head, attn_mask=None, drop_path=0.0,
|
60 |
-
dw_reduction=1.5, no_lmhra=False, double_lmhra=True
|
61 |
-
):
|
62 |
-
super().__init__()
|
63 |
-
|
64 |
-
self.n_head = n_head
|
65 |
-
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
66 |
-
print(f'Drop path rate: {drop_path}')
|
67 |
-
|
68 |
-
self.no_lmhra = no_lmhra
|
69 |
-
self.double_lmhra = double_lmhra
|
70 |
-
print(f'No L_MHRA: {no_lmhra}')
|
71 |
-
print(f'Double L_MHRA: {double_lmhra}')
|
72 |
-
if not no_lmhra:
|
73 |
-
self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction)
|
74 |
-
if double_lmhra:
|
75 |
-
self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction)
|
76 |
-
|
77 |
-
# spatial
|
78 |
-
self.attn = MultiheadAttention(d_model, n_head)
|
79 |
-
self.ln_1 = LayerNorm(d_model)
|
80 |
-
self.mlp = nn.Sequential(OrderedDict([
|
81 |
-
("c_fc", nn.Linear(d_model, d_model * 4)),
|
82 |
-
("gelu", QuickGELU()),
|
83 |
-
("c_proj", nn.Linear(d_model * 4, d_model))
|
84 |
-
]))
|
85 |
-
self.ln_2 = LayerNorm(d_model)
|
86 |
-
self.attn_mask = attn_mask
|
87 |
-
|
88 |
-
def attention(self, x):
|
89 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
90 |
-
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
91 |
-
|
92 |
-
def forward(self, x, T=8, use_checkpoint=False):
|
93 |
-
# x: 1+HW, NT, C
|
94 |
-
if not self.no_lmhra:
|
95 |
-
# Local MHRA
|
96 |
-
tmp_x = x[1:, :, :]
|
97 |
-
L, NT, C = tmp_x.shape
|
98 |
-
N = NT // T
|
99 |
-
H = W = int(L ** 0.5)
|
100 |
-
tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
|
101 |
-
tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x))
|
102 |
-
tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
|
103 |
-
x = torch.cat([x[:1, :, :], tmp_x], dim=0)
|
104 |
-
# MHSA
|
105 |
-
if use_checkpoint:
|
106 |
-
attn_out = checkpoint.checkpoint(self.attention, self.ln_1(x))
|
107 |
-
x = x + self.drop_path(attn_out)
|
108 |
-
else:
|
109 |
-
x = x + self.drop_path(self.attention(self.ln_1(x)))
|
110 |
-
# Local MHRA
|
111 |
-
if not self.no_lmhra and self.double_lmhra:
|
112 |
-
tmp_x = x[1:, :, :]
|
113 |
-
tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
|
114 |
-
tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x))
|
115 |
-
tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
|
116 |
-
x = torch.cat([x[:1, :, :], tmp_x], dim=0)
|
117 |
-
# FFN
|
118 |
-
if use_checkpoint:
|
119 |
-
mlp_out = checkpoint.checkpoint(self.mlp, self.ln_2(x))
|
120 |
-
x = x + self.drop_path(mlp_out)
|
121 |
-
else:
|
122 |
-
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
123 |
-
return x
|
124 |
-
|
125 |
-
|
126 |
-
class Extractor(nn.Module):
|
127 |
-
def __init__(
|
128 |
-
self, d_model, n_head, attn_mask=None,
|
129 |
-
mlp_factor=4.0, dropout=0.0, drop_path=0.0,
|
130 |
-
):
|
131 |
-
super().__init__()
|
132 |
-
|
133 |
-
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
134 |
-
print(f'Drop path rate: {drop_path}')
|
135 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
136 |
-
self.ln_1 = nn.LayerNorm(d_model)
|
137 |
-
d_mlp = round(mlp_factor * d_model)
|
138 |
-
self.mlp = nn.Sequential(OrderedDict([
|
139 |
-
("c_fc", nn.Linear(d_model, d_mlp)),
|
140 |
-
("gelu", QuickGELU()),
|
141 |
-
("dropout", nn.Dropout(dropout)),
|
142 |
-
("c_proj", nn.Linear(d_mlp, d_model))
|
143 |
-
]))
|
144 |
-
self.ln_2 = nn.LayerNorm(d_model)
|
145 |
-
self.ln_3 = nn.LayerNorm(d_model)
|
146 |
-
self.attn_mask = attn_mask
|
147 |
-
|
148 |
-
# zero init
|
149 |
-
nn.init.xavier_uniform_(self.attn.in_proj_weight)
|
150 |
-
nn.init.constant_(self.attn.out_proj.weight, 0.)
|
151 |
-
nn.init.constant_(self.attn.out_proj.bias, 0.)
|
152 |
-
nn.init.xavier_uniform_(self.mlp[0].weight)
|
153 |
-
nn.init.constant_(self.mlp[-1].weight, 0.)
|
154 |
-
nn.init.constant_(self.mlp[-1].bias, 0.)
|
155 |
-
|
156 |
-
def attention(self, x, y):
|
157 |
-
d_model = self.ln_1.weight.size(0)
|
158 |
-
q = (x @ self.attn.in_proj_weight[:d_model].T) + self.attn.in_proj_bias[:d_model]
|
159 |
-
|
160 |
-
k = (y @ self.attn.in_proj_weight[d_model:-d_model].T) + self.attn.in_proj_bias[d_model:-d_model]
|
161 |
-
v = (y @ self.attn.in_proj_weight[-d_model:].T) + self.attn.in_proj_bias[-d_model:]
|
162 |
-
Tx, Ty, N = q.size(0), k.size(0), q.size(1)
|
163 |
-
q = q.view(Tx, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
|
164 |
-
k = k.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
|
165 |
-
v = v.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
|
166 |
-
aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim ** 0.5))
|
167 |
-
|
168 |
-
aff = aff.softmax(dim=-1)
|
169 |
-
out = aff @ v
|
170 |
-
out = out.permute(2, 0, 1, 3).flatten(2)
|
171 |
-
out = self.attn.out_proj(out)
|
172 |
-
return out
|
173 |
-
|
174 |
-
def forward(self, x, y):
|
175 |
-
x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y)))
|
176 |
-
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
177 |
-
return x
|
178 |
-
|
179 |
-
|
180 |
-
class Transformer(nn.Module):
|
181 |
-
def __init__(
|
182 |
-
self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0.,
|
183 |
-
use_checkpoint=False, checkpoint_num=[0], t_size=8, dw_reduction=2,
|
184 |
-
no_lmhra=False, double_lmhra=True,
|
185 |
-
return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
186 |
-
n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
|
187 |
-
mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
|
188 |
-
cls_dropout=0.5, num_classes=400,
|
189 |
-
):
|
190 |
-
super().__init__()
|
191 |
-
self.T = t_size
|
192 |
-
self.return_list = return_list
|
193 |
-
# backbone
|
194 |
-
b_dpr = [x.item() for x in torch.linspace(0, backbone_drop_path_rate, layers)]
|
195 |
-
self.resblocks = nn.ModuleList([
|
196 |
-
ResidualAttentionBlock(
|
197 |
-
width, heads, attn_mask,
|
198 |
-
drop_path=b_dpr[i],
|
199 |
-
dw_reduction=dw_reduction,
|
200 |
-
no_lmhra=no_lmhra,
|
201 |
-
double_lmhra=double_lmhra,
|
202 |
-
) for i in range(layers)
|
203 |
-
])
|
204 |
-
# checkpoint
|
205 |
-
self.use_checkpoint = use_checkpoint
|
206 |
-
self.checkpoint_num = checkpoint_num
|
207 |
-
self.n_layers = n_layers
|
208 |
-
print(f'Use checkpoint: {self.use_checkpoint}')
|
209 |
-
print(f'Checkpoint number: {self.checkpoint_num}')
|
210 |
-
|
211 |
-
# global block
|
212 |
-
assert n_layers == len(return_list)
|
213 |
-
if n_layers > 0:
|
214 |
-
self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim))
|
215 |
-
self.dpe = nn.ModuleList([
|
216 |
-
nn.Conv3d(n_dim, n_dim, kernel_size=3, stride=1, padding=1, bias=True, groups=n_dim)
|
217 |
-
for i in range(n_layers)
|
218 |
-
])
|
219 |
-
for m in self.dpe:
|
220 |
-
nn.init.constant_(m.bias, 0.)
|
221 |
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
|
222 |
-
self.dec = nn.ModuleList([
|
223 |
-
Extractor(
|
224 |
-
n_dim, n_head, mlp_factor=mlp_factor,
|
225 |
-
dropout=mlp_dropout[i], drop_path=dpr[i],
|
226 |
-
) for i in range(n_layers)
|
227 |
-
])
|
228 |
-
self.balance = nn.Parameter(torch.zeros((n_dim)))
|
229 |
-
self.sigmoid = nn.Sigmoid()
|
230 |
-
# projection
|
231 |
-
self.proj = nn.Sequential(
|
232 |
-
nn.LayerNorm(n_dim),
|
233 |
-
nn.Dropout(cls_dropout),
|
234 |
-
nn.Linear(n_dim, num_classes),
|
235 |
-
)
|
236 |
-
|
237 |
-
def forward(self, x):
|
238 |
-
T_down = self.T
|
239 |
-
L, NT, C = x.shape
|
240 |
-
N = NT // T_down
|
241 |
-
H = W = int((L - 1) ** 0.5)
|
242 |
-
|
243 |
-
if self.n_layers > 0:
|
244 |
-
cls_token = self.temporal_cls_token.repeat(1, N, 1)
|
245 |
-
|
246 |
-
j = -1
|
247 |
-
for i, resblock in enumerate(self.resblocks):
|
248 |
-
if self.use_checkpoint and i < self.checkpoint_num[0]:
|
249 |
-
x = resblock(x, self.T, use_checkpoint=True)
|
250 |
-
else:
|
251 |
-
x = resblock(x, T_down)
|
252 |
-
if i in self.return_list:
|
253 |
-
j += 1
|
254 |
-
tmp_x = x.clone()
|
255 |
-
tmp_x = tmp_x.view(L, N, T_down, C)
|
256 |
-
# dpe
|
257 |
-
_, tmp_feats = tmp_x[:1], tmp_x[1:]
|
258 |
-
tmp_feats = tmp_feats.permute(1, 3, 2, 0).reshape(N, C, T_down, H, W)
|
259 |
-
tmp_feats = self.dpe[j](tmp_feats).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()
|
260 |
-
tmp_x[1:] = tmp_x[1:] + tmp_feats
|
261 |
-
# global block
|
262 |
-
tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C
|
263 |
-
cls_token = self.dec[j](cls_token, tmp_x)
|
264 |
-
|
265 |
-
if self.n_layers > 0:
|
266 |
-
weight = self.sigmoid(self.balance)
|
267 |
-
residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
|
268 |
-
return self.proj((1 - weight) * cls_token[0, :, :] + weight * residual)
|
269 |
-
else:
|
270 |
-
residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
|
271 |
-
return self.proj(residual)
|
272 |
-
|
273 |
-
|
274 |
-
class VisionTransformer(nn.Module):
|
275 |
-
def __init__(
|
276 |
-
self,
|
277 |
-
# backbone
|
278 |
-
input_resolution, patch_size, width, layers, heads, output_dim, backbone_drop_path_rate=0.,
|
279 |
-
use_checkpoint=False, checkpoint_num=[0], t_size=8, kernel_size=3, dw_reduction=1.5,
|
280 |
-
temporal_downsample=True,
|
281 |
-
no_lmhra=-False, double_lmhra=True,
|
282 |
-
# global block
|
283 |
-
return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
284 |
-
n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
|
285 |
-
mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
|
286 |
-
cls_dropout=0.5, num_classes=400,
|
287 |
-
):
|
288 |
-
super().__init__()
|
289 |
-
self.input_resolution = input_resolution
|
290 |
-
self.output_dim = output_dim
|
291 |
-
padding = (kernel_size - 1) // 2
|
292 |
-
if temporal_downsample:
|
293 |
-
self.conv1 = nn.Conv3d(3, width, (kernel_size, patch_size, patch_size), (2, patch_size, patch_size), (padding, 0, 0), bias=False)
|
294 |
-
t_size = t_size // 2
|
295 |
-
else:
|
296 |
-
self.conv1 = nn.Conv3d(3, width, (1, patch_size, patch_size), (1, patch_size, patch_size), (0, 0, 0), bias=False)
|
297 |
-
|
298 |
-
scale = width ** -0.5
|
299 |
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
300 |
-
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
301 |
-
self.ln_pre = LayerNorm(width)
|
302 |
-
|
303 |
-
self.transformer = Transformer(
|
304 |
-
width, layers, heads, dw_reduction=dw_reduction,
|
305 |
-
backbone_drop_path_rate=backbone_drop_path_rate,
|
306 |
-
use_checkpoint=use_checkpoint, checkpoint_num=checkpoint_num, t_size=t_size,
|
307 |
-
no_lmhra=no_lmhra, double_lmhra=double_lmhra,
|
308 |
-
return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head,
|
309 |
-
mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout,
|
310 |
-
cls_dropout=cls_dropout, num_classes=num_classes,
|
311 |
-
)
|
312 |
-
|
313 |
-
def forward(self, x):
|
314 |
-
x = self.conv1(x) # shape = [*, width, grid, grid]
|
315 |
-
N, C, T, H, W = x.shape
|
316 |
-
x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C)
|
317 |
-
|
318 |
-
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
319 |
-
x = x + self.positional_embedding.to(x.dtype)
|
320 |
-
x = self.ln_pre(x)
|
321 |
-
|
322 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
323 |
-
out = self.transformer(x)
|
324 |
-
return out
|
325 |
-
|
326 |
-
|
327 |
-
def inflate_weight(weight_2d, time_dim, center=True):
|
328 |
-
print(f'Init center: {center}')
|
329 |
-
if center:
|
330 |
-
weight_3d = torch.zeros(*weight_2d.shape)
|
331 |
-
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
|
332 |
-
middle_idx = time_dim // 2
|
333 |
-
weight_3d[:, :, middle_idx, :, :] = weight_2d
|
334 |
-
else:
|
335 |
-
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
|
336 |
-
weight_3d = weight_3d / time_dim
|
337 |
-
return weight_3d
|
338 |
-
|
339 |
-
|
340 |
-
def load_state_dict(model, state_dict):
|
341 |
-
state_dict_3d = model.state_dict()
|
342 |
-
for k in state_dict.keys():
|
343 |
-
if state_dict[k].shape != state_dict_3d[k].shape:
|
344 |
-
if len(state_dict_3d[k].shape) <= 2:
|
345 |
-
print(f'Ignore: {k}')
|
346 |
-
continue
|
347 |
-
print(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
|
348 |
-
time_dim = state_dict_3d[k].shape[2]
|
349 |
-
state_dict[k] = inflate_weight(state_dict[k], time_dim)
|
350 |
-
model.load_state_dict(state_dict, strict=False)
|
351 |
-
|
352 |
-
|
353 |
-
def uniformerv2_b16(
|
354 |
-
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
|
355 |
-
t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
|
356 |
-
temporal_downsample=True,
|
357 |
-
no_lmhra=False, double_lmhra=True,
|
358 |
-
return_list=[8, 9, 10, 11],
|
359 |
-
n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
|
360 |
-
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
|
361 |
-
cls_dropout=0.5, num_classes=400,
|
362 |
-
):
|
363 |
-
model = VisionTransformer(
|
364 |
-
input_resolution=224,
|
365 |
-
patch_size=16,
|
366 |
-
width=768,
|
367 |
-
layers=12,
|
368 |
-
heads=12,
|
369 |
-
output_dim=512,
|
370 |
-
use_checkpoint=use_checkpoint,
|
371 |
-
checkpoint_num=checkpoint_num,
|
372 |
-
t_size=t_size,
|
373 |
-
dw_reduction=dw_reduction,
|
374 |
-
backbone_drop_path_rate=backbone_drop_path_rate,
|
375 |
-
temporal_downsample=temporal_downsample,
|
376 |
-
no_lmhra=no_lmhra,
|
377 |
-
double_lmhra=double_lmhra,
|
378 |
-
return_list=return_list,
|
379 |
-
n_layers=n_layers,
|
380 |
-
n_dim=n_dim,
|
381 |
-
n_head=n_head,
|
382 |
-
mlp_factor=mlp_factor,
|
383 |
-
drop_path_rate=drop_path_rate,
|
384 |
-
mlp_dropout=mlp_dropout,
|
385 |
-
cls_dropout=cls_dropout,
|
386 |
-
num_classes=num_classes,
|
387 |
-
)
|
388 |
-
|
389 |
-
if pretrained:
|
390 |
-
print('load pretrained weights')
|
391 |
-
state_dict = torch.load(_MODELS["ViT-B/16"], map_location='cpu')
|
392 |
-
load_state_dict(model, state_dict)
|
393 |
-
return model.eval()
|
394 |
-
|
395 |
-
|
396 |
-
def uniformerv2_l14(
|
397 |
-
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
|
398 |
-
t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
|
399 |
-
temporal_downsample=True,
|
400 |
-
no_lmhra=False, double_lmhra=True,
|
401 |
-
return_list=[20, 21, 22, 23],
|
402 |
-
n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
|
403 |
-
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
|
404 |
-
cls_dropout=0.5, num_classes=400,
|
405 |
-
):
|
406 |
-
model = VisionTransformer(
|
407 |
-
input_resolution=224,
|
408 |
-
patch_size=14,
|
409 |
-
width=1024,
|
410 |
-
layers=24,
|
411 |
-
heads=16,
|
412 |
-
output_dim=768,
|
413 |
-
use_checkpoint=use_checkpoint,
|
414 |
-
checkpoint_num=checkpoint_num,
|
415 |
-
t_size=t_size,
|
416 |
-
dw_reduction=dw_reduction,
|
417 |
-
backbone_drop_path_rate=backbone_drop_path_rate,
|
418 |
-
temporal_downsample=temporal_downsample,
|
419 |
-
no_lmhra=no_lmhra,
|
420 |
-
double_lmhra=double_lmhra,
|
421 |
-
return_list=return_list,
|
422 |
-
n_layers=n_layers,
|
423 |
-
n_dim=n_dim,
|
424 |
-
n_head=n_head,
|
425 |
-
mlp_factor=mlp_factor,
|
426 |
-
drop_path_rate=drop_path_rate,
|
427 |
-
mlp_dropout=mlp_dropout,
|
428 |
-
cls_dropout=cls_dropout,
|
429 |
-
num_classes=num_classes,
|
430 |
-
)
|
431 |
-
|
432 |
-
if pretrained:
|
433 |
-
print('load pretrained weights')
|
434 |
-
state_dict = torch.load(_MODELS["ViT-L/14"], map_location='cpu')
|
435 |
-
load_state_dict(model, state_dict)
|
436 |
-
return model.eval()
|
437 |
-
|
438 |
-
|
439 |
-
def uniformerv2_l14_336(
|
440 |
-
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
|
441 |
-
t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
|
442 |
-
no_temporal_downsample=True,
|
443 |
-
no_lmhra=False, double_lmhra=True,
|
444 |
-
return_list=[20, 21, 22, 23],
|
445 |
-
n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
|
446 |
-
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
|
447 |
-
cls_dropout=0.5, num_classes=400,
|
448 |
-
):
|
449 |
-
model = VisionTransformer(
|
450 |
-
input_resolution=336,
|
451 |
-
patch_size=14,
|
452 |
-
width=1024,
|
453 |
-
layers=24,
|
454 |
-
heads=16,
|
455 |
-
output_dim=768,
|
456 |
-
use_checkpoint=use_checkpoint,
|
457 |
-
checkpoint_num=checkpoint_num,
|
458 |
-
t_size=t_size,
|
459 |
-
dw_reduction=dw_reduction,
|
460 |
-
backbone_drop_path_rate=backbone_drop_path_rate,
|
461 |
-
no_temporal_downsample=no_temporal_downsample,
|
462 |
-
no_lmhra=no_lmhra,
|
463 |
-
double_lmhra=double_lmhra,
|
464 |
-
return_list=return_list,
|
465 |
-
n_layers=n_layers,
|
466 |
-
n_dim=n_dim,
|
467 |
-
n_head=n_head,
|
468 |
-
mlp_factor=mlp_factor,
|
469 |
-
drop_path_rate=drop_path_rate,
|
470 |
-
mlp_dropout=mlp_dropout,
|
471 |
-
cls_dropout=cls_dropout,
|
472 |
-
num_classes=num_classes,
|
473 |
-
)
|
474 |
-
|
475 |
-
if pretrained:
|
476 |
-
print('load pretrained weights')
|
477 |
-
state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
|
478 |
-
load_state_dict(model, state_dict)
|
479 |
-
return model.eval()
|
480 |
-
|
481 |
-
|
482 |
-
if __name__ == '__main__':
|
483 |
-
import time
|
484 |
-
from fvcore.nn import FlopCountAnalysis
|
485 |
-
from fvcore.nn import flop_count_table
|
486 |
-
import numpy as np
|
487 |
-
|
488 |
-
seed = 4217
|
489 |
-
np.random.seed(seed)
|
490 |
-
torch.manual_seed(seed)
|
491 |
-
torch.cuda.manual_seed(seed)
|
492 |
-
torch.cuda.manual_seed_all(seed)
|
493 |
-
num_frames = 16
|
494 |
-
|
495 |
-
model = uniformerv2_l14(
|
496 |
-
pretrained=False,
|
497 |
-
t_size=num_frames, backbone_drop_path_rate=0., drop_path_rate=0.,
|
498 |
-
dw_reduction=1.5,
|
499 |
-
no_lmhra=False,
|
500 |
-
temporal_downsample=True,
|
501 |
-
return_list=[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
|
502 |
-
mlp_dropout=[0.5]*16,
|
503 |
-
n_layers=16
|
504 |
-
)
|
505 |
-
print(model)
|
506 |
-
|
507 |
-
flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
|
508 |
-
s = time.time()
|
509 |
-
print(flop_count_table(flops, max_depth=1))
|
510 |
-
print(time.time()-s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|