wzhouxiff commited on
Commit
e767025
1 Parent(s): 6d9a40b

Create srvgg_arch.py

Browse files
Files changed (1) hide show
  1. srvgg_arch.py +67 -0
srvgg_arch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class SRVGGNetCompact(nn.Module):
6
+ """A compact VGG-style network structure for super-resolution.
7
+
8
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
9
+ conducted on the HR feature space.
10
+
11
+ Args:
12
+ num_in_ch (int): Channel number of inputs. Default: 3.
13
+ num_out_ch (int): Channel number of outputs. Default: 3.
14
+ num_feat (int): Channel number of intermediate features. Default: 64.
15
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
16
+ upscale (int): Upsampling factor. Default: 4.
17
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
18
+ """
19
+
20
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
21
+ super(SRVGGNetCompact, self).__init__()
22
+ self.num_in_ch = num_in_ch
23
+ self.num_out_ch = num_out_ch
24
+ self.num_feat = num_feat
25
+ self.num_conv = num_conv
26
+ self.upscale = upscale
27
+ self.act_type = act_type
28
+
29
+ self.body = nn.ModuleList()
30
+ # the first conv
31
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
32
+ # the first activation
33
+ if act_type == 'relu':
34
+ activation = nn.ReLU(inplace=True)
35
+ elif act_type == 'prelu':
36
+ activation = nn.PReLU(num_parameters=num_feat)
37
+ elif act_type == 'leakyrelu':
38
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
39
+ self.body.append(activation)
40
+
41
+ # the body structure
42
+ for _ in range(num_conv):
43
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
44
+ # activation
45
+ if act_type == 'relu':
46
+ activation = nn.ReLU(inplace=True)
47
+ elif act_type == 'prelu':
48
+ activation = nn.PReLU(num_parameters=num_feat)
49
+ elif act_type == 'leakyrelu':
50
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
51
+ self.body.append(activation)
52
+
53
+ # the last conv
54
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
55
+ # upsample
56
+ self.upsampler = nn.PixelShuffle(upscale)
57
+
58
+ def forward(self, x):
59
+ out = x
60
+ for i in range(0, len(self.body)):
61
+ out = self.body[i](out)
62
+
63
+ out = self.upsampler(out)
64
+ # add the nearest upsampled image, so that the network learns the residual
65
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
66
+ out += base
67
+ return out