PKaushik commited on
Commit
4e9f3e3
1 Parent(s): 65d9523
Files changed (1) hide show
  1. yolov6/utils/torch_utils.py +110 -0
yolov6/utils/torch_utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+
4
+ import time
5
+ from contextlib import contextmanager
6
+ from copy import deepcopy
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from yolov6.utils.events import LOGGER
12
+
13
+ try:
14
+ import thop # for FLOPs computation
15
+ except ImportError:
16
+ thop = None
17
+
18
+
19
+ @contextmanager
20
+ def torch_distributed_zero_first(local_rank: int):
21
+ """
22
+ Decorator to make all processes in distributed training wait for each local_master to do something.
23
+ """
24
+ if local_rank not in [-1, 0]:
25
+ dist.barrier(device_ids=[local_rank])
26
+ yield
27
+ if local_rank == 0:
28
+ dist.barrier(device_ids=[0])
29
+
30
+
31
+ def time_sync():
32
+ # Waits for all kernels in all streams on a CUDA device to complete if cuda is available.
33
+ if torch.cuda.is_available():
34
+ torch.cuda.synchronize()
35
+ return time.time()
36
+
37
+
38
+ def initialize_weights(model):
39
+ for m in model.modules():
40
+ t = type(m)
41
+ if t is nn.Conv2d:
42
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
43
+ elif t is nn.BatchNorm2d:
44
+ m.eps = 1e-3
45
+ m.momentum = 0.03
46
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
47
+ m.inplace = True
48
+
49
+
50
+ def fuse_conv_and_bn(conv, bn):
51
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
52
+ fusedconv = (
53
+ nn.Conv2d(
54
+ conv.in_channels,
55
+ conv.out_channels,
56
+ kernel_size=conv.kernel_size,
57
+ stride=conv.stride,
58
+ padding=conv.padding,
59
+ groups=conv.groups,
60
+ bias=True,
61
+ )
62
+ .requires_grad_(False)
63
+ .to(conv.weight.device)
64
+ )
65
+
66
+ # prepare filters
67
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
68
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
69
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
70
+
71
+ # prepare spatial bias
72
+ b_conv = (
73
+ torch.zeros(conv.weight.size(0), device=conv.weight.device)
74
+ if conv.bias is None
75
+ else conv.bias
76
+ )
77
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
78
+ torch.sqrt(bn.running_var + bn.eps)
79
+ )
80
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
81
+
82
+ return fusedconv
83
+
84
+
85
+ def fuse_model(model):
86
+ from yolov6.layers.common import Conv
87
+
88
+ for m in model.modules():
89
+ if type(m) is Conv and hasattr(m, "bn"):
90
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
91
+ delattr(m, "bn") # remove batchnorm
92
+ m.forward = m.forward_fuse # update forward
93
+ return model
94
+
95
+
96
+ def get_model_info(model, img_size=640):
97
+ """Get model Params and GFlops.
98
+ Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py
99
+ """
100
+ from thop import profile
101
+ stride = 32
102
+ img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
103
+
104
+ flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
105
+ params /= 1e6
106
+ flops /= 1e9
107
+ img_size = img_size if isinstance(img_size, list) else [img_size, img_size]
108
+ flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops
109
+ info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
110
+ return info