Abhinowww commited on
Commit
d083399
·
1 Parent(s): 1c7b57f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ import gradio as gr
5
+ import torch
6
+ torch.backends.cudnn.benchmark = True
7
+ from torchvision import transforms, utils
8
+ from util import *
9
+ from PIL import Image
10
+ import math
11
+ import random
12
+ import numpy as np
13
+ from torch import nn, autograd, optim
14
+ from torch.nn import functional as F
15
+ from tqdm import tqdm
16
+ import lpips
17
+ from model import *
18
+
19
+
20
+ #from e4e_projection import projection as e4e_projection
21
+
22
+ from copy import deepcopy
23
+ import imageio
24
+
25
+ import os
26
+ import sys
27
+ import numpy as np
28
+ from PIL import Image
29
+ import torch
30
+ import torchvision.transforms as transforms
31
+ from argparse import Namespace
32
+ from e4e.models.psp import pSp
33
+ from util import *
34
+ from huggingface_hub import hf_hub_download
35
+
36
+ device= 'cpu'
37
+ model_path_e = hf_hub_download(repo_id="akhaliq/JoJoGAN_e4e_ffhq_encode", filename="e4e_ffhq_encode.pt")
38
+ ckpt = torch.load(model_path_e, map_location='cpu')
39
+ opts = ckpt['opts']
40
+ opts['checkpoint_path'] = model_path_e
41
+ opts= Namespace(**opts)
42
+ net = pSp(opts, device).eval().to(device)
43
+
44
+ @ torch.no_grad()
45
+ def projection(img, name, device='cuda'):
46
+
47
+
48
+ transform = transforms.Compose(
49
+ [
50
+ transforms.Resize(256),
51
+ transforms.CenterCrop(256),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
54
+ ]
55
+ )
56
+ img = transform(img).unsqueeze(0).to(device)
57
+ images, w_plus = net(img, randomize_noise=False, return_latents=True)
58
+ result_file = {}
59
+ result_file['latent'] = w_plus[0]
60
+ torch.save(result_file, name)
61
+ return w_plus[0]
62
+
63
+
64
+
65
+
66
+ device = 'cpu'
67
+
68
+
69
+ latent_dim = 512
70
+
71
+ model_path_s = hf_hub_download(repo_id="akhaliq/jojogan-stylegan2-ffhq-config-f", filename="stylegan2-ffhq-config-f.pt")
72
+ original_generator = Generator(1024, latent_dim, 8, 2).to(device)
73
+ ckpt = torch.load(model_path_s, map_location=lambda storage, loc: storage)
74
+ original_generator.load_state_dict(ckpt["g_ema"], strict=False)
75
+ mean_latent = original_generator.mean_latent(10000)
76
+
77
+ generatorjojo = deepcopy(original_generator)
78
+
79
+ generatordisney = deepcopy(original_generator)
80
+
81
+ generatorjinx = deepcopy(original_generator)
82
+
83
+ generatorcaitlyn = deepcopy(original_generator)
84
+
85
+ generatoryasuho = deepcopy(original_generator)
86
+
87
+ generatorarcanemulti = deepcopy(original_generator)
88
+
89
+ generatorart = deepcopy(original_generator)
90
+
91
+ generatorspider = deepcopy(original_generator)
92
+
93
+ generatorsketch = deepcopy(original_generator)
94
+
95
+
96
+ transform = transforms.Compose(
97
+ [
98
+ transforms.Resize((1024, 1024)),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
101
+ ]
102
+ )
103
+
104
+
105
+
106
+
107
+ modeljojo = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_preserve_color.pt")
108
+
109
+
110
+ ckptjojo = torch.load(modeljojo, map_location=lambda storage, loc: storage)
111
+ generatorjojo.load_state_dict(ckptjojo["g"], strict=False)
112
+
113
+
114
+ modeldisney = hf_hub_download(repo_id="akhaliq/jojogan-disney", filename="disney_preserve_color.pt")
115
+
116
+ ckptdisney = torch.load(modeldisney, map_location=lambda storage, loc: storage)
117
+ generatordisney.load_state_dict(ckptdisney["g"], strict=False)
118
+
119
+
120
+ modeljinx = hf_hub_download(repo_id="akhaliq/jojo-gan-jinx", filename="arcane_jinx_preserve_color.pt")
121
+
122
+ ckptjinx = torch.load(modeljinx, map_location=lambda storage, loc: storage)
123
+ generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
124
+
125
+
126
+ modelcaitlyn = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_caitlyn_preserve_color.pt")
127
+
128
+ ckptcaitlyn = torch.load(modelcaitlyn, map_location=lambda storage, loc: storage)
129
+ generatorcaitlyn.load_state_dict(ckptcaitlyn["g"], strict=False)
130
+
131
+
132
+ modelyasuho = hf_hub_download(repo_id="akhaliq/JoJoGAN-jojo", filename="jojo_yasuho_preserve_color.pt")
133
+
134
+ ckptyasuho = torch.load(modelyasuho, map_location=lambda storage, loc: storage)
135
+ generatoryasuho.load_state_dict(ckptyasuho["g"], strict=False)
136
+
137
+
138
+ model_arcane_multi = hf_hub_download(repo_id="akhaliq/jojogan-arcane", filename="arcane_multi_preserve_color.pt")
139
+
140
+ ckptarcanemulti = torch.load(model_arcane_multi, map_location=lambda storage, loc: storage)
141
+ generatorarcanemulti.load_state_dict(ckptarcanemulti["g"], strict=False)
142
+
143
+
144
+ modelart = hf_hub_download(repo_id="akhaliq/jojo-gan-art", filename="art.pt")
145
+
146
+ ckptart = torch.load(modelart, map_location=lambda storage, loc: storage)
147
+ generatorart.load_state_dict(ckptart["g"], strict=False)
148
+
149
+
150
+ modelSpiderverse = hf_hub_download(repo_id="akhaliq/jojo-gan-spiderverse", filename="Spiderverse-face-500iters-8face.pt")
151
+
152
+ ckptspider = torch.load(modelSpiderverse, map_location=lambda storage, loc: storage)
153
+ generatorspider.load_state_dict(ckptspider["g"], strict=False)
154
+
155
+ modelSketch = hf_hub_download(repo_id="akhaliq/jojogan-sketch", filename="sketch_multi.pt")
156
+
157
+ ckptsketch = torch.load(modelSketch, map_location=lambda storage, loc: storage)
158
+ generatorsketch.load_state_dict(ckptsketch["g"], strict=False)
159
+
160
+ def inference(img, model):
161
+ img.save('out.jpg')
162
+ aligned_face = align_face('out.jpg')
163
+
164
+ my_w = projection(aligned_face, "test.pt", device).unsqueeze(0)
165
+ if model == 'JoJo':
166
+ with torch.no_grad():
167
+ my_sample = generatorjojo(my_w, input_is_latent=True)
168
+ elif model == 'Disney':
169
+ with torch.no_grad():
170
+ my_sample = generatordisney(my_w, input_is_latent=True)
171
+ elif model == 'Jinx':
172
+ with torch.no_grad():
173
+ my_sample = generatorjinx(my_w, input_is_latent=True)
174
+ elif model == 'Caitlyn':
175
+ with torch.no_grad():
176
+ my_sample = generatorcaitlyn(my_w, input_is_latent=True)
177
+ elif model == 'Yasuho':
178
+ with torch.no_grad():
179
+ my_sample = generatoryasuho(my_w, input_is_latent=True)
180
+ elif model == 'Arcane Multi':
181
+ with torch.no_grad():
182
+ my_sample = generatorarcanemulti(my_w, input_is_latent=True)
183
+ elif model == 'Art':
184
+ with torch.no_grad():
185
+ my_sample = generatorart(my_w, input_is_latent=True)
186
+ elif model == 'Spider-Verse':
187
+ with torch.no_grad():
188
+ my_sample = generatorspider(my_w, input_is_latent=True)
189
+ else:
190
+ with torch.no_grad():
191
+ my_sample = generatorsketch(my_w, input_is_latent=True)
192
+
193
+
194
+ npimage = my_sample[0].permute(1, 2, 0).detach().numpy()
195
+ imageio.imwrite('filename.jpeg', npimage)
196
+ return 'filename.jpeg'
197
+
198
+ title = "JoJoGAN"
199
+ description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
200
+
201
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"
202
+
203
+ examples=[['mona.png','Jinx']]
204
+ gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse','Sketch'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False).launch()