|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class DownsamplingBlock(nn.Module): |
|
"""Defines the Unet downsampling block. |
|
|
|
Consists of Convolution-BatchNorm-ReLU layer with k filters. |
|
""" |
|
def __init__(self, c_in, c_out, kernel_size=4, stride=2, |
|
padding=1, negative_slope=0.2, use_norm=True): |
|
""" |
|
Initializes the UnetDownsamplingBlock. |
|
|
|
Args: |
|
c_in (int): The number of input channels. |
|
c_out (int): The number of output channels. |
|
kernel_size (int, optional): The size of the convolving kernel. Default is 4. |
|
stride (int, optional): Stride of the convolution. Default is 2. |
|
padding (int, optional): Zero-padding added to both sides of the input. Default is 0. |
|
negative_slope (float, optional): Negative slope for the LeakyReLU activation function. Default is 0.2. |
|
use_norm (bool, optinal): If use norm layer. If True add a BatchNorm layer after Conv. Default is True. |
|
""" |
|
super(DownsamplingBlock, self).__init__() |
|
block = [] |
|
block += [nn.Conv2d(in_channels=c_in, out_channels=c_out, |
|
kernel_size=kernel_size, stride=stride, padding=padding, |
|
bias=(not use_norm) |
|
)] |
|
if use_norm: |
|
block += [nn.BatchNorm2d(num_features=c_out)] |
|
|
|
block += [nn.LeakyReLU(negative_slope=negative_slope)] |
|
|
|
self.conv_block = nn.Sequential(*block) |
|
|
|
def forward(self, x): |
|
return self.conv_block(x) |
|
|
|
|
|
class UpsamplingBlock(nn.Module): |
|
"""Defines the Unet upsampling block. |
|
""" |
|
def __init__(self, c_in, c_out, kernel_size=4, stride=2, |
|
padding=1, use_dropout=False, use_upsampling=False, mode='nearest'): |
|
|
|
""" |
|
Initializes the Unet Upsampling Block. |
|
|
|
Args: |
|
c_in (int): The number of input channels. |
|
c_out (int): The number of output channels. |
|
kernel_size (int, optional): Size of the convolving kernel. Default is 4. |
|
stride (int, optional): Stride of the convolution. Default is 2. |
|
padding (int, optional): Zero-padding added to both sides of the input. Default is 0. |
|
use_dropout (bool, optional): if use dropout layers. Default is False. |
|
upsample (bool, optinal): if use upsampling rather than transpose convolution. Default is False. |
|
mode (str, optional): the upsampling algorithm: one of 'nearest', |
|
'bilinear', 'bicubic'. Default: 'nearest' |
|
""" |
|
super(UpsamplingBlock, self).__init__() |
|
block = [] |
|
if use_upsampling: |
|
|
|
|
|
|
|
|
|
|
|
mode = mode if mode in ('nearest', 'bilinear', 'bicubic') else 'nearest' |
|
|
|
block += [nn.Sequential( |
|
nn.Upsample(scale_factor=2, mode=mode), |
|
nn.Conv2d(in_channels=c_in, out_channels=c_out, |
|
kernel_size=3, stride=1, padding=padding, |
|
bias=False |
|
) |
|
)] |
|
else: |
|
block += [nn.ConvTranspose2d(in_channels=c_in, |
|
out_channels=c_out, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, bias=False |
|
) |
|
] |
|
|
|
block += [nn.BatchNorm2d(num_features=c_out)] |
|
|
|
if use_dropout: |
|
block += [nn.Dropout(0.5)] |
|
|
|
block += [nn.ReLU()] |
|
|
|
self.conv_block = nn.Sequential(*block) |
|
|
|
def forward(self, x): |
|
return self.conv_block(x) |