amazinghaha commited on
Commit
b4d6f1e
1 Parent(s): 5c4b9bd

Upload 106 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. model/CoordAttention.py +110 -0
  2. model/Vision_Transformer_with_mask.py +990 -0
  3. model/__pycache__/CoordAttention.cpython-38.pyc +0 -0
  4. model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc +0 -0
  5. model/__pycache__/features.cpython-38.pyc +0 -0
  6. model/__pycache__/helpers.cpython-38.pyc +0 -0
  7. model/__pycache__/hub.cpython-38.pyc +0 -0
  8. model/__pycache__/registry.cpython-38.pyc +0 -0
  9. model/features.py +284 -0
  10. model/helpers.py +508 -0
  11. model/hub.py +96 -0
  12. model/layers/__init__.py +40 -0
  13. model/layers/__pycache__/__init__.cpython-38.pyc +0 -0
  14. model/layers/__pycache__/activations.cpython-38.pyc +0 -0
  15. model/layers/__pycache__/activations_jit.cpython-38.pyc +0 -0
  16. model/layers/__pycache__/activations_me.cpython-38.pyc +0 -0
  17. model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc +0 -0
  18. model/layers/__pycache__/blur_pool.cpython-38.pyc +0 -0
  19. model/layers/__pycache__/bottleneck_attn.cpython-38.pyc +0 -0
  20. model/layers/__pycache__/cbam.cpython-38.pyc +0 -0
  21. model/layers/__pycache__/classifier.cpython-38.pyc +0 -0
  22. model/layers/__pycache__/cond_conv2d.cpython-38.pyc +0 -0
  23. model/layers/__pycache__/config.cpython-38.pyc +0 -0
  24. model/layers/__pycache__/conv2d_same.cpython-38.pyc +0 -0
  25. model/layers/__pycache__/conv_bn_act.cpython-38.pyc +0 -0
  26. model/layers/__pycache__/create_act.cpython-38.pyc +0 -0
  27. model/layers/__pycache__/create_attn.cpython-38.pyc +0 -0
  28. model/layers/__pycache__/create_conv2d.cpython-38.pyc +0 -0
  29. model/layers/__pycache__/create_norm_act.cpython-38.pyc +0 -0
  30. model/layers/__pycache__/drop.cpython-38.pyc +0 -0
  31. model/layers/__pycache__/eca.cpython-38.pyc +0 -0
  32. model/layers/__pycache__/evo_norm.cpython-38.pyc +0 -0
  33. model/layers/__pycache__/gather_excite.cpython-38.pyc +0 -0
  34. model/layers/__pycache__/global_context.cpython-38.pyc +0 -0
  35. model/layers/__pycache__/halo_attn.cpython-38.pyc +0 -0
  36. model/layers/__pycache__/helpers.cpython-38.pyc +0 -0
  37. model/layers/__pycache__/inplace_abn.cpython-38.pyc +0 -0
  38. model/layers/__pycache__/involution.cpython-38.pyc +0 -0
  39. model/layers/__pycache__/lambda_layer.cpython-38.pyc +0 -0
  40. model/layers/__pycache__/linear.cpython-38.pyc +0 -0
  41. model/layers/__pycache__/mixed_conv2d.cpython-38.pyc +0 -0
  42. model/layers/__pycache__/mlp.cpython-38.pyc +0 -0
  43. model/layers/__pycache__/non_local_attn.cpython-38.pyc +0 -0
  44. model/layers/__pycache__/norm.cpython-38.pyc +0 -0
  45. model/layers/__pycache__/norm_act.cpython-38.pyc +0 -0
  46. model/layers/__pycache__/padding.cpython-38.pyc +0 -0
  47. model/layers/__pycache__/patch_embed.cpython-38.pyc +0 -0
  48. model/layers/__pycache__/pool2d_same.cpython-38.pyc +0 -0
  49. model/layers/__pycache__/selective_kernel.cpython-38.pyc +0 -0
  50. model/layers/__pycache__/separable_conv.cpython-38.pyc +0 -0
