Abhinowww commited on
Commit
f7f0543
·
1 Parent(s): 0395f78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -62
app.py CHANGED
@@ -34,7 +34,7 @@ from util import *
34
  from huggingface_hub import hf_hub_download
35
 
36
  device= 'cpu'
37
- model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt")
38
  ckpt = torch.load(model_path_e, map_location='cpu')
39
  opts = ckpt['opts']
40
  opts['checkpoint_path'] = model_path_e
@@ -68,29 +68,29 @@ device = 'cpu'
68
 
69
  latent_dim = 512
70
 
71
- model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
72
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
73
  ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
  mean_latent = original_generator.mean_latent(10000)
76
 
77
- generatorjojo = deepcopy(original_generator)
78
 
79
- generatordisney = deepcopy(original_generator)
80
 
81
- generatorjinx = deepcopy(original_generator)
82
 
83
- generatorcaitlyn = deepcopy(original_generator)
84
 
85
- generatoryasuho = deepcopy(original_generator)
86
 
87
- generatorarcanemulti = deepcopy(original_generator)
88
 
89
- generatorart = deepcopy(original_generator)
90
 
91
- generatorspider = deepcopy(original_generator)
92
 
93
- generatorsketch = deepcopy(original_generator)
94
 
95
 
96
  transform = transforms.Compose(
@@ -104,58 +104,58 @@ transform = transforms.Compose(
104
 
105
 
106
 
107
- modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
108
 
109
 
110
- ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
111
- generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
112
 
113
 
114
- modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
115
 
116
- ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
117
  generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
 
119
 
120
- modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
 
122
- ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
- generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
 
125
 
126
- modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
 
128
- ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
- generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
 
131
 
132
- modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
 
134
- ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
- generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
 
137
 
138
- model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
 
140
- ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
- generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
 
143
 
144
- modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
 
146
- ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
- generatorart.load_state_dict(ckptart["g"], strict=False)
148
 
149
 
150
- modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
 
152
- ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
- generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
 
155
- modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
 
157
- ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
- generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
 
160
  def inference(img, model):
161
  img.save('out.jpg')
@@ -168,37 +168,38 @@ def inference(img, model):
168
  elif model == 'Disney':
169
  with torch.no_grad():
170
  my_sample = generatordisney(my_w, input_is_latent=True)
171
- elif model == 'Jinx':
172
- with torch.no_grad():
173
- my_sample = generatorjinx(my_w, input_is_latent=True)
174
- elif model == 'Caitlyn':
175
- with torch.no_grad():
176
- my_sample = generatorcaitlyn(my_w, input_is_latent=True)
177
- elif model == 'Yasuho':
178
- with torch.no_grad():
179
- my_sample = generatoryasuho(my_w, input_is_latent=True)
180
- elif model == 'Arcane Multi':
181
- with torch.no_grad():
182
- my_sample = generatorarcanemulti(my_w, input_is_latent=True)
183
- elif model == 'Art':
184
- with torch.no_grad():
185
- my_sample = generatorart(my_w, input_is_latent=True)
186
- elif model == 'Spider-Verse':
187
- with torch.no_grad():
188
- my_sample = generatorspider(my_w, input_is_latent=True)
189
- else:
190
- with torch.no_grad():
191
- my_sample = generatorsketch(my_w, input_is_latent=True)
192
 
193
 
194
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
195
  imageio.imwrite('filename.jpeg', npimage)
196
  return 'filename.jpeg'
197
 
198
- title = "JoJoGAN"
199
- description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
200
 
201
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
202
 
203
- examples=[['mona.png','Jinx']]
204
- gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
 
 
34
  from huggingface_hub import hf_hub_download
35
 
36
  device= 'cpu'
37
+ model_path_e = hf_hub_download(repo_id="Abhinowww/Capstone", filename="e4e_ffhq_encode.pt")
38
  ckpt = torch.load(model_path_e, map_location='cpu')
39
  opts = ckpt['opts']
40
  opts['checkpoint_path'] = model_path_e
 
68
 
69
  latent_dim = 512
70
 
71
+ model_path_s = hf_hub_download(repo_id="Abhinowww/Capstone", filename="stylegan2-ffhq-config-f.pt")
72
  original_generator = Generator(1024, latent_dim, 8, 2).to(device)
73
  ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
  original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
  mean_latent = original_generator.mean_latent(10000)
76
 
77
+ generatorjoker = deepcopy(original_generator)
78
 
79
+ generatorvoldemort = deepcopy(original_generator)
80
 
81
+ # generatorjinx = deepcopy(original_generator)
82
 
83
+ # generatorcaitlyn = deepcopy(original_generator)
84
 
85
+ # generatoryasuho = deepcopy(original_generator)
86
 
87
+ # generatorarcanemulti = deepcopy(original_generator)
88
 
89
+ # generatorart = deepcopy(original_generator)
90
 
91
+ # generatorspider = deepcopy(original_generator)
92
 
93
+ # generatorsketch = deepcopy(original_generator)
94
 
95
 
96
  transform = transforms.Compose(
 
104
 
105
 
106
 
107
+ modeljoker = hf_hub_download(repo_id="Abhinowww/Capstone", filename="JokerEightHundredFalse.pt")
108
 
109
 
110
+ ckptjoker = torch.load(modeljoker, map_location=lambda storage, loc: storage)
111
+ generatorjoker.load_state_dict(ckptjojo["g"], strict=False)
112
 
113
 
114
+ modelvoldemort = hf_hub_download(repo_id="Abhinowww/Capstone", filename="VoldemortEightHundredFalse.pt")
115
 
116
+ ckptdisney = torch.load(modelvoldemort, map_location=lambda storage, loc: storage)
117
  generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
 
119
 
120
+ # modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
 
122
+ # ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
+ # generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
 
125
 
126
+ # modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
 
128
+ # ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
+ # generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
 
131
 
132
+ # modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
 
134
+ # ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
+ # generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
 
137
 
138
+ # model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
 
140
+ # ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
+ # generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
 
143
 
144
+ # modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
 
146
+ # ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
+ # generatorart.load_state_dict(ckptart["g"], strict=False)
148
 
149
 
150
+ # modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
 
152
+ # ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
+ # generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
 
155
+ # modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
 
157
+ # ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
+ # generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
 
160
  def inference(img, model):
161
  img.save('out.jpg')
 
168
  elif model == 'Disney':
169
  with torch.no_grad():
170
  my_sample = generatordisney(my_w, input_is_latent=True)
171
+ # elif model == 'Jinx':
172
+ # with torch.no_grad():
173
+ # my_sample = generatorjinx(my_w, input_is_latent=True)
174
+ # elif model == 'Caitlyn':
175
+ # with torch.no_grad():
176
+ # my_sample = generatorcaitlyn(my_w, input_is_latent=True)
177
+ # elif model == 'Yasuho':
178
+ # with torch.no_grad():
179
+ # my_sample = generatoryasuho(my_w, input_is_latent=True)
180
+ # elif model == 'Arcane Multi':
181
+ # with torch.no_grad():
182
+ # my_sample = generatorarcanemulti(my_w, input_is_latent=True)
183
+ # elif model == 'Art':
184
+ # with torch.no_grad():
185
+ # my_sample = generatorart(my_w, input_is_latent=True)
186
+ # elif model == 'Spider-Verse':
187
+ # with torch.no_grad():
188
+ # my_sample = generatorspider(my_w, input_is_latent=True)
189
+ # else:
190
+ # with torch.no_grad():
191
+ # my_sample = generatorsketch(my_w, input_is_latent=True)
192
 
193
 
194
  npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
195
  imageio.imwrite('filename.jpeg', npimage)
196
  return 'filename.jpeg'
197
 
198
+ title = "Capstone"
199
+ description = "Capstone Project. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
200
 
201
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
202
 
203
+ examples=[['mona.png','Joker']]
204
+ # gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()
205
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['Joker', 'Voldemort'], type="value", default='Joker', label="Model")], gr.outputs.Image(type="pil"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()