File size: 3,711 Bytes
2620eb0
 
 
 
dcc6440
2620eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28801c1
2620eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch.nn as nn
import torch
import torch.nn.functional as F


# Dropout layer that works even in the evaluation mode
class DropoutAlways(nn.Dropout2d):
    def forward(self, x):
        return F.dropout2d(x, self.p, training=True)

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, padding_mode='reflect', bias=False if normalize else True),
            nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(),
            # Note that nn.Identity() is just a placeholder layer that returns its input.
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=False, activation='relu'):
        super().__init__()

        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False if normalize else True),
            nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(),
            DropoutAlways() if dropout else nn.Identity(),
            nn.ReLU() if activation == 'relu' else nn.Tanh(),
        )

    def forward(self, x):
        return self.block(x)


class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.encoder1 = DownBlock(1, 64, normalize=False)  # 256x256 -> 128x128
        self.encoder2 = DownBlock(64, 128)  # 128x128 -> 64x64
        self.encoder3 = DownBlock(128, 256)  # 64x64 -> 32x32
        self.encoder4 = DownBlock(256, 512)  # 32x32 -> 16x16
        self.encoder5 = DownBlock(512, 512)  # 16x16 -> 8x8
        self.encoder6 = DownBlock(512, 512)  # 8x8 -> 4x4
        self.encoder7 = DownBlock(512, 512)  # 4x4 -> 2x2
        self.encoder8 = DownBlock(512, 512, normalize=False)  # 2x2 -> 1x1

        # Decoder
        self.decoder1 = UpBlock(512, 512, dropout=True)  # 1x1 -> 2x2
        self.decoder2 = UpBlock(512 * 2, 512, dropout=True)  # 2x2 -> 4x4
        self.decoder3 = UpBlock(512 * 2, 512, dropout=True)  # 4x4 -> 8x8
        self.decoder4 = UpBlock(512 * 2, 512)  # 8x8 -> 16x16
        self.decoder5 = UpBlock(512 * 2, 256)  # 16x16 -> 32x32
        self.decoder6 = UpBlock(256 * 2, 128)  # 32x32 -> 64x64
        self.decoder7 = UpBlock(128 * 2, 64)  # 64x64 -> 128x128
        self.decoder8 = UpBlock(64 * 2, 2, normalize=False, activation='tanh')  # 128x128 -> 256x256

    def forward(self, x):
        # Encoder
        ch256_down = x
        ch128_down = self.encoder1(ch256_down)
        ch64_down = self.encoder2(ch128_down)
        ch32_down = self.encoder3(ch64_down)
        ch16_down = self.encoder4(ch32_down)
        ch8_down = self.encoder5(ch16_down)
        ch4_down = self.encoder6(ch8_down)
        ch2_down = self.encoder7(ch4_down)
        ch1 = self.encoder8(ch2_down)

        # Decoder
        ch2_up = self.decoder1(ch1)
        ch2 = torch.cat([ch2_up, ch2_down], dim=1)
        ch4_up = self.decoder2(ch2)
        ch4 = torch.cat([ch4_up, ch4_down], dim=1)
        ch8_up = self.decoder3(ch4)
        ch8 = torch.cat([ch8_up, ch8_down], dim=1)
        ch16_up = self.decoder4(ch8)
        ch16 = torch.cat([ch16_up, ch16_down], dim=1)
        ch32_up = self.decoder5(ch16)
        ch32 = torch.cat([ch32_up, ch32_down], dim=1)
        ch64_up = self.decoder6(ch32)
        ch64 = torch.cat([ch64_up, ch64_down], dim=1)
        ch128_up = self.decoder7(ch64)
        ch128 = torch.cat([ch128_up, ch128_down], dim=1)
        ch256_up = self.decoder8(ch128)

        return ch256_up