geninhu commited on
Commit
6c2e66e
1 Parent(s): 5deacda

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import streamlit as st
4
+
5
+ from models import Generator, Discriminrator
6
+ from utils import image_to_base64
7
+ import torch
8
+ import torchvision.transforms as T
9
+ from torchvision.utils import make_grid
10
+ from PIL import Image
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+
15
+ model_name = {
16
+ "aurora": 'huggan/fastgan-few-shot-aurora-bs8',
17
+ "painting": 'huggan/fastgan-few-shot-painting-bs8',
18
+ "shell": 'huggan/fastgan-few-shot-shells',
19
+ "fauvism": 'huggan/fastgan-few-shot-fauvism-still-life',
20
+ }
21
+
22
+ #@st.cache(allow_output_mutation=True)
23
+ def load_generator(model_name_or_path):
24
+ generator = Generator(in_channels=256, out_channels=3)
25
+ generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
26
+ _ = generator.to('cuda')
27
+ _ = generator.eval()
28
+
29
+ return generator
30
+
31
+ def _denormalize(input: torch.Tensor) -> torch.Tensor:
32
+ return (input * 127.5) + 127.5
33
+
34
+
35
+ def generate_images(generator, number_imgs):
36
+ noise = torch.zeros(number_imgs, 256, 1, 1, device='cuda').normal_(0.0, 1.0)
37
+ with torch.no_grad():
38
+ gan_images, _ = generator(noise)
39
+
40
+ gan_images = _denormalize(gan_images.detach()).cpu()
41
+ gan_images = make_grid(gan_images, nrow=number_imgs, normalize=True)
42
+ gan_images = gan_images.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
43
+ gan_images = Image.fromarray(gan_images)
44
+ return gan_images
45
+
46
+
47
+ def main():
48
+
49
+ st.set_page_config(
50
+ page_title="FastGAN Generator",
51
+ page_icon="🖥️",
52
+ layout="wide",
53
+ initial_sidebar_state="expanded"
54
+ )
55
+
56
+ # st.sidebar.markdown(
57
+ # """
58
+ # <style>
59
+ # .aligncenter {
60
+ # text-align: center;
61
+ # }
62
+ # </style>
63
+ # <p class="aligncenter">
64
+ # <img src="https://e7.pngegg.com/pngimages/510/121/png-clipart-machine-learning-deep-learning-artificial-intelligence-algorithm-machine-learning-angle-text.png"/>
65
+ # </p>
66
+ # """,
67
+ # unsafe_allow_html=True,
68
+ # )
69
+ st.sidebar.markdown(
70
+ """
71
+ ___
72
+ <p style='text-align: center'>
73
+ FastGAN is an few-shot GAN model that generates images of several types!
74
+ </p>
75
+ <p style='text-align: center'>
76
+ Model training and Space creation by
77
+ <br/>
78
+ <a href="https://huggingface.co/vumichien" target="_blank">Chien Vu</a> | <a href="https://huggingface.co/geninhu" target="_blank">Nhu Hoang</a>
79
+ <br/>
80
+ </p>
81
+
82
+ <p style='text-align: center'>
83
+ <a href="https://github.com/silentz/Towards-Faster-And-Stabilized-GAN-Training-For-High-Fidelity-Few-Shot-Image-Synthesis" target="_blank">based on FastGAN model</a> | <a href="https://arxiv.org/abs/2101.04775" target="_blank">Article</a>
84
+ </p>
85
+ """,
86
+ unsafe_allow_html=True,
87
+ )
88
+
89
+ st.header("Welcome to FastGAN")
90
+
91
+ col1, col2, col3, col4 = st.columns([3,3,3,3])
92
+ with col1:
93
+ st.markdown('Fauvism GAN [model](https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life)', unsafe_allow_html=True)
94
+ st.image('fauvism.png', width=300)
95
+
96
+ with col2:
97
+ st.markdown('Aurora GAN [model](https://huggingface.co/huggan/fastgan-few-shot-aurora-bs8)', unsafe_allow_html=True)
98
+ st.image('aurora.png', width=300)
99
+
100
+ with col3:
101
+ st.markdown('Painting GAN [model](https://huggingface.co/huggan/fastgan-few-shot-painting-bs8)', unsafe_allow_html=True)
102
+ st.image('painting.png', width=300)
103
+ with col4:
104
+ st.markdown('Shell GAN [model](https://huggingface.co/huggan/fastgan-few-shot-shells)', unsafe_allow_html=True)
105
+ st.image('shell.png', width=300)
106
+
107
+ # Choose generator
108
+ col11, col12, col13 = st.columns([4,4,2])
109
+ with col11:
110
+ st.markdown('Choose type of image to generate', unsafe_allow_html=True)
111
+ img_type = st.selectbox("", index=0, options=["shell", "aurora", "painting", "fauvism"])
112
+
113
+ with col12:
114
+ number_imgs = st.number_input('How many images you want to generate ?', min_value=1, max_value=5)
115
+ if number_imgs is None:
116
+ st.write('Invalid number ! Please insert number of images to generate !')
117
+ raise ValueError('Invalid number ! Please insert number of images to generate !')
118
+ with col13:
119
+ generate_button = st.button('Get Image!')
120
+
121
+ # row2 = st.columns([10])
122
+ # with row2:
123
+ if generate_button:
124
+ st.markdown("""
125
+ <small><i>Predictions may take up to 1mn under high load. Please stand by.</i></small>
126
+ """,
127
+ unsafe_allow_html=True,)
128
+ generator = load_generator(model_name[img_type])
129
+ gan_images = generate_images(generator, number_imgs)
130
+ # margin = 0.1 # for better position of zoom in arrow
131
+ # n_columns = 2
132
+ # cols = st.columns([1] + [margin, 1] * (n_columns - 1))
133
+ # for i, img in enumerate(gan_images):
134
+ # cols[(i % n_columns) * 2].image(img)
135
+
136
+ st.image(gan_images, width=200*number_imgs)
137
+
138
+
139
+ if __name__ == '__main__':
140
+ main()