hank1996 commited on
Commit
e754800
·
1 Parent(s): 4c84f7d

Update models/yolo.py

Browse files
Files changed (1) hide show
  1. models/yolo.py +592 -0
models/yolo.py CHANGED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Solutions
9
+ Pricing
10
+
11
+
12
+
13
+ Hugging Face is way more fun with friends and colleagues! 🤗 Join an organization
14
+ Spaces:
15
+
16
+ akhaliq
17
+ /
18
+ yolov7 Copied
19
+ like
20
+ 33
21
+ App
22
+ Files and versions
23
+ Community
24
+ 5
25
+ yolov7
26
+ /
27
+ models
28
+ /
29
+ yolo.py
30
+ akhaliq's picture
31
+ akhaliq
32
+ add files
33
+ e20a59b
34
+ about 2 months ago
35
+ raw
36
+ history
37
+ blame
38
+ contribute
39
+ delete
40
+ Safe
41
+ 25.6 kB
42
+ import argparse
43
+ import logging
44
+ import sys
45
+ from copy import deepcopy
46
+
47
+ sys.path.append('./') # to run '$ python *.py' files in subdirectories
48
+ logger = logging.getLogger(__name__)
49
+
50
+ from models.common import *
51
+ from models.experimental import *
52
+ from utils.autoanchor import check_anchor_order
53
+ from utils.general import make_divisible, check_file, set_logging
54
+ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
55
+ select_device, copy_attr
56
+ from utils.loss import SigmoidBin
57
+
58
+ try:
59
+ import thop # for FLOPS computation
60
+ except ImportError:
61
+ thop = None
62
+
63
+
64
+ class Detect(nn.Module):
65
+ stride = None # strides computed during build
66
+ export = False # onnx export
67
+
68
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
69
+ super(Detect, self).__init__()
70
+ self.nc = nc # number of classes
71
+ self.no = nc + 5 # number of outputs per anchor
72
+ self.nl = len(anchors) # number of detection layers
73
+ self.na = len(anchors[0]) // 2 # number of anchors
74
+ self.grid = [torch.zeros(1)] * self.nl # init grid
75
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
76
+ self.register_buffer('anchors', a) # shape(nl,na,2)
77
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
78
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
79
+
80
+ def forward(self, x):
81
+ # x = x.copy() # for profiling
82
+ z = [] # inference output
83
+ self.training |= self.export
84
+ for i in range(self.nl):
85
+ x[i] = self.m[i](x[i]) # conv
86
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
87
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
88
+
89
+ if not self.training: # inference
90
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
91
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
92
+
93
+ y = x[i].sigmoid()
94
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
95
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
96
+ z.append(y.view(bs, -1, self.no))
97
+
98
+ return x if self.training else (torch.cat(z, 1), x)
99
+
100
+ @staticmethod
101
+ def _make_grid(nx=20, ny=20):
102
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
103
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
104
+
105
+
106
+ class IDetect(nn.Module):
107
+ stride = None # strides computed during build
108
+ export = False # onnx export
109
+
110
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
111
+ super(IDetect, self).__init__()
112
+ self.nc = nc # number of classes
113
+ self.no = nc + 5 # number of outputs per anchor
114
+ self.nl = len(anchors) # number of detection layers
115
+ self.na = len(anchors[0]) // 2 # number of anchors
116
+ self.grid = [torch.zeros(1)] * self.nl # init grid
117
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
118
+ self.register_buffer('anchors', a) # shape(nl,na,2)
119
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
120
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
121
+
122
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
123
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
124
+
125
+ def forward(self, x):
126
+ # x = x.copy() # for profiling
127
+ z = [] # inference output
128
+ self.training |= self.export
129
+ for i in range(self.nl):
130
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
131
+ x[i] = self.im[i](x[i])
132
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
133
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
134
+
135
+ if not self.training: # inference
136
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
137
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
138
+
139
+ y = x[i].sigmoid()
140
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
141
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
142
+ z.append(y.view(bs, -1, self.no))
143
+
144
+ return x if self.training else (torch.cat(z, 1), x)
145
+
146
+ @staticmethod
147
+ def _make_grid(nx=20, ny=20):
148
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
149
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
150
+
151
+
152
+ class IAuxDetect(nn.Module):
153
+ stride = None # strides computed during build
154
+ export = False # onnx export
155
+
156
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
157
+ super(IAuxDetect, self).__init__()
158
+ self.nc = nc # number of classes
159
+ self.no = nc + 5 # number of outputs per anchor
160
+ self.nl = len(anchors) # number of detection layers
161
+ self.na = len(anchors[0]) // 2 # number of anchors
162
+ self.grid = [torch.zeros(1)] * self.nl # init grid
163
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
164
+ self.register_buffer('anchors', a) # shape(nl,na,2)
165
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
166
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[:self.nl]) # output conv
167
+ self.m2 = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[self.nl:]) # output conv
168
+
169
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
170
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch[:self.nl])
171
+
172
+ def forward(self, x):
173
+ # x = x.copy() # for profiling
174
+ z = [] # inference output
175
+ self.training |= self.export
176
+ for i in range(self.nl):
177
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
178
+ x[i] = self.im[i](x[i])
179
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
180
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
181
+
182
+ x[i+self.nl] = self.m2[i](x[i+self.nl])
183
+ x[i+self.nl] = x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
184
+
185
+ if not self.training: # inference
186
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
187
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
188
+
189
+ y = x[i].sigmoid()
190
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
191
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
192
+ z.append(y.view(bs, -1, self.no))
193
+
194
+ return x if self.training else (torch.cat(z, 1), x[:self.nl])
195
+
196
+ @staticmethod
197
+ def _make_grid(nx=20, ny=20):
198
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
199
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
200
+
201
+
202
+ class IBin(nn.Module):
203
+ stride = None # strides computed during build
204
+ export = False # onnx export
205
+
206
+ def __init__(self, nc=80, anchors=(), ch=(), bin_count=21): # detection layer
207
+ super(IBin, self).__init__()
208
+ self.nc = nc # number of classes
209
+ self.bin_count = bin_count
210
+
211
+ self.w_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
212
+ self.h_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
213
+ # classes, x,y,obj
214
+ self.no = nc + 3 + \
215
+ self.w_bin_sigmoid.get_length() + self.h_bin_sigmoid.get_length() # w-bce, h-bce
216
+ # + self.x_bin_sigmoid.get_length() + self.y_bin_sigmoid.get_length()
217
+
218
+ self.nl = len(anchors) # number of detection layers
219
+ self.na = len(anchors[0]) // 2 # number of anchors
220
+ self.grid = [torch.zeros(1)] * self.nl # init grid
221
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
222
+ self.register_buffer('anchors', a) # shape(nl,na,2)
223
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
224
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
225
+
226
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
227
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
228
+
229
+ def forward(self, x):
230
+
231
+ #self.x_bin_sigmoid.use_fw_regression = True
232
+ #self.y_bin_sigmoid.use_fw_regression = True
233
+ self.w_bin_sigmoid.use_fw_regression = True
234
+ self.h_bin_sigmoid.use_fw_regression = True
235
+
236
+ # x = x.copy() # for profiling
237
+ z = [] # inference output
238
+ self.training |= self.export
239
+ for i in range(self.nl):
240
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
241
+ x[i] = self.im[i](x[i])
242
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
243
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
244
+
245
+ if not self.training: # inference
246
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
247
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
248
+
249
+ y = x[i].sigmoid()
250
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
251
+ #y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
252
+
253
+
254
+ #px = (self.x_bin_sigmoid.forward(y[..., 0:12]) + self.grid[i][..., 0]) * self.stride[i]
255
+ #py = (self.y_bin_sigmoid.forward(y[..., 12:24]) + self.grid[i][..., 1]) * self.stride[i]
256
+
257
+ pw = self.w_bin_sigmoid.forward(y[..., 2:24]) * self.anchor_grid[i][..., 0]
258
+ ph = self.h_bin_sigmoid.forward(y[..., 24:46]) * self.anchor_grid[i][..., 1]
259
+
260
+ #y[..., 0] = px
261
+ #y[..., 1] = py
262
+ y[..., 2] = pw
263
+ y[..., 3] = ph
264
+
265
+ y = torch.cat((y[..., 0:4], y[..., 46:]), dim=-1)
266
+
267
+ z.append(y.view(bs, -1, y.shape[-1]))
268
+
269
+ return x if self.training else (torch.cat(z, 1), x)
270
+
271
+ @staticmethod
272
+ def _make_grid(nx=20, ny=20):
273
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
274
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
275
+
276
+
277
+ class Model(nn.Module):
278
+ def __init__(self, cfg='yolor-csp-c.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
279
+ super(Model, self).__init__()
280
+ self.traced = False
281
+ if isinstance(cfg, dict):
282
+ self.yaml = cfg # model dict
283
+ else: # is *.yaml
284
+ import yaml # for torch hub
285
+ self.yaml_file = Path(cfg).name
286
+ with open(cfg) as f:
287
+ self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
288
+
289
+ # Define model
290
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
291
+ if nc and nc != self.yaml['nc']:
292
+ logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
293
+ self.yaml['nc'] = nc # override yaml value
294
+ if anchors:
295
+ logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
296
+ self.yaml['anchors'] = round(anchors) # override yaml value
297
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
298
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
299
+ # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
300
+
301
+ # Build strides, anchors
302
+ m = self.model[-1] # Detect()
303
+ if isinstance(m, Detect):
304
+ s = 256 # 2x min stride
305
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
306
+ m.anchors /= m.stride.view(-1, 1, 1)
307
+ check_anchor_order(m)
308
+ self.stride = m.stride
309
+ self._initialize_biases() # only run once
310
+ # print('Strides: %s' % m.stride.tolist())
311
+ if isinstance(m, IDetect):
312
+ s = 256 # 2x min stride
313
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
314
+ m.anchors /= m.stride.view(-1, 1, 1)
315
+ check_anchor_order(m)
316
+ self.stride = m.stride
317
+ self._initialize_biases() # only run once
318
+ # print('Strides: %s' % m.stride.tolist())
319
+ if isinstance(m, IAuxDetect):
320
+ s = 256 # 2x min stride
321
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]]) # forward
322
+ #print(m.stride)
323
+ m.anchors /= m.stride.view(-1, 1, 1)
324
+ check_anchor_order(m)
325
+ self.stride = m.stride
326
+ self._initialize_aux_biases() # only run once
327
+ # print('Strides: %s' % m.stride.tolist())
328
+ if isinstance(m, IBin):
329
+ s = 256 # 2x min stride
330
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
331
+ m.anchors /= m.stride.view(-1, 1, 1)
332
+ check_anchor_order(m)
333
+ self.stride = m.stride
334
+ self._initialize_biases_bin() # only run once
335
+ # print('Strides: %s' % m.stride.tolist())
336
+
337
+ # Init weights, biases
338
+ initialize_weights(self)
339
+ self.info()
340
+ logger.info('')
341
+
342
+ def forward(self, x, augment=False, profile=False):
343
+ if augment:
344
+ img_size = x.shape[-2:] # height, width
345
+ s = [1, 0.83, 0.67] # scales
346
+ f = [None, 3, None] # flips (2-ud, 3-lr)
347
+ y = [] # outputs
348
+ for si, fi in zip(s, f):
349
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
350
+ yi = self.forward_once(xi)[0] # forward
351
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
352
+ yi[..., :4] /= si # de-scale
353
+ if fi == 2:
354
+ yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
355
+ elif fi == 3:
356
+ yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
357
+ y.append(yi)
358
+ return torch.cat(y, 1), None # augmented inference, train
359
+ else:
360
+ return self.forward_once(x, profile) # single-scale inference, train
361
+
362
+ def forward_once(self, x, profile=False):
363
+ y, dt = [], [] # outputs
364
+ for m in self.model:
365
+ if m.f != -1: # if not from previous layer
366
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
367
+
368
+ if not hasattr(self, 'traced'):
369
+ self.traced=False
370
+
371
+ if self.traced:
372
+ if isinstance(m, Detect) or isinstance(m, IDetect) or isinstance(m, IAuxDetect):
373
+ break
374
+
375
+ if profile:
376
+ c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
377
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
378
+ for _ in range(10):
379
+ m(x.copy() if c else x)
380
+ t = time_synchronized()
381
+ for _ in range(10):
382
+ m(x.copy() if c else x)
383
+ dt.append((time_synchronized() - t) * 100)
384
+ print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
385
+
386
+ x = m(x) # run
387
+
388
+ y.append(x if m.i in self.save else None) # save output
389
+
390
+ if profile:
391
+ print('%.1fms total' % sum(dt))
392
+ return x
393
+
394
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
395
+ # https://arxiv.org/abs/1708.02002 section 3.3
396
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
397
+ m = self.model[-1] # Detect() module
398
+ for mi, s in zip(m.m, m.stride): # from
399
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
400
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
401
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
402
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
403
+
404
+ def _initialize_aux_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
405
+ # https://arxiv.org/abs/1708.02002 section 3.3
406
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
407
+ m = self.model[-1] # Detect() module
408
+ for mi, mi2, s in zip(m.m, m.m2, m.stride): # from
409
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
410
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
411
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
412
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
413
+ b2 = mi2.bias.view(m.na, -1) # conv.bias(255) to (3,85)
414
+ b2.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
415
+ b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
416
+ mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)
417
+
418
+ def _initialize_biases_bin(self, cf=None): # initialize biases into Detect(), cf is class frequency
419
+ # https://arxiv.org/abs/1708.02002 section 3.3
420
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
421
+ m = self.model[-1] # Bin() module
422
+ bc = m.bin_count
423
+ for mi, s in zip(m.m, m.stride): # from
424
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
425
+ old = b[:, (0,1,2,bc+3)].data
426
+ obj_idx = 2*bc+4
427
+ b[:, :obj_idx].data += math.log(0.6 / (bc + 1 - 0.99))
428
+ b[:, obj_idx].data += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
429
+ b[:, (obj_idx+1):].data += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
430
+ b[:, (0,1,2,bc+3)].data = old
431
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
432
+
433
+ def _print_biases(self):
434
+ m = self.model[-1] # Detect() module
435
+ for mi in m.m: # from
436
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
437
+ print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
438
+
439
+ # def _print_weights(self):
440
+ # for m in self.model.modules():
441
+ # if type(m) is Bottleneck:
442
+ # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
443
+
444
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
445
+ print('Fusing layers... ')
446
+ for m in self.model.modules():
447
+ if isinstance(m, RepConv):
448
+ #print(f" fuse_repvgg_block")
449
+ m.fuse_repvgg_block()
450
+ elif isinstance(m, RepConv_OREPA):
451
+ #print(f" switch_to_deploy")
452
+ m.switch_to_deploy()
453
+ elif type(m) is Conv and hasattr(m, 'bn'):
454
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
455
+ delattr(m, 'bn') # remove batchnorm
456
+ m.forward = m.fuseforward # update forward
457
+ self.info()
458
+ return self
459
+
460
+ def nms(self, mode=True): # add or remove NMS module
461
+ present = type(self.model[-1]) is NMS # last layer is NMS
462
+ if mode and not present:
463
+ print('Adding NMS... ')
464
+ m = NMS() # module
465
+ m.f = -1 # from
466
+ m.i = self.model[-1].i + 1 # index
467
+ self.model.add_module(name='%s' % m.i, module=m) # add
468
+ self.eval()
469
+ elif not mode and present:
470
+ print('Removing NMS... ')
471
+ self.model = self.model[:-1] # remove
472
+ return self
473
+
474
+ def autoshape(self): # add autoShape module
475
+ print('Adding autoShape... ')
476
+ m = autoShape(self) # wrap model
477
+ copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
478
+ return m
479
+
480
+ def info(self, verbose=False, img_size=640): # print model information
481
+ model_info(self, verbose, img_size)
482
+
483
+
484
+ def parse_model(d, ch): # model_dict, input_channels(3)
485
+ logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
486
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
487
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
488
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
489
+
490
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
491
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
492
+ m = eval(m) if isinstance(m, str) else m # eval strings
493
+ for j, a in enumerate(args):
494
+ try:
495
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
496
+ except:
497
+ pass
498
+
499
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
500
+ if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC,
501
+ SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv,
502
+ Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
503
+ RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
504
+ Res, ResCSPA, ResCSPB, ResCSPC,
505
+ RepRes, RepResCSPA, RepResCSPB, RepResCSPC,
506
+ ResX, ResXCSPA, ResXCSPB, ResXCSPC,
507
+ RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC,
508
+ Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
509
+ SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
510
+ SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC]:
511
+ c1, c2 = ch[f], args[0]
512
+ if c2 != no: # if not output
513
+ c2 = make_divisible(c2 * gw, 8)
514
+
515
+ args = [c1, c2, *args[1:]]
516
+ if m in [DownC, SPPCSPC, GhostSPPCSPC,
517
+ BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
518
+ RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
519
+ ResCSPA, ResCSPB, ResCSPC,
520
+ RepResCSPA, RepResCSPB, RepResCSPC,
521
+ ResXCSPA, ResXCSPB, ResXCSPC,
522
+ RepResXCSPA, RepResXCSPB, RepResXCSPC,
523
+ GhostCSPA, GhostCSPB, GhostCSPC,
524
+ STCSPA, STCSPB, STCSPC,
525
+ ST2CSPA, ST2CSPB, ST2CSPC]:
526
+ args.insert(2, n) # number of repeats
527
+ n = 1
528
+ elif m is nn.BatchNorm2d:
529
+ args = [ch[f]]
530
+ elif m is Concat:
531
+ c2 = sum([ch[x] for x in f])
532
+ elif m is Chuncat:
533
+ c2 = sum([ch[x] for x in f])
534
+ elif m is Shortcut:
535
+ c2 = ch[f[0]]
536
+ elif m is Foldcut:
537
+ c2 = ch[f] // 2
538
+ elif m in [Detect, IDetect, IAuxDetect, IBin]:
539
+ args.append([ch[x] for x in f])
540
+ if isinstance(args[1], int): # number of anchors
541
+ args[1] = [list(range(args[1] * 2))] * len(f)
542
+ elif m is ReOrg:
543
+ c2 = ch[f] * 4
544
+ elif m is Contract:
545
+ c2 = ch[f] * args[0] ** 2
546
+ elif m is Expand:
547
+ c2 = ch[f] // args[0] ** 2
548
+ else:
549
+ c2 = ch[f]
550
+
551
+ m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
552
+ t = str(m)[8:-2].replace('__main__.', '') # module type
553
+ np = sum([x.numel() for x in m_.parameters()]) # number params
554
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
555
+ logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
556
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
557
+ layers.append(m_)
558
+ if i == 0:
559
+ ch = []
560
+ ch.append(c2)
561
+ return nn.Sequential(*layers), sorted(save)
562
+
563
+
564
+ if __name__ == '__main__':
565
+ parser = argparse.ArgumentParser()
566
+ parser.add_argument('--cfg', type=str, default='yolor-csp-c.yaml', help='model.yaml')
567
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
568
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
569
+ opt = parser.parse_args()
570
+ opt.cfg = check_file(opt.cfg) # check file
571
+ set_logging()
572
+ device = select_device(opt.device)
573
+
574
+ # Create model
575
+ model = Model(opt.cfg).to(device)
576
+ model.train()
577
+
578
+ if opt.profile:
579
+ img = torch.rand(1, 3, 640, 640).to(device)
580
+ y = model(img, profile=True)
581
+
582
+ # Profile
583
+ # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
584
+ # y = model(img, profile=True)
585
+
586
+ # Tensorboard
587
+ # from torch.utils.tensorboard import SummaryWriter
588
+ # tb_writer = SummaryWriter()
589
+ # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
590
+ # tb_writer.add_graph(model.model, img) # add model to tensorboard
591
+ # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard
592
+