NimaBoscarino commited on
Commit
72caa6f
1 Parent(s): d015ad6

Add num_punks selector

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -1,8 +1,11 @@
 
 
 
1
  from torch import nn
 
 
2
 
3
  class Generator(nn.Module):
4
- # Refer to the link below for explanations about nc, nz, and ngf
5
- # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs
6
  def __init__(self, nc=4, nz=100, ngf=64):
7
  super(Generator, self).__init__()
8
  self.network = nn.Sequential(
@@ -23,29 +26,29 @@ class Generator(nn.Module):
23
  output = self.network(input)
24
  return output
25
 
26
- from huggingface_hub import hf_hub_download
27
- import torch
28
 
29
  model = Generator()
30
  weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
31
- model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available
32
 
33
- from torchvision.utils import save_image
34
 
35
- def predict(seed):
36
- num_punks = 4
37
  torch.manual_seed(seed)
38
  z = torch.randn(num_punks, 100, 1, 1)
39
  punks = model(z)
40
  save_image(punks, "punks.png", normalize=True)
41
  return 'punks.png'
42
 
43
- import gradio as gr
44
 
45
  gr.Interface(
46
  predict,
47
  inputs=[
48
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
 
49
  ],
50
  outputs="image",
51
- ).launch()
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from huggingface_hub import hf_hub_download
4
  from torch import nn
5
+ from torchvision.utils import save_image
6
+
7
 
8
  class Generator(nn.Module):
 
 
9
  def __init__(self, nc=4, nz=100, ngf=64):
10
  super(Generator, self).__init__()
11
  self.network = nn.Sequential(
 
26
  output = self.network(input)
27
  return output
28
 
 
 
29
 
30
  model = Generator()
31
  weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
32
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
33
 
 
34
 
35
+ def predict(seed, num_punks):
 
36
  torch.manual_seed(seed)
37
  z = torch.randn(num_punks, 100, 1, 1)
38
  punks = model(z)
39
  save_image(punks, "punks.png", normalize=True)
40
  return 'punks.png'
41
 
 
42
 
43
  gr.Interface(
44
  predict,
45
  inputs=[
46
  gr.inputs.Slider(label='Seed', minimum=0, maximum=1000, default=42),
47
+ gr.inputs.Slider(label='Number of Punks', minimum=4, maximum=64, step=1, default=10),
48
  ],
49
  outputs="image",
50
+ title="Cryptopunks GAN",
51
+ description="These CryptoPunks do not exist. Generate random punks with an initial seed!",
52
+ article="<p style='text-align: center'><a href='https://arxiv.org/pdf/1511.06434.pdf'>Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a> | <a href='https://github.com/teddykoker/cryptopunks-gan'>Github Repo</a></p>",
53
+ examples=[[123], [42], [456], [1337]],
54
+ ).launch(cache_examples=True)