vitaliykinakh's picture
Initial
8d6cd57
from torch import cat
from torch.optim import Adam
from torch.nn import Sequential, ModuleList, \
Conv2d, Linear, \
LeakyReLU, Tanh, \
BatchNorm1d, BatchNorm2d, \
ConvTranspose2d, UpsamplingBilinear2d
from .neuralnetwork import NeuralNetwork
# parameters for cVAE
colors_dim = 3
labels_dim = 37
momentum = 0.99 # Batchnorm
negative_slope = 0.2 # LeakyReLU
optimizer = Adam
betas = (0.5, 0.999)
# hyperparameters
learning_rate = 2e-4
latent_dim = 128
def genUpsample(input_channels, output_channels, stride, pad):
return Sequential(
ConvTranspose2d(input_channels, output_channels, 4, stride, pad, bias=False),
BatchNorm2d(output_channels),
LeakyReLU(negative_slope=negative_slope))
def genUpsample2(input_channels, output_channels, kernel_size):
return Sequential(
Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2),
BatchNorm2d(output_channels),
LeakyReLU(negative_slope=negative_slope),
Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2),
BatchNorm2d(output_channels),
LeakyReLU(negative_slope=negative_slope),
UpsamplingBilinear2d(scale_factor=2))
class ConditionalDecoder(NeuralNetwork):
def __init__(self, ll_scaling=1.0, dim_z=latent_dim):
super(ConditionalDecoder, self).__init__()
self.dim_z = dim_z
ngf = 32
self.init = genUpsample(self.dim_z, ngf * 16, 1, 0)
self.embedding = Sequential(
Linear(labels_dim, self.dim_z),
BatchNorm1d(self.dim_z, momentum=momentum),
LeakyReLU(negative_slope=negative_slope),
)
self.dense_init = Sequential(
Linear(self.dim_z*2, self.dim_z),
BatchNorm1d(self.dim_z, momentum=momentum),
LeakyReLU(negative_slope=negative_slope),
)
self.m_modules = ModuleList() # to 4x4
self.c_modules = ModuleList()
for i in range(4):
self.m_modules.append(genUpsample2(ngf * 2**(4-i), ngf * 2**(3-i), 3))
self.c_modules.append(Sequential(Conv2d(ngf * 2**(3-i), colors_dim, 3, 1, 1, bias=False), Tanh()))
self.set_optimizer(optimizer, lr=learning_rate*ll_scaling, betas=betas)
def forward(self, latent, labels, step=3):
y = self.embedding(labels)
out = cat((latent, y), dim=1)
out = self.dense_init(out)
out = out.unsqueeze(2).unsqueeze(3)
out = self.init(out)
for i in range(step):
out = self.m_modules[i](out)
out = self.c_modules[step](self.m_modules[step](out))
return out