add app and requirements
Browse files- app.py +193 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
MODEL_DIR = 'models/pretrain'
|
6 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
7 |
+
|
8 |
+
os.system("wget https://hkustconnect-my.sharepoint.com/:u:/g/personal/jzhubt_connect_ust_hk/ETYVen9KXGlAia2gH6pcZswB9Lw-21vWrE75OACvG2SBow\?e\=SCGqg0\&download=1 -O $MODEL_DIR/stylegan2-ffhq-config-f-1024x1024.pth --quiet")
|
9 |
+
|
10 |
+
|
11 |
+
# python 3.7
|
12 |
+
"""Demo."""
|
13 |
+
import io
|
14 |
+
import cv2
|
15 |
+
import warnings
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from PIL import Image
|
19 |
+
from models import build_model
|
20 |
+
|
21 |
+
warnings.filterwarnings(action='ignore', category=UserWarning)
|
22 |
+
|
23 |
+
def postprocess_image(image, min_val=-1.0, max_val=1.0):
|
24 |
+
"""Post-processes image to pixel range [0, 255] with dtype `uint8`.
|
25 |
+
|
26 |
+
This function is particularly used to handle the results produced by deep
|
27 |
+
models.
|
28 |
+
|
29 |
+
NOTE: The input image is assumed to be with format `NCHW`, and the returned
|
30 |
+
image will always be with format `NHWC`.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
image: The input image for post-processing.
|
34 |
+
min_val: Expected minimum value of the input image.
|
35 |
+
max_val: Expected maximum value of the input image.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
The post-processed image.
|
39 |
+
"""
|
40 |
+
assert isinstance(image, np.ndarray)
|
41 |
+
|
42 |
+
image = image.astype(np.float64)
|
43 |
+
image = (image - min_val) / (max_val - min_val) * 255
|
44 |
+
image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
|
45 |
+
|
46 |
+
assert image.ndim == 4 and image.shape[1] in [1, 3, 4]
|
47 |
+
return image.transpose(0, 2, 3, 1)
|
48 |
+
|
49 |
+
|
50 |
+
def to_numpy(data):
|
51 |
+
"""Converts the input data to `numpy.ndarray`."""
|
52 |
+
if isinstance(data, (int, float)):
|
53 |
+
return np.array(data)
|
54 |
+
if isinstance(data, np.ndarray):
|
55 |
+
return data
|
56 |
+
if isinstance(data, torch.Tensor):
|
57 |
+
return data.detach().cpu().numpy()
|
58 |
+
raise TypeError(f'Not supported data type `{type(data)}` for '
|
59 |
+
f'converting to `numpy.ndarray`!')
|
60 |
+
|
61 |
+
|
62 |
+
def linear_interpolate(latent_code,
|
63 |
+
boundary,
|
64 |
+
layer_index=None,
|
65 |
+
start_distance=-10.0,
|
66 |
+
end_distance=10.0,
|
67 |
+
steps=7):
|
68 |
+
"""Interpolate between the latent code and boundary."""
|
69 |
+
assert (len(latent_code.shape) == 3 and len(boundary.shape) == 3 and
|
70 |
+
latent_code.shape[0] == 1 and boundary.shape[0] == 1 and
|
71 |
+
latent_code.shape[1] == boundary.shape[1])
|
72 |
+
linspace = np.linspace(start_distance, end_distance, steps)
|
73 |
+
linspace = linspace.reshape([-1, 1, 1]).astype(np.float32)
|
74 |
+
inter_code = linspace * boundary
|
75 |
+
is_manipulatable = np.zeros(inter_code.shape, dtype=bool)
|
76 |
+
is_manipulatable[:, layer_index, :] = True
|
77 |
+
mani_code = np.where(is_manipulatable, latent_code+inter_code, latent_code)
|
78 |
+
return mani_code
|
79 |
+
|
80 |
+
|
81 |
+
def imshow(images, col, viz_size=256):
|
82 |
+
"""Shows images in one figure."""
|
83 |
+
num, height, width, channels = images.shape
|
84 |
+
assert num % col == 0
|
85 |
+
row = num // col
|
86 |
+
|
87 |
+
fused_image = np.zeros((viz_size*row, viz_size*col, channels), dtype=np.uint8)
|
88 |
+
|
89 |
+
for idx, image in enumerate(images):
|
90 |
+
i, j = divmod(idx, col)
|
91 |
+
y = i * viz_size
|
92 |
+
x = j * viz_size
|
93 |
+
if height != viz_size or width != viz_size:
|
94 |
+
image = cv2.resize(image, (viz_size, viz_size))
|
95 |
+
fused_image[y:y + viz_size, x:x + viz_size] = image
|
96 |
+
|
97 |
+
fused_image = np.asarray(fused_image, dtype=np.uint8)
|
98 |
+
data = io.BytesIO()
|
99 |
+
if channels == 4:
|
100 |
+
Image.fromarray(fused_image).save(data, 'png')
|
101 |
+
elif channels == 3:
|
102 |
+
Image.fromarray(fused_image).save(data, 'jpeg')
|
103 |
+
else:
|
104 |
+
raise ValueError('Image channel error')
|
105 |
+
im_data = data.getvalue()
|
106 |
+
image = Image.open(io.BytesIO(im_data))
|
107 |
+
return image
|
108 |
+
|
109 |
+
print('Building generator')
|
110 |
+
|
111 |
+
checkpoint_path=f'{MODEL_DIR}/stylegan2-ffhq-config-f-1024x1024.pth'
|
112 |
+
config = dict(model_type='StyleGAN2Generator',
|
113 |
+
resolution=1024,
|
114 |
+
w_dim=512,
|
115 |
+
fmaps_base=int(1 * (32 << 10)),
|
116 |
+
fmaps_max=512,)
|
117 |
+
generator = build_model(**config)
|
118 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
119 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
|
120 |
+
if 'generator_smooth' in checkpoint:
|
121 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
122 |
+
else:
|
123 |
+
generator.load_state_dict(checkpoint['generator'])
|
124 |
+
generator = generator.eval().cpu()
|
125 |
+
print('Finish loading checkpoint.')
|
126 |
+
|
127 |
+
print('Loading boundaries')
|
128 |
+
ATTRS = ['eyebrows', 'eyesize', 'gaze_direction', 'nose_length', 'mouth', 'lipstick']
|
129 |
+
boundaries = {}
|
130 |
+
for attr in ATTRS:
|
131 |
+
boundary_path = os.path.join(f'directions/ffhq/stylegan2/{attr}.npy')
|
132 |
+
boundary = np.load(boundary_path)
|
133 |
+
boundaries[attr] = boundary
|
134 |
+
print('Generator and boundaries are ready.')
|
135 |
+
|
136 |
+
|
137 |
+
def inference(num_of_image,seed,trunc_psi,eyebrows,eyesize,gaze_direction,nose_length,mouth,lipstick):
|
138 |
+
print('Sampling latent codes with given seed.')
|
139 |
+
num_of_image = num_of_image #@param {type:"slider", min:1, max:8, step:1}
|
140 |
+
seed = seed #@param {type:"slider", min:0, max:10000, step:1}
|
141 |
+
trunc_psi = trunc_psi #@param {type:"slider", min:0, max:1, step:0.1}
|
142 |
+
trunc_layers = 8
|
143 |
+
np.random.seed(seed)
|
144 |
+
latent_z = np.random.randn(num_of_image, generator.z_dim)
|
145 |
+
latent_z = torch.from_numpy(latent_z.astype(np.float32))
|
146 |
+
latent_z = latent_z.cpu()
|
147 |
+
wp = generator.mapping(latent_z, None)['wp']
|
148 |
+
if trunc_psi < 1.0:
|
149 |
+
w_avg = generator.w_avg
|
150 |
+
w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers]
|
151 |
+
wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi)
|
152 |
+
with torch.no_grad():
|
153 |
+
images_ori = generator.synthesis(wp)['image']
|
154 |
+
images_ori = postprocess_image(to_numpy(images_ori))
|
155 |
+
print('Original images are shown as belows.')
|
156 |
+
imshow(images_ori, col=images_ori.shape[0])
|
157 |
+
latent_wp = to_numpy(wp)
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
eyebrows = eyebrows #@param {type:"slider", min:-12.0, max:12.0, step:2}
|
162 |
+
eyesize = eyesize #@param {type:"slider", min:-12.0, max:12.0, step:2}
|
163 |
+
gaze_direction = gaze_direction #@param {type:"slider", min:-12.0, max:12.0, step:2}
|
164 |
+
nose_length = nose_length #@param {type:"slider", min:-12.0, max:12.0, step:2}
|
165 |
+
mouth = mouth #@param {type:"slider", min:-12.0, max:12.0, step:2}
|
166 |
+
lipstick = lipstick #@param {type:"slider", min:-12.0, max:12.0, step:2}
|
167 |
+
|
168 |
+
new_codes = latent_wp.copy()
|
169 |
+
for attr_name in ATTRS:
|
170 |
+
if attr_name in ['eyebrows', 'lipstick']:
|
171 |
+
layers_idx = [8,9,10,11]
|
172 |
+
else:
|
173 |
+
layers_idx = [4,5,6,7]
|
174 |
+
step = eval(attr_name)
|
175 |
+
direction = boundaries[attr_name]
|
176 |
+
direction = np.tile(direction, [1, generator.num_layers, 1])
|
177 |
+
new_codes[:, layers_idx, :] += direction[:, layers_idx, :] * step
|
178 |
+
new_codes = torch.from_numpy(new_codes.astype(np.float32)).cpu()
|
179 |
+
with torch.no_grad():
|
180 |
+
images_mani = generator.synthesis(new_codes)['image']
|
181 |
+
images_mani = postprocess_image(to_numpy(images_mani))
|
182 |
+
return imshow(images_mani, col=images_mani.shape[0])
|
183 |
+
|
184 |
+
gr.Interface(inference,[gr.Slider(1, 3, value=1,label="num_of_image"),
|
185 |
+
gr.Slider(0, 10000, value=210,label="seed"),
|
186 |
+
gr.Slider(0, 1, value=0.7,step=0.1,label="truncation psi"),
|
187 |
+
gr.Slider(-12, 12, value=0,label="eyebrows"),
|
188 |
+
gr.Slider(-12, 12, value=0,label="eyesize"),
|
189 |
+
gr.Slider(-12, 12, value=0,label="gaze direction"),
|
190 |
+
gr.Slider(-12, 12, value=0,label="nose_length"),
|
191 |
+
gr.Slider(-12, 12, value=0,label="mouth"),
|
192 |
+
gr.Slider(-12, 12, value=0,label="lipstick"),
|
193 |
+
],gr.Image(type="pil")).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
scikit-video
|
3 |
+
pillow
|
4 |
+
opencv-python-headless
|
5 |
+
numpy
|