Spaces:
Runtime error
Runtime error
from torch import cat | |
from torch.optim import Adam | |
from torch.nn import Sequential, ModuleList, \ | |
Conv2d, Linear, \ | |
LeakyReLU, Tanh, \ | |
BatchNorm1d, BatchNorm2d, \ | |
ConvTranspose2d, UpsamplingBilinear2d | |
from .neuralnetwork import NeuralNetwork | |
# parameters for cVAE | |
colors_dim = 3 | |
labels_dim = 37 | |
momentum = 0.99 # Batchnorm | |
negative_slope = 0.2 # LeakyReLU | |
optimizer = Adam | |
betas = (0.5, 0.999) | |
# hyperparameters | |
learning_rate = 2e-4 | |
latent_dim = 128 | |
def genUpsample(input_channels, output_channels, stride, pad): | |
return Sequential( | |
ConvTranspose2d(input_channels, output_channels, 4, stride, pad, bias=False), | |
BatchNorm2d(output_channels), | |
LeakyReLU(negative_slope=negative_slope)) | |
def genUpsample2(input_channels, output_channels, kernel_size): | |
return Sequential( | |
Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2), | |
BatchNorm2d(output_channels), | |
LeakyReLU(negative_slope=negative_slope), | |
Conv2d(output_channels, output_channels, kernel_size=kernel_size, stride=1, padding= (kernel_size-1) // 2), | |
BatchNorm2d(output_channels), | |
LeakyReLU(negative_slope=negative_slope), | |
UpsamplingBilinear2d(scale_factor=2)) | |
class ConditionalDecoder(NeuralNetwork): | |
def __init__(self, ll_scaling=1.0, dim_z=latent_dim): | |
super(ConditionalDecoder, self).__init__() | |
self.dim_z = dim_z | |
ngf = 32 | |
self.init = genUpsample(self.dim_z, ngf * 16, 1, 0) | |
self.embedding = Sequential( | |
Linear(labels_dim, self.dim_z), | |
BatchNorm1d(self.dim_z, momentum=momentum), | |
LeakyReLU(negative_slope=negative_slope), | |
) | |
self.dense_init = Sequential( | |
Linear(self.dim_z*2, self.dim_z), | |
BatchNorm1d(self.dim_z, momentum=momentum), | |
LeakyReLU(negative_slope=negative_slope), | |
) | |
self.m_modules = ModuleList() # to 4x4 | |
self.c_modules = ModuleList() | |
for i in range(4): | |
self.m_modules.append(genUpsample2(ngf * 2**(4-i), ngf * 2**(3-i), 3)) | |
self.c_modules.append(Sequential(Conv2d(ngf * 2**(3-i), colors_dim, 3, 1, 1, bias=False), Tanh())) | |
self.set_optimizer(optimizer, lr=learning_rate*ll_scaling, betas=betas) | |
def forward(self, latent, labels, step=3): | |
y = self.embedding(labels) | |
out = cat((latent, y), dim=1) | |
out = self.dense_init(out) | |
out = out.unsqueeze(2).unsqueeze(3) | |
out = self.init(out) | |
for i in range(step): | |
out = self.m_modules[i](out) | |
out = self.c_modules[step](self.m_modules[step](out)) | |
return out | |