PKaushik commited on
Commit
331b555
1 Parent(s): d570deb
Files changed (1) hide show
  1. yolov6/models/effidehead.py +211 -0
yolov6/models/effidehead.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from yolov6.layers.common import *
5
+
6
+
7
+ class Detect(nn.Module):
8
+ '''Efficient Decoupled Head
9
+ With hardware-aware degisn, the decoupled head is optimized with
10
+ hybridchannels methods.
11
+ '''
12
+ def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer
13
+ super().__init__()
14
+ assert head_layers is not None
15
+ self.nc = num_classes # number of classes
16
+ self.no = num_classes + 5 # number of outputs per anchor
17
+ self.nl = num_layers # number of detection layers
18
+ if isinstance(anchors, (list, tuple)):
19
+ self.na = len(anchors[0]) // 2
20
+ else:
21
+ self.na = anchors
22
+ self.anchors = anchors
23
+ self.grid = [torch.zeros(1)] * num_layers
24
+ self.prior_prob = 1e-2
25
+ self.inplace = inplace
26
+ stride = [8, 16, 32] # strides computed during build
27
+ self.stride = torch.tensor(stride)
28
+
29
+ # Init decouple head
30
+ self.cls_convs = nn.ModuleList()
31
+ self.reg_convs = nn.ModuleList()
32
+ self.cls_preds = nn.ModuleList()
33
+ self.reg_preds = nn.ModuleList()
34
+ self.obj_preds = nn.ModuleList()
35
+ self.stems = nn.ModuleList()
36
+
37
+ # Efficient decoupled head layers
38
+ for i in range(num_layers):
39
+ idx = i*6
40
+ self.stems.append(head_layers[idx])
41
+ self.cls_convs.append(head_layers[idx+1])
42
+ self.reg_convs.append(head_layers[idx+2])
43
+ self.cls_preds.append(head_layers[idx+3])
44
+ self.reg_preds.append(head_layers[idx+4])
45
+ self.obj_preds.append(head_layers[idx+5])
46
+
47
+ def initialize_biases(self):
48
+ for conv in self.cls_preds:
49
+ b = conv.bias.view(self.na, -1)
50
+ b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
51
+ conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
52
+ for conv in self.obj_preds:
53
+ b = conv.bias.view(self.na, -1)
54
+ b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
55
+ conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
56
+
57
+ def forward(self, x):
58
+ z = []
59
+ for i in range(self.nl):
60
+ x[i] = self.stems[i](x[i])
61
+ cls_x = x[i]
62
+ reg_x = x[i]
63
+ cls_feat = self.cls_convs[i](cls_x)
64
+ cls_output = self.cls_preds[i](cls_feat)
65
+ reg_feat = self.reg_convs[i](reg_x)
66
+ reg_output = self.reg_preds[i](reg_feat)
67
+ obj_output = self.obj_preds[i](reg_feat)
68
+ if self.training:
69
+ x[i] = torch.cat([reg_output, obj_output, cls_output], 1)
70
+ bs, _, ny, nx = x[i].shape
71
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
72
+ else:
73
+ y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
74
+ bs, _, ny, nx = y.shape
75
+ y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
76
+ if self.grid[i].shape[2:4] != y.shape[2:4]:
77
+ d = self.stride.device
78
+ yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
79
+ self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float()
80
+ if self.inplace:
81
+ y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
82
+ y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh
83
+ else:
84
+ xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
85
+ wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh
86
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
87
+ z.append(y.view(bs, -1, self.no))
88
+ return x if self.training else torch.cat(z, 1)
89
+
90
+
91
+ def build_effidehead_layer(channels_list, num_anchors, num_classes):
92
+ head_layers = nn.Sequential(
93
+ # stem0
94
+ Conv(
95
+ in_channels=channels_list[6],
96
+ out_channels=channels_list[6],
97
+ kernel_size=1,
98
+ stride=1
99
+ ),
100
+ # cls_conv0
101
+ Conv(
102
+ in_channels=channels_list[6],
103
+ out_channels=channels_list[6],
104
+ kernel_size=3,
105
+ stride=1
106
+ ),
107
+ # reg_conv0
108
+ Conv(
109
+ in_channels=channels_list[6],
110
+ out_channels=channels_list[6],
111
+ kernel_size=3,
112
+ stride=1
113
+ ),
114
+ # cls_pred0
115
+ nn.Conv2d(
116
+ in_channels=channels_list[6],
117
+ out_channels=num_classes * num_anchors,
118
+ kernel_size=1
119
+ ),
120
+ # reg_pred0
121
+ nn.Conv2d(
122
+ in_channels=channels_list[6],
123
+ out_channels=4 * num_anchors,
124
+ kernel_size=1
125
+ ),
126
+ # obj_pred0
127
+ nn.Conv2d(
128
+ in_channels=channels_list[6],
129
+ out_channels=1 * num_anchors,
130
+ kernel_size=1
131
+ ),
132
+ # stem1
133
+ Conv(
134
+ in_channels=channels_list[8],
135
+ out_channels=channels_list[8],
136
+ kernel_size=1,
137
+ stride=1
138
+ ),
139
+ # cls_conv1
140
+ Conv(
141
+ in_channels=channels_list[8],
142
+ out_channels=channels_list[8],
143
+ kernel_size=3,
144
+ stride=1
145
+ ),
146
+ # reg_conv1
147
+ Conv(
148
+ in_channels=channels_list[8],
149
+ out_channels=channels_list[8],
150
+ kernel_size=3,
151
+ stride=1
152
+ ),
153
+ # cls_pred1
154
+ nn.Conv2d(
155
+ in_channels=channels_list[8],
156
+ out_channels=num_classes * num_anchors,
157
+ kernel_size=1
158
+ ),
159
+ # reg_pred1
160
+ nn.Conv2d(
161
+ in_channels=channels_list[8],
162
+ out_channels=4 * num_anchors,
163
+ kernel_size=1
164
+ ),
165
+ # obj_pred1
166
+ nn.Conv2d(
167
+ in_channels=channels_list[8],
168
+ out_channels=1 * num_anchors,
169
+ kernel_size=1
170
+ ),
171
+ # stem2
172
+ Conv(
173
+ in_channels=channels_list[10],
174
+ out_channels=channels_list[10],
175
+ kernel_size=1,
176
+ stride=1
177
+ ),
178
+ # cls_conv2
179
+ Conv(
180
+ in_channels=channels_list[10],
181
+ out_channels=channels_list[10],
182
+ kernel_size=3,
183
+ stride=1
184
+ ),
185
+ # reg_conv2
186
+ Conv(
187
+ in_channels=channels_list[10],
188
+ out_channels=channels_list[10],
189
+ kernel_size=3,
190
+ stride=1
191
+ ),
192
+ # cls_pred2
193
+ nn.Conv2d(
194
+ in_channels=channels_list[10],
195
+ out_channels=num_classes * num_anchors,
196
+ kernel_size=1
197
+ ),
198
+ # reg_pred2
199
+ nn.Conv2d(
200
+ in_channels=channels_list[10],
201
+ out_channels=4 * num_anchors,
202
+ kernel_size=1
203
+ ),
204
+ # obj_pred2
205
+ nn.Conv2d(
206
+ in_channels=channels_list[10],
207
+ out_channels=1 * num_anchors,
208
+ kernel_size=1
209
+ )
210
+ )
211
+ return head_layers