amazinghaha commited on
Commit
a6116c3
·
1 Parent(s): 88cd70c

Upload resnet_gn.py

Browse files
Files changed (1) hide show
  1. resnet_gn.py +375 -0
resnet_gn.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ import math
6
+ from functools import partial
7
+
8
+ __all__ = [
9
+ 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10
+ 'resnet152', 'resnet200'
11
+ ]
12
+
13
+ class FilterResponseNormNd(nn.Module):
14
+
15
+ def __init__(self, ndim, num_features, eps=1e-6,
16
+ learnable_eps=False):
17
+ """
18
+ Input Variables:
19
+ ----------------
20
+ ndim: An integer indicating the number of dimensions of the expected input tensor.
21
+ num_features: An integer indicating the number of input feature dimensions.
22
+ eps: A scalar constant or learnable variable.
23
+ learnable_eps: A bool value indicating whether the eps is learnable.
24
+ """
25
+ assert ndim in [3, 4, 5], \
26
+ 'FilterResponseNorm only supports 3d, 4d or 5d inputs.'
27
+ super(FilterResponseNormNd, self).__init__()
28
+ shape = (1, num_features) + (1,) * (ndim - 2)
29
+ self.eps = nn.Parameter(torch.ones(*shape) * eps)
30
+ if not learnable_eps:
31
+ self.eps.requires_grad_(False)
32
+ self.gamma = nn.Parameter(torch.Tensor(*shape))
33
+ self.beta = nn.Parameter(torch.Tensor(*shape))
34
+ self.tau = nn.Parameter(torch.Tensor(*shape))
35
+ self.reset_parameters()
36
+
37
+ def forward(self, x):
38
+ avg_dims = tuple(range(2, x.dim())) # (2, 3)
39
+ nu2 = torch.pow(x, 2).mean(dim=avg_dims, keepdim=True)
40
+ x = x * torch.rsqrt(nu2 + torch.abs(self.eps))
41
+ return torch.max(self.gamma * x + self.beta, self.tau)
42
+
43
+ def reset_parameters(self):
44
+ nn.init.ones_(self.gamma)
45
+ nn.init.zeros_(self.beta)
46
+ nn.init.zeros_(self.tau)
47
+
48
+ def conv3x3x3(in_planes, out_planes, stride=1):
49
+ # 3x3x3 convolution with padding
50
+ return nn.Conv3d(
51
+ in_planes,
52
+ out_planes,
53
+ kernel_size=3,
54
+ stride=stride,
55
+ padding=1,
56
+ bias=False)
57
+
58
+
59
+ def downsample_basic_block(x, planes, stride):
60
+ out = F.avg_pool3d(x, kernel_size=1, stride=stride)
61
+ zero_pads = torch.Tensor(
62
+ out.size(0), planes - out.size(1), out.size(2), out.size(3),
63
+ out.size(4)).zero_()
64
+ if isinstance(out.data, torch.cuda.FloatTensor):
65
+ zero_pads = zero_pads.cuda()
66
+
67
+ out = Variable(torch.cat([out.data, zero_pads], dim=1))
68
+
69
+ return out
70
+
71
+
72
+ class BasicBlock(nn.Module):
73
+ expansion = 1
74
+
75
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
76
+ super(BasicBlock, self).__init__()
77
+ self.conv1 = conv3x3x3(inplanes, planes, stride)
78
+ self.gn1 = nn.GroupNorm(32,planes)
79
+ #self.bn1 = nn.BatchNorm3d(planes)
80
+ self.relu = nn.ReLU(inplace=True)
81
+ self.conv2 = conv3x3x3(planes, planes)
82
+ #self.bn2 = nn.BatchNorm3d(planes)
83
+ self.gn2 = nn.GroupNorm(32,planes)
84
+ self.downsample = downsample
85
+ self.stride = stride
86
+
87
+ def forward(self, x):
88
+ residual = x
89
+
90
+ out = self.conv1(x)
91
+ #out = self.bn1(out)
92
+ out = self.gn1(out)
93
+ out = self.relu(out)
94
+
95
+ out = self.conv2(out)
96
+ #out = self.bn2(out)
97
+ out = self.gn2(out)
98
+
99
+ if self.downsample is not None:
100
+ residual = self.downsample(x)
101
+
102
+ out += residual
103
+ out = self.relu(out)
104
+
105
+ return out
106
+
107
+
108
+ class Bottleneck(nn.Module):
109
+ expansion = 4
110
+
111
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
112
+ super(Bottleneck, self).__init__()
113
+ self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
114
+ #self.bn1 = nn.BatchNorm3d(planes)
115
+ self.gn1 = nn.GroupNorm(32,planes)
116
+ self.conv2 = nn.Conv3d(
117
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
118
+ #self.bn2 = nn.BatchNorm3d(planes)
119
+ self.gn2 = nn.GroupNorm(32,planes)
120
+ self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
121
+ #self.bn3 = nn.BatchNorm3d(planes * 4)
122
+ self.gn3 = nn.GroupNorm(32,planes*4)
123
+ self.relu = nn.ReLU(inplace=True)
124
+ self.downsample = downsample
125
+ self.stride = stride
126
+
127
+ def forward(self, x):
128
+ residual = x
129
+
130
+ out = self.conv1(x)
131
+ #out = self.bn1(out)
132
+ out = self.gn1(out)
133
+ out = self.relu(out)
134
+
135
+ out = self.conv2(out)
136
+ #out = self.bn2(out)
137
+ out = self.gn2(out)
138
+ out = self.relu(out)
139
+
140
+ out = self.conv3(out)
141
+ #out = self.bn3(out)
142
+ out = self.gn3(out)
143
+ if self.downsample is not None:
144
+ residual = self.downsample(x)
145
+
146
+ out += residual
147
+ out = self.relu(out)
148
+
149
+ return out
150
+
151
+ class MLP(nn.Module):
152
+ def __init__(
153
+ self,
154
+ input_dim: int,
155
+ hidden_dim: int,
156
+ output_dim: int,
157
+ num_layers: int,
158
+ sigmoid_output: bool = False,
159
+ ) -> None:
160
+ super().__init__()
161
+ self.num_layers = num_layers
162
+ h = [hidden_dim] * (num_layers - 1)
163
+ self.layers = nn.ModuleList(
164
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
165
+ )
166
+ self.sigmoid_output = sigmoid_output
167
+
168
+ def forward(self, x):
169
+ for i, layer in enumerate(self.layers):
170
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
171
+ if self.sigmoid_output:
172
+ x = F.sigmoid(x)
173
+ return x
174
+
175
+ class ResNet(nn.Module):
176
+
177
+ def __init__(self,
178
+ block,
179
+ layers,
180
+ sample_size,
181
+ sample_duration,
182
+ shortcut_type='B',
183
+ num_classes=400):
184
+ self.num_classes = num_classes
185
+ self.inplanes = 64
186
+ super(ResNet, self).__init__()
187
+ self.conv1 = nn.Conv3d(
188
+ 1,
189
+ 64,
190
+ kernel_size=7,
191
+ stride=(1, 2, 2),
192
+ padding=(3, 3, 3),
193
+ bias=False)
194
+ #self.bn1 = nn.BatchNorm3d(64)
195
+ self.gn1 = nn.GroupNorm(32,64)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
198
+ self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
199
+ self.layer2 = self._make_layer(
200
+ block, 128, layers[1], shortcut_type, stride=2)
201
+ self.layer3 = self._make_layer(
202
+ block, 256, layers[2], shortcut_type, stride=2)
203
+ self.layer4 = self._make_layer(
204
+ block, 512, layers[3], shortcut_type, stride=2)
205
+ last_duration = int(math.ceil(sample_duration / 16))
206
+ last_size = int(math.ceil(sample_size / 32))
207
+ self.avgpool = nn.AvgPool3d(
208
+ (last_duration, last_size, last_size), stride=1)
209
+ # self.avgpool = nn.AvgPool3d(
210
+ # (4, 2, 2), stride=1)
211
+ #self.fc = nn.Linear(81920, num_classes)
212
+ self.classfily = MLP(81920, 256, self.num_classes, 2, sigmoid_output=False)
213
+
214
+ # for m in self.modules():
215
+ # if isinstance(m, nn.Conv3d):
216
+ # m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
217
+ # elif isinstance(m, nn.BatchNorm3d):
218
+ # m.weight.data.fill_(1)
219
+ # m.bias.data.zero_()
220
+ for m in self.modules():
221
+ if isinstance(m, nn.Conv3d):
222
+ m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
223
+ elif isinstance(m, nn.GroupNorm):
224
+ m.weight.data.fill_(1)
225
+ m.bias.data.zero_()
226
+
227
+
228
+ def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
229
+ downsample = None
230
+ if stride != 1 or self.inplanes != planes * block.expansion:
231
+ if shortcut_type == 'A':
232
+ downsample = partial(
233
+ downsample_basic_block,
234
+ planes=planes * block.expansion,
235
+ stride=stride)
236
+ else:
237
+ downsample = nn.Sequential(
238
+ nn.Conv3d(
239
+ self.inplanes,
240
+ planes * block.expansion,
241
+ kernel_size=1,
242
+ stride=stride,
243
+ bias=False), nn.GroupNorm(32,planes * block.expansion))
244
+ # downsample = nn.Sequential(
245
+ # nn.Conv3d(
246
+ # self.inplanes,
247
+ # planes * block.expansion,
248
+ # kernel_size=1,
249
+ # stride=stride,
250
+ # bias=False), nn.BatchNorm3d(planes * block.expansion))
251
+
252
+
253
+ layers = []
254
+ layers.append(block(self.inplanes, planes, stride, downsample))
255
+ self.inplanes = planes * block.expansion
256
+ for i in range(1, blocks):
257
+ layers.append(block(self.inplanes, planes))
258
+
259
+ return nn.Sequential(*layers)
260
+
261
+
262
+ def forward(self, x):
263
+ x = self.conv1(x)
264
+ #x = self.bn1(x)
265
+ x = self.gn1(x)
266
+ x = self.relu(x)
267
+ x = self.maxpool(x)
268
+
269
+ x = self.layer1(x)
270
+ x = self.layer2(x)
271
+ x = self.layer3(x)
272
+ x = self.layer4(x)
273
+
274
+ x = self.avgpool(x)
275
+
276
+ x = x.view(x.size(0), -1)
277
+ #x = self.fc(x)
278
+ self.feature = x
279
+ x = self.classfily(x)
280
+ if self.num_classes==1:
281
+ x = F.sigmoid(x)
282
+ return x
283
+
284
+
285
+ # def initialize_weights(self):
286
+ # # print(self.modules())
287
+ #
288
+ # for m in self.modules():
289
+ # if isinstance(m, nn.Linear):
290
+ # # print(m.weight.data.type())
291
+ # # input()
292
+ # # m.weight.data.fill_(1.0)
293
+ # nn.init.kaiming_normal_(m.weight,a=0, mode='fan_in', nonlinearity='relu')
294
+ # print(m.weight)
295
+
296
+ def weights_init(m):
297
+ classname = m.__class__.__name__
298
+ if classname.find('Conv2d') != -1:
299
+ nn.init.xavier_normal_(m.weight.data)
300
+ nn.init.constant_(m.bias.data, 0.0)
301
+ elif classname.find('Linear') != -1:
302
+ nn.init.xavier_normal_(m.weight)
303
+ nn.init.constant_(m.bias, 0.0)
304
+
305
+ def get_fine_tuning_parameters(model, ft_begin_index):
306
+ if ft_begin_index == 0:
307
+ return model.parameters()
308
+
309
+ ft_module_names = []
310
+ for i in range(ft_begin_index, 5):
311
+ ft_module_names.append('layer{}'.format(i))
312
+ ft_module_names.append('fc')
313
+
314
+ parameters = []
315
+ for k, v in model.named_parameters():
316
+ for ft_module in ft_module_names:
317
+ if ft_module in k:
318
+ parameters.append({'params': v})
319
+ break
320
+ else:
321
+ parameters.append({'params': v, 'lr': 0.0})
322
+
323
+ return parameters
324
+
325
+
326
+ def resnet10(**kwargs):
327
+ """Constructs a ResNet-18 model.
328
+ """
329
+ model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
330
+ return model
331
+
332
+
333
+ def resnet18(**kwargs):
334
+ """Constructs a ResNet-18 model.
335
+ """
336
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
337
+ return model
338
+
339
+
340
+ def resnet34(**kwargs):
341
+ """Constructs a ResNet-34 model.
342
+ """
343
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
344
+ return model
345
+
346
+
347
+ def resnet50(**kwargs):
348
+ """Constructs a ResNet-50 model.
349
+ """
350
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
351
+ #model.apply(weights_init)
352
+ return model
353
+
354
+
355
+ def resnet101(**kwargs):
356
+ """Constructs a ResNet-101 model.
357
+ """
358
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
359
+ # model.apply(weights_init)
360
+ return model
361
+
362
+
363
+ def resnet152(**kwargs):
364
+ """Constructs a ResNet-101 model.
365
+ """
366
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
367
+ return model
368
+
369
+
370
+ def resnet200(**kwargs):
371
+ """Constructs a ResNet-101 model.
372
+ """
373
+ model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
374
+ # model.apply(weights_init)
375
+ return model