RashiAgarwal commited on
Commit
ad60f7d
·
1 Parent(s): 95f89cc

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +176 -0
model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of YOLOv3 architecture
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ """
9
+ Information about architecture config:
10
+ Tuple is structured by (filters, kernel_size, stride)
11
+ Every conv is a same convolution.
12
+ List is structured by "B" indicating a residual block followed by the number of repeats
13
+ "S" is for scale prediction block and computing the yolo loss
14
+ "U" is for upsampling the feature map and concatenating with a previous layer
15
+ """
16
+ config = [
17
+ (32, 3, 1),
18
+ (64, 3, 2),
19
+ ["B", 1],
20
+ (128, 3, 2),
21
+ ["B", 2],
22
+ (256, 3, 2),
23
+ ["B", 8],
24
+ (512, 3, 2),
25
+ ["B", 8],
26
+ (1024, 3, 2),
27
+ ["B", 4], # To this point is Darknet-53
28
+ (512, 1, 1),
29
+ (1024, 3, 1),
30
+ "S",
31
+ (256, 1, 1),
32
+ "U",
33
+ (256, 1, 1),
34
+ (512, 3, 1),
35
+ "S",
36
+ (128, 1, 1),
37
+ "U",
38
+ (128, 1, 1),
39
+ (256, 3, 1),
40
+ "S",
41
+ ]
42
+
43
+
44
+ class CNNBlock(nn.Module):
45
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
46
+ super().__init__()
47
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
48
+ self.bn = nn.BatchNorm2d(out_channels)
49
+ self.leaky = nn.LeakyReLU(0.1)
50
+ self.use_bn_act = bn_act
51
+
52
+ def forward(self, x):
53
+ if self.use_bn_act:
54
+ return self.leaky(self.bn(self.conv(x)))
55
+ else:
56
+ return self.conv(x)
57
+
58
+
59
+ class ResidualBlock(nn.Module):
60
+ def __init__(self, channels, use_residual=True, num_repeats=1):
61
+ super().__init__()
62
+ self.layers = nn.ModuleList()
63
+ for repeat in range(num_repeats):
64
+ self.layers += [
65
+ nn.Sequential(
66
+ CNNBlock(channels, channels // 2, kernel_size=1),
67
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
68
+ )
69
+ ]
70
+
71
+ self.use_residual = use_residual
72
+ self.num_repeats = num_repeats
73
+
74
+ def forward(self, x):
75
+ for layer in self.layers:
76
+ if self.use_residual:
77
+ x = x + layer(x)
78
+ else:
79
+ x = layer(x)
80
+
81
+ return x
82
+
83
+
84
+ class ScalePrediction(nn.Module):
85
+ def __init__(self, in_channels, num_classes):
86
+ super().__init__()
87
+ self.pred = nn.Sequential(
88
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
89
+ CNNBlock(
90
+ 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
91
+ ),
92
+ )
93
+ self.num_classes = num_classes
94
+
95
+ def forward(self, x):
96
+ return (
97
+ self.pred(x)
98
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
99
+ .permute(0, 1, 3, 4, 2)
100
+ )
101
+
102
+
103
+ class YOLOv3(nn.Module):
104
+ def __init__(self, in_channels=3, num_classes=80):
105
+ super().__init__()
106
+ self.num_classes = num_classes
107
+ self.in_channels = in_channels
108
+ self.layers = self._create_conv_layers()
109
+
110
+ def forward(self, x):
111
+ outputs = [] # for each scale
112
+ route_connections = []
113
+ for layer in self.layers:
114
+ if isinstance(layer, ScalePrediction):
115
+ outputs.append(layer(x))
116
+ continue
117
+
118
+ x = layer(x)
119
+
120
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
121
+ route_connections.append(x)
122
+
123
+ elif isinstance(layer, nn.Upsample):
124
+ x = torch.cat([x, route_connections[-1]], dim=1)
125
+ route_connections.pop()
126
+
127
+ return outputs
128
+
129
+ def _create_conv_layers(self):
130
+ layers = nn.ModuleList()
131
+ in_channels = self.in_channels
132
+
133
+ for module in config:
134
+ if isinstance(module, tuple):
135
+ out_channels, kernel_size, stride = module
136
+ layers.append(
137
+ CNNBlock(
138
+ in_channels,
139
+ out_channels,
140
+ kernel_size=kernel_size,
141
+ stride=stride,
142
+ padding=1 if kernel_size == 3 else 0,
143
+ )
144
+ )
145
+ in_channels = out_channels
146
+
147
+ elif isinstance(module, list):
148
+ num_repeats = module[1]
149
+ layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))
150
+
151
+ elif isinstance(module, str):
152
+ if module == "S":
153
+ layers += [
154
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
155
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
156
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
157
+ ]
158
+ in_channels = in_channels // 2
159
+
160
+ elif module == "U":
161
+ layers.append(nn.Upsample(scale_factor=2),)
162
+ in_channels = in_channels * 3
163
+
164
+ return layers
165
+
166
+
167
+ if __name__ == "__main__":
168
+ num_classes = 20
169
+ IMAGE_SIZE = 416
170
+ model = YOLOv3(num_classes=num_classes)
171
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
172
+ out = model(x)
173
+ assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
174
+ assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
175
+ assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
176
+ print("Success!")