model/CoordAttention.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class h_sigmoid(nn.Module):
7
+ def __init__(self, inplace=True):
8
+ super(h_sigmoid, self).__init__()
9
+ self.relu = nn.ReLU6(inplace=inplace)
10
+
11
+ def forward(self, x):
12
+ return self.relu(x + 3) / 6
13
+
14
+
15
+ class h_swish(nn.Module):
16
+ def __init__(self, inplace=True):
17
+ super(h_swish, self).__init__()
18
+ self.sigmoid = h_sigmoid(inplace=inplace)
19
+
20
+ def forward(self, x):
21
+ return x * self.sigmoid(x)
22
+
23
+
24
+ class CoordAtt(nn.Module):
25
+ def __init__(self, inp, oup, reduction=32):
26
+ super(CoordAtt, self).__init__()
27
+ self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
28
+ self.pool_w = nn.AdaptiveAvgPool2d((1, None))
29
+
30
+ mip = max(8, inp // reduction)
31
+
32
+ self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
33
+ self.bn1 = nn.BatchNorm2d(mip)
34
+
35
+ self.bn2 = nn.BatchNorm2d(1)
36
+ self.bn3 = nn.BatchNorm2d(1)
37
+ self.act = h_swish()
38
+
39
+ self.bn4 = nn.BatchNorm2d(mip)
40
+ self.bn5 = nn.BatchNorm2d(mip)
41
+
42
+ self.bn6 = nn.BatchNorm2d(1)
43
+ self.bn7 = nn.BatchNorm2d(1)
44
+
45
+ self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
46
+ self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
47
+
48
+ def forward(self, x):
49
+ x = torch.unsqueeze(x, 1) #2 1 2304 196
50
+ identity = x
51
+
52
+ n, c, h, w = x.size()#2 1 2304 196
53
+ x_h = self.bn2(self.pool_h(x))#2 1 2304 1
54
+ x_w = self.bn3(self.pool_w(x).permute(0, 1, 3, 2)) #2 1 196 1
55
+ identity_x_w = x_w
56
+ identity_x_h = x_h
57
+ y = torch.cat([x_h, x_w], dim=2)
58
+ y = self.conv1(y) #2 8 2500 1
59
+ y = self.bn1(y)
60
+ y = self.act(y)
61
+
62
+ x_h, x_w = torch.split(y, [h, w], dim=2) #2 8 2304 1 | 2 8 196 1
63
+ x_h = self.bn4(x_h)+identity_x_h
64
+ x_w = self.bn5(x_w)+identity_x_w
65
+ x_w = x_w.permute(0, 1, 3, 2)
66
+
67
+ a_h = self.bn6(self.conv_h(x_h)).sigmoid() #2 1 2304 1
68
+ a_w = self.bn7(self.conv_w(x_w)).sigmoid() #24 1 1 196
69
+
70
+ out = identity * a_w * a_h #点×
71
+ out = torch.squeeze(out, 1)
72
+ return out
73
+
74
+ class CoordAtt_ori(nn.Module):
75
+ def __init__(self, inp, oup, reduction=32):
76
+ super(CoordAtt_ori, self).__init__()
77
+ self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
78
+ self.pool_w = nn.AdaptiveAvgPool2d((1, None))
79
+
80
+ mip = max(8, inp // reduction)
81
+
82
+ self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
83
+ self.bn1 = nn.BatchNorm2d(mip)
84
+ self.act = h_swish()
85
+
86
+ self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
87
+ self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
88
+
89
+ def forward(self, x):
90
+ x = torch.unsqueeze(x, 1)
91
+ identity = x
92
+
93
+ n, c, h, w = x.size()
94
+ x_h = self.pool_h(x)
95
+ x_w = self.pool_w(x).permute(0, 1, 3, 2)
96
+
97
+ y = torch.cat([x_h, x_w], dim=2)
98
+ y = self.conv1(y)
99
+ y = self.bn1(y)
100
+ y = self.act(y)
101
+
102
+ x_h, x_w = torch.split(y, [h, w], dim=2)
103
+ x_w = x_w.permute(0, 1, 3, 2)
104
+
105
+ a_h = self.conv_h(x_h).sigmoid()
106
+ a_w = self.conv_w(x_w).sigmoid()
107
+
108
+ out = identity * a_w * a_h
109
+ out = torch.squeeze(out, 1)
110
+ return out
model/Vision_Transformer_with_mask.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in:
4
+
5
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6
+ - https://arxiv.org/abs/2010.11929
7
+
8
+ `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9
+ - https://arxiv.org/abs/2106.10270
10
+
11
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
12
+
13
+ DeiT model defs and weights from https://github.com/facebookresearch/deit,
14
+ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
15
+
16
+ Acknowledgments:
17
+ * The paper authors for releasing code and weights, thanks!
18
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19
+ for some einops/einsum fun
20
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22
+
23
+ Hacked together by / Copyright 2021 Ross Wightman
24
+ """
25
+ import math
26
+ import logging
27
+ from functools import partial
28
+ from collections import OrderedDict
29
+ from copy import deepcopy
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
35
+
36
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
37
+ from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
38
+ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
39
+ from .registry import register_model
40
+
41
+ _logger = logging.getLogger(__name__)
42
+
43
+
44
+ def _cfg(url='', **kwargs):
45
+ return {
46
+ 'url': url,
47
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
48
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
49
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
50
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
51
+ **kwargs
52
+ }
53
+
54
+
55
+ default_cfgs = {
56
+ # patch models (weights from official Google JAX impl)
57
+ 'vit_tiny_patch16_224': _cfg(
58
+ url='https://storage.googleapis.com/vit_models/augreg/'
59
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
60
+ 'vit_tiny_patch16_384': _cfg(
61
+ url='https://storage.googleapis.com/vit_models/augreg/'
62
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
63
+ input_size=(3, 384, 384), crop_pct=1.0),
64
+ 'vit_small_patch32_224': _cfg(
65
+ url='https://storage.googleapis.com/vit_models/augreg/'
66
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
67
+ 'vit_small_patch32_384': _cfg(
68
+ url='https://storage.googleapis.com/vit_models/augreg/'
69
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
70
+ input_size=(3, 384, 384), crop_pct=1.0),
71
+ 'vit_small_patch16_224': _cfg(
72
+ url='https://storage.googleapis.com/vit_models/augreg/'
73
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
74
+ 'vit_small_patch16_384': _cfg(
75
+ url='https://storage.googleapis.com/vit_models/augreg/'
76
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
77
+ input_size=(3, 384, 384), crop_pct=1.0),
78
+ 'vit_base_patch32_224': _cfg(
79
+ url='https://storage.googleapis.com/vit_models/augreg/'
80
+ 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
81
+ 'vit_base_patch32_384': _cfg(
82
+ url='https://storage.googleapis.com/vit_models/augreg/'
83
+ 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
84
+ input_size=(3, 384, 384), crop_pct=1.0),
85
+ 'vit_base_patch16_224': _cfg(
86
+ url='https://storage.googleapis.com/vit_models/augreg/'
87
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
88
+ 'vit_base_patch16_384': _cfg(
89
+ url='https://storage.googleapis.com/vit_models/augreg/'
90
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
91
+ input_size=(3, 384, 384), crop_pct=1.0),
92
+ 'vit_large_patch32_224': _cfg(
93
+ url='', # no official model weights for this combo, only for in21k
94
+ ),
95
+ 'vit_large_patch32_384': _cfg(
96
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
97
+ input_size=(3, 384, 384), crop_pct=1.0),
98
+ 'vit_large_patch16_224': _cfg(
99
+ url='https://storage.googleapis.com/vit_models/augreg/'
100
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
101
+ 'vit_large_patch16_384': _cfg(
102
+ url='https://storage.googleapis.com/vit_models/augreg/'
103
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
104
+ input_size=(3, 384, 384), crop_pct=1.0),
105
+
106
+ # patch models, imagenet21k (weights from official Google JAX impl)
107
+ 'vit_tiny_patch16_224_in21k': _cfg(
108
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
109
+ num_classes=21843),
110
+ 'vit_small_patch32_224_in21k': _cfg(
111
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
112
+ num_classes=21843),
113
+ 'vit_small_patch16_224_in21k': _cfg(
114
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
115
+ num_classes=21843),
116
+ 'vit_base_patch32_224_in21k': _cfg(
117
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
118
+ num_classes=21843),
119
+ 'vit_base_patch16_224_in21k': _cfg(
120
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
121
+ num_classes=21843),
122
+ 'vit_large_patch32_224_in21k': _cfg(
123
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
124
+ num_classes=21843),
125
+ 'vit_large_patch16_224_in21k': _cfg(
126
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
127
+ num_classes=21843),
128
+ 'vit_huge_patch14_224_in21k': _cfg(
129
+ url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
130
+ hf_hub='timm/vit_huge_patch14_224_in21k',
131
+ num_classes=21843),
132
+
133
+ # deit models (FB weights)
134
+ 'deit_tiny_patch16_224': _cfg(
135
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
136
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
137
+ 'deit_small_patch16_224': _cfg(
138
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
139
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
140
+ 'deit_base_patch16_224': _cfg(
141
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
142
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
143
+ 'deit_base_patch16_384': _cfg(
144
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
145
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
146
+ 'deit_tiny_distilled_patch16_224': _cfg(
147
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
148
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
149
+ 'deit_small_distilled_patch16_224': _cfg(
150
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
151
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
152
+ 'deit_base_distilled_patch16_224': _cfg(
153
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
154
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
155
+ 'deit_base_distilled_patch16_384': _cfg(
156
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
157
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
158
+ classifier=('head', 'head_dist')),
159
+
160
+ # ViT ImageNet-21K-P pretraining by MILL
161
+ 'vit_base_patch16_224_miil_in21k': _cfg(
162
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
163
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
164
+ ),
165
+ 'vit_base_patch16_224_miil': _cfg(
166
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
167
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
168
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
169
+ ),
170
+ }
171
+
172
+
173
+ class CrossAttention(nn.Module):
174
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
175
+ super().__init__()
176
+ self.num_heads = num_heads
177
+ head_dim = dim // num_heads
178
+ self.scale = qk_scale or head_dim ** -0.5 #这行多了个qk_scale #0.125
179
+
180
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
181
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
182
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
183
+ self.attn_drop = nn.Dropout(attn_drop)
184
+ self.proj = nn.Linear(dim, dim)
185
+ self.proj_drop = nn.Dropout(proj_drop)
186
+
187
+ def forward(self, x):
188
+
189
+ B, N, C = x.shape #2 512 768
190
+ q = self.wq(x[:, 0:int(N/2), ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)#2 12 256 64
191
+ k = self.wk(x[:, (int(N/2)):, ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
192
+ v = self.wv(x[:, (int(N/2)):, ...]).reshape(B, int(N/2), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
193
+
194
+ attn = (q @ k.transpose(-2, -1)) * self.scale
195
+ attn = attn.softmax(dim=-1)
196
+ attn = self.attn_drop(attn)
197
+
198
+ x = (attn @ v).transpose(1, 2).reshape(B, int(N/2), C) #变成了B/2 2 256 768
199
+ x = self.proj(x)
200
+ x = self.proj_drop(x)
201
+ return x
202
+
203
+
204
+
205
+ class Attention(nn.Module):
206
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0.):
207
+ super().__init__()
208
+ self.num_heads = num_heads
209
+ head_dim = dim // num_heads
210
+ self.scale = qk_scale or head_dim ** -0.5
211
+
212
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
213
+ self.attn_drop = nn.Dropout(attn_drop)
214
+ self.proj = nn.Linear(dim, dim)
215
+ self.proj_drop = nn.Dropout(proj_drop)
216
+
217
+ def forward(self, data):
218
+ b,c,h = data.shape
219
+ x,atten_mask = data[:,0:int(c/2),...],data[:,int(c/2):,...]
220
+ B, N, C = x.shape
221
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
222
+ q, k, v = qkv[0], qkv[1], qkv[2] #2,12,49,64 # make torchscript happy (cannot use tensor as tuple)
223
+
224
+
225
+ attn = (q @ k.transpose(-2, -1)) * self.scale #2,12,49,49 #mask 2,1,49,49
226
+ if atten_mask.sum() != 0:
227
+ atten_mask = atten_mask.unsqueeze(1) # 2,1,49,49
228
+ attn = attn + atten_mask
229
+ attn = attn.softmax(dim=-1)
230
+ attn = self.attn_drop(attn)
231
+
232
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
233
+ x = self.proj(x)
234
+ x = self.proj_drop(x)
235
+ return x
236
+
237
+ class Attention_ori(nn.Module):
238
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,attn_drop=0., proj_drop=0.):
239
+ super().__init__()
240
+ self.num_heads = num_heads
241
+ head_dim = dim // num_heads
242
+ self.scale = qk_scale or head_dim ** -0.5
243
+
244
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
245
+ self.attn_drop = nn.Dropout(attn_drop)
246
+ self.proj = nn.Linear(dim, dim)
247
+ self.proj_drop = nn.Dropout(proj_drop)
248
+
249
+ def forward(self, x):
250
+
251
+ B, N, C = x.shape
252
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
253
+ q, k, v = qkv[0], qkv[1], qkv[2] #2,12,49,64 # make torchscript happy (cannot use tensor as tuple)
254
+ attn = (q @ k.transpose(-2, -1)) * self.scale #2,12,49,49 #mask 2,1,49,49
255
+
256
+ attn = attn.softmax(dim=-1)
257
+ attn = self.attn_drop(attn)
258
+
259
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
260
+ x = self.proj(x)
261
+ x = self.proj_drop(x)
262
+ return x
263
+
264
+
265
+
266
+
267
+ class Block(nn.Module):
268
+
269
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
270
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
271
+ super().__init__()
272
+ self.norm1 = norm_layer(dim)
273
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
274
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
276
+ self.norm2 = norm_layer(dim)
277
+ mlp_hidden_dim = int(dim * mlp_ratio)
278
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
279
+
280
+ def forward(self, data):
281
+ b,c,h = data.shape
282
+ x,mask = data[:,0:int(c/2),...],data[:,int(c/2):,...]
283
+ x = x + self.drop_path(self.attn(torch.cat([self.norm1(x),mask],dim=1)))
284
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
285
+ return torch.cat([x,mask],dim=1)
286
+
287
+ class mask_PatchEmbed(nn.Module):
288
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, norm_layer=None, flatten=True):
289
+ super().__init__()
290
+ img_size = to_2tuple(img_size)
291
+ patch_size = to_2tuple(patch_size)
292
+ self.img_size = img_size
293
+ self.patch_size = patch_size
294
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
295
+ self.flatten = flatten
296
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
297
+ self.proj = nn.Conv2d(in_chans, 1, kernel_size=patch_size, stride=patch_size).requires_grad_(False)
298
+ nn.init.ones_(self.proj.weight)
299
+ nn.init.zeros_(self.proj.bias)
300
+ def forward(self, x):
301
+ B, C, H, W = x.shape
302
+ assert H == self.img_size[0] and W == self.img_size[1], \
303
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
304
+ x = self.proj(x)
305
+ if self.flatten:
306
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
307
+ return x
308
+
309
+ class VisionTransformer(nn.Module):
310
+ """ Vision Transformer
311
+
312
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
313
+ - https://arxiv.org/abs/2010.11929
314
+
315
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
316
+ - https://arxiv.org/abs/2012.12877
317
+ """
318
+
319
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
320
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
321
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
322
+ act_layer=None,as_backbone=True, weight_init=''):
323
+ """
324
+ Args:
325
+ img_size (int, tuple): input image size
326
+ patch_size (int, tuple): patch size
327
+ in_chans (int): number of input channels
328
+ num_classes (int): number of classes for classification head
329
+ embed_dim (int): embedding dimension
330
+ depth (int): depth of transformer
331
+ num_heads (int): number of attention heads
332
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
333
+ qkv_bias (bool): enable bias for qkv if True
334
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
335
+ distilled (bool): model includes a distillation token and head as in DeiT models
336
+ drop_rate (float): dropout rate
337
+ attn_drop_rate (float): attention dropout rate
338
+ drop_path_rate (float): stochastic depth rate
339
+ embed_layer (nn.Module): patch embedding layer
340
+ norm_layer: (nn.Module): normalization layer
341
+ weight_init: (str): weight init scheme
342
+ """
343
+ super().__init__()
344
+ self.num_classes = num_classes
345
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
346
+ self.num_tokens = 2 if distilled else 1
347
+ self.num_heads = num_heads
348
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
349
+ act_layer = act_layer or nn.GELU
350
+ self.as_backbone = as_backbone #是否分类任务,如果不是,class不加上去
351
+ self.patch_embed = embed_layer(
352
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
353
+ self.mask_embed = mask_PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans)
354
+ num_patches = self.patch_embed.num_patches
355
+
356
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
357
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
358
+ if not self.as_backbone:
359
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
360
+ else:
361
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
362
+ self.pos_drop = nn.Dropout(p=drop_rate)
363
+
364
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
365
+ self.blocks = nn.Sequential(*[
366
+ Block(
367
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
368
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
369
+ for i in range(depth)])
370
+ self.norm = norm_layer(embed_dim)
371
+
372
+ # Representation layer
373
+ if representation_size and not distilled:
374
+ self.num_features = representation_size
375
+ self.pre_logits = nn.Sequential(OrderedDict([
376
+ ('fc', nn.Linear(embed_dim, representation_size)),
377
+ ('act', nn.Tanh())
378
+ ]))
379
+ else:
380
+ self.pre_logits = nn.Identity()
381
+ if not self.as_backbone:
382
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
383
+ # Classifier head(s)
384
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
385
+ self.head_dist = None
386
+ if distilled:
387
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
388
+
389
+ self.init_weights(weight_init)
390
+
391
+ def init_weights(self, mode=''):
392
+ assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
393
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
394
+ trunc_normal_(self.pos_embed, std=.02)
395
+ if self.dist_token is not None:
396
+ trunc_normal_(self.dist_token, std=.02)
397
+ if mode.startswith('jax'):
398
+ # leave cls token as zeros to match jax impl
399
+ named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
400
+ else:
401
+ trunc_normal_(self.cls_token, std=.02)
402
+ self.apply(_init_vit_weights)
403
+
404
+ def _init_weights(self, m):
405
+ # this fn left here for compat with downstream users
406
+ _init_vit_weights(m)
407
+
408
+ @torch.jit.ignore()
409
+ def load_pretrained(self, checkpoint_path, prefix=''):
410
+ _load_weights(self, checkpoint_path, prefix)
411
+
412
+ @torch.jit.ignore
413
+ def no_weight_decay(self):
414
+ return {'pos_embed', 'cls_token', 'dist_token'}
415
+
416
+ def get_classifier(self):
417
+ if self.dist_token is None:
418
+ return self.head
419
+ else:
420
+ return self.head, self.head_dist
421
+
422
+ def reset_classifier(self, num_classes, global_pool=''):
423
+ self.num_classes = num_classes
424
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
425
+ if self.num_tokens == 2:
426
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
427
+
428
+ def forward_features(self, data):
429
+ x,mask = data[:,0,:,:].unsqueeze(1),data[:,1,:,:].unsqueeze(1)
430
+ x = self.patch_embed(x)#B N C
431
+ atten_mask = torch.zeros_like(x) # 2 49 768
432
+ if mask.sum() != 0:
433
+ mask = self.mask_embed(mask) ###
434
+ mask.squeeze_(dim=2)
435
+ mask[mask != 0] = 1 ### H W数目token C编码长度
436
+ k1 = mask[:, None, :]
437
+ k2 = torch.ones_like(mask)[:, :, None]
438
+ k3 = k1 * k2
439
+ atten_mask = (1.0 - k3) * (-1e6)
440
+ atten_mask.requires_grad_(True)
441
+ self.atten_mask = atten_mask
442
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
443
+ if not self.as_backbone:
444
+ if self.dist_token is None:
445
+ x = torch.cat((cls_token, x), dim=1)
446
+ else:
447
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
448
+ x = self.pos_drop(x + self.pos_embed) #2 49 768
449
+ x = self.blocks(torch.cat([x,atten_mask],dim=1))
450
+ b,c,h = x.shape
451
+ x = x[:,0:int(c/2),...]
452
+ x = self.norm(x)
453
+ if self.as_backbone:
454
+ # x = self.avgpool(x.transpose(1, 2)) # B C 1
455
+ # x = torch.flatten(x, 1)
456
+ return x
457
+ if self.dist_token is None:
458
+ return self.pre_logits(x[:, 0])
459
+ else:
460
+ return x[:, 0], x[:, 1]
461
+
462
+ def forward(self, data):
463
+ x = self.forward_features(data) #2 49 768
464
+ if self.as_backbone:
465
+ return x
466
+ else:
467
+ if self.head_dist is not None:
468
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
469
+ if self.training and not torch.jit.is_scripting():
470
+ # during inference, return the average of both classifier predictions
471
+ return x, x_dist
472
+ else:
473
+ return (x + x_dist) / 2
474
+ else:
475
+ x = self.head(x)
476
+ return x
477
+
478
+
479
+ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
480
+ """ ViT weight initialization
481
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
482
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
483
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
484
+ """
485
+ if isinstance(module, nn.Linear):
486
+ if name.startswith('head'):
487
+ nn.init.zeros_(module.weight)
488
+ nn.init.constant_(module.bias, head_bias)
489
+ elif name.startswith('pre_logits'):
490
+ lecun_normal_(module.weight)
491
+ nn.init.zeros_(module.bias)
492
+ else:
493
+ if jax_impl:
494
+ nn.init.xavier_uniform_(module.weight)
495
+ if module.bias is not None:
496
+ if 'mlp' in name:
497
+ nn.init.normal_(module.bias, std=1e-6)
498
+ else:
499
+ nn.init.zeros_(module.bias)
500
+ else:
501
+ trunc_normal_(module.weight, std=.02)
502
+ if module.bias is not None:
503
+ nn.init.zeros_(module.bias)
504
+ elif jax_impl and isinstance(module, nn.Conv2d):
505
+ # NOTE conv was left to pytorch default in my original init
506
+ lecun_normal_(module.weight)
507
+ if module.bias is not None:
508
+ nn.init.zeros_(module.bias)
509
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
510
+ nn.init.zeros_(module.bias)
511
+ nn.init.ones_(module.weight)
512
+
513
+
514
+ @torch.no_grad()
515
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
516
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
517
+ """
518
+ import numpy as np
519
+
520
+ def _n2p(w, t=True):
521
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
522
+ w = w.flatten()
523
+ if t:
524
+ if w.ndim == 4:
525
+ w = w.transpose([3, 2, 0, 1])
526
+ elif w.ndim == 3:
527
+ w = w.transpose([2, 0, 1])
528
+ elif w.ndim == 2:
529
+ w = w.transpose([1, 0])
530
+ return torch.from_numpy(w)
531
+
532
+ w = np.load(checkpoint_path)
533
+ if not prefix and 'opt/target/embedding/kernel' in w:
534
+ prefix = 'opt/target/'
535
+
536
+ if hasattr(model.patch_embed, 'backbone'):
537
+ # hybrid
538
+ backbone = model.patch_embed.backbone
539
+ stem_only = not hasattr(backbone, 'stem')
540
+ stem = backbone if stem_only else backbone.stem
541
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
542
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
543
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
544
+ if not stem_only:
545
+ for i, stage in enumerate(backbone.stages):
546
+ for j, block in enumerate(stage.blocks):
547
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
548
+ for r in range(3):
549
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
550
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
551
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
552
+ if block.downsample is not None:
553
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
554
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
555
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
556
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
557
+ else:
558
+ embed_conv_w = adapt_input_conv(
559
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
560
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
561
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
562
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
563
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
564
+ if pos_embed_w.shape != model.pos_embed.shape:
565
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
566
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
567
+ model.pos_embed.copy_(pos_embed_w)
568
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
569
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
570
+ if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
571
+ model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
572
+ model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
573
+ if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
574
+ model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
575
+ model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
576
+ for i, block in enumerate(model.blocks.children()):
577
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
578
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
579
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
580
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
581
+ block.attn.qkv.weight.copy_(torch.cat([
582
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
583
+ block.attn.qkv.bias.copy_(torch.cat([
584
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
585
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
586
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
587
+ for r in range(2):
588
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
589
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
590
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
591
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
592
+
593
+
594
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
595
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
596
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
597
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
598
+ ntok_new = posemb_new.shape[1]
599
+ if num_tokens:
600
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
601
+ ntok_new -= num_tokens
602
+ else:
603
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
604
+ gs_old = int(math.sqrt(len(posemb_grid)))
605
+ if not len(gs_new): # backwards compatibility
606
+ gs_new = [int(math.sqrt(ntok_new))] * 2
607
+ assert len(gs_new) >= 2
608
+ _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
609
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
610
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
611
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
612
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
613
+ return posemb
614
+
615
+
616
+ def checkpoint_filter_fn(state_dict, model):
617
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
618
+ out_dict = {}
619
+ if 'model' in state_dict:
620
+ # For deit models
621
+ state_dict = state_dict['model']
622
+ for k, v in state_dict.items():
623
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
624
+ # For old models that I trained prior to conv based patchification
625
+ O, I, H, W = model.patch_embed.proj.weight.shape
626
+ v = v.reshape(O, -1, H, W)
627
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
628
+ # To resize pos embedding when using model at different size from pretrained weights
629
+ v = resize_pos_embed(
630
+ v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
631
+ out_dict[k] = v
632
+ return out_dict
633
+
634
+
635
+ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
636
+ default_cfg = default_cfg or default_cfgs[variant]
637
+ if kwargs.get('features_only', None):
638
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
639
+
640
+ # NOTE this extra code to support handling of repr size for in21k pretrained models
641
+ default_num_classes = default_cfg['num_classes']
642
+ num_classes = kwargs.get('num_classes', default_num_classes)
643
+ repr_size = kwargs.pop('representation_size', None)
644
+ if repr_size is not None and num_classes != default_num_classes:
645
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
646
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
647
+ _logger.warning("Removing representation layer for fine-tuning.")
648
+ repr_size = None
649
+
650
+ model = build_model_with_cfg(
651
+ VisionTransformer, variant, pretrained,
652
+ default_cfg=default_cfg,
653
+ representation_size=repr_size,
654
+ pretrained_filter_fn=checkpoint_filter_fn,
655
+ pretrained_custom_load='npz' in default_cfg['url'],
656
+ **kwargs)
657
+ return model
658
+
659
+
660
+ @register_model
661
+ def vit_tiny_patch16_224(pretrained=False, **kwargs):
662
+ """ ViT-Tiny (Vit-Ti/16)
663
+ """
664
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
665
+ model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
666
+ return model
667
+
668
+
669
+ @register_model
670
+ def vit_tiny_patch16_384(pretrained=False, **kwargs):
671
+ """ ViT-Tiny (Vit-Ti/16) @ 384x384.
672
+ """
673
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
674
+ model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
675
+ return model
676
+
677
+
678
+ @register_model
679
+ def vit_small_patch32_224(pretrained=False, **kwargs):
680
+ """ ViT-Small (ViT-S/32)
681
+ """
682
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
683
+ model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
684
+ return model
685
+
686
+
687
+ @register_model
688
+ def vit_small_patch32_384(pretrained=False, **kwargs):
689
+ """ ViT-Small (ViT-S/32) at 384x384.
690
+ """
691
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
692
+ model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
693
+ return model
694
+
695
+
696
+ @register_model
697
+ def vit_small_patch16_224(pretrained=False, **kwargs):
698
+ """ ViT-Small (ViT-S/16)
699
+ NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
700
+ """
701
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
702
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
703
+ return model
704
+
705
+
706
+ @register_model
707
+ def vit_small_patch16_384(pretrained=False, **kwargs):
708
+ """ ViT-Small (ViT-S/16)
709
+ NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
710
+ """
711
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
712
+ model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
713
+ return model
714
+
715
+
716
+ @register_model
717
+ def vit_base_patch32_224(pretrained=False, **kwargs):
718
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
719
+ """
720
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
721
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
722
+ return model
723
+
724
+
725
+ @register_model
726
+ def vit_base_patch32_384(pretrained=False, **kwargs):
727
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
728
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
729
+ """
730
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
731
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
732
+ return model
733
+
734
+
735
+ @register_model
736
+ def vit_base_patch16_224(pretrained=False, **kwargs):
737
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
738
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
739
+ """
740
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
741
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
742
+ return model
743
+
744
+
745
+ @register_model
746
+ def vit_base_patch16_384(pretrained=False, **kwargs):
747
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
748
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
749
+ """
750
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
751
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
752
+ return model
753
+
754
+
755
+ @register_model
756
+ def vit_large_patch32_224(pretrained=False, **kwargs):
757
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
758
+ """
759
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
760
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
761
+ return model
762
+
763
+
764
+ @register_model
765
+ def vit_large_patch32_384(pretrained=False, **kwargs):
766
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
767
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
768
+ """
769
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
770
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
771
+ return model
772
+
773
+
774
+ @register_model
775
+ def vit_large_patch16_224(pretrained=False, **kwargs):
776
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
777
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
778
+ """
779
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
780
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
781
+ return model
782
+
783
+
784
+ @register_model
785
+ def vit_large_patch16_384(pretrained=False, **kwargs):
786
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
787
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
788
+ """
789
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
790
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
791
+ return model
792
+
793
+
794
+ @register_model
795
+ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
796
+ """ ViT-Tiny (Vit-Ti/16).
797
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
798
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
799
+ """
800
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
801
+ model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
802
+ return model
803
+
804
+
805
+ @register_model
806
+ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
807
+ """ ViT-Small (ViT-S/16)
808
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
809
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
810
+ """
811
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
812
+ model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
813
+ return model
814
+
815
+
816
+ @register_model
817
+ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
818
+ """ ViT-Small (ViT-S/16)
819
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
820
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
821
+ """
822
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
823
+ model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
824
+ return model
825
+
826
+
827
+ @register_model
828
+ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
829
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
830
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
831
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
832
+ """
833
+ model_kwargs = dict(
834
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
835
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
836
+ return model
837
+
838
+
839
+ @register_model
840
+ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
841
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
842
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
843
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
844
+ """
845
+ model_kwargs = dict(
846
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
847
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
848
+ return model
849
+
850
+
851
+ @register_model
852
+ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
853
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
854
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
855
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
856
+ """
857
+ model_kwargs = dict(
858
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
859
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
860
+ return model
861
+
862
+
863
+ @register_model
864
+ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
865
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
866
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
867
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
868
+ """
869
+ model_kwargs = dict(
870
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
871
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
872
+ return model
873
+
874
+
875
+ @register_model
876
+ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
877
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
878
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
879
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
880
+ """
881
+ model_kwargs = dict(
882
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
883
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
884
+ return model
885
+
886
+
887
+ @register_model
888
+ def deit_tiny_patch16_224(pretrained=False, **kwargs):
889
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
890
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
891
+ """
892
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
893
+ model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
894
+ return model
895
+
896
+
897
+ @register_model
898
+ def deit_small_patch16_224(pretrained=False, **kwargs):
899
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
900
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
901
+ """
902
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
903
+ model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
904
+ return model
905
+
906
+
907
+ @register_model
908
+ def deit_base_patch16_224(pretrained=False, **kwargs):
909
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
910
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
911
+ """
912
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
913
+ model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
914
+ return model
915
+
916
+
917
+ @register_model
918
+ def deit_base_patch16_384(pretrained=False, **kwargs):
919
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
920
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
921
+ """
922
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
923
+ model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
924
+ return model
925
+
926
+
927
+ @register_model
928
+ def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
929
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
930
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
931
+ """
932
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
933
+ model = _create_vision_transformer(
934
+ 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
935
+ return model
936
+
937
+
938
+ @register_model
939
+ def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
940
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
941
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
942
+ """
943
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
944
+ model = _create_vision_transformer(
945
+ 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
946
+ return model
947
+
948
+
949
+ @register_model
950
+ def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
951
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
952
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
953
+ """
954
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
955
+ model = _create_vision_transformer(
956
+ 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
957
+ return model
958
+
959
+
960
+ @register_model
961
+ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
962
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
963
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
964
+ """
965
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
966
+ model = _create_vision_transformer(
967
+ 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
968
+ return model
969
+
970
+
971
+ @register_model
972
+ def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
973
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
974
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
975
+ """
976
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
977
+ model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
978
+ return model
979
+
980
+
981
+ @register_model
982
+ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
983
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
984
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
985
+ """
986
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
987
+ model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
988
+ return model
989
+
990
+
model/__pycache__/CoordAttention.cpython-38.pyc ADDED
Binary file (3.58 kB). View file
 
model/__pycache__/Vision_Transformer_with_mask.cpython-38.pyc ADDED
Binary file (37.2 kB). View file
 
model/__pycache__/features.cpython-38.pyc ADDED
Binary file (12.4 kB). View file
 
model/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (14.8 kB). View file
 
model/__pycache__/hub.cpython-38.pyc ADDED
Binary file (3.45 kB). View file
 
model/__pycache__/registry.cpython-38.pyc ADDED
Binary file (4.77 kB). View file
 
model/features.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Feature Extraction Helpers
2
+
3
+ A collection of classes, functions, modules to help extract features from models
4
+ and provide a common interface for describing them.
5
+
6
+ The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
7
+ https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
8
+
9
+ Hacked together by / Copyright 2020 Ross Wightman
10
+ """
11
+ from collections import OrderedDict, defaultdict
12
+ from copy import deepcopy
13
+ from functools import partial
14
+ from typing import Dict, List, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class FeatureInfo:
21
+
22
+ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
23
+ prev_reduction = 1
24
+ for fi in feature_info:
25
+ # sanity check the mandatory fields, there may be additional fields depending on the model
26
+ assert 'num_chs' in fi and fi['num_chs'] > 0
27
+ assert 'reduction' in fi and fi['reduction'] >= prev_reduction
28
+ prev_reduction = fi['reduction']
29
+ assert 'module' in fi
30
+ self.out_indices = out_indices
31
+ self.info = feature_info
32
+
33
+ def from_other(self, out_indices: Tuple[int]):
34
+ return FeatureInfo(deepcopy(self.info), out_indices)
35
+
36
+ def get(self, key, idx=None):
37
+ """ Get value by key at specified index (indices)
38
+ if idx == None, returns value for key at each output index
39
+ if idx is an integer, return value for that feature module index (ignoring output indices)
40
+ if idx is a list/tupple, return value for each module index (ignoring output indices)
41
+ """
42
+ if idx is None:
43
+ return [self.info[i][key] for i in self.out_indices]
44
+ if isinstance(idx, (tuple, list)):
45
+ return [self.info[i][key] for i in idx]
46
+ else:
47
+ return self.info[idx][key]
48
+
49
+ def get_dicts(self, keys=None, idx=None):
50
+ """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
51
+ """
52
+ if idx is None:
53
+ if keys is None:
54
+ return [self.info[i] for i in self.out_indices]
55
+ else:
56
+ return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
57
+ if isinstance(idx, (tuple, list)):
58
+ return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
59
+ else:
60
+ return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
61
+
62
+ def channels(self, idx=None):
63
+ """ feature channels accessor
64
+ """
65
+ return self.get('num_chs', idx)
66
+
67
+ def reduction(self, idx=None):
68
+ """ feature reduction (output stride) accessor
69
+ """
70
+ return self.get('reduction', idx)
71
+
72
+ def module_name(self, idx=None):
73
+ """ feature module name accessor
74
+ """
75
+ return self.get('module', idx)
76
+
77
+ def __getitem__(self, item):
78
+ return self.info[item]
79
+
80
+ def __len__(self):
81
+ return len(self.info)
82
+
83
+
84
+ class FeatureHooks:
85
+ """ Feature Hook Helper
86
+
87
+ This module helps with the setup and extraction of hooks for extracting features from
88
+ internal nodes in a model by node name. This works quite well in eager Python but needs
89
+ redesign for torcscript.
90
+ """
91
+
92
+ def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
93
+ # setup feature hooks
94
+ modules = {k: v for k, v in named_modules}
95
+ for i, h in enumerate(hooks):
96
+ hook_name = h['module']
97
+ m = modules[hook_name]
98
+ hook_id = out_map[i] if out_map else hook_name
99
+ hook_fn = partial(self._collect_output_hook, hook_id)
100
+ hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
101
+ if hook_type == 'forward_pre':
102
+ m.register_forward_pre_hook(hook_fn)
103
+ elif hook_type == 'forward':
104
+ m.register_forward_hook(hook_fn)
105
+ else:
106
+ assert False, "Unsupported hook type"
107
+ self._feature_outputs = defaultdict(OrderedDict)
108
+
109
+ def _collect_output_hook(self, hook_id, *args):
110
+ x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
111
+ if isinstance(x, tuple):
112
+ x = x[0] # unwrap input tuple
113
+ self._feature_outputs[x.device][hook_id] = x
114
+
115
+ def get_output(self, device) -> Dict[str, torch.tensor]:
116
+ output = self._feature_outputs[device]
117
+ self._feature_outputs[device] = OrderedDict() # clear after reading
118
+ return output
119
+
120
+
121
+ def _module_list(module, flatten_sequential=False):
122
+ # a yield/iter would be better for this but wouldn't be compatible with torchscript
123
+ ml = []
124
+ for name, module in module.named_children():
125
+ if flatten_sequential and isinstance(module, nn.Sequential):
126
+ # first level of Sequential containers is flattened into containing model
127
+ for child_name, child_module in module.named_children():
128
+ combined = [name, child_name]
129
+ ml.append(('_'.join(combined), '.'.join(combined), child_module))
130
+ else:
131
+ ml.append((name, name, module))
132
+ return ml
133
+
134
+
135
+ def _get_feature_info(net, out_indices):
136
+ feature_info = getattr(net, 'feature_info')
137
+ if isinstance(feature_info, FeatureInfo):
138
+ return feature_info.from_other(out_indices)
139
+ elif isinstance(feature_info, (list, tuple)):
140
+ return FeatureInfo(net.feature_info, out_indices)
141
+ else:
142
+ assert False, "Provided feature_info is not valid"
143
+
144
+
145
+ def _get_return_layers(feature_info, out_map):
146
+ module_names = feature_info.module_name()
147
+ return_layers = {}
148
+ for i, name in enumerate(module_names):
149
+ return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
150
+ return return_layers
151
+
152
+
153
+ class FeatureDictNet(nn.ModuleDict):
154
+ """ Feature extractor with OrderedDict return
155
+
156
+ Wrap a model and extract features as specified by the out indices, the network is
157
+ partially re-built from contained modules.
158
+
159
+ There is a strong assumption that the modules have been registered into the model in the same
160
+ order as they are used. There should be no reuse of the same nn.Module more than once, including
161
+ trivial modules like `self.relu = nn.ReLU`.
162
+
163
+ Only submodules that are directly assigned to the model class (`model.feature1`) or at most
164
+ one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
165
+ All Sequential containers that are directly assigned to the original model will have their
166
+ modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
167
+
168
+ Arguments:
169
+ model (nn.Module): model from which we will extract the features
170
+ out_indices (tuple[int]): model output indices to extract features for
171
+ out_map (sequence): list or tuple specifying desired return id for each out index,
172
+ otherwise str(index) is used
173
+ feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
174
+ vs select element [0]
175
+ flatten_sequential (bool): whether to flatten sequential modules assigned to model
176
+ """
177
+ def __init__(
178
+ self, model,
179
+ out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
180
+ super(FeatureDictNet, self).__init__()
181
+ self.feature_info = _get_feature_info(model, out_indices)
182
+ self.concat = feature_concat
183
+ self.return_layers = {}
184
+ return_layers = _get_return_layers(self.feature_info, out_map)
185
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
186
+ remaining = set(return_layers.keys())
187
+ layers = OrderedDict()
188
+ for new_name, old_name, module in modules:
189
+ layers[new_name] = module
190
+ if old_name in remaining:
191
+ # return id has to be consistently str type for torchscript
192
+ self.return_layers[new_name] = str(return_layers[old_name])
193
+ remaining.remove(old_name)
194
+ if not remaining:
195
+ break
196
+ assert not remaining and len(self.return_layers) == len(return_layers), \
197
+ f'Return layers ({remaining}) are not present in model'
198
+ self.update(layers)
199
+
200
+ def _collect(self, x) -> (Dict[str, torch.Tensor]):
201
+ out = OrderedDict()
202
+ for name, module in self.items():
203
+ x = module(x)
204
+ if name in self.return_layers:
205
+ out_id = self.return_layers[name]
206
+ if isinstance(x, (tuple, list)):
207
+ # If model tap is a tuple or list, concat or select first element
208
+ # FIXME this may need to be more generic / flexible for some nets
209
+ out[out_id] = torch.cat(x, 1) if self.concat else x[0]
210
+ else:
211
+ out[out_id] = x
212
+ return out
213
+
214
+ def forward(self, x) -> Dict[str, torch.Tensor]:
215
+ return self._collect(x)
216
+
217
+
218
+ class FeatureListNet(FeatureDictNet):
219
+ """ Feature extractor with list return
220
+
221
+ See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
222
+ In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
223
+ """
224
+ def __init__(
225
+ self, model,
226
+ out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
227
+ super(FeatureListNet, self).__init__(
228
+ model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
229
+ flatten_sequential=flatten_sequential)
230
+
231
+ def forward(self, x) -> (List[torch.Tensor]):
232
+ return list(self._collect(x).values())
233
+
234
+
235
+ class FeatureHookNet(nn.ModuleDict):
236
+ """ FeatureHookNet
237
+
238
+ Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
239
+
240
+ If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
241
+ network in any way.
242
+
243
+ If `no_rewrite` is False, the model will be re-written as in the
244
+ FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
245
+
246
+ FIXME this does not currently work with Torchscript, see FeatureHooks class
247
+ """
248
+ def __init__(
249
+ self, model,
250
+ out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
251
+ feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
252
+ super(FeatureHookNet, self).__init__()
253
+ assert not torch.jit.is_scripting()
254
+ self.feature_info = _get_feature_info(model, out_indices)
255
+ self.out_as_dict = out_as_dict
256
+ layers = OrderedDict()
257
+ hooks = []
258
+ if no_rewrite:
259
+ assert not flatten_sequential
260
+ if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
261
+ model.reset_classifier(0)
262
+ layers['body'] = model
263
+ hooks.extend(self.feature_info.get_dicts())
264
+ else:
265
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
266
+ remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
267
+ for f in self.feature_info.get_dicts()}
268
+ for new_name, old_name, module in modules:
269
+ layers[new_name] = module
270
+ for fn, fm in module.named_modules(prefix=old_name):
271
+ if fn in remaining:
272
+ hooks.append(dict(module=fn, hook_type=remaining[fn]))
273
+ del remaining[fn]
274
+ if not remaining:
275
+ break
276
+ assert not remaining, f'Return layers ({remaining}) are not present in model'
277
+ self.update(layers)
278
+ self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
279
+
280
+ def forward(self, x):
281
+ for name, module in self.items():
282
+ x = module(x)
283
+ out = self.hooks.get_output(x.device)
284
+ return out if self.out_as_dict else list(out.values())
model/helpers.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Model creation / weight loading / state_dict helpers
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import logging
6
+ import os
7
+ import math
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from typing import Any, Callable, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+
16
+ from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
17
+ from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
18
+ from .layers import Conv2dSame, Linear
19
+
20
+
21
+ _logger = logging.getLogger(__name__)
22
+
23
+
24
+ def load_state_dict(checkpoint_path, use_ema=False):
25
+ if checkpoint_path and os.path.isfile(checkpoint_path):
26
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
27
+ state_dict_key = 'state_dict'
28
+ if isinstance(checkpoint, dict):
29
+ if use_ema and 'state_dict_ema' in checkpoint:
30
+ state_dict_key = 'state_dict_ema'
31
+ if state_dict_key and state_dict_key in checkpoint:
32
+ new_state_dict = OrderedDict()
33
+ for k, v in checkpoint[state_dict_key].items():
34
+ # strip `module.` prefix
35
+ name = k[7:] if k.startswith('module') else k
36
+ new_state_dict[name] = v
37
+ state_dict = new_state_dict
38
+ else:
39
+ state_dict = checkpoint
40
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
41
+ return state_dict
42
+ else:
43
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
44
+ raise FileNotFoundError()
45
+
46
+
47
+ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
48
+ if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
49
+ # numpy checkpoint, try to load via model specific load_pretrained fn
50
+ if hasattr(model, 'load_pretrained'):
51
+ model.load_pretrained(checkpoint_path)
52
+ else:
53
+ raise NotImplementedError('Model cannot load numpy checkpoint')
54
+ return
55
+ state_dict = load_state_dict(checkpoint_path, use_ema)
56
+ model.load_state_dict(state_dict, strict=strict)
57
+
58
+
59
+ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
60
+ resume_epoch = None
61
+ if os.path.isfile(checkpoint_path):
62
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
63
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
64
+ if log_info:
65
+ _logger.info('Restoring model state from checkpoint...')
66
+ new_state_dict = OrderedDict()
67
+ for k, v in checkpoint['state_dict'].items():
68
+ name = k[7:] if k.startswith('module') else k
69
+ new_state_dict[name] = v
70
+ model.load_state_dict(new_state_dict)
71
+
72
+ if optimizer is not None and 'optimizer' in checkpoint:
73
+ if log_info:
74
+ _logger.info('Restoring optimizer state from checkpoint...')
75
+ optimizer.load_state_dict(checkpoint['optimizer'])
76
+
77
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
78
+ if log_info:
79
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
80
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
81
+
82
+ if 'epoch' in checkpoint:
83
+ resume_epoch = checkpoint['epoch']
84
+ if 'version' in checkpoint and checkpoint['version'] > 1:
85
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
86
+
87
+ if log_info:
88
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
89
+ else:
90
+ model.load_state_dict(checkpoint)
91
+ if log_info:
92
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
93
+ return resume_epoch
94
+ else:
95
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
96
+ raise FileNotFoundError()
97
+
98
+
99
+ def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False):
100
+ r"""Loads a custom (read non .pth) weight file
101
+
102
+ Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
103
+ a passed in custom load fun, or the `load_pretrained` model member fn.
104
+
105
+ If the object is already present in `model_dir`, it's deserialized and returned.
106
+ The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
107
+ `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
108
+
109
+ Args:
110
+ model: The instantiated model to load weights into
111
+ default_cfg (dict): Default pretrained model cfg
112
+ load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
113
+ 'laod_pretrained' on the model will be called if it exists
114
+ progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
115
+ check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
116
+ ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
117
+ digits of the SHA256 hash of the contents of the file. The hash is used to
118
+ ensure unique names and to verify the contents of the file. Default: False
119
+ """
120
+ default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
121
+ pretrained_url = default_cfg.get('url', None)
122
+ if not pretrained_url:
123
+ _logger.warning("No pretrained weights exist for this model. Using random initialization.")
124
+ return
125
+ cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress)
126
+
127
+ if load_fn is not None:
128
+ load_fn(model, cached_file)
129
+ elif hasattr(model, 'load_pretrained'):
130
+ model.load_pretrained(cached_file)
131
+ else:
132
+ _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
133
+
134
+
135
+ def adapt_input_conv(in_chans, conv_weight):
136
+ conv_type = conv_weight.dtype
137
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
138
+ O, I, J, K = conv_weight.shape
139
+ if in_chans == 1:
140
+ if I > 3:
141
+ assert conv_weight.shape[1] % 3 == 0
142
+ # For models with space2depth stems
143
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
144
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
145
+ else:
146
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
147
+ elif in_chans != 3:
148
+ if I != 3:
149
+ raise NotImplementedError('Weight format not supported by conversion.')
150
+ else:
151
+ # NOTE this strategy should be better than random init, but there could be other combinations of
152
+ # the original RGB input layer weights that'd work better for specific cases.
153
+ repeat = int(math.ceil(in_chans / 3))
154
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
155
+ conv_weight *= (3 / float(in_chans))
156
+ conv_weight = conv_weight.to(conv_type)
157
+ return conv_weight
158
+
159
+
160
+ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
161
+ """ Load pretrained checkpoint
162
+
163
+ Args:
164
+ model (nn.Module) : PyTorch model module
165
+ default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
166
+ num_classes (int): num_classes for model
167
+ in_chans (int): in_chans for model
168
+ filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
169
+ strict (bool): strict load of checkpoint
170
+ progress (bool): enable progress bar for weight download
171
+
172
+ """
173
+ default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
174
+ pretrained_url = default_cfg.get('url', None)
175
+ hf_hub_id = default_cfg.get('hf_hub', None)
176
+ if not pretrained_url and not hf_hub_id:
177
+ _logger.warning("No pretrained weights exist for this model. Using random initialization.")
178
+ return
179
+ if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
180
+ _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
181
+ state_dict = load_state_dict_from_hf(hf_hub_id)
182
+ else:
183
+ _logger.info(f'Loading pretrained weights from url ({pretrained_url})')
184
+ state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
185
+ if filter_fn is not None:
186
+ # for backwards compat with filter fn that take one arg, try one first, the two
187
+ try:
188
+ state_dict = filter_fn(state_dict)
189
+ except TypeError:
190
+ state_dict = filter_fn(state_dict, model)
191
+
192
+ input_convs = default_cfg.get('first_conv', None)
193
+ if input_convs is not None and in_chans != 3:
194
+ if isinstance(input_convs, str):
195
+ input_convs = (input_convs,)
196
+ for input_conv_name in input_convs:
197
+ weight_name = input_conv_name + '.weight'
198
+ try:
199
+ state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
200
+ _logger.info(
201
+ f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
202
+ except NotImplementedError as e:
203
+ del state_dict[weight_name]
204
+ strict = False
205
+ _logger.warning(
206
+ f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
207
+
208
+ classifiers = default_cfg.get('classifier', None)
209
+ label_offset = default_cfg.get('label_offset', 0)
210
+ if classifiers is not None:
211
+ if isinstance(classifiers, str):
212
+ classifiers = (classifiers,)
213
+ if num_classes != default_cfg['num_classes']:
214
+ for classifier_name in classifiers:
215
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
216
+ del state_dict[classifier_name + '.weight']
217
+ del state_dict[classifier_name + '.bias']
218
+ strict = False
219
+ elif label_offset > 0:
220
+ for classifier_name in classifiers:
221
+ # special case for pretrained weights with an extra background class in pretrained weights
222
+ classifier_weight = state_dict[classifier_name + '.weight']
223
+ state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
224
+ classifier_bias = state_dict[classifier_name + '.bias']
225
+ state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
226
+
227
+ model.load_state_dict(state_dict, strict=strict)
228
+
229
+
230
+ def extract_layer(model, layer):
231
+ layer = layer.split('.')
232
+ module = model
233
+ if hasattr(model, 'module') and layer[0] != 'module':
234
+ module = model.module
235
+ if not hasattr(model, 'module') and layer[0] == 'module':
236
+ layer = layer[1:]
237
+ for l in layer:
238
+ if hasattr(module, l):
239
+ if not l.isdigit():
240
+ module = getattr(module, l)
241
+ else:
242
+ module = module[int(l)]
243
+ else:
244
+ return module
245
+ return module
246
+
247
+
248
+ def set_layer(model, layer, val):
249
+ layer = layer.split('.')
250
+ module = model
251
+ if hasattr(model, 'module') and layer[0] != 'module':
252
+ module = model.module
253
+ lst_index = 0
254
+ module2 = module
255
+ for l in layer:
256
+ if hasattr(module2, l):
257
+ if not l.isdigit():
258
+ module2 = getattr(module2, l)
259
+ else:
260
+ module2 = module2[int(l)]
261
+ lst_index += 1
262
+ lst_index -= 1
263
+ for l in layer[:lst_index]:
264
+ if not l.isdigit():
265
+ module = getattr(module, l)
266
+ else:
267
+ module = module[int(l)]
268
+ l = layer[lst_index]
269
+ setattr(module, l, val)
270
+
271
+
272
+ def adapt_model_from_string(parent_module, model_string):
273
+ separator = '***'
274
+ state_dict = {}
275
+ lst_shape = model_string.split(separator)
276
+ for k in lst_shape:
277
+ k = k.split(':')
278
+ key = k[0]
279
+ shape = k[1][1:-1].split(',')
280
+ if shape[0] != '':
281
+ state_dict[key] = [int(i) for i in shape]
282
+
283
+ new_module = deepcopy(parent_module)
284
+ for n, m in parent_module.named_modules():
285
+ old_module = extract_layer(parent_module, n)
286
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
287
+ if isinstance(old_module, Conv2dSame):
288
+ conv = Conv2dSame
289
+ else:
290
+ conv = nn.Conv2d
291
+ s = state_dict[n + '.weight']
292
+ in_channels = s[1]
293
+ out_channels = s[0]
294
+ g = 1
295
+ if old_module.groups > 1:
296
+ in_channels = out_channels
297
+ g = in_channels
298
+ new_conv = conv(
299
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
300
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
301
+ groups=g, stride=old_module.stride)
302
+ set_layer(new_module, n, new_conv)
303
+ if isinstance(old_module, nn.BatchNorm2d):
304
+ new_bn = nn.BatchNorm2d(
305
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
306
+ affine=old_module.affine, track_running_stats=True)
307
+ set_layer(new_module, n, new_bn)
308
+ if isinstance(old_module, nn.Linear):
309
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
310
+ num_features = state_dict[n + '.weight'][1]
311
+ new_fc = Linear(
312
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
313
+ set_layer(new_module, n, new_fc)
314
+ if hasattr(new_module, 'num_features'):
315
+ new_module.num_features = num_features
316
+ new_module.eval()
317
+ parent_module.eval()
318
+
319
+ return new_module
320
+
321
+
322
+ def adapt_model_from_file(parent_module, model_variant):
323
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
324
+ with open(adapt_file, 'r') as f:
325
+ return adapt_model_from_string(parent_module, f.read().strip())
326
+
327
+
328
+ def default_cfg_for_features(default_cfg):
329
+ default_cfg = deepcopy(default_cfg)
330
+ # remove default pretrained cfg fields that don't have much relevance for feature backbone
331
+ to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
332
+ for tr in to_remove:
333
+ default_cfg.pop(tr, None)
334
+ return default_cfg
335
+
336
+
337
+ def overlay_external_default_cfg(default_cfg, kwargs):
338
+ """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
339
+ """
340
+ external_default_cfg = kwargs.pop('external_default_cfg', None)
341
+ if external_default_cfg:
342
+ default_cfg.pop('url', None) # url should come from external cfg
343
+ default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
344
+ default_cfg.update(external_default_cfg)
345
+
346
+
347
+ def set_default_kwargs(kwargs, names, default_cfg):
348
+ for n in names:
349
+ # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
350
+ # default_cfg has one input_size=(C, H ,W) entry
351
+ if n == 'img_size':
352
+ input_size = default_cfg.get('input_size', None)
353
+ if input_size is not None:
354
+ assert len(input_size) == 3
355
+ kwargs.setdefault(n, input_size[-2:])
356
+ elif n == 'in_chans':
357
+ input_size = default_cfg.get('input_size', None)
358
+ if input_size is not None:
359
+ assert len(input_size) == 3
360
+ kwargs.setdefault(n, input_size[0])
361
+ else:
362
+ default_val = default_cfg.get(n, None)
363
+ if default_val is not None:
364
+ kwargs.setdefault(n, default_cfg[n])
365
+
366
+
367
+ def filter_kwargs(kwargs, names):
368
+ if not kwargs or not names:
369
+ return
370
+ for n in names:
371
+ kwargs.pop(n, None)
372
+
373
+
374
+ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
375
+ """ Update the default_cfg and kwargs before passing to model
376
+
377
+ FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
378
+ could/should be replaced by an improved configuration mechanism
379
+
380
+ Args:
381
+ default_cfg: input default_cfg (updated in-place)
382
+ kwargs: keyword args passed to model build fn (updated in-place)
383
+ kwargs_filter: keyword arg keys that must be removed before model __init__
384
+ """
385
+ # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
386
+ overlay_external_default_cfg(default_cfg, kwargs)
387
+ # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
388
+ default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
389
+ if default_cfg.get('fixed_input_size', False):
390
+ # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
391
+ default_kwarg_names += ('img_size',)
392
+ set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
393
+ # Filter keyword args for task specific model variants (some 'features only' models, etc.)
394
+ filter_kwargs(kwargs, names=kwargs_filter)
395
+
396
+
397
+ def build_model_with_cfg(
398
+ model_cls: Callable,
399
+ variant: str,
400
+ pretrained: bool,
401
+ default_cfg: dict,
402
+ model_cfg: Optional[Any] = None,
403
+ feature_cfg: Optional[dict] = None,
404
+ pretrained_strict: bool = True,
405
+ pretrained_filter_fn: Optional[Callable] = None,
406
+ pretrained_custom_load: bool = False,
407
+ kwargs_filter: Optional[Tuple[str]] = None,
408
+ **kwargs):
409
+ """ Build model with specified default_cfg and optional model_cfg
410
+
411
+ This helper fn aids in the construction of a model including:
412
+ * handling default_cfg and associated pretained weight loading
413
+ * passing through optional model_cfg for models with config based arch spec
414
+ * features_only model adaptation
415
+ * pruning config / model adaptation
416
+
417
+ Args:
418
+ model_cls (nn.Module): model class
419
+ variant (str): model variant name
420
+ pretrained (bool): load pretrained weights
421
+ default_cfg (dict): model's default pretrained/task config
422
+ model_cfg (Optional[Dict]): model's architecture config
423
+ feature_cfg (Optional[Dict]: feature extraction adapter config
424
+ pretrained_strict (bool): load pretrained weights strictly
425
+ pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
426
+ pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
427
+ kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
428
+ **kwargs: model args passed through to model __init__
429
+ """
430
+ pruned = kwargs.pop('pruned', False)
431
+ features = False
432
+ feature_cfg = feature_cfg or {}
433
+ default_cfg = deepcopy(default_cfg) if default_cfg else {}
434
+ update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
435
+ default_cfg.setdefault('architecture', variant)
436
+
437
+ # Setup for feature extraction wrapper done at end of this fn
438
+ if kwargs.pop('features_only', False):
439
+ features = True
440
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
441
+ if 'out_indices' in kwargs:
442
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
443
+
444
+ # Build the model
445
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
446
+ model.default_cfg = default_cfg
447
+
448
+ if pruned:
449
+ model = adapt_model_from_file(model, variant)
450
+
451
+ # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
452
+ num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
453
+ if pretrained:
454
+ if pretrained_custom_load:
455
+ load_custom_pretrained(model)
456
+ else:
457
+ load_pretrained(
458
+ model,
459
+ num_classes=num_classes_pretrained,
460
+ in_chans=kwargs.get('in_chans', 3),
461
+ filter_fn=pretrained_filter_fn,
462
+ strict=pretrained_strict)
463
+
464
+ # Wrap the model in a feature extraction module if enabled
465
+ if features:
466
+ feature_cls = FeatureListNet
467
+ if 'feature_cls' in feature_cfg:
468
+ feature_cls = feature_cfg.pop('feature_cls')
469
+ if isinstance(feature_cls, str):
470
+ feature_cls = feature_cls.lower()
471
+ if 'hook' in feature_cls:
472
+ feature_cls = FeatureHookNet
473
+ else:
474
+ assert False, f'Unknown feature class {feature_cls}'
475
+ model = feature_cls(model, **feature_cfg)
476
+ model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
477
+
478
+ return model
479
+
480
+
481
+ def model_parameters(model, exclude_head=False):
482
+ if exclude_head:
483
+ # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
484
+ return [p for p in model.parameters()][:-2]
485
+ else:
486
+ return model.parameters()
487
+
488
+
489
+ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
490
+ if not depth_first and include_root:
491
+ fn(module=module, name=name)
492
+ for child_name, child_module in module.named_children():
493
+ child_name = '.'.join((name, child_name)) if name else child_name
494
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
495
+ if depth_first and include_root:
496
+ fn(module=module, name=name)
497
+ return module
498
+
499
+
500
+ def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
501
+ if not depth_first and include_root:
502
+ yield name, module
503
+ for child_name, child_module in module.named_children():
504
+ child_name = '.'.join((name, child_name)) if name else child_name
505
+ yield from named_modules(
506
+ module=child_module, name=child_name, depth_first=depth_first, include_root=True)
507
+ if depth_first and include_root:
508
+ yield name, module
model/hub.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from functools import partial
5
+ from typing import Union, Optional
6
+
7
+ import torch
8
+ from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
9
+ try:
10
+ from torch.hub import get_dir
11
+ except ImportError:
12
+ from torch.hub import _get_torch_home as get_dir
13
+
14
+ from timm import __version__
15
+ try:
16
+ from huggingface_hub import hf_hub_url
17
+ from huggingface_hub import cached_download
18
+ cached_download = partial(cached_download, library_name="timm", library_version=__version__)
19
+ except ImportError:
20
+ hf_hub_url = None
21
+ cached_download = None
22
+
23
+ _logger = logging.getLogger(__name__)
24
+
25
+
26
+ def get_cache_dir(child_dir=''):
27
+ """
28
+ Returns the location of the directory where models are cached (and creates it if necessary).
29
+ """
30
+ # Issue warning to move data if old env is set
31
+ if os.getenv('TORCH_MODEL_ZOO'):
32
+ _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
33
+
34
+ hub_dir = get_dir()
35
+ child_dir = () if not child_dir else (child_dir,)
36
+ model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
37
+ os.makedirs(model_dir, exist_ok=True)
38
+ return model_dir
39
+
40
+
41
+ def download_cached_file(url, check_hash=True, progress=False):
42
+ parts = urlparse(url)
43
+ filename = os.path.basename(parts.path)
44
+ cached_file = os.path.join(get_cache_dir(), filename)
45
+ if not os.path.exists(cached_file):
46
+ _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
47
+ hash_prefix = None
48
+ if check_hash:
49
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
50
+ hash_prefix = r.group(1) if r else None
51
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
52
+ return cached_file
53
+
54
+
55
+ def has_hf_hub(necessary=False):
56
+ if hf_hub_url is None and necessary:
57
+ # if no HF Hub module installed and it is necessary to continue, raise error
58
+ raise RuntimeError(
59
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
60
+ return hf_hub_url is not None
61
+
62
+
63
+ def hf_split(hf_id):
64
+ rev_split = hf_id.split('@')
65
+ assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
66
+ hf_model_id = rev_split[0]
67
+ hf_revision = rev_split[-1] if len(rev_split) > 1 else None
68
+ return hf_model_id, hf_revision
69
+
70
+
71
+ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
72
+ with open(json_file, "r", encoding="utf-8") as reader:
73
+ text = reader.read()
74
+ return json.loads(text)
75
+
76
+
77
+ def _download_from_hf(model_id: str, filename: str):
78
+ hf_model_id, hf_revision = hf_split(model_id)
79
+ url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
80
+ return cached_download(url, cache_dir=get_cache_dir('hf'))
81
+
82
+
83
+ def load_model_config_from_hf(model_id: str):
84
+ assert has_hf_hub(True)
85
+ cached_file = _download_from_hf(model_id, 'config.json')
86
+ default_cfg = load_cfg_from_json(cached_file)
87
+ default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
88
+ model_name = default_cfg.get('architecture')
89
+ return default_cfg, model_name
90
+
91
+
92
+ def load_state_dict_from_hf(model_id: str):
93
+ assert has_hf_hub(True)
94
+ cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
95
+ state_dict = torch.load(cached_file, map_location='cpu')
96
+ return state_dict
model/layers/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .activations import *
2
+ from .adaptive_avgmax_pool import \
3
+ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4
+ from .blur_pool import BlurPool2d
5
+ from .classifier import ClassifierHead, create_classifier
6
+ from .cond_conv2d import CondConv2d, get_condconv_initializer
7
+ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
8
+ set_layer_config
9
+ from .conv2d_same import Conv2dSame, conv2d_same
10
+ from .conv_bn_act import ConvBnAct
11
+ from .create_act import create_act_layer, get_act_layer, get_act_fn
12
+ from .create_attn import get_attn, create_attn
13
+ from .create_conv2d import create_conv2d
14
+ from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
15
+ from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
16
+ from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
17
+ from .evo_norm import EvoNormBatch2d, EvoNormSample2d
18
+ from .gather_excite import GatherExcite
19
+ from .global_context import GlobalContext
20
+ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
21
+ from .inplace_abn import InplaceAbn
22
+ from .involution import Involution
23
+ from .linear import Linear
24
+ from .mixed_conv2d import MixedConv2d
25
+ from .mlp import Mlp, GluMlp, GatedMlp
26
+ from .non_local_attn import NonLocalAttn, BatNonLocalAttn
27
+ from .norm import GroupNorm, LayerNorm2d
28
+ from .norm_act import BatchNormAct2d, GroupNormAct
29
+ from .padding import get_padding, get_same_padding, pad_same
30
+ from .patch_embed import PatchEmbed
31
+ from .pool2d_same import AvgPool2dSame, create_pool2d
32
+ from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
33
+ from .selective_kernel import SelectiveKernel
34
+ from .separable_conv import SeparableConv2d, SeparableConvBnAct
35
+ from .space_to_depth import SpaceToDepthModule
36
+ from .split_attn import SplitAttn
37
+ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
38
+ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
39
+ from .test_time_pool import TestTimePoolHead, apply_test_time_pool
40
+ from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
model/layers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (3.1 kB). View file
 
model/layers/__pycache__/activations.cpython-38.pyc ADDED
Binary file (6.66 kB). View file
 
model/layers/__pycache__/activations_jit.cpython-38.pyc ADDED
Binary file (4.14 kB). View file
 
model/layers/__pycache__/activations_me.cpython-38.pyc ADDED
Binary file (8.92 kB). View file
 
model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc ADDED
Binary file (4.78 kB). View file
 
model/layers/__pycache__/blur_pool.cpython-38.pyc ADDED
Binary file (2.1 kB). View file
 
model/layers/__pycache__/bottleneck_attn.cpython-38.pyc ADDED
Binary file (4.76 kB). View file
 
model/layers/__pycache__/cbam.cpython-38.pyc ADDED
Binary file (5.17 kB). View file
 
model/layers/__pycache__/classifier.cpython-38.pyc ADDED
Binary file (2.24 kB). View file
 
model/layers/__pycache__/cond_conv2d.cpython-38.pyc ADDED
Binary file (3.84 kB). View file
 
model/layers/__pycache__/config.cpython-38.pyc ADDED
Binary file (3.44 kB). View file
 
model/layers/__pycache__/conv2d_same.cpython-38.pyc ADDED
Binary file (1.96 kB). View file
 
model/layers/__pycache__/conv_bn_act.cpython-38.pyc ADDED
Binary file (1.66 kB). View file
 
model/layers/__pycache__/create_act.cpython-38.pyc ADDED
Binary file (3.65 kB). View file
 
model/layers/__pycache__/create_attn.cpython-38.pyc ADDED
Binary file (2.09 kB). View file
 
model/layers/__pycache__/create_conv2d.cpython-38.pyc ADDED
Binary file (1.09 kB). View file
 
model/layers/__pycache__/create_norm_act.cpython-38.pyc ADDED
Binary file (2.33 kB). View file
 
model/layers/__pycache__/drop.cpython-38.pyc ADDED
Binary file (5.74 kB). View file
 
model/layers/__pycache__/eca.cpython-38.pyc ADDED
Binary file (6.15 kB). View file
 
model/layers/__pycache__/evo_norm.cpython-38.pyc ADDED
Binary file (3.39 kB). View file
 
model/layers/__pycache__/gather_excite.cpython-38.pyc ADDED
Binary file (3.11 kB). View file
 
model/layers/__pycache__/global_context.cpython-38.pyc ADDED
Binary file (2.43 kB). View file
 
model/layers/__pycache__/halo_attn.cpython-38.pyc ADDED
Binary file (5.59 kB). View file
 
model/layers/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
model/layers/__pycache__/inplace_abn.cpython-38.pyc ADDED
Binary file (3.18 kB). View file
 
model/layers/__pycache__/involution.cpython-38.pyc ADDED
Binary file (1.83 kB). View file
 
model/layers/__pycache__/lambda_layer.cpython-38.pyc ADDED
Binary file (3.01 kB). View file
 
model/layers/__pycache__/linear.cpython-38.pyc ADDED
Binary file (1.08 kB). View file
 
model/layers/__pycache__/mixed_conv2d.cpython-38.pyc ADDED
Binary file (2.29 kB). View file
 
model/layers/__pycache__/mlp.cpython-38.pyc ADDED
Binary file (3.81 kB). View file
 
model/layers/__pycache__/non_local_attn.cpython-38.pyc ADDED
Binary file (5.64 kB). View file
 
model/layers/__pycache__/norm.cpython-38.pyc ADDED
Binary file (1.52 kB). View file
 
model/layers/__pycache__/norm_act.cpython-38.pyc ADDED
Binary file (3.07 kB). View file
 
model/layers/__pycache__/padding.cpython-38.pyc ADDED
Binary file (1.8 kB). View file
 
model/layers/__pycache__/patch_embed.cpython-38.pyc ADDED
Binary file (1.68 kB). View file
 
model/layers/__pycache__/pool2d_same.cpython-38.pyc ADDED
Binary file (3.12 kB). View file
 
model/layers/__pycache__/selective_kernel.cpython-38.pyc ADDED
Binary file (5.51 kB). View file
 
model/layers/__pycache__/separable_conv.cpython-38.pyc ADDED
Binary file (2.98 kB). View file