Spaces:
Runtime error
Runtime error
Added StyleCLIP support
Browse files- app.py +32 -6
- generate_videos.py +2 -2
- styleclip/styleclip_global.py +158 -0
app.py
CHANGED
@@ -19,12 +19,16 @@ from torchvision import utils
|
|
19 |
|
20 |
from model.sg2_model import Generator
|
21 |
from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
|
|
|
|
|
|
|
22 |
|
23 |
model_dir = "models"
|
24 |
os.makedirs(model_dir, exist_ok=True)
|
25 |
|
26 |
model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
|
27 |
"dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
|
|
|
28 |
"base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
|
29 |
"anime": ("rinong/stylegan-nada-models", "anime.pt"),
|
30 |
"joker": ("rinong/stylegan-nada-models", "joker.pt"),
|
@@ -70,7 +74,7 @@ class ImageEditor(object):
|
|
70 |
|
71 |
self.generators = {}
|
72 |
|
73 |
-
self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
|
74 |
|
75 |
for model in self.model_list:
|
76 |
g_ema = Generator(
|
@@ -108,6 +112,10 @@ class ImageEditor(object):
|
|
108 |
model_paths["dlib"]
|
109 |
)
|
110 |
|
|
|
|
|
|
|
|
|
111 |
print("setup complete")
|
112 |
|
113 |
def get_style_list(self):
|
@@ -186,7 +194,15 @@ class ImageEditor(object):
|
|
186 |
target_latents.append(project_code_by_edit_name(np_source_latent, attribute_name, strength))
|
187 |
|
188 |
elif edit_choices["edit_type"] == "StyleCLIP":
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
# if edit type is none or if all slides were set to 0
|
192 |
if not target_latents:
|
@@ -228,9 +244,13 @@ class ImageEditor(object):
|
|
228 |
with torch.no_grad():
|
229 |
for g_ema in generators:
|
230 |
latent_for_gen = random.choice(target_latents)
|
231 |
-
latent_for_gen = [torch.from_numpy(latent_for_gen).float().to(self.device)]
|
232 |
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
|
236 |
utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
|
@@ -294,6 +314,9 @@ with blocks:
|
|
294 |
gr.Markdown(
|
295 |
"For more information about the paper and code for training your own models (with examples OR text), see below."
|
296 |
)
|
|
|
|
|
|
|
297 |
|
298 |
with gr.Row():
|
299 |
input_img = gr.inputs.Image(type="filepath", label="Input image")
|
@@ -306,7 +329,8 @@ with blocks:
|
|
306 |
with gr.Tabs():
|
307 |
with gr.TabItem("InterFaceGAN Editing Options"):
|
308 |
gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
|
309 |
-
gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together")
|
|
|
310 |
|
311 |
pose_slider = gr.Slider(label="Pose", minimum=-1, maximum=1, value=0, step=0.05)
|
312 |
smile_slider = gr.Slider(label="Smile", minimum=-1, maximum=1, value=0, step=0.05)
|
@@ -343,7 +367,9 @@ with blocks:
|
|
343 |
with gr.Row():
|
344 |
vid_button = gr.Button("Generate Video")
|
345 |
loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
|
346 |
-
|
|
|
|
|
347 |
with gr.Column():
|
348 |
vid_output = gr.outputs.Video(label="Output Video")
|
349 |
|
|
|
19 |
|
20 |
from model.sg2_model import Generator
|
21 |
from generate_videos import generate_frames, video_from_interpolations, project_code_by_edit_name
|
22 |
+
from styleclip.styleclip_global import project_code_with_styleclip, style_tensor_to_style_dict
|
23 |
+
|
24 |
+
import clip
|
25 |
|
26 |
model_dir = "models"
|
27 |
os.makedirs(model_dir, exist_ok=True)
|
28 |
|
29 |
model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
|
30 |
"dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
|
31 |
+
"sc_fs3": ("rinong/stylegan-nada-models", "fs3.npy"),
|
32 |
"base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
|
33 |
"anime": ("rinong/stylegan-nada-models", "anime.pt"),
|
34 |
"joker": ("rinong/stylegan-nada-models", "joker.pt"),
|
|
|
74 |
|
75 |
self.generators = {}
|
76 |
|
77 |
+
self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib", "sc_fs3"]]
|
78 |
|
79 |
for model in self.model_list:
|
80 |
g_ema = Generator(
|
|
|
112 |
model_paths["dlib"]
|
113 |
)
|
114 |
|
115 |
+
self.styleclip_fs3 = torch.from_numpy(np.load(model_paths["sc_fs3"])).to(self.device)
|
116 |
+
|
117 |
+
self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
|
118 |
+
|
119 |
print("setup complete")
|
120 |
|
121 |
def get_style_list(self):
|
|
|
194 |
target_latents.append(project_code_by_edit_name(np_source_latent, attribute_name, strength))
|
195 |
|
196 |
elif edit_choices["edit_type"] == "StyleCLIP":
|
197 |
+
source_s_dict = generators[0].get_s_code(source_latent, input_is_latent=True)
|
198 |
+
target_latents.append(project_code_with_styleclip(source_s_dict,
|
199 |
+
edit_choices["src_text"],
|
200 |
+
edit_choices["tar_text"],
|
201 |
+
edit_choices["alpha"],
|
202 |
+
edit_choices["beta"],
|
203 |
+
generators[0],
|
204 |
+
self.styleclip_fs3,
|
205 |
+
self.clip_model))
|
206 |
|
207 |
# if edit type is none or if all slides were set to 0
|
208 |
if not target_latents:
|
|
|
244 |
with torch.no_grad():
|
245 |
for g_ema in generators:
|
246 |
latent_for_gen = random.choice(target_latents)
|
|
|
247 |
|
248 |
+
if edit_choices["edit_type"] == "StyleCLIP":
|
249 |
+
latent_for_gen = style_tensor_to_style_dict(latent_for_gen, g_ema)
|
250 |
+
img, _ = g_ema(latent_for_gen, input_is_s_code=True, input_is_latent=True, truncation=1, randomize_noise=False)
|
251 |
+
else:
|
252 |
+
latent_for_gen = [torch.from_numpy(latent_for_gen).float().to(self.device)]
|
253 |
+
img, _ = g_ema(latent_for_gen, input_is_latent=True, truncation=1, randomize_noise=False)
|
254 |
|
255 |
output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
|
256 |
utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
|
|
|
314 |
gr.Markdown(
|
315 |
"For more information about the paper and code for training your own models (with examples OR text), see below."
|
316 |
)
|
317 |
+
|
318 |
+
|
319 |
+
gr.Markdown("<h4 style='font-size: 110%;margin-top:.5em'>On biases</h4><div>This model relies on StyleGAN and CLIP, both of which are prone to biases such as poor representation of minorities or reinforcement of societal biases, such as gender norms. </div>")
|
320 |
|
321 |
with gr.Row():
|
322 |
input_img = gr.inputs.Image(type="filepath", label="Input image")
|
|
|
329 |
with gr.Tabs():
|
330 |
with gr.TabItem("InterFaceGAN Editing Options"):
|
331 |
gr.Markdown("Move the sliders to make the chosen attribute stronger (e.g. the person older) or leave at 0 to disable editing.")
|
332 |
+
gr.Markdown("If multiple options are provided, they will be used randomly between images (or sequentially for a video), <u>not</u> together.")
|
333 |
+
gr.Markdown("Please note that some directions may be entangled. For example, hair length adjustments are likely to also modify the perceived gender.")
|
334 |
|
335 |
pose_slider = gr.Slider(label="Pose", minimum=-1, maximum=1, value=0, step=0.05)
|
336 |
smile_slider = gr.Slider(label="Smile", minimum=-1, maximum=1, value=0, step=0.05)
|
|
|
367 |
with gr.Row():
|
368 |
vid_button = gr.Button("Generate Video")
|
369 |
loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
|
370 |
+
with gr.Row():
|
371 |
+
gr.Markdown("Warning: Videos generation requires the synthesis of hundreds of frames and is expected to take several minutes.")
|
372 |
+
gr.Markdown("To reduce queue times, we significantly reduced the number of video frames. Using more than 3 styles will further reduce the frames per style, leading to quicker transitions. For better control, we reccomend cloning the gradio app, adjusting `num_alphas` in `generate_videos`, and running the code locally.")
|
373 |
with gr.Column():
|
374 |
vid_output = gr.outputs.Video(label="Output Video")
|
375 |
|
generate_videos.py
CHANGED
@@ -62,14 +62,14 @@ def generate_frames(source_latent, target_latents, g_ema_list, output_dir):
|
|
62 |
|
63 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
64 |
|
65 |
-
num_alphas = min(
|
66 |
|
67 |
alphas = np.linspace(0, 1, num=num_alphas)
|
68 |
|
69 |
latents = interpolate_with_target_latents(source_latent, target_latents, alphas)
|
70 |
|
71 |
segments = len(g_ema_list) - 1
|
72 |
-
|
73 |
if segments:
|
74 |
segment_length = len(latents) / segments
|
75 |
|
|
|
62 |
|
63 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
64 |
|
65 |
+
num_alphas = min(10, 30 // len(target_latents))
|
66 |
|
67 |
alphas = np.linspace(0, 1, num=num_alphas)
|
68 |
|
69 |
latents = interpolate_with_target_latents(source_latent, target_latents, alphas)
|
70 |
|
71 |
segments = len(g_ema_list) - 1
|
72 |
+
|
73 |
if segments:
|
74 |
segment_length = len(latents) / segments
|
75 |
|
styleclip/styleclip_global.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from pathlib import Path
|
5 |
+
import os
|
6 |
+
|
7 |
+
import clip
|
8 |
+
|
9 |
+
imagenet_templates = [
|
10 |
+
'a bad photo of a {}.',
|
11 |
+
'a photo of many {}.',
|
12 |
+
'a sculpture of a {}.',
|
13 |
+
'a photo of the hard to see {}.',
|
14 |
+
'a low resolution photo of the {}.',
|
15 |
+
'a rendering of a {}.',
|
16 |
+
'graffiti of a {}.',
|
17 |
+
'a bad photo of the {}.',
|
18 |
+
'a cropped photo of the {}.',
|
19 |
+
'a tattoo of a {}.',
|
20 |
+
'the embroidered {}.',
|
21 |
+
'a photo of a hard to see {}.',
|
22 |
+
'a bright photo of a {}.',
|
23 |
+
'a photo of a clean {}.',
|
24 |
+
'a photo of a dirty {}.',
|
25 |
+
'a dark photo of the {}.',
|
26 |
+
'a drawing of a {}.',
|
27 |
+
'a photo of my {}.',
|
28 |
+
'the plastic {}.',
|
29 |
+
'a photo of the cool {}.',
|
30 |
+
'a close-up photo of a {}.',
|
31 |
+
'a black and white photo of the {}.',
|
32 |
+
'a painting of the {}.',
|
33 |
+
'a painting of a {}.',
|
34 |
+
'a pixelated photo of the {}.',
|
35 |
+
'a sculpture of the {}.',
|
36 |
+
'a bright photo of the {}.',
|
37 |
+
'a cropped photo of a {}.',
|
38 |
+
'a plastic {}.',
|
39 |
+
'a photo of the dirty {}.',
|
40 |
+
'a jpeg corrupted photo of a {}.',
|
41 |
+
'a blurry photo of the {}.',
|
42 |
+
'a photo of the {}.',
|
43 |
+
'a good photo of the {}.',
|
44 |
+
'a rendering of the {}.',
|
45 |
+
'a {} in a video game.',
|
46 |
+
'a photo of one {}.',
|
47 |
+
'a doodle of a {}.',
|
48 |
+
'a close-up photo of the {}.',
|
49 |
+
'a photo of a {}.',
|
50 |
+
'the origami {}.',
|
51 |
+
'the {} in a video game.',
|
52 |
+
'a sketch of a {}.',
|
53 |
+
'a doodle of the {}.',
|
54 |
+
'a origami {}.',
|
55 |
+
'a low resolution photo of a {}.',
|
56 |
+
'the toy {}.',
|
57 |
+
'a rendition of the {}.',
|
58 |
+
'a photo of the clean {}.',
|
59 |
+
'a photo of a large {}.',
|
60 |
+
'a rendition of a {}.',
|
61 |
+
'a photo of a nice {}.',
|
62 |
+
'a photo of a weird {}.',
|
63 |
+
'a blurry photo of a {}.',
|
64 |
+
'a cartoon {}.',
|
65 |
+
'art of a {}.',
|
66 |
+
'a sketch of the {}.',
|
67 |
+
'a embroidered {}.',
|
68 |
+
'a pixelated photo of a {}.',
|
69 |
+
'itap of the {}.',
|
70 |
+
'a jpeg corrupted photo of the {}.',
|
71 |
+
'a good photo of a {}.',
|
72 |
+
'a plushie {}.',
|
73 |
+
'a photo of the nice {}.',
|
74 |
+
'a photo of the small {}.',
|
75 |
+
'a photo of the weird {}.',
|
76 |
+
'the cartoon {}.',
|
77 |
+
'art of the {}.',
|
78 |
+
'a drawing of the {}.',
|
79 |
+
'a photo of the large {}.',
|
80 |
+
'a black and white photo of a {}.',
|
81 |
+
'the plushie {}.',
|
82 |
+
'a dark photo of a {}.',
|
83 |
+
'itap of a {}.',
|
84 |
+
'graffiti of the {}.',
|
85 |
+
'a toy {}.',
|
86 |
+
'itap of my {}.',
|
87 |
+
'a photo of a cool {}.',
|
88 |
+
'a photo of a small {}.',
|
89 |
+
'a tattoo of the {}.',
|
90 |
+
]
|
91 |
+
|
92 |
+
FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
|
93 |
+
[(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
|
94 |
+
|
95 |
+
def zeroshot_classifier(model, classnames, templates, device):
|
96 |
+
|
97 |
+
with torch.no_grad():
|
98 |
+
zeroshot_weights = []
|
99 |
+
for classname in tqdm(classnames):
|
100 |
+
texts = [template.format(classname) for template in templates] # format with class
|
101 |
+
texts = clip.tokenize(texts).to(device) # tokenize
|
102 |
+
class_embeddings = model.encode_text(texts) # embed with text encoder
|
103 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
104 |
+
class_embedding = class_embeddings.mean(dim=0)
|
105 |
+
class_embedding /= class_embedding.norm()
|
106 |
+
zeroshot_weights.append(class_embedding)
|
107 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
108 |
+
return zeroshot_weights
|
109 |
+
|
110 |
+
|
111 |
+
def get_direction(neutral_class, target_class, beta, di, clip_model=None):
|
112 |
+
|
113 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
114 |
+
|
115 |
+
if clip_model is None:
|
116 |
+
clip_model, _ = clip.load("ViT-B/32", device=device)
|
117 |
+
|
118 |
+
class_names = [neutral_class, target_class]
|
119 |
+
class_weights = zeroshot_classifier(clip_model, class_names, imagenet_templates, device)
|
120 |
+
|
121 |
+
dt = class_weights[:, 1] - class_weights[:, 0]
|
122 |
+
dt = dt / dt.norm()
|
123 |
+
relevance = di @ dt
|
124 |
+
mask = relevance.abs() > beta
|
125 |
+
direction = relevance * mask
|
126 |
+
direction_max = direction.abs().max()
|
127 |
+
if direction_max > 0:
|
128 |
+
direction = direction / direction_max
|
129 |
+
else:
|
130 |
+
raise ValueError(f'Beta value {beta} is too high for mapping from {neutral_class} to {target_class},'
|
131 |
+
f' try setting it to a lower value')
|
132 |
+
return direction
|
133 |
+
|
134 |
+
def style_tensor_to_style_dict(style_tensor, refernce_generator):
|
135 |
+
style_layers = refernce_generator.modulation_layers
|
136 |
+
|
137 |
+
style_dict = {}
|
138 |
+
for layer_idx, layer in enumerate(style_layers):
|
139 |
+
style_dict[layer] = style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]]
|
140 |
+
|
141 |
+
return style_dict
|
142 |
+
|
143 |
+
def style_dict_to_style_tensor(style_dict, reference_generator):
|
144 |
+
style_layers = reference_generator.modulation_layers
|
145 |
+
|
146 |
+
style_tensor = torch.zeros(shape=(1, 9088))
|
147 |
+
for layer in style_dict:
|
148 |
+
layer_idx = style_layers.index(layer)
|
149 |
+
style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
|
150 |
+
|
151 |
+
return style_tensor
|
152 |
+
|
153 |
+
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
|
154 |
+
edit_direction = get_direction(source_class, target_class, beta)
|
155 |
+
|
156 |
+
source_s = style_dict_to_style_tensor(source_latent, reference_generator)
|
157 |
+
|
158 |
+
return source_s + alpha * edit_direction
|