Gholamreza
commited on
Upload 5 files
Browse files- README.md +26 -14
- app.py +33 -0
- conditional_gan.py +47 -0
- generated_digit.png +0 -0
- models.py +67 -0
README.md
CHANGED
@@ -1,14 +1,26 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generating MNIST digits using Conditional GAN
|
2 |
+
|
3 |
+
This is a simple implementation of Conditional Generative Adversarial Networks (GAN) for generating MNIST digits.
|
4 |
+
|
5 |
+
![cover](demos/gen_all_digits.png)
|
6 |
+
|
7 |
+
I use simple BCE loss function for calculating the loss and Adam optimizer (lr=0.0001) for training.
|
8 |
+
|
9 |
+
## Architecture
|
10 |
+
|
11 |
+
- The **generator** is series of Linear layers with BatchNorm and ReLU activations.
|
12 |
+
- The **discriminator** is a series of Linear layers with BatchNorm andLeakyReLU activations.
|
13 |
+
- The Conditioning class is appended to the noise vector as a one-hot vector.
|
14 |
+
|
15 |
+
## Huggingface Space
|
16 |
+
|
17 |
+
You can try generating digits using this model on Huggingface Space.
|
18 |
+
https://huggingface.co/spaces/gholamreza/Conditional-GAN-MNIST
|
19 |
+
|
20 |
+
![Huggingface Space](demos/gradio_app.png)
|
21 |
+
|
22 |
+
## Training History
|
23 |
+
|
24 |
+
![losses_plot](demos/losses.png)
|
25 |
+
|
26 |
+
Visit https://github.com/gholamrezadar/GAN-MNIST for a simpler version of this code and more details.
|
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from models import Generator
|
4 |
+
from conditional_gan import generate_digit
|
5 |
+
|
6 |
+
generator = Generator()
|
7 |
+
|
8 |
+
def init():
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
# Load the generator
|
12 |
+
generator.load_state_dict(torch.load('models/generator.pt', map_location=device))
|
13 |
+
generator.to(device)
|
14 |
+
|
15 |
+
def generate_mnist_digit(digit):
|
16 |
+
return generate_digit(generator, digit)
|
17 |
+
|
18 |
+
# Gradio Interface
|
19 |
+
def gradio_generate(digit):
|
20 |
+
return generate_mnist_digit(digit)
|
21 |
+
|
22 |
+
with gr.Blocks() as demo:
|
23 |
+
gr.Markdown("# MNIST Digit Generator")
|
24 |
+
digit = gr.Dropdown(list(range(10)), label="Select a Digit")
|
25 |
+
generate_button = gr.Button("Generate")
|
26 |
+
output_image = gr.Image(label="Generated Image", type="filepath")
|
27 |
+
|
28 |
+
generate_button.click(gradio_generate, inputs=digit, outputs=output_image)
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
init()
|
32 |
+
print("* Model loaded")
|
33 |
+
demo.launch()
|
conditional_gan.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This files serves the neccessary functions for generating images using pretrained models
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
from models import get_noise
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
def display_image_grid(images, num_rows=5, title=""):
|
13 |
+
if(images.shape[-1]!=28):
|
14 |
+
images = images.view(-1, 1, 28, 28)
|
15 |
+
plt.figure(figsize=(5, 5))
|
16 |
+
plt.axis("off")
|
17 |
+
plt.title(title)
|
18 |
+
grid = make_grid(images.detach().cpu()[:25], nrow=num_rows).permute(1, 2, 0).numpy()
|
19 |
+
plt.imshow(grid)
|
20 |
+
plt.show()
|
21 |
+
|
22 |
+
def check_generation(generator):
|
23 |
+
generator.eval()
|
24 |
+
labels = torch.tensor([0,1,2,3,4,5,6,7,8,9] * 10).to(device)
|
25 |
+
fake_eval_batch = generator(get_noise(100, 10, device=device), labels).view(-1, 1, 28, 28)
|
26 |
+
grid = make_grid(fake_eval_batch.detach().cpu(), nrow=10).permute(1, 2, 0).numpy()
|
27 |
+
plt.figure(figsize=(9, 9))
|
28 |
+
plt.title("Generated Images")
|
29 |
+
plt.axis('off')
|
30 |
+
plt.xlabel("Class")
|
31 |
+
plt.imshow(grid)
|
32 |
+
plt.show()
|
33 |
+
|
34 |
+
def generate_digit(generator, digit):
|
35 |
+
generator.eval()
|
36 |
+
labels = torch.tensor([digit] * 25).to(device)
|
37 |
+
fake_eval_batch = generator(get_noise(25, 10, device=device), labels).view(-1, 1, 28, 28)
|
38 |
+
grid = make_grid(fake_eval_batch.detach().cpu(), nrow=5).permute(1, 2, 0).numpy()
|
39 |
+
plt.figure(figsize=(5, 5))
|
40 |
+
# no border
|
41 |
+
plt.axis('off')
|
42 |
+
plt.grid(False)
|
43 |
+
plt.xticks([])
|
44 |
+
plt.yticks([])
|
45 |
+
plt.imshow(grid)
|
46 |
+
plt.savefig('generated_digit.png', bbox_inches='tight', pad_inches=0) # Save the generated image
|
47 |
+
return 'generated_digit.png' # Return the image path
|
generated_digit.png
ADDED
models.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.utils import make_grid
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
def get_noise(n_samples, z_dim, device='cpu'):
|
8 |
+
return torch.randn((n_samples, z_dim), device=device)
|
9 |
+
|
10 |
+
def get_random_labels(n_samples, device='cpu'):
|
11 |
+
return torch.randint(0, 10, (n_samples,), device=device).type(torch.long)
|
12 |
+
|
13 |
+
def get_generator_block(input_dim, output_dim):
|
14 |
+
return nn.Sequential(
|
15 |
+
nn.Linear(input_dim, output_dim),
|
16 |
+
nn.BatchNorm1d(output_dim),
|
17 |
+
nn.ReLU(inplace=True)
|
18 |
+
)
|
19 |
+
|
20 |
+
class Generator(nn.Module):
|
21 |
+
def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
|
22 |
+
super(Generator, self).__init__()
|
23 |
+
|
24 |
+
# input is of shape (batch_size, z_dim + 10)
|
25 |
+
self.gen = nn.Sequential(
|
26 |
+
get_generator_block(z_dim + 10, hidden_dim), # 128
|
27 |
+
get_generator_block(hidden_dim, hidden_dim*2), # 256
|
28 |
+
get_generator_block(hidden_dim*2, hidden_dim*4), # 512
|
29 |
+
get_generator_block(hidden_dim*4, hidden_dim*8), # 1024
|
30 |
+
nn.Linear(hidden_dim*8, im_dim), # 784
|
31 |
+
nn.Sigmoid(), # output between 0 and 1
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, noise, classes):
|
35 |
+
'''
|
36 |
+
noise (batch_size, z_dim) noise vector for each image in a batch
|
37 |
+
classes:long (batch_size) condition class for each image in a batch
|
38 |
+
'''
|
39 |
+
# classes = classes.type(torch.long)
|
40 |
+
# one-hot encode condition_class e.g. 3 -> [0,0,0,1,0,0,0,0,0,0]
|
41 |
+
one_hot_vec = F.one_hot(classes, num_classes=10).type(torch.float32) # (batch_size, 10)
|
42 |
+
conditioned_noise = torch.concat((noise, one_hot_vec), dim=1) # (batch_size, z_dim + 10)
|
43 |
+
return self.gen(conditioned_noise)
|
44 |
+
|
45 |
+
|
46 |
+
def get_discriminator_block(input_dim, output_dim):
|
47 |
+
return nn.Sequential(
|
48 |
+
nn.Linear(input_dim, output_dim),
|
49 |
+
nn.LeakyReLU(0.2, inplace=True)
|
50 |
+
)
|
51 |
+
|
52 |
+
class Discriminator(nn.Module):
|
53 |
+
def __init__(self, im_dim=784, hidden_dim=128):
|
54 |
+
super(Discriminator, self).__init__()
|
55 |
+
self.disc = nn.Sequential(
|
56 |
+
get_discriminator_block(im_dim + 10, hidden_dim*4), # 512
|
57 |
+
get_discriminator_block(hidden_dim * 4, hidden_dim * 2), # 256
|
58 |
+
get_discriminator_block(hidden_dim * 2, hidden_dim), # 128
|
59 |
+
nn.Linear(hidden_dim, 1),
|
60 |
+
# nn.Sigmoid(),
|
61 |
+
# using a sigmoid followed by BCE is less numerically stable than BCEWithLogitsLoss alone
|
62 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss:~:text=This%20loss%20combines%20a%20Sigmoid%20layer%20and%20the%20BCELoss%20in%20one%20single%20class.%20This%20version%20is%20more%20numerically%20stable%20than%20using%20a%20plain%20Sigmoid%20followed%20by%20a%20BCELoss%20as%2C%20by%20combining%20the%20operations%20into%20one%20layer%2C%20we%20take%20advantage%20of%20the%20log%2Dsum%2Dexp%20trick%20for%20numerical%20stability.
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(self, image_batch):
|
66 |
+
'''image_batch (batch_size, 784+10)'''
|
67 |
+
return self.disc(image_batch)
|