File size: 1,533 Bytes
1ba3df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from models.module import ResidualBlocks

_DECODER_CHANNEL_DEFAULT = 512


class Decoder(nn.Module):
    def __init__(self, hp, in_channels=_DECODER_CHANNEL_DEFAULT, out_channels=1):
        super().__init__()
        self.module = nn.ModuleList()

    def forward(self, x):
        for block in self.module:
            x = block(x)
        return x


class VanillaDecoder(Decoder):
    def __init__(self, hp, in_channels, out_channels):
        super().__init__(hp, in_channels, out_channels)
        self.depth = hp.decoder.depth
        self.blocks = hp.decoder.residual_blocks

        self.module = nn.ModuleList()
        if self.blocks > 0:
            self.module.append(ResidualBlocks(in_channels, n_blocks=self.blocks))

        for layer_idx in range(1, self.depth + 1):  # add upsampling layers
            self.module.append(nn.Sequential(
                nn.ConvTranspose2d(in_channels // (2 ** (layer_idx - 1)),
                                   in_channels // (2 ** layer_idx),
                                   kernel_size=3, stride=2,
                                   padding=1, output_padding=1,
                                   bias=False),
                nn.BatchNorm2d(in_channels // (2 ** layer_idx)),
                nn.ReLU(True)
            ))

        final = nn.Sequential(
            nn.Conv2d(in_channels // (2 ** self.depth), out_channels, kernel_size=7, padding=3, padding_mode='reflect'),
            nn.Tanh()
        )

        self.module.append(final)