hank1996 commited on
Commit
d566fd1
·
1 Parent(s): 4e3f0a4

Update models/common.py

Browse files
Files changed (1) hide show
  1. models/common.py +2021 -0
models/common.py CHANGED
@@ -0,0 +1,2021 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from copy import copy
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import requests
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torchvision.ops import DeformConv2d
13
+ from PIL import Image
14
+ from torch.cuda import amp
15
+
16
+ from utils.datasets import letterbox
17
+ from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
18
+ from utils.plots import color_list, plot_one_box
19
+ from utils.torch_utils import time_synchronized
20
+
21
+
22
+ ##### basic ####
23
+
24
+ def autopad(k, p=None): # kernel, padding
25
+ # Pad to 'same'
26
+ if p is None:
27
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
28
+ return p
29
+
30
+
31
+ class MP(nn.Module):
32
+ def __init__(self, k=2):
33
+ super(MP, self).__init__()
34
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
35
+
36
+ def forward(self, x):
37
+ return self.m(x)
38
+
39
+
40
+ class SP(nn.Module):
41
+ def __init__(self, k=3, s=1):
42
+ super(SP, self).__init__()
43
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
44
+
45
+ def forward(self, x):
46
+ return self.m(x)
47
+
48
+
49
+ class ReOrg(nn.Module):
50
+ def __init__(self):
51
+ super(ReOrg, self).__init__()
52
+
53
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
54
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
55
+
56
+
57
+ class Concat(nn.Module):
58
+ def __init__(self, dimension=1):
59
+ super(Concat, self).__init__()
60
+ self.d = dimension
61
+
62
+ def forward(self, x):
63
+ return torch.cat(x, self.d)
64
+
65
+
66
+ class Chuncat(nn.Module):
67
+ def __init__(self, dimension=1):
68
+ super(Chuncat, self).__init__()
69
+ self.d = dimension
70
+
71
+ def forward(self, x):
72
+ x1 = []
73
+ x2 = []
74
+ for xi in x:
75
+ xi1, xi2 = xi.chunk(2, self.d)
76
+ x1.append(xi1)
77
+ x2.append(xi2)
78
+ return torch.cat(x1+x2, self.d)
79
+
80
+
81
+ class Shortcut(nn.Module):
82
+ def __init__(self, dimension=0):
83
+ super(Shortcut, self).__init__()
84
+ self.d = dimension
85
+
86
+ def forward(self, x):
87
+ return x[0]+x[1]
88
+
89
+
90
+ class Foldcut(nn.Module):
91
+ def __init__(self, dimension=0):
92
+ super(Foldcut, self).__init__()
93
+ self.d = dimension
94
+
95
+ def forward(self, x):
96
+ x1, x2 = x.chunk(2, self.d)
97
+ return x1+x2
98
+
99
+
100
+ class Conv(nn.Module):
101
+ # Standard convolution
102
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
103
+ super(Conv, self).__init__()
104
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
105
+ self.bn = nn.BatchNorm2d(c2)
106
+ self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
107
+
108
+ def forward(self, x):
109
+ return self.act(self.bn(self.conv(x)))
110
+
111
+ def fuseforward(self, x):
112
+ return self.act(self.conv(x))
113
+
114
+
115
+ class RobustConv(nn.Module):
116
+ # Robust convolution (use high kernel size 7-11 for: downsampling and other layers). Train for 300 - 450 epochs.
117
+ def __init__(self, c1, c2, k=7, s=1, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
118
+ super(RobustConv, self).__init__()
119
+ self.conv_dw = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
120
+ self.conv1x1 = nn.Conv2d(c1, c2, 1, 1, 0, groups=1, bias=True)
121
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
122
+
123
+ def forward(self, x):
124
+ x = x.to(memory_format=torch.channels_last)
125
+ x = self.conv1x1(self.conv_dw(x))
126
+ if self.gamma is not None:
127
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
128
+ return x
129
+
130
+
131
+ class RobustConv2(nn.Module):
132
+ # Robust convolution 2 (use [32, 5, 2] or [32, 7, 4] or [32, 11, 8] for one of the paths in CSP).
133
+ def __init__(self, c1, c2, k=7, s=4, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
134
+ super(RobustConv2, self).__init__()
135
+ self.conv_strided = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
136
+ self.conv_deconv = nn.ConvTranspose2d(in_channels=c1, out_channels=c2, kernel_size=s, stride=s,
137
+ padding=0, bias=True, dilation=1, groups=1
138
+ )
139
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
140
+
141
+ def forward(self, x):
142
+ x = self.conv_deconv(self.conv_strided(x))
143
+ if self.gamma is not None:
144
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
145
+ return x
146
+
147
+
148
+ def DWConv(c1, c2, k=1, s=1, act=True):
149
+ # Depthwise convolution
150
+ return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
151
+
152
+
153
+ class GhostConv(nn.Module):
154
+ # Ghost Convolution https://github.com/huawei-noah/ghostnet
155
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
156
+ super(GhostConv, self).__init__()
157
+ c_ = c2 // 2 # hidden channels
158
+ self.cv1 = Conv(c1, c_, k, s, None, g, act)
159
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
160
+
161
+ def forward(self, x):
162
+ y = self.cv1(x)
163
+ return torch.cat([y, self.cv2(y)], 1)
164
+
165
+
166
+ class Stem(nn.Module):
167
+ # Stem
168
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
169
+ super(Stem, self).__init__()
170
+ c_ = int(c2/2) # hidden channels
171
+ self.cv1 = Conv(c1, c_, 3, 2)
172
+ self.cv2 = Conv(c_, c_, 1, 1)
173
+ self.cv3 = Conv(c_, c_, 3, 2)
174
+ self.pool = torch.nn.MaxPool2d(2, stride=2)
175
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
176
+
177
+ def forward(self, x):
178
+ x = self.cv1(x)
179
+ return self.cv4(torch.cat((self.cv3(self.cv2(x)), self.pool(x)), dim=1))
180
+
181
+
182
+ class DownC(nn.Module):
183
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
184
+ def __init__(self, c1, c2, n=1, k=2):
185
+ super(DownC, self).__init__()
186
+ c_ = int(c1) # hidden channels
187
+ self.cv1 = Conv(c1, c_, 1, 1)
188
+ self.cv2 = Conv(c_, c2//2, 3, k)
189
+ self.cv3 = Conv(c1, c2//2, 1, 1)
190
+ self.mp = nn.MaxPool2d(kernel_size=k, stride=k)
191
+
192
+ def forward(self, x):
193
+ return torch.cat((self.cv2(self.cv1(x)), self.cv3(self.mp(x))), dim=1)
194
+
195
+
196
+ class SPP(nn.Module):
197
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
198
+ def __init__(self, c1, c2, k=(5, 9, 13)):
199
+ super(SPP, self).__init__()
200
+ c_ = c1 // 2 # hidden channels
201
+ self.cv1 = Conv(c1, c_, 1, 1)
202
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
203
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
204
+
205
+ def forward(self, x):
206
+ x = self.cv1(x)
207
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
208
+
209
+
210
+ class Bottleneck(nn.Module):
211
+ # Darknet bottleneck
212
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
213
+ super(Bottleneck, self).__init__()
214
+ c_ = int(c2 * e) # hidden channels
215
+ self.cv1 = Conv(c1, c_, 1, 1)
216
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
217
+ self.add = shortcut and c1 == c2
218
+
219
+ def forward(self, x):
220
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
221
+
222
+
223
+ class Res(nn.Module):
224
+ # ResNet bottleneck
225
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
226
+ super(Res, self).__init__()
227
+ c_ = int(c2 * e) # hidden channels
228
+ self.cv1 = Conv(c1, c_, 1, 1)
229
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
230
+ self.cv3 = Conv(c_, c2, 1, 1)
231
+ self.add = shortcut and c1 == c2
232
+
233
+ def forward(self, x):
234
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
235
+
236
+
237
+ class ResX(Res):
238
+ # ResNet bottleneck
239
+ def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
240
+ super().__init__(c1, c2, shortcu, g, e)
241
+ c_ = int(c2 * e) # hidden channels
242
+
243
+
244
+ class Ghost(nn.Module):
245
+ # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
246
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
247
+ super(Ghost, self).__init__()
248
+ c_ = c2 // 2
249
+ self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
250
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
251
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
252
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
253
+ Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
254
+
255
+ def forward(self, x):
256
+ return self.conv(x) + self.shortcut(x)
257
+
258
+ ##### end of basic #####
259
+
260
+
261
+ ##### cspnet #####
262
+
263
+ class SPPCSPC(nn.Module):
264
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
265
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
266
+ super(SPPCSPC, self).__init__()
267
+ c_ = int(2 * c2 * e) # hidden channels
268
+ self.cv1 = Conv(c1, c_, 1, 1)
269
+ self.cv2 = Conv(c1, c_, 1, 1)
270
+ self.cv3 = Conv(c_, c_, 3, 1)
271
+ self.cv4 = Conv(c_, c_, 1, 1)
272
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
273
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
274
+ self.cv6 = Conv(c_, c_, 3, 1)
275
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
276
+
277
+ def forward(self, x):
278
+ x1 = self.cv4(self.cv3(self.cv1(x)))
279
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
280
+ y2 = self.cv2(x)
281
+ return self.cv7(torch.cat((y1, y2), dim=1))
282
+
283
+ class GhostSPPCSPC(SPPCSPC):
284
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
285
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
286
+ super().__init__(c1, c2, n, shortcut, g, e, k)
287
+ c_ = int(2 * c2 * e) # hidden channels
288
+ self.cv1 = GhostConv(c1, c_, 1, 1)
289
+ self.cv2 = GhostConv(c1, c_, 1, 1)
290
+ self.cv3 = GhostConv(c_, c_, 3, 1)
291
+ self.cv4 = GhostConv(c_, c_, 1, 1)
292
+ self.cv5 = GhostConv(4 * c_, c_, 1, 1)
293
+ self.cv6 = GhostConv(c_, c_, 3, 1)
294
+ self.cv7 = GhostConv(2 * c_, c2, 1, 1)
295
+
296
+
297
+ class GhostStem(Stem):
298
+ # Stem
299
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
300
+ super().__init__(c1, c2, k, s, p, g, act)
301
+ c_ = int(c2/2) # hidden channels
302
+ self.cv1 = GhostConv(c1, c_, 3, 2)
303
+ self.cv2 = GhostConv(c_, c_, 1, 1)
304
+ self.cv3 = GhostConv(c_, c_, 3, 2)
305
+ self.cv4 = GhostConv(2 * c_, c2, 1, 1)
306
+
307
+
308
+ class BottleneckCSPA(nn.Module):
309
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
310
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
311
+ super(BottleneckCSPA, self).__init__()
312
+ c_ = int(c2 * e) # hidden channels
313
+ self.cv1 = Conv(c1, c_, 1, 1)
314
+ self.cv2 = Conv(c1, c_, 1, 1)
315
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
316
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
317
+
318
+ def forward(self, x):
319
+ y1 = self.m(self.cv1(x))
320
+ y2 = self.cv2(x)
321
+ return self.cv3(torch.cat((y1, y2), dim=1))
322
+
323
+
324
+ class BottleneckCSPB(nn.Module):
325
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
326
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
327
+ super(BottleneckCSPB, self).__init__()
328
+ c_ = int(c2) # hidden channels
329
+ self.cv1 = Conv(c1, c_, 1, 1)
330
+ self.cv2 = Conv(c_, c_, 1, 1)
331
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
332
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
333
+
334
+ def forward(self, x):
335
+ x1 = self.cv1(x)
336
+ y1 = self.m(x1)
337
+ y2 = self.cv2(x1)
338
+ return self.cv3(torch.cat((y1, y2), dim=1))
339
+
340
+
341
+ class BottleneckCSPC(nn.Module):
342
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
343
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
344
+ super(BottleneckCSPC, self).__init__()
345
+ c_ = int(c2 * e) # hidden channels
346
+ self.cv1 = Conv(c1, c_, 1, 1)
347
+ self.cv2 = Conv(c1, c_, 1, 1)
348
+ self.cv3 = Conv(c_, c_, 1, 1)
349
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
350
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
351
+
352
+ def forward(self, x):
353
+ y1 = self.cv3(self.m(self.cv1(x)))
354
+ y2 = self.cv2(x)
355
+ return self.cv4(torch.cat((y1, y2), dim=1))
356
+
357
+
358
+ class ResCSPA(BottleneckCSPA):
359
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
360
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
361
+ super().__init__(c1, c2, n, shortcut, g, e)
362
+ c_ = int(c2 * e) # hidden channels
363
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
364
+
365
+
366
+ class ResCSPB(BottleneckCSPB):
367
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
368
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
369
+ super().__init__(c1, c2, n, shortcut, g, e)
370
+ c_ = int(c2) # hidden channels
371
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
372
+
373
+
374
+ class ResCSPC(BottleneckCSPC):
375
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
376
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
377
+ super().__init__(c1, c2, n, shortcut, g, e)
378
+ c_ = int(c2 * e) # hidden channels
379
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
380
+
381
+
382
+ class ResXCSPA(ResCSPA):
383
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
384
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
385
+ super().__init__(c1, c2, n, shortcut, g, e)
386
+ c_ = int(c2 * e) # hidden channels
387
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
388
+
389
+
390
+ class ResXCSPB(ResCSPB):
391
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
392
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
393
+ super().__init__(c1, c2, n, shortcut, g, e)
394
+ c_ = int(c2) # hidden channels
395
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
396
+
397
+
398
+ class ResXCSPC(ResCSPC):
399
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
400
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
401
+ super().__init__(c1, c2, n, shortcut, g, e)
402
+ c_ = int(c2 * e) # hidden channels
403
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
404
+
405
+
406
+ class GhostCSPA(BottleneckCSPA):
407
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
408
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
409
+ super().__init__(c1, c2, n, shortcut, g, e)
410
+ c_ = int(c2 * e) # hidden channels
411
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
412
+
413
+
414
+ class GhostCSPB(BottleneckCSPB):
415
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
416
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
417
+ super().__init__(c1, c2, n, shortcut, g, e)
418
+ c_ = int(c2) # hidden channels
419
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
420
+
421
+
422
+ class GhostCSPC(BottleneckCSPC):
423
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
424
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
425
+ super().__init__(c1, c2, n, shortcut, g, e)
426
+ c_ = int(c2 * e) # hidden channels
427
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
428
+
429
+ ##### end of cspnet #####
430
+
431
+
432
+ ##### yolor #####
433
+
434
+ class ImplicitA(nn.Module):
435
+ def __init__(self, channel, mean=0., std=.02):
436
+ super(ImplicitA, self).__init__()
437
+ self.channel = channel
438
+ self.mean = mean
439
+ self.std = std
440
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
441
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
442
+
443
+ def forward(self, x):
444
+ return self.implicit + x
445
+
446
+
447
+ class ImplicitM(nn.Module):
448
+ def __init__(self, channel, mean=0., std=.02):
449
+ super(ImplicitM, self).__init__()
450
+ self.channel = channel
451
+ self.mean = mean
452
+ self.std = std
453
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
454
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
455
+
456
+ def forward(self, x):
457
+ return self.implicit * x
458
+
459
+ ##### end of yolor #####
460
+
461
+
462
+ ##### repvgg #####
463
+
464
+ class RepConv(nn.Module):
465
+ # Represented convolution
466
+ # https://arxiv.org/abs/2101.03697
467
+
468
+ def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, deploy=False):
469
+ super(RepConv, self).__init__()
470
+
471
+ self.deploy = deploy
472
+ self.groups = g
473
+ self.in_channels = c1
474
+ self.out_channels = c2
475
+
476
+ assert k == 3
477
+ assert autopad(k, p) == 1
478
+
479
+ padding_11 = autopad(k, p) - k // 2
480
+
481
+ self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
482
+
483
+ if deploy:
484
+ self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)
485
+
486
+ else:
487
+ self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
488
+
489
+ self.rbr_dense = nn.Sequential(
490
+ nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False),
491
+ nn.BatchNorm2d(num_features=c2),
492
+ )
493
+
494
+ self.rbr_1x1 = nn.Sequential(
495
+ nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False),
496
+ nn.BatchNorm2d(num_features=c2),
497
+ )
498
+
499
+ def forward(self, inputs):
500
+ if hasattr(self, "rbr_reparam"):
501
+ return self.act(self.rbr_reparam(inputs))
502
+
503
+ if self.rbr_identity is None:
504
+ id_out = 0
505
+ else:
506
+ id_out = self.rbr_identity(inputs)
507
+
508
+ return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
509
+
510
+ def get_equivalent_kernel_bias(self):
511
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
512
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
513
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
514
+ return (
515
+ kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
516
+ bias3x3 + bias1x1 + biasid,
517
+ )
518
+
519
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
520
+ if kernel1x1 is None:
521
+ return 0
522
+ else:
523
+ return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
524
+
525
+ def _fuse_bn_tensor(self, branch):
526
+ if branch is None:
527
+ return 0, 0
528
+ if isinstance(branch, nn.Sequential):
529
+ kernel = branch[0].weight
530
+ running_mean = branch[1].running_mean
531
+ running_var = branch[1].running_var
532
+ gamma = branch[1].weight
533
+ beta = branch[1].bias
534
+ eps = branch[1].eps
535
+ else:
536
+ assert isinstance(branch, nn.BatchNorm2d)
537
+ if not hasattr(self, "id_tensor"):
538
+ input_dim = self.in_channels // self.groups
539
+ kernel_value = np.zeros(
540
+ (self.in_channels, input_dim, 3, 3), dtype=np.float32
541
+ )
542
+ for i in range(self.in_channels):
543
+ kernel_value[i, i % input_dim, 1, 1] = 1
544
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
545
+ kernel = self.id_tensor
546
+ running_mean = branch.running_mean
547
+ running_var = branch.running_var
548
+ gamma = branch.weight
549
+ beta = branch.bias
550
+ eps = branch.eps
551
+ std = (running_var + eps).sqrt()
552
+ t = (gamma / std).reshape(-1, 1, 1, 1)
553
+ return kernel * t, beta - running_mean * gamma / std
554
+
555
+ def repvgg_convert(self):
556
+ kernel, bias = self.get_equivalent_kernel_bias()
557
+ return (
558
+ kernel.detach().cpu().numpy(),
559
+ bias.detach().cpu().numpy(),
560
+ )
561
+
562
+ def fuse_conv_bn(self, conv, bn):
563
+
564
+ std = (bn.running_var + bn.eps).sqrt()
565
+ bias = bn.bias - bn.running_mean * bn.weight / std
566
+
567
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
568
+ weights = conv.weight * t
569
+
570
+ bn = nn.Identity()
571
+ conv = nn.Conv2d(in_channels = conv.in_channels,
572
+ out_channels = conv.out_channels,
573
+ kernel_size = conv.kernel_size,
574
+ stride=conv.stride,
575
+ padding = conv.padding,
576
+ dilation = conv.dilation,
577
+ groups = conv.groups,
578
+ bias = True,
579
+ padding_mode = conv.padding_mode)
580
+
581
+ conv.weight = torch.nn.Parameter(weights)
582
+ conv.bias = torch.nn.Parameter(bias)
583
+ return conv
584
+
585
+ def fuse_repvgg_block(self):
586
+ if self.deploy:
587
+ return
588
+ print(f"RepConv.fuse_repvgg_block")
589
+
590
+ self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
591
+
592
+ self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
593
+ rbr_1x1_bias = self.rbr_1x1.bias
594
+ weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
595
+
596
+ # Fuse self.rbr_identity
597
+ if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
598
+ # print(f"fuse: rbr_identity == BatchNorm2d or SyncBatchNorm")
599
+ identity_conv_1x1 = nn.Conv2d(
600
+ in_channels=self.in_channels,
601
+ out_channels=self.out_channels,
602
+ kernel_size=1,
603
+ stride=1,
604
+ padding=0,
605
+ groups=self.groups,
606
+ bias=False)
607
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
608
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
609
+ # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
610
+ identity_conv_1x1.weight.data.fill_(0.0)
611
+ identity_conv_1x1.weight.data.fill_diagonal_(1.0)
612
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
613
+ # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
614
+
615
+ identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
616
+ bias_identity_expanded = identity_conv_1x1.bias
617
+ weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])
618
+ else:
619
+ # print(f"fuse: rbr_identity != BatchNorm2d, rbr_identity = {self.rbr_identity}")
620
+ bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
621
+ weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )
622
+
623
+
624
+ #print(f"self.rbr_1x1.weight = {self.rbr_1x1.weight.shape}, ")
625
+ #print(f"weight_1x1_expanded = {weight_1x1_expanded.shape}, ")
626
+ #print(f"self.rbr_dense.weight = {self.rbr_dense.weight.shape}, ")
627
+
628
+ self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
629
+ self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
630
+
631
+ self.rbr_reparam = self.rbr_dense
632
+ self.deploy = True
633
+
634
+ if self.rbr_identity is not None:
635
+ del self.rbr_identity
636
+ self.rbr_identity = None
637
+
638
+ if self.rbr_1x1 is not None:
639
+ del self.rbr_1x1
640
+ self.rbr_1x1 = None
641
+
642
+ if self.rbr_dense is not None:
643
+ del self.rbr_dense
644
+ self.rbr_dense = None
645
+
646
+
647
+ class RepBottleneck(Bottleneck):
648
+ # Standard bottleneck
649
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
650
+ super().__init__(c1, c2, shortcut=True, g=1, e=0.5)
651
+ c_ = int(c2 * e) # hidden channels
652
+ self.cv2 = RepConv(c_, c2, 3, 1, g=g)
653
+
654
+
655
+ class RepBottleneckCSPA(BottleneckCSPA):
656
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
657
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
658
+ super().__init__(c1, c2, n, shortcut, g, e)
659
+ c_ = int(c2 * e) # hidden channels
660
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
661
+
662
+
663
+ class RepBottleneckCSPB(BottleneckCSPB):
664
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
665
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
666
+ super().__init__(c1, c2, n, shortcut, g, e)
667
+ c_ = int(c2) # hidden channels
668
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
669
+
670
+
671
+ class RepBottleneckCSPC(BottleneckCSPC):
672
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
673
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
674
+ super().__init__(c1, c2, n, shortcut, g, e)
675
+ c_ = int(c2 * e) # hidden channels
676
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
677
+
678
+
679
+ class RepRes(Res):
680
+ # Standard bottleneck
681
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
682
+ super().__init__(c1, c2, shortcut, g, e)
683
+ c_ = int(c2 * e) # hidden channels
684
+ self.cv2 = RepConv(c_, c_, 3, 1, g=g)
685
+
686
+
687
+ class RepResCSPA(ResCSPA):
688
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
689
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
690
+ super().__init__(c1, c2, n, shortcut, g, e)
691
+ c_ = int(c2 * e) # hidden channels
692
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
693
+
694
+
695
+ class RepResCSPB(ResCSPB):
696
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
697
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
698
+ super().__init__(c1, c2, n, shortcut, g, e)
699
+ c_ = int(c2) # hidden channels
700
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
701
+
702
+
703
+ class RepResCSPC(ResCSPC):
704
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
705
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
706
+ super().__init__(c1, c2, n, shortcut, g, e)
707
+ c_ = int(c2 * e) # hidden channels
708
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
709
+
710
+
711
+ class RepResX(ResX):
712
+ # Standard bottleneck
713
+ def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
714
+ super().__init__(c1, c2, shortcut, g, e)
715
+ c_ = int(c2 * e) # hidden channels
716
+ self.cv2 = RepConv(c_, c_, 3, 1, g=g)
717
+
718
+
719
+ class RepResXCSPA(ResXCSPA):
720
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
721
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
722
+ super().__init__(c1, c2, n, shortcut, g, e)
723
+ c_ = int(c2 * e) # hidden channels
724
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
725
+
726
+
727
+ class RepResXCSPB(ResXCSPB):
728
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
729
+ def __init__(self, c1, c2, n=1, shortcut=False, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
730
+ super().__init__(c1, c2, n, shortcut, g, e)
731
+ c_ = int(c2) # hidden channels
732
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
733
+
734
+
735
+ class RepResXCSPC(ResXCSPC):
736
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
737
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
738
+ super().__init__(c1, c2, n, shortcut, g, e)
739
+ c_ = int(c2 * e) # hidden channels
740
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
741
+
742
+ ##### end of repvgg #####
743
+
744
+
745
+ ##### transformer #####
746
+
747
+ class TransformerLayer(nn.Module):
748
+ # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
749
+ def __init__(self, c, num_heads):
750
+ super().__init__()
751
+ self.q = nn.Linear(c, c, bias=False)
752
+ self.k = nn.Linear(c, c, bias=False)
753
+ self.v = nn.Linear(c, c, bias=False)
754
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
755
+ self.fc1 = nn.Linear(c, c, bias=False)
756
+ self.fc2 = nn.Linear(c, c, bias=False)
757
+
758
+ def forward(self, x):
759
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
760
+ x = self.fc2(self.fc1(x)) + x
761
+ return x
762
+
763
+
764
+ class TransformerBlock(nn.Module):
765
+ # Vision Transformer https://arxiv.org/abs/2010.11929
766
+ def __init__(self, c1, c2, num_heads, num_layers):
767
+ super().__init__()
768
+ self.conv = None
769
+ if c1 != c2:
770
+ self.conv = Conv(c1, c2)
771
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
772
+ self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
773
+ self.c2 = c2
774
+
775
+ def forward(self, x):
776
+ if self.conv is not None:
777
+ x = self.conv(x)
778
+ b, _, w, h = x.shape
779
+ p = x.flatten(2)
780
+ p = p.unsqueeze(0)
781
+ p = p.transpose(0, 3)
782
+ p = p.squeeze(3)
783
+ e = self.linear(p)
784
+ x = p + e
785
+
786
+ x = self.tr(x)
787
+ x = x.unsqueeze(3)
788
+ x = x.transpose(0, 3)
789
+ x = x.reshape(b, self.c2, w, h)
790
+ return x
791
+
792
+ ##### end of transformer #####
793
+
794
+
795
+ ##### yolov5 #####
796
+
797
+ class Focus(nn.Module):
798
+ # Focus wh information into c-space
799
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
800
+ super(Focus, self).__init__()
801
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
802
+ # self.contract = Contract(gain=2)
803
+
804
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
805
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
806
+ # return self.conv(self.contract(x))
807
+
808
+
809
+ class SPPF(nn.Module):
810
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
811
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
812
+ super().__init__()
813
+ c_ = c1 // 2 # hidden channels
814
+ self.cv1 = Conv(c1, c_, 1, 1)
815
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
816
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
817
+
818
+ def forward(self, x):
819
+ x = self.cv1(x)
820
+ y1 = self.m(x)
821
+ y2 = self.m(y1)
822
+ return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
823
+
824
+
825
+ class Contract(nn.Module):
826
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
827
+ def __init__(self, gain=2):
828
+ super().__init__()
829
+ self.gain = gain
830
+
831
+ def forward(self, x):
832
+ N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
833
+ s = self.gain
834
+ x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
835
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
836
+ return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
837
+
838
+
839
+ class Expand(nn.Module):
840
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
841
+ def __init__(self, gain=2):
842
+ super().__init__()
843
+ self.gain = gain
844
+
845
+ def forward(self, x):
846
+ N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
847
+ s = self.gain
848
+ x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
849
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
850
+ return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
851
+
852
+
853
+ class NMS(nn.Module):
854
+ # Non-Maximum Suppression (NMS) module
855
+ conf = 0.25 # confidence threshold
856
+ iou = 0.45 # IoU threshold
857
+ classes = None # (optional list) filter by class
858
+
859
+ def __init__(self):
860
+ super(NMS, self).__init__()
861
+
862
+ def forward(self, x):
863
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
864
+
865
+
866
+ class autoShape(nn.Module):
867
+ # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
868
+ conf = 0.25 # NMS confidence threshold
869
+ iou = 0.45 # NMS IoU threshold
870
+ classes = None # (optional list) filter by class
871
+
872
+ def __init__(self, model):
873
+ super(autoShape, self).__init__()
874
+ self.model = model.eval()
875
+
876
+ def autoshape(self):
877
+ print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
878
+ return self
879
+
880
+ @torch.no_grad()
881
+ def forward(self, imgs, size=640, augment=False, profile=False):
882
+ # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
883
+ # filename: imgs = 'data/samples/zidane.jpg'
884
+ # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
885
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
886
+ # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
887
+ # numpy: = np.zeros((640,1280,3)) # HWC
888
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
889
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
890
+
891
+ t = [time_synchronized()]
892
+ p = next(self.model.parameters()) # for device and type
893
+ if isinstance(imgs, torch.Tensor): # torch
894
+ with amp.autocast(enabled=p.device.type != 'cpu'):
895
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
896
+
897
+ # Pre-process
898
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
899
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
900
+ for i, im in enumerate(imgs):
901
+ f = f'image{i}' # filename
902
+ if isinstance(im, str): # filename or uri
903
+ im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
904
+ elif isinstance(im, Image.Image): # PIL Image
905
+ im, f = np.asarray(im), getattr(im, 'filename', f) or f
906
+ files.append(Path(f).with_suffix('.jpg').name)
907
+ if im.shape[0] < 5: # image in CHW
908
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
909
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
910
+ s = im.shape[:2] # HWC
911
+ shape0.append(s) # image shape
912
+ g = (size / max(s)) # gain
913
+ shape1.append([y * g for y in s])
914
+ imgs[i] = im # update
915
+ shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
916
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
917
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
918
+ x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
919
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
920
+ t.append(time_synchronized())
921
+
922
+ with amp.autocast(enabled=p.device.type != 'cpu'):
923
+ # Inference
924
+ y = self.model(x, augment, profile)[0] # forward
925
+ t.append(time_synchronized())
926
+
927
+ # Post-process
928
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
929
+ for i in range(n):
930
+ scale_coords(shape1, y[i][:, :4], shape0[i])
931
+
932
+ t.append(time_synchronized())
933
+ return Detections(imgs, y, files, t, self.names, x.shape)
934
+
935
+
936
+ class Detections:
937
+ # detections class for YOLOv5 inference results
938
+ def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
939
+ super(Detections, self).__init__()
940
+ d = pred[0].device # device
941
+ gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
942
+ self.imgs = imgs # list of images as numpy arrays
943
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
944
+ self.names = names # class names
945
+ self.files = files # image filenames
946
+ self.xyxy = pred # xyxy pixels
947
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
948
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
949
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
950
+ self.n = len(self.pred) # number of images (batch size)
951
+ self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
952
+ self.s = shape # inference BCHW shape
953
+
954
+ def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
955
+ colors = color_list()
956
+ for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
957
+ str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
958
+ if pred is not None:
959
+ for c in pred[:, -1].unique():
960
+ n = (pred[:, -1] == c).sum() # detections per class
961
+ str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
962
+ if show or save or render:
963
+ for *box, conf, cls in pred: # xyxy, confidence, class
964
+ label = f'{self.names[int(cls)]} {conf:.2f}'
965
+ plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
966
+ img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
967
+ if pprint:
968
+ print(str.rstrip(', '))
969
+ if show:
970
+ img.show(self.files[i]) # show
971
+ if save:
972
+ f = self.files[i]
973
+ img.save(Path(save_dir) / f) # save
974
+ print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
975
+ if render:
976
+ self.imgs[i] = np.asarray(img)
977
+
978
+ def print(self):
979
+ self.display(pprint=True) # print results
980
+ print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
981
+
982
+ def show(self):
983
+ self.display(show=True) # show results
984
+
985
+ def save(self, save_dir='runs/hub/exp'):
986
+ save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
987
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
988
+ self.display(save=True, save_dir=save_dir) # save results
989
+
990
+ def render(self):
991
+ self.display(render=True) # render results
992
+ return self.imgs
993
+
994
+ def pandas(self):
995
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
996
+ new = copy(self) # return copy
997
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
998
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
999
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
1000
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1001
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1002
+ return new
1003
+
1004
+ def tolist(self):
1005
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1006
+ x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
1007
+ for d in x:
1008
+ for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1009
+ setattr(d, k, getattr(d, k)[0]) # pop out of list
1010
+ return x
1011
+
1012
+ def __len__(self):
1013
+ return self.n
1014
+
1015
+
1016
+ class Classify(nn.Module):
1017
+ # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
1018
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1019
+ super(Classify, self).__init__()
1020
+ self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
1021
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
1022
+ self.flat = nn.Flatten()
1023
+
1024
+ def forward(self, x):
1025
+ z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
1026
+ return self.flat(self.conv(z)) # flatten to x(b,c2)
1027
+
1028
+ ##### end of yolov5 ######
1029
+
1030
+
1031
+ ##### orepa #####
1032
+
1033
+ def transI_fusebn(kernel, bn):
1034
+ gamma = bn.weight
1035
+ std = (bn.running_var + bn.eps).sqrt()
1036
+ return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
1037
+
1038
+
1039
+ class ConvBN(nn.Module):
1040
+ def __init__(self, in_channels, out_channels, kernel_size,
1041
+ stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None):
1042
+ super().__init__()
1043
+ if nonlinear is None:
1044
+ self.nonlinear = nn.Identity()
1045
+ else:
1046
+ self.nonlinear = nonlinear
1047
+ if deploy:
1048
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
1049
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
1050
+ else:
1051
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
1052
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
1053
+ self.bn = nn.BatchNorm2d(num_features=out_channels)
1054
+
1055
+ def forward(self, x):
1056
+ if hasattr(self, 'bn'):
1057
+ return self.nonlinear(self.bn(self.conv(x)))
1058
+ else:
1059
+ return self.nonlinear(self.conv(x))
1060
+
1061
+ def switch_to_deploy(self):
1062
+ kernel, bias = transI_fusebn(self.conv.weight, self.bn)
1063
+ conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
1064
+ stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True)
1065
+ conv.weight.data = kernel
1066
+ conv.bias.data = bias
1067
+ for para in self.parameters():
1068
+ para.detach_()
1069
+ self.__delattr__('conv')
1070
+ self.__delattr__('bn')
1071
+ self.conv = conv
1072
+
1073
+ class OREPA_3x3_RepConv(nn.Module):
1074
+
1075
+ def __init__(self, in_channels, out_channels, kernel_size,
1076
+ stride=1, padding=0, dilation=1, groups=1,
1077
+ internal_channels_1x1_3x3=None,
1078
+ deploy=False, nonlinear=None, single_init=False):
1079
+ super(OREPA_3x3_RepConv, self).__init__()
1080
+ self.deploy = deploy
1081
+
1082
+ if nonlinear is None:
1083
+ self.nonlinear = nn.Identity()
1084
+ else:
1085
+ self.nonlinear = nonlinear
1086
+
1087
+ self.kernel_size = kernel_size
1088
+ self.in_channels = in_channels
1089
+ self.out_channels = out_channels
1090
+ self.groups = groups
1091
+ assert padding == kernel_size // 2
1092
+
1093
+ self.stride = stride
1094
+ self.padding = padding
1095
+ self.dilation = dilation
1096
+
1097
+ self.branch_counter = 0
1098
+
1099
+ self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size))
1100
+ nn.init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0))
1101
+ self.branch_counter += 1
1102
+
1103
+
1104
+ if groups < out_channels:
1105
+ self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
1106
+ self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
1107
+ nn.init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0)
1108
+ nn.init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0)
1109
+ self.weight_rbr_avg_conv.data
1110
+ self.weight_rbr_pfir_conv.data
1111
+ self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size))
1112
+ self.branch_counter += 1
1113
+
1114
+ else:
1115
+ raise NotImplementedError
1116
+ self.branch_counter += 1
1117
+
1118
+ if internal_channels_1x1_3x3 is None:
1119
+ internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
1120
+
1121
+ if internal_channels_1x1_3x3 == in_channels:
1122
+ self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1))
1123
+ id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1))
1124
+ for i in range(in_channels):
1125
+ id_value[i, i % int(in_channels/self.groups), 0, 0] = 1
1126
+ id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1)
1127
+ self.register_buffer('id_tensor', id_tensor)
1128
+
1129
+ else:
1130
+ self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1))
1131
+ nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0))
1132
+ self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size))
1133
+ nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0))
1134
+ self.branch_counter += 1
1135
+
1136
+ expand_ratio = 8
1137
+ self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size))
1138
+ self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1))
1139
+ nn.init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0))
1140
+ nn.init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0))
1141
+ self.branch_counter += 1
1142
+
1143
+ if out_channels == in_channels and stride == 1:
1144
+ self.branch_counter += 1
1145
+
1146
+ self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
1147
+ self.bn = nn.BatchNorm2d(out_channels)
1148
+
1149
+ self.fre_init()
1150
+
1151
+ nn.init.constant_(self.vector[0, :], 0.25) #origin
1152
+ nn.init.constant_(self.vector[1, :], 0.25) #avg
1153
+ nn.init.constant_(self.vector[2, :], 0.0) #prior
1154
+ nn.init.constant_(self.vector[3, :], 0.5) #1x1_kxk
1155
+ nn.init.constant_(self.vector[4, :], 0.5) #dws_conv
1156
+
1157
+
1158
+ def fre_init(self):
1159
+ prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size)
1160
+ half_fg = self.out_channels/2
1161
+ for i in range(self.out_channels):
1162
+ for h in range(3):
1163
+ for w in range(3):
1164
+ if i < half_fg:
1165
+ prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
1166
+ else:
1167
+ prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3)
1168
+
1169
+ self.register_buffer('weight_rbr_prior', prior_tensor)
1170
+
1171
+ def weight_gen(self):
1172
+
1173
+ weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :])
1174
+
1175
+ weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :])
1176
+
1177
+ weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :])
1178
+
1179
+ weight_rbr_1x1_kxk_conv1 = None
1180
+ if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'):
1181
+ weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze()
1182
+ elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'):
1183
+ weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze()
1184
+ else:
1185
+ raise NotImplementedError
1186
+ weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2
1187
+
1188
+ if self.groups > 1:
1189
+ g = self.groups
1190
+ t, ig = weight_rbr_1x1_kxk_conv1.size()
1191
+ o, tg, h, w = weight_rbr_1x1_kxk_conv2.size()
1192
+ weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig)
1193
+ weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w)
1194
+ weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w)
1195
+ else:
1196
+ weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2)
1197
+
1198
+ weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :])
1199
+
1200
+ weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels)
1201
+ weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :])
1202
+
1203
+ weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv
1204
+
1205
+ return weight
1206
+
1207
+ def dwsc2full(self, weight_dw, weight_pw, groups):
1208
+
1209
+ t, ig, h, w = weight_dw.size()
1210
+ o, _, _, _ = weight_pw.size()
1211
+ tg = int(t/groups)
1212
+ i = int(ig*groups)
1213
+ weight_dw = weight_dw.view(groups, tg, ig, h, w)
1214
+ weight_pw = weight_pw.squeeze().view(o, groups, tg)
1215
+
1216
+ weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw)
1217
+ return weight_dsc.view(o, i, h, w)
1218
+
1219
+ def forward(self, inputs):
1220
+ weight = self.weight_gen()
1221
+ out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
1222
+
1223
+ return self.nonlinear(self.bn(out))
1224
+
1225
+ class RepConv_OREPA(nn.Module):
1226
+
1227
+ def __init__(self, c1, c2, k=3, s=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.SiLU()):
1228
+ super(RepConv_OREPA, self).__init__()
1229
+ self.deploy = deploy
1230
+ self.groups = groups
1231
+ self.in_channels = c1
1232
+ self.out_channels = c2
1233
+
1234
+ self.padding = padding
1235
+ self.dilation = dilation
1236
+ self.groups = groups
1237
+
1238
+ assert k == 3
1239
+ assert padding == 1
1240
+
1241
+ padding_11 = padding - k // 2
1242
+
1243
+ if nonlinear is None:
1244
+ self.nonlinearity = nn.Identity()
1245
+ else:
1246
+ self.nonlinearity = nonlinear
1247
+
1248
+ if use_se:
1249
+ self.se = SEBlock(self.out_channels, internal_neurons=self.out_channels // 16)
1250
+ else:
1251
+ self.se = nn.Identity()
1252
+
1253
+ if deploy:
1254
+ self.rbr_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s,
1255
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
1256
+
1257
+ else:
1258
+ self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.out_channels == self.in_channels and s == 1 else None
1259
+ self.rbr_dense = OREPA_3x3_RepConv(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s, padding=padding, groups=groups, dilation=1)
1260
+ self.rbr_1x1 = ConvBN(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=s, padding=padding_11, groups=groups, dilation=1)
1261
+ print('RepVGG Block, identity = ', self.rbr_identity)
1262
+
1263
+
1264
+ def forward(self, inputs):
1265
+ if hasattr(self, 'rbr_reparam'):
1266
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
1267
+
1268
+ if self.rbr_identity is None:
1269
+ id_out = 0
1270
+ else:
1271
+ id_out = self.rbr_identity(inputs)
1272
+
1273
+ out1 = self.rbr_dense(inputs)
1274
+ out2 = self.rbr_1x1(inputs)
1275
+ out3 = id_out
1276
+ out = out1 + out2 + out3
1277
+
1278
+ return self.nonlinearity(self.se(out))
1279
+
1280
+
1281
+ # Optional. This improves the accuracy and facilitates quantization.
1282
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
1283
+ # 2. Use like this.
1284
+ # loss = criterion(....)
1285
+ # for every RepVGGBlock blk:
1286
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
1287
+ # optimizer.zero_grad()
1288
+ # loss.backward()
1289
+
1290
+ # Not used for OREPA
1291
+ def get_custom_L2(self):
1292
+ K3 = self.rbr_dense.weight_gen()
1293
+ K1 = self.rbr_1x1.conv.weight
1294
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
1295
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
1296
+
1297
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
1298
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
1299
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
1300
+ return l2_loss_eq_kernel + l2_loss_circle
1301
+
1302
+ def get_equivalent_kernel_bias(self):
1303
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
1304
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
1305
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
1306
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
1307
+
1308
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
1309
+ if kernel1x1 is None:
1310
+ return 0
1311
+ else:
1312
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
1313
+
1314
+ def _fuse_bn_tensor(self, branch):
1315
+ if branch is None:
1316
+ return 0, 0
1317
+ if not isinstance(branch, nn.BatchNorm2d):
1318
+ if isinstance(branch, OREPA_3x3_RepConv):
1319
+ kernel = branch.weight_gen()
1320
+ elif isinstance(branch, ConvBN):
1321
+ kernel = branch.conv.weight
1322
+ else:
1323
+ raise NotImplementedError
1324
+ running_mean = branch.bn.running_mean
1325
+ running_var = branch.bn.running_var
1326
+ gamma = branch.bn.weight
1327
+ beta = branch.bn.bias
1328
+ eps = branch.bn.eps
1329
+ else:
1330
+ if not hasattr(self, 'id_tensor'):
1331
+ input_dim = self.in_channels // self.groups
1332
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
1333
+ for i in range(self.in_channels):
1334
+ kernel_value[i, i % input_dim, 1, 1] = 1
1335
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
1336
+ kernel = self.id_tensor
1337
+ running_mean = branch.running_mean
1338
+ running_var = branch.running_var
1339
+ gamma = branch.weight
1340
+ beta = branch.bias
1341
+ eps = branch.eps
1342
+ std = (running_var + eps).sqrt()
1343
+ t = (gamma / std).reshape(-1, 1, 1, 1)
1344
+ return kernel * t, beta - running_mean * gamma / std
1345
+
1346
+ def switch_to_deploy(self):
1347
+ if hasattr(self, 'rbr_reparam'):
1348
+ return
1349
+ print(f"RepConv_OREPA.switch_to_deploy")
1350
+ kernel, bias = self.get_equivalent_kernel_bias()
1351
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels,
1352
+ kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride,
1353
+ padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True)
1354
+ self.rbr_reparam.weight.data = kernel
1355
+ self.rbr_reparam.bias.data = bias
1356
+ for para in self.parameters():
1357
+ para.detach_()
1358
+ self.__delattr__('rbr_dense')
1359
+ self.__delattr__('rbr_1x1')
1360
+ if hasattr(self, 'rbr_identity'):
1361
+ self.__delattr__('rbr_identity')
1362
+
1363
+ ##### end of orepa #####
1364
+
1365
+
1366
+ ##### swin transformer #####
1367
+
1368
+ class WindowAttention(nn.Module):
1369
+
1370
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
1371
+
1372
+ super().__init__()
1373
+ self.dim = dim
1374
+ self.window_size = window_size # Wh, Ww
1375
+ self.num_heads = num_heads
1376
+ head_dim = dim // num_heads
1377
+ self.scale = qk_scale or head_dim ** -0.5
1378
+
1379
+ # define a parameter table of relative position bias
1380
+ self.relative_position_bias_table = nn.Parameter(
1381
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
1382
+
1383
+ # get pair-wise relative position index for each token inside the window
1384
+ coords_h = torch.arange(self.window_size[0])
1385
+ coords_w = torch.arange(self.window_size[1])
1386
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1387
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1388
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1389
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1390
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
1391
+ relative_coords[:, :, 1] += self.window_size[1] - 1
1392
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
1393
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1394
+ self.register_buffer("relative_position_index", relative_position_index)
1395
+
1396
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1397
+ self.attn_drop = nn.Dropout(attn_drop)
1398
+ self.proj = nn.Linear(dim, dim)
1399
+ self.proj_drop = nn.Dropout(proj_drop)
1400
+
1401
+ nn.init.normal_(self.relative_position_bias_table, std=.02)
1402
+ self.softmax = nn.Softmax(dim=-1)
1403
+
1404
+ def forward(self, x, mask=None):
1405
+
1406
+ B_, N, C = x.shape
1407
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
1408
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
1409
+
1410
+ q = q * self.scale
1411
+ attn = (q @ k.transpose(-2, -1))
1412
+
1413
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
1414
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
1415
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1416
+ attn = attn + relative_position_bias.unsqueeze(0)
1417
+
1418
+ if mask is not None:
1419
+ nW = mask.shape[0]
1420
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
1421
+ attn = attn.view(-1, self.num_heads, N, N)
1422
+ attn = self.softmax(attn)
1423
+ else:
1424
+ attn = self.softmax(attn)
1425
+
1426
+ attn = self.attn_drop(attn)
1427
+
1428
+ # print(attn.dtype, v.dtype)
1429
+ try:
1430
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
1431
+ except:
1432
+ #print(attn.dtype, v.dtype)
1433
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
1434
+ x = self.proj(x)
1435
+ x = self.proj_drop(x)
1436
+ return x
1437
+
1438
+ class Mlp(nn.Module):
1439
+
1440
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
1441
+ super().__init__()
1442
+ out_features = out_features or in_features
1443
+ hidden_features = hidden_features or in_features
1444
+ self.fc1 = nn.Linear(in_features, hidden_features)
1445
+ self.act = act_layer()
1446
+ self.fc2 = nn.Linear(hidden_features, out_features)
1447
+ self.drop = nn.Dropout(drop)
1448
+
1449
+ def forward(self, x):
1450
+ x = self.fc1(x)
1451
+ x = self.act(x)
1452
+ x = self.drop(x)
1453
+ x = self.fc2(x)
1454
+ x = self.drop(x)
1455
+ return x
1456
+
1457
+ def window_partition(x, window_size):
1458
+
1459
+ B, H, W, C = x.shape
1460
+ assert H % window_size == 0, 'feature map h and w can not divide by window size'
1461
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1462
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
1463
+ return windows
1464
+
1465
+ def window_reverse(windows, window_size, H, W):
1466
+
1467
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
1468
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1469
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
1470
+ return x
1471
+
1472
+
1473
+ class SwinTransformerLayer(nn.Module):
1474
+
1475
+ def __init__(self, dim, num_heads, window_size=8, shift_size=0,
1476
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
1477
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
1478
+ super().__init__()
1479
+ self.dim = dim
1480
+ self.num_heads = num_heads
1481
+ self.window_size = window_size
1482
+ self.shift_size = shift_size
1483
+ self.mlp_ratio = mlp_ratio
1484
+ # if min(self.input_resolution) <= self.window_size:
1485
+ # # if window size is larger than input resolution, we don't partition windows
1486
+ # self.shift_size = 0
1487
+ # self.window_size = min(self.input_resolution)
1488
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
1489
+
1490
+ self.norm1 = norm_layer(dim)
1491
+ self.attn = WindowAttention(
1492
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
1493
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
1494
+
1495
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1496
+ self.norm2 = norm_layer(dim)
1497
+ mlp_hidden_dim = int(dim * mlp_ratio)
1498
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1499
+
1500
+ def create_mask(self, H, W):
1501
+ # calculate attention mask for SW-MSA
1502
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
1503
+ h_slices = (slice(0, -self.window_size),
1504
+ slice(-self.window_size, -self.shift_size),
1505
+ slice(-self.shift_size, None))
1506
+ w_slices = (slice(0, -self.window_size),
1507
+ slice(-self.window_size, -self.shift_size),
1508
+ slice(-self.shift_size, None))
1509
+ cnt = 0
1510
+ for h in h_slices:
1511
+ for w in w_slices:
1512
+ img_mask[:, h, w, :] = cnt
1513
+ cnt += 1
1514
+
1515
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1516
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1517
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1518
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1519
+
1520
+ return attn_mask
1521
+
1522
+ def forward(self, x):
1523
+ # reshape x[b c h w] to x[b l c]
1524
+ _, _, H_, W_ = x.shape
1525
+
1526
+ Padding = False
1527
+ if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
1528
+ Padding = True
1529
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
1530
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
1531
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
1532
+ x = F.pad(x, (0, pad_r, 0, pad_b))
1533
+
1534
+ # print('2', x.shape)
1535
+ B, C, H, W = x.shape
1536
+ L = H * W
1537
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
1538
+
1539
+ # create mask from init to forward
1540
+ if self.shift_size > 0:
1541
+ attn_mask = self.create_mask(H, W).to(x.device)
1542
+ else:
1543
+ attn_mask = None
1544
+
1545
+ shortcut = x
1546
+ x = self.norm1(x)
1547
+ x = x.view(B, H, W, C)
1548
+
1549
+ # cyclic shift
1550
+ if self.shift_size > 0:
1551
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
1552
+ else:
1553
+ shifted_x = x
1554
+
1555
+ # partition windows
1556
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
1557
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
1558
+
1559
+ # W-MSA/SW-MSA
1560
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
1561
+
1562
+ # merge windows
1563
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1564
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
1565
+
1566
+ # reverse cyclic shift
1567
+ if self.shift_size > 0:
1568
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1569
+ else:
1570
+ x = shifted_x
1571
+ x = x.view(B, H * W, C)
1572
+
1573
+ # FFN
1574
+ x = shortcut + self.drop_path(x)
1575
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1576
+
1577
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
1578
+
1579
+ if Padding:
1580
+ x = x[:, :, :H_, :W_] # reverse padding
1581
+
1582
+ return x
1583
+
1584
+
1585
+ class SwinTransformerBlock(nn.Module):
1586
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
1587
+ super().__init__()
1588
+ self.conv = None
1589
+ if c1 != c2:
1590
+ self.conv = Conv(c1, c2)
1591
+
1592
+ # remove input_resolution
1593
+ self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
1594
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
1595
+
1596
+ def forward(self, x):
1597
+ if self.conv is not None:
1598
+ x = self.conv(x)
1599
+ x = self.blocks(x)
1600
+ return x
1601
+
1602
+
1603
+ class STCSPA(nn.Module):
1604
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1605
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1606
+ super(STCSPA, self).__init__()
1607
+ c_ = int(c2 * e) # hidden channels
1608
+ self.cv1 = Conv(c1, c_, 1, 1)
1609
+ self.cv2 = Conv(c1, c_, 1, 1)
1610
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1611
+ num_heads = c_ // 32
1612
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1613
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1614
+
1615
+ def forward(self, x):
1616
+ y1 = self.m(self.cv1(x))
1617
+ y2 = self.cv2(x)
1618
+ return self.cv3(torch.cat((y1, y2), dim=1))
1619
+
1620
+
1621
+ class STCSPB(nn.Module):
1622
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1623
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1624
+ super(STCSPB, self).__init__()
1625
+ c_ = int(c2) # hidden channels
1626
+ self.cv1 = Conv(c1, c_, 1, 1)
1627
+ self.cv2 = Conv(c_, c_, 1, 1)
1628
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1629
+ num_heads = c_ // 32
1630
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1631
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1632
+
1633
+ def forward(self, x):
1634
+ x1 = self.cv1(x)
1635
+ y1 = self.m(x1)
1636
+ y2 = self.cv2(x1)
1637
+ return self.cv3(torch.cat((y1, y2), dim=1))
1638
+
1639
+
1640
+ class STCSPC(nn.Module):
1641
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1642
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1643
+ super(STCSPC, self).__init__()
1644
+ c_ = int(c2 * e) # hidden channels
1645
+ self.cv1 = Conv(c1, c_, 1, 1)
1646
+ self.cv2 = Conv(c1, c_, 1, 1)
1647
+ self.cv3 = Conv(c_, c_, 1, 1)
1648
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
1649
+ num_heads = c_ // 32
1650
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1651
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1652
+
1653
+ def forward(self, x):
1654
+ y1 = self.cv3(self.m(self.cv1(x)))
1655
+ y2 = self.cv2(x)
1656
+ return self.cv4(torch.cat((y1, y2), dim=1))
1657
+
1658
+ ##### end of swin transformer #####
1659
+
1660
+
1661
+ ##### swin transformer v2 #####
1662
+
1663
+ class WindowAttention_v2(nn.Module):
1664
+
1665
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
1666
+ pretrained_window_size=[0, 0]):
1667
+
1668
+ super().__init__()
1669
+ self.dim = dim
1670
+ self.window_size = window_size # Wh, Ww
1671
+ self.pretrained_window_size = pretrained_window_size
1672
+ self.num_heads = num_heads
1673
+
1674
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
1675
+
1676
+ # mlp to generate continuous relative position bias
1677
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
1678
+ nn.ReLU(inplace=True),
1679
+ nn.Linear(512, num_heads, bias=False))
1680
+
1681
+ # get relative_coords_table
1682
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
1683
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
1684
+ relative_coords_table = torch.stack(
1685
+ torch.meshgrid([relative_coords_h,
1686
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
1687
+ if pretrained_window_size[0] > 0:
1688
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
1689
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
1690
+ else:
1691
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
1692
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
1693
+ relative_coords_table *= 8 # normalize to -8, 8
1694
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1695
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
1696
+
1697
+ self.register_buffer("relative_coords_table", relative_coords_table)
1698
+
1699
+ # get pair-wise relative position index for each token inside the window
1700
+ coords_h = torch.arange(self.window_size[0])
1701
+ coords_w = torch.arange(self.window_size[1])
1702
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1703
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1704
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1705
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1706
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
1707
+ relative_coords[:, :, 1] += self.window_size[1] - 1
1708
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
1709
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1710
+ self.register_buffer("relative_position_index", relative_position_index)
1711
+
1712
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
1713
+ if qkv_bias:
1714
+ self.q_bias = nn.Parameter(torch.zeros(dim))
1715
+ self.v_bias = nn.Parameter(torch.zeros(dim))
1716
+ else:
1717
+ self.q_bias = None
1718
+ self.v_bias = None
1719
+ self.attn_drop = nn.Dropout(attn_drop)
1720
+ self.proj = nn.Linear(dim, dim)
1721
+ self.proj_drop = nn.Dropout(proj_drop)
1722
+ self.softmax = nn.Softmax(dim=-1)
1723
+
1724
+ def forward(self, x, mask=None):
1725
+
1726
+ B_, N, C = x.shape
1727
+ qkv_bias = None
1728
+ if self.q_bias is not None:
1729
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
1730
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
1731
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
1732
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
1733
+
1734
+ # cosine attention
1735
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
1736
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
1737
+ attn = attn * logit_scale
1738
+
1739
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
1740
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
1741
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
1742
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1743
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
1744
+ attn = attn + relative_position_bias.unsqueeze(0)
1745
+
1746
+ if mask is not None:
1747
+ nW = mask.shape[0]
1748
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
1749
+ attn = attn.view(-1, self.num_heads, N, N)
1750
+ attn = self.softmax(attn)
1751
+ else:
1752
+ attn = self.softmax(attn)
1753
+
1754
+ attn = self.attn_drop(attn)
1755
+
1756
+ try:
1757
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
1758
+ except:
1759
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
1760
+
1761
+ x = self.proj(x)
1762
+ x = self.proj_drop(x)
1763
+ return x
1764
+
1765
+ def extra_repr(self) -> str:
1766
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
1767
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
1768
+
1769
+ def flops(self, N):
1770
+ # calculate flops for 1 window with token length of N
1771
+ flops = 0
1772
+ # qkv = self.qkv(x)
1773
+ flops += N * self.dim * 3 * self.dim
1774
+ # attn = (q @ k.transpose(-2, -1))
1775
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
1776
+ # x = (attn @ v)
1777
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
1778
+ # x = self.proj(x)
1779
+ flops += N * self.dim * self.dim
1780
+ return flops
1781
+
1782
+ class Mlp_v2(nn.Module):
1783
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
1784
+ super().__init__()
1785
+ out_features = out_features or in_features
1786
+ hidden_features = hidden_features or in_features
1787
+ self.fc1 = nn.Linear(in_features, hidden_features)
1788
+ self.act = act_layer()
1789
+ self.fc2 = nn.Linear(hidden_features, out_features)
1790
+ self.drop = nn.Dropout(drop)
1791
+
1792
+ def forward(self, x):
1793
+ x = self.fc1(x)
1794
+ x = self.act(x)
1795
+ x = self.drop(x)
1796
+ x = self.fc2(x)
1797
+ x = self.drop(x)
1798
+ return x
1799
+
1800
+
1801
+ def window_partition_v2(x, window_size):
1802
+
1803
+ B, H, W, C = x.shape
1804
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1805
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
1806
+ return windows
1807
+
1808
+
1809
+ def window_reverse_v2(windows, window_size, H, W):
1810
+
1811
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
1812
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1813
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
1814
+ return x
1815
+
1816
+
1817
+ class SwinTransformerLayer_v2(nn.Module):
1818
+
1819
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
1820
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
1821
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
1822
+ super().__init__()
1823
+ self.dim = dim
1824
+ #self.input_resolution = input_resolution
1825
+ self.num_heads = num_heads
1826
+ self.window_size = window_size
1827
+ self.shift_size = shift_size
1828
+ self.mlp_ratio = mlp_ratio
1829
+ #if min(self.input_resolution) <= self.window_size:
1830
+ # # if window size is larger than input resolution, we don't partition windows
1831
+ # self.shift_size = 0
1832
+ # self.window_size = min(self.input_resolution)
1833
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
1834
+
1835
+ self.norm1 = norm_layer(dim)
1836
+ self.attn = WindowAttention_v2(
1837
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
1838
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1839
+ pretrained_window_size=(pretrained_window_size, pretrained_window_size))
1840
+
1841
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1842
+ self.norm2 = norm_layer(dim)
1843
+ mlp_hidden_dim = int(dim * mlp_ratio)
1844
+ self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1845
+
1846
+ def create_mask(self, H, W):
1847
+ # calculate attention mask for SW-MSA
1848
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
1849
+ h_slices = (slice(0, -self.window_size),
1850
+ slice(-self.window_size, -self.shift_size),
1851
+ slice(-self.shift_size, None))
1852
+ w_slices = (slice(0, -self.window_size),
1853
+ slice(-self.window_size, -self.shift_size),
1854
+ slice(-self.shift_size, None))
1855
+ cnt = 0
1856
+ for h in h_slices:
1857
+ for w in w_slices:
1858
+ img_mask[:, h, w, :] = cnt
1859
+ cnt += 1
1860
+
1861
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1862
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1863
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1864
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1865
+
1866
+ return attn_mask
1867
+
1868
+ def forward(self, x):
1869
+ # reshape x[b c h w] to x[b l c]
1870
+ _, _, H_, W_ = x.shape
1871
+
1872
+ Padding = False
1873
+ if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
1874
+ Padding = True
1875
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
1876
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
1877
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
1878
+ x = F.pad(x, (0, pad_r, 0, pad_b))
1879
+
1880
+ # print('2', x.shape)
1881
+ B, C, H, W = x.shape
1882
+ L = H * W
1883
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
1884
+
1885
+ # create mask from init to forward
1886
+ if self.shift_size > 0:
1887
+ attn_mask = self.create_mask(H, W).to(x.device)
1888
+ else:
1889
+ attn_mask = None
1890
+
1891
+ shortcut = x
1892
+ x = x.view(B, H, W, C)
1893
+
1894
+ # cyclic shift
1895
+ if self.shift_size > 0:
1896
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
1897
+ else:
1898
+ shifted_x = x
1899
+
1900
+ # partition windows
1901
+ x_windows = window_partition_v2(shifted_x, self.window_size) # nW*B, window_size, window_size, C
1902
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
1903
+
1904
+ # W-MSA/SW-MSA
1905
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
1906
+
1907
+ # merge windows
1908
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1909
+ shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W) # B H' W' C
1910
+
1911
+ # reverse cyclic shift
1912
+ if self.shift_size > 0:
1913
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1914
+ else:
1915
+ x = shifted_x
1916
+ x = x.view(B, H * W, C)
1917
+ x = shortcut + self.drop_path(self.norm1(x))
1918
+
1919
+ # FFN
1920
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
1921
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
1922
+
1923
+ if Padding:
1924
+ x = x[:, :, :H_, :W_] # reverse padding
1925
+
1926
+ return x
1927
+
1928
+ def extra_repr(self) -> str:
1929
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
1930
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
1931
+
1932
+ def flops(self):
1933
+ flops = 0
1934
+ H, W = self.input_resolution
1935
+ # norm1
1936
+ flops += self.dim * H * W
1937
+ # W-MSA/SW-MSA
1938
+ nW = H * W / self.window_size / self.window_size
1939
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
1940
+ # mlp
1941
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
1942
+ # norm2
1943
+ flops += self.dim * H * W
1944
+ return flops
1945
+
1946
+
1947
+ class SwinTransformer2Block(nn.Module):
1948
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
1949
+ super().__init__()
1950
+ self.conv = None
1951
+ if c1 != c2:
1952
+ self.conv = Conv(c1, c2)
1953
+
1954
+ # remove input_resolution
1955
+ self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
1956
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
1957
+
1958
+ def forward(self, x):
1959
+ if self.conv is not None:
1960
+ x = self.conv(x)
1961
+ x = self.blocks(x)
1962
+ return x
1963
+
1964
+
1965
+ class ST2CSPA(nn.Module):
1966
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1967
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1968
+ super(ST2CSPA, self).__init__()
1969
+ c_ = int(c2 * e) # hidden channels
1970
+ self.cv1 = Conv(c1, c_, 1, 1)
1971
+ self.cv2 = Conv(c1, c_, 1, 1)
1972
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1973
+ num_heads = c_ // 32
1974
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
1975
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1976
+
1977
+ def forward(self, x):
1978
+ y1 = self.m(self.cv1(x))
1979
+ y2 = self.cv2(x)
1980
+ return self.cv3(torch.cat((y1, y2), dim=1))
1981
+
1982
+
1983
+ class ST2CSPB(nn.Module):
1984
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1985
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1986
+ super(ST2CSPB, self).__init__()
1987
+ c_ = int(c2) # hidden channels
1988
+ self.cv1 = Conv(c1, c_, 1, 1)
1989
+ self.cv2 = Conv(c_, c_, 1, 1)
1990
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1991
+ num_heads = c_ // 32
1992
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
1993
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1994
+
1995
+ def forward(self, x):
1996
+ x1 = self.cv1(x)
1997
+ y1 = self.m(x1)
1998
+ y2 = self.cv2(x1)
1999
+ return self.cv3(torch.cat((y1, y2), dim=1))
2000
+
2001
+
2002
+ class ST2CSPC(nn.Module):
2003
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
2004
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
2005
+ super(ST2CSPC, self).__init__()
2006
+ c_ = int(c2 * e) # hidden channels
2007
+ self.cv1 = Conv(c1, c_, 1, 1)
2008
+ self.cv2 = Conv(c1, c_, 1, 1)
2009
+ self.cv3 = Conv(c_, c_, 1, 1)
2010
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
2011
+ num_heads = c_ // 32
2012
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
2013
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
2014
+
2015
+ def forward(self, x):
2016
+ y1 = self.cv3(self.m(self.cv1(x)))
2017
+ y2 = self.cv2(x)
2018
+ return self.cv4(torch.cat((y1, y2), dim=1))
2019
+
2020
+ ##### end of swin transformer v2 #####
2021
+