File size: 4,507 Bytes
9b9b1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn as nn
import torch.nn.functional as fun
import numpy as np


class DownsamplingBlock(nn.Module):
    """
    Downsamples the input to halve the dimensions while doubling the channels through two parallel conv + antialiased downsampling branches.

    In: HxWxC
    Out: H/2xW/2x2C
    """

    def __init__(self, in_channels, bias=False):
        super().__init__()
        self.branch1 = (
            nn.Sequential(  # 1x1 conv + PReLU -> 3x3 conv + PReLU -> AD -> 1x1 conv
                nn.Conv2d(
                    in_channels, in_channels, kernel_size=1, padding=0, bias=bias
                ),
                nn.PReLU(),
                nn.Conv2d(
                    in_channels, in_channels, kernel_size=3, padding=1, bias=bias
                ),
                nn.PReLU(),
                DownSample(channels=in_channels, filter_size=3, stride=2),
                nn.Conv2d(
                    in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
                ),
            )
        )
        self.branch2 = nn.Sequential(
            DownSample(channels=in_channels, filter_size=3, stride=2),
            nn.Conv2d(
                in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
            ),
        )

    def forward(self, x):
        return self.branch1(x) + self.branch2(x)  # H/2xW/2x2C


class DownsamplingModule(nn.Module):
    """
    Downsampling module of the network composed of (scaling factor) DownsamplingBlocks.

    In: HxWxC
    Out: H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
    """

    def __init__(self, in_channels, scaling_factor, stride=2):
        super().__init__()
        self.scaling_factor = int(np.log2(scaling_factor))

        blocks = []
        for i in range(self.scaling_factor):
            blocks.append(DownsamplingBlock(in_channels))
            in_channels = int(in_channels * stride)
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        x = self.blocks(x)
        return x  # H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)


class DownSample(nn.Module):
    """
    Antialiased downsampling module using the blur-pooling method.

    From Adobe's implementation available here: https://github.com/yilundu/improved_contrastive_divergence/blob/master/downsample.py
    """

    def __init__(
        self, pad_type="reflect", filter_size=3, stride=2, channels=None, pad_off=0
    ):
        super().__init__()
        self.filter_size = filter_size
        self.stride = stride
        self.pad_off = pad_off
        self.channels = channels
        self.pad_sizes = [
            int(1.0 * (filter_size - 1) / 2),
            int(np.ceil(1.0 * (filter_size - 1) / 2)),
            int(1.0 * (filter_size - 1) / 2),
            int(np.ceil(1.0 * (filter_size - 1) / 2)),
        ]

        self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
        self.off = int((self.stride - 1) / 2.0)

        if self.filter_size == 1:
            a = np.array([1.0])
        elif self.filter_size == 2:
            a = np.array([1.0, 1.0])
        elif self.filter_size == 3:
            a = np.array([1.0, 2.0, 1.0])
        elif self.filter_size == 4:
            a = np.array([1.0, 3.0, 3.0, 1.0])
        elif self.filter_size == 5:
            a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
        elif self.filter_size == 6:
            a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
        elif self.filter_size == 7:
            a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])

        filt = torch.Tensor(a[:, None] * a[None, :])
        filt = filt / torch.sum(filt)
        self.register_buffer(
            "filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
        )
        self.pad = get_pad_layer(pad_type)(self.pad_sizes)

    def forward(self, x):
        if self.filter_size == 1:
            if self.pad_off == 0:
                return x[:, :, :: self.stride, :: self.stride]
            else:
                return self.pad(x)[:, :, :: self.stride, :: self.stride]

        else:
            return fun.conv2d(
                self.pad(x), self.filt, stride=self.stride, groups=x.shape[1]
            )


def get_pad_layer(pad_type):
    if pad_type == "reflect":
        pad_layer = nn.ReflectionPad2d
    elif pad_type == "replication":
        pad_layer = nn.ReplicationPad2d
    else:
        print("Pad Type [%s] not recognized" % pad_type)

    return pad_layer