File size: 3,137 Bytes
8895b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn((n_samples, z_dim), device=device)

def get_random_labels(n_samples, device='cpu'):
    return torch.randint(0, 10, (n_samples,), device=device).type(torch.long)

def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()

        # input is of shape (batch_size, z_dim + 10)
        self.gen = nn.Sequential(
            get_generator_block(z_dim + 10, hidden_dim), # 128
            get_generator_block(hidden_dim, hidden_dim*2), # 256
            get_generator_block(hidden_dim*2, hidden_dim*4), # 512
            get_generator_block(hidden_dim*4, hidden_dim*8), # 1024
            nn.Linear(hidden_dim*8, im_dim), # 784
            nn.Sigmoid(), # output between 0 and 1
        )

    def forward(self, noise, classes):
        '''

        noise (batch_size, z_dim) noise vector for each image in a batch 

        classes:long (batch_size) condition class for each image in a batch

        '''
        # classes = classes.type(torch.long)
        # one-hot encode condition_class e.g. 3 -> [0,0,0,1,0,0,0,0,0,0]
        one_hot_vec = F.one_hot(classes, num_classes=10).type(torch.float32) # (batch_size, 10)
        conditioned_noise = torch.concat((noise, one_hot_vec), dim=1) # (batch_size, z_dim + 10)
        return self.gen(conditioned_noise)


def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace=True)
    )

class Discriminator(nn.Module):
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim + 10, hidden_dim*4), # 512
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2), # 256
            get_discriminator_block(hidden_dim * 2, hidden_dim), # 128
            nn.Linear(hidden_dim, 1),
            # nn.Sigmoid(),
            # using a sigmoid followed by BCE is less numerically stable than BCEWithLogitsLoss alone
            # https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss:~:text=This%20loss%20combines%20a%20Sigmoid%20layer%20and%20the%20BCELoss%20in%20one%20single%20class.%20This%20version%20is%20more%20numerically%20stable%20than%20using%20a%20plain%20Sigmoid%20followed%20by%20a%20BCELoss%20as%2C%20by%20combining%20the%20operations%20into%20one%20layer%2C%20we%20take%20advantage%20of%20the%20log%2Dsum%2Dexp%20trick%20for%20numerical%20stability.
        )

    def forward(self, image_batch):
        '''image_batch (batch_size, 784+10)'''
        return self.disc(image_batch)