|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.utils import make_grid
|
|
import matplotlib.pyplot as plt
|
|
|
|
def get_noise(n_samples, z_dim, device='cpu'):
|
|
return torch.randn((n_samples, z_dim), device=device)
|
|
|
|
def get_random_labels(n_samples, device='cpu'):
|
|
return torch.randint(0, 10, (n_samples,), device=device).type(torch.long)
|
|
|
|
def get_generator_block(input_dim, output_dim):
|
|
return nn.Sequential(
|
|
nn.Linear(input_dim, output_dim),
|
|
nn.BatchNorm1d(output_dim),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
|
|
super(Generator, self).__init__()
|
|
|
|
|
|
self.gen = nn.Sequential(
|
|
get_generator_block(z_dim + 10, hidden_dim),
|
|
get_generator_block(hidden_dim, hidden_dim*2),
|
|
get_generator_block(hidden_dim*2, hidden_dim*4),
|
|
get_generator_block(hidden_dim*4, hidden_dim*8),
|
|
nn.Linear(hidden_dim*8, im_dim),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, noise, classes):
|
|
'''
|
|
noise (batch_size, z_dim) noise vector for each image in a batch
|
|
classes:long (batch_size) condition class for each image in a batch
|
|
'''
|
|
|
|
|
|
one_hot_vec = F.one_hot(classes, num_classes=10).type(torch.float32)
|
|
conditioned_noise = torch.concat((noise, one_hot_vec), dim=1)
|
|
return self.gen(conditioned_noise)
|
|
|
|
|
|
def get_discriminator_block(input_dim, output_dim):
|
|
return nn.Sequential(
|
|
nn.Linear(input_dim, output_dim),
|
|
nn.LeakyReLU(0.2, inplace=True)
|
|
)
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, im_dim=784, hidden_dim=128):
|
|
super(Discriminator, self).__init__()
|
|
self.disc = nn.Sequential(
|
|
get_discriminator_block(im_dim + 10, hidden_dim*4),
|
|
get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
|
|
get_discriminator_block(hidden_dim * 2, hidden_dim),
|
|
nn.Linear(hidden_dim, 1),
|
|
|
|
|
|
|
|
)
|
|
|
|
def forward(self, image_batch):
|
|
'''image_batch (batch_size, 784+10)'''
|
|
return self.disc(image_batch) |