dblasko's picture
Upload 11 files
9b9b1dc
raw
history blame contribute delete
No virus
4.51 kB
import torch
import torch.nn as nn
import torch.nn.functional as fun
import numpy as np
class DownsamplingBlock(nn.Module):
"""
Downsamples the input to halve the dimensions while doubling the channels through two parallel conv + antialiased downsampling branches.
In: HxWxC
Out: H/2xW/2x2C
"""
def __init__(self, in_channels, bias=False):
super().__init__()
self.branch1 = (
nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> AD -> 1x1 conv
nn.Conv2d(
in_channels, in_channels, kernel_size=1, padding=0, bias=bias
),
nn.PReLU(),
nn.Conv2d(
in_channels, in_channels, kernel_size=3, padding=1, bias=bias
),
nn.PReLU(),
DownSample(channels=in_channels, filter_size=3, stride=2),
nn.Conv2d(
in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
),
)
)
self.branch2 = nn.Sequential(
DownSample(channels=in_channels, filter_size=3, stride=2),
nn.Conv2d(
in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
),
)
def forward(self, x):
return self.branch1(x) + self.branch2(x) # H/2xW/2x2C
class DownsamplingModule(nn.Module):
"""
Downsampling module of the network composed of (scaling factor) DownsamplingBlocks.
In: HxWxC
Out: H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
"""
def __init__(self, in_channels, scaling_factor, stride=2):
super().__init__()
self.scaling_factor = int(np.log2(scaling_factor))
blocks = []
for i in range(self.scaling_factor):
blocks.append(DownsamplingBlock(in_channels))
in_channels = int(in_channels * stride)
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.blocks(x)
return x # H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
class DownSample(nn.Module):
"""
Antialiased downsampling module using the blur-pooling method.
From Adobe's implementation available here: https://github.com/yilundu/improved_contrastive_divergence/blob/master/downsample.py
"""
def __init__(
self, pad_type="reflect", filter_size=3, stride=2, channels=None, pad_off=0
):
super().__init__()
self.filter_size = filter_size
self.stride = stride
self.pad_off = pad_off
self.channels = channels
self.pad_sizes = [
int(1.0 * (filter_size - 1) / 2),
int(np.ceil(1.0 * (filter_size - 1) / 2)),
int(1.0 * (filter_size - 1) / 2),
int(np.ceil(1.0 * (filter_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.off = int((self.stride - 1) / 2.0)
if self.filter_size == 1:
a = np.array([1.0])
elif self.filter_size == 2:
a = np.array([1.0, 1.0])
elif self.filter_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filter_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filter_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filter_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filter_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
self.register_buffer(
"filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, x):
if self.filter_size == 1:
if self.pad_off == 0:
return x[:, :, :: self.stride, :: self.stride]
else:
return self.pad(x)[:, :, :: self.stride, :: self.stride]
else:
return fun.conv2d(
self.pad(x), self.filt, stride=self.stride, groups=x.shape[1]
)
def get_pad_layer(pad_type):
if pad_type == "reflect":
pad_layer = nn.ReflectionPad2d
elif pad_type == "replication":
pad_layer = nn.ReplicationPad2d
else:
print("Pad Type [%s] not recognized" % pad_type)
return pad_layer