AlekseyKorshuk commited on
Commit
d14e266
1 Parent(s): 33289a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -4
app.py CHANGED
@@ -1,13 +1,21 @@
1
  import json
 
 
 
2
  from huggingnft.lightweight_gan.train import timestamped_filename
3
  from streamlit_option_menu import option_menu
4
 
5
  from huggingface_hub import hf_hub_download, file_download
 
6
 
7
  from huggingface_hub.hf_api import HfApi
8
  import streamlit as st
9
  from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer
10
  from accelerate import Accelerator
 
 
 
 
11
 
12
  hfapi = HfApi()
13
  model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
@@ -33,7 +41,7 @@ INTERPOLATION_TEXT = "Text about Interpolation"
33
  COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection"
34
 
35
  STOPWORDS = ["-old"]
36
- COLLECTION2COLLECTION_KEYS = ["_2_"]
37
 
38
 
39
  def load_lightweight_model(model_name):
@@ -61,6 +69,11 @@ def clean_models(model_names, stopwords):
61
  cleaned_model_names.append(model_name)
62
  return cleaned_model_names
63
 
 
 
 
 
 
64
 
65
  model_names = clean_models(model_names, STOPWORDS)
66
 
@@ -141,7 +154,7 @@ if choose == "Generate image":
141
  nrow=nrows,
142
  checkpoint=-1,
143
  types=generation_type
144
- )
145
  )
146
 
147
  if choose == "Interpolation":
@@ -184,13 +197,75 @@ if choose == "Interpolation":
184
 
185
  if choose == "Collection2Collection":
186
  st.title(choose)
187
- st.markdown(INTERPOLATION_TEXT)
188
 
189
  model_name = st.selectbox(
190
  'Choose model:',
191
  set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS))
192
  )
 
 
 
 
 
 
193
  generate_image_button = st.button("Generate")
194
 
195
  if generate_image_button:
196
- st.markdown("generating Collection2Collection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+
3
+ import torch
4
+
5
  from huggingnft.lightweight_gan.train import timestamped_filename
6
  from streamlit_option_menu import option_menu
7
 
8
  from huggingface_hub import hf_hub_download, file_download
9
+ from PIL import Image
10
 
11
  from huggingface_hub.hf_api import HfApi
12
  import streamlit as st
13
  from huggingnft.lightweight_gan.lightweight_gan import Generator, LightweightGAN, evaluate_in_chunks, Trainer
14
  from accelerate import Accelerator
15
+ from huggan.pytorch.cyclegan.modeling_cyclegan import GeneratorResNet
16
+ from torchvision import transforms as T
17
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip
18
+ from torchvision.utils import make_grid
19
 
20
  hfapi = HfApi()
21
  model_names = [model.modelId[model.modelId.index("/") + 1:] for model in hfapi.list_models(author="huggingnft")]
 
41
  COLLECTION2COLLECTION_TEXT = "Text about Collection2Collection"
42
 
43
  STOPWORDS = ["-old"]
44
+ COLLECTION2COLLECTION_KEYS = ["__2__"]
45
 
46
 
47
  def load_lightweight_model(model_name):
 
69
  cleaned_model_names.append(model_name)
70
  return cleaned_model_names
71
 
72
+ def get_concat_h(im1, im2):
73
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
74
+ dst.paste(im1, (0, 0))
75
+ dst.paste(im2, (im1.width, 0))
76
+ return dst
77
 
78
  model_names = clean_models(model_names, STOPWORDS)
79
 
 
154
  nrow=nrows,
155
  checkpoint=-1,
156
  types=generation_type
157
+ )[0]
158
  )
159
 
160
  if choose == "Interpolation":
 
197
 
198
  if choose == "Collection2Collection":
199
  st.title(choose)
200
+ st.markdown(COLLECTION2COLLECTION_TEXT)
201
 
202
  model_name = st.selectbox(
203
  'Choose model:',
204
  set(model_names) - set(clean_models(model_names, COLLECTION2COLLECTION_KEYS))
205
  )
206
+ nrows = st.number_input("Number of images to generate:",
207
+ min_value=1,
208
+ max_value=10,
209
+ step=1,
210
+ value=1,
211
+ )
212
  generate_image_button = st.button("Generate")
213
 
214
  if generate_image_button:
215
+ n_channels = 3
216
+
217
+ image_size = 256
218
+
219
+ input_shape = (image_size, image_size)
220
+
221
+ transform = Compose([
222
+ T.ToPILImage(),
223
+ T.Resize(input_shape),
224
+ ToTensor(),
225
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
226
+ ])
227
+
228
+ # generator = modeling_dcgan.Generator.from_pretrained("huggingnft/cryptopunks")
229
+ with st.spinner(text=f"Downloading selected model..."):
230
+ translator = GeneratorResNet.from_pretrained(f'huggingnft/{model_name}',
231
+ input_shape=(n_channels, image_size, image_size),
232
+ num_residual_blocks=9)
233
+
234
+ z = torch.randn(nrows, 100, 1, 1)
235
+
236
+ with st.spinner(text=f"Downloading selected model..."):
237
+ model = load_lightweight_model(f"huggingnft/{model_name.split('__2__')[0]}")
238
+
239
+ with st.spinner(text=f"Generating input images..."):
240
+ punks = model.generate_app(
241
+ num=timestamped_filename(),
242
+ nrow=4,
243
+ checkpoint=-1,
244
+ types="default"
245
+ )[1]
246
+
247
+ pipe_transform = T.Resize((256, 256))
248
+
249
+ input = pipe_transform(punks)
250
+
251
+ with st.spinner(text=f"Generating output images..."):
252
+ output = translator(input)
253
+
254
+ out_img = make_grid(output,
255
+ nrow=4, normalize=True)
256
+
257
+ # out_img = make_grid(punks,
258
+ # nrow=8, normalize=True)
259
+
260
+ out_transform = Compose([
261
+ T.ToPILImage()
262
+ ])
263
+
264
+ results = []
265
+
266
+ for out_punk, out_ape in zip(input, output):
267
+ results.append(
268
+ get_concat_h(out_transform(make_grid(out_punk, nrow=1, normalize=True)), out_transform(make_grid(out_ape, nrow=1, normalize=True)))
269
+ )
270
+ for result in results:
271
+ st.image(result)