File size: 4,291 Bytes
a664a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class DownsamplingBlock(nn.Module):
"""Defines the Unet downsampling block.
Consists of Convolution-BatchNorm-ReLU layer with k filters.
"""
def __init__(self, c_in, c_out, kernel_size=4, stride=2,
padding=1, negative_slope=0.2, use_norm=True):
"""
Initializes the UnetDownsamplingBlock.
Args:
c_in (int): The number of input channels.
c_out (int): The number of output channels.
kernel_size (int, optional): The size of the convolving kernel. Default is 4.
stride (int, optional): Stride of the convolution. Default is 2.
padding (int, optional): Zero-padding added to both sides of the input. Default is 0.
negative_slope (float, optional): Negative slope for the LeakyReLU activation function. Default is 0.2.
use_norm (bool, optinal): If use norm layer. If True add a BatchNorm layer after Conv. Default is True.
"""
super(DownsamplingBlock, self).__init__()
block = []
block += [nn.Conv2d(in_channels=c_in, out_channels=c_out,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=(not use_norm) # No need to use a bias if there is a batchnorm layer after conv
)]
if use_norm:
block += [nn.BatchNorm2d(num_features=c_out)]
block += [nn.LeakyReLU(negative_slope=negative_slope)]
self.conv_block = nn.Sequential(*block)
def forward(self, x):
return self.conv_block(x)
class UpsamplingBlock(nn.Module):
"""Defines the Unet upsampling block.
"""
def __init__(self, c_in, c_out, kernel_size=4, stride=2,
padding=1, use_dropout=False, use_upsampling=False, mode='nearest'):
"""
Initializes the Unet Upsampling Block.
Args:
c_in (int): The number of input channels.
c_out (int): The number of output channels.
kernel_size (int, optional): Size of the convolving kernel. Default is 4.
stride (int, optional): Stride of the convolution. Default is 2.
padding (int, optional): Zero-padding added to both sides of the input. Default is 0.
use_dropout (bool, optional): if use dropout layers. Default is False.
upsample (bool, optinal): if use upsampling rather than transpose convolution. Default is False.
mode (str, optional): the upsampling algorithm: one of 'nearest',
'bilinear', 'bicubic'. Default: 'nearest'
"""
super(UpsamplingBlock, self).__init__()
block = []
if use_upsampling:
# Transpose convolution causes checkerboard artifacts. Upsampling
# followed by a regular convolutions produces better results appearantly
# Please check for further reading: https://distill.pub/2016/deconv-checkerboard/
# Odena, et al., "Deconvolution and Checkerboard Artifacts", Distill, 2016. http://doi.org/10.23915/distill.00003
mode = mode if mode in ('nearest', 'bilinear', 'bicubic') else 'nearest'
block += [nn.Sequential(
nn.Upsample(scale_factor=2, mode=mode),
nn.Conv2d(in_channels=c_in, out_channels=c_out,
kernel_size=3, stride=1, padding=padding,
bias=False
)
)]
else:
block += [nn.ConvTranspose2d(in_channels=c_in,
out_channels=c_out,
kernel_size=kernel_size,
stride=stride,
padding=padding, bias=False
)
]
block += [nn.BatchNorm2d(num_features=c_out)]
if use_dropout:
block += [nn.Dropout(0.5)]
block += [nn.ReLU()]
self.conv_block = nn.Sequential(*block)
def forward(self, x):
return self.conv_block(x) |