Spaces:
Runtime error
Runtime error
Yannic Kilcher
commited on
Commit
•
5c824be
1
Parent(s):
6f160b3
added interfaces for interpolation and projection
Browse files- .gitignore +1 -0
- README.md +5 -0
- interface.py +70 -0
- interface_projector.py +126 -0
- interpolate.py +10 -0
- projector.py +54 -5
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
__pycache__/
|
2 |
.cache/
|
|
|
|
1 |
__pycache__/
|
2 |
.cache/
|
3 |
+
proj.mp4
|
README.md
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
## StyleGAN2-ADA — Official PyTorch implementation
|
2 |
|
3 |
![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)
|
|
|
1 |
+
## Project repo for apes by ykilcher
|
2 |
+
|
3 |
+
Note: most of the code is taken from nvlabs/stylegan2-ada-pytroch (original readme below).
|
4 |
+
I added gradio interfaces and CLIP projection.
|
5 |
+
|
6 |
## StyleGAN2-ADA — Official PyTorch implementation
|
7 |
|
8 |
![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)
|
interface.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
import types
|
9 |
+
|
10 |
+
from huggingface_hub import hf_hub_url, cached_download
|
11 |
+
|
12 |
+
# with open('../models/gamma500/network-snapshot-010000.pkl', 'rb') as f:
|
13 |
+
with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
|
14 |
+
G = pickle.load(f)['G_ema']# torch.nn.Module
|
15 |
+
|
16 |
+
device = torch.device("cpu")
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
device = torch.device("cuda")
|
19 |
+
G = G.to(device)
|
20 |
+
else:
|
21 |
+
_old_forward = G.forward
|
22 |
+
|
23 |
+
def _new_forward(self, *args, **kwargs):
|
24 |
+
kwargs["force_fp32"] = True
|
25 |
+
return _old_forward(self, *args, **kwargs)
|
26 |
+
|
27 |
+
G.forward = types.MethodType(_new_forward, G)
|
28 |
+
|
29 |
+
_old_synthesis_forward = G.synthesis.forward
|
30 |
+
|
31 |
+
def _new_synthesis_forward(self, *args, **kwargs):
|
32 |
+
kwargs["force_fp32"] = True
|
33 |
+
return _old_synthesis_forward(self, *args, **kwargs)
|
34 |
+
|
35 |
+
G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
|
36 |
+
|
37 |
+
|
38 |
+
def generate(num_images, interpolate):
|
39 |
+
if interpolate:
|
40 |
+
z1 = torch.randn([1, G.z_dim])# latent codes
|
41 |
+
z2 = torch.randn([1, G.z_dim])# latent codes
|
42 |
+
zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
|
43 |
+
else:
|
44 |
+
zs = torch.randn([num_images, G.z_dim])# latent codes
|
45 |
+
with torch.no_grad():
|
46 |
+
zs = zs.to(device)
|
47 |
+
img = G(zs, None, force_fp32=True, truncation_psi=1, noise_mode='const')
|
48 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
49 |
+
return img.cpu().numpy()
|
50 |
+
|
51 |
+
def greet(num_images, interpolate):
|
52 |
+
img = generate(round(num_images), interpolate)
|
53 |
+
imgs = list(img)
|
54 |
+
if len(imgs) == 1:
|
55 |
+
return imgs[0]
|
56 |
+
grid_len = int(np.ceil(np.sqrt(len(imgs)))) * 2
|
57 |
+
grid_height = int(np.ceil(len(imgs) / grid_len))
|
58 |
+
grid = np.zeros((grid_height * imgs[0].shape[0], grid_len * imgs[0].shape[1], 3), dtype=np.uint8)
|
59 |
+
for i, img in enumerate(imgs):
|
60 |
+
y = (i // grid_len) * img.shape[0]
|
61 |
+
x = (i % grid_len) * img.shape[1]
|
62 |
+
grid[y:y+img.shape[0], x:x+img.shape[1], :] = img
|
63 |
+
return grid
|
64 |
+
|
65 |
+
|
66 |
+
iface = gr.Interface(fn=greet, inputs=[
|
67 |
+
gr.inputs.Number(default=1, label="Num Images"),
|
68 |
+
gr.inputs.Checkbox(default=False, label="Interpolate")
|
69 |
+
], outputs="image")
|
70 |
+
iface.launch()
|
interface_projector.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
import PIL.Image
|
9 |
+
import types
|
10 |
+
|
11 |
+
from projector import project, imageio, _MODELS
|
12 |
+
|
13 |
+
from huggingface_hub import hf_hub_url, cached_download
|
14 |
+
|
15 |
+
# with open("../models/gamma500/network-snapshot-010000.pkl", "rb") as f:
|
16 |
+
# with open("../models/gamma400/network-snapshot-010600.pkl", "rb") as f:
|
17 |
+
# with open("../models/gamma400/network-snapshot-019600.pkl", "rb") as f:
|
18 |
+
with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
|
19 |
+
G = pickle.load(f)["G_ema"] # torch.nn.Module
|
20 |
+
|
21 |
+
device = torch.device("cpu")
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
device = torch.device("cuda")
|
24 |
+
G = G.to(device)
|
25 |
+
else:
|
26 |
+
_old_forward = G.forward
|
27 |
+
|
28 |
+
def _new_forward(self, *args, **kwargs):
|
29 |
+
kwargs["force_fp32"] = True
|
30 |
+
return _old_forward(self, *args, **kwargs)
|
31 |
+
|
32 |
+
G.forward = types.MethodType(_new_forward, G)
|
33 |
+
|
34 |
+
_old_synthesis_forward = G.synthesis.forward
|
35 |
+
|
36 |
+
def _new_synthesis_forward(self, *args, **kwargs):
|
37 |
+
kwargs["force_fp32"] = True
|
38 |
+
return _old_synthesis_forward(self, *args, **kwargs)
|
39 |
+
|
40 |
+
G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
|
41 |
+
|
42 |
+
|
43 |
+
def generate(
|
44 |
+
target_image_upload,
|
45 |
+
# target_image_webcam,
|
46 |
+
num_steps,
|
47 |
+
seed,
|
48 |
+
learning_rate,
|
49 |
+
model_name,
|
50 |
+
normalize_for_clip,
|
51 |
+
loss_type,
|
52 |
+
regularize_noise_weight,
|
53 |
+
initial_noise_factor,
|
54 |
+
):
|
55 |
+
seed = round(seed)
|
56 |
+
np.random.seed(seed)
|
57 |
+
torch.manual_seed(seed)
|
58 |
+
target_image = target_image_upload
|
59 |
+
# if target_image is None:
|
60 |
+
# target_image = target_image_webcam
|
61 |
+
num_steps = round(num_steps)
|
62 |
+
print(type(target_image))
|
63 |
+
print(target_image.dtype)
|
64 |
+
print(target_image.max())
|
65 |
+
print(target_image.min())
|
66 |
+
print(target_image.shape)
|
67 |
+
target_pil = PIL.Image.fromarray(target_image).convert("RGB")
|
68 |
+
w, h = target_pil.size
|
69 |
+
s = min(w, h)
|
70 |
+
target_pil = target_pil.crop(
|
71 |
+
((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)
|
72 |
+
)
|
73 |
+
target_pil = target_pil.resize(
|
74 |
+
(G.img_resolution, G.img_resolution), PIL.Image.LANCZOS
|
75 |
+
)
|
76 |
+
target_uint8 = np.array(target_pil, dtype=np.uint8)
|
77 |
+
target_image = torch.from_numpy(target_uint8.transpose([2, 0, 1])).to(device)
|
78 |
+
projected_w_steps = project(
|
79 |
+
G,
|
80 |
+
target=target_image,
|
81 |
+
num_steps=num_steps,
|
82 |
+
device=device,
|
83 |
+
verbose=True,
|
84 |
+
initial_learning_rate=learning_rate,
|
85 |
+
model_name=model_name,
|
86 |
+
normalize_for_clip=normalize_for_clip,
|
87 |
+
loss_type=loss_type,
|
88 |
+
regularize_noise_weight=regularize_noise_weight,
|
89 |
+
initial_noise_factor=initial_noise_factor,
|
90 |
+
)
|
91 |
+
with torch.no_grad():
|
92 |
+
video = imageio.get_writer(f'proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
|
93 |
+
for w in projected_w_steps:
|
94 |
+
synth_image = G.synthesis(w.to(device).unsqueeze(0), noise_mode="const")
|
95 |
+
synth_image = (synth_image + 1) * (255 / 2)
|
96 |
+
synth_image = (
|
97 |
+
synth_image.permute(0, 2, 3, 1)
|
98 |
+
.clamp(0, 255)
|
99 |
+
.to(torch.uint8)[0]
|
100 |
+
.cpu()
|
101 |
+
.numpy()
|
102 |
+
)
|
103 |
+
video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
|
104 |
+
video.close()
|
105 |
+
return synth_image, "proj.mp4"
|
106 |
+
|
107 |
+
|
108 |
+
iface = gr.Interface(
|
109 |
+
fn=generate,
|
110 |
+
inputs=[
|
111 |
+
gr.inputs.Image(source="upload", optional=True),
|
112 |
+
# gr.inputs.Image(source="webcam", optional=True),
|
113 |
+
gr.inputs.Number(default=250, label="steps"),
|
114 |
+
gr.inputs.Number(default=69420, label="seed"),
|
115 |
+
gr.inputs.Number(default=0.05, label="learning_rate"),
|
116 |
+
gr.inputs.Dropdown(default='RN50', label="model_name", choices=['vgg16', *_MODELS.keys()]),
|
117 |
+
gr.inputs.Checkbox(default=True, label="normalize_for_clip"),
|
118 |
+
gr.inputs.Dropdown(
|
119 |
+
default="l2", label="loss_type", choices=["l2", "l1", "cosine"]
|
120 |
+
),
|
121 |
+
gr.inputs.Number(default=1e5, label="regularize_noise_weight"),
|
122 |
+
gr.inputs.Number(default=0.05, label="initial_noise_factor"),
|
123 |
+
],
|
124 |
+
outputs=["image", "video"],
|
125 |
+
)
|
126 |
+
iface.launch(inbrowser=True)
|
interpolate.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
with open('../models/gamma500/network-snapshot-010000.pkl', 'rb') as f:
|
7 |
+
G = pickle.load(f)['G_ema']# torch.nn.Module
|
8 |
+
z = torch.randn([1, G.z_dim])# latent codes
|
9 |
+
c = None # class labels (not used in this example)
|
10 |
+
img = G(z, c, force_fp32=True) # NCHW, float32, dynamic range [-1, +1]
|
projector.py
CHANGED
@@ -22,6 +22,18 @@ import torch.nn.functional as F
|
|
22 |
import dnnlib
|
23 |
import legacy
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def project(
|
26 |
G,
|
27 |
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
|
@@ -35,6 +47,9 @@ def project(
|
|
35 |
noise_ramp_length = 0.75,
|
36 |
regularize_noise_weight = 1e5,
|
37 |
verbose = False,
|
|
|
|
|
|
|
38 |
device: torch.device
|
39 |
):
|
40 |
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
|
@@ -56,16 +71,38 @@ def project(
|
|
56 |
# Setup noise inputs.
|
57 |
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
|
58 |
|
|
|
59 |
# Load VGG16 feature detector.
|
60 |
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
with dnnlib.util.open_url(url) as f:
|
62 |
vgg16 = torch.jit.load(f).eval().to(device)
|
63 |
|
64 |
# Features for target image.
|
65 |
target_images = target.unsqueeze(0).to(device).to(torch.float32)
|
66 |
-
if
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
|
71 |
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
|
@@ -98,8 +135,20 @@ def project(
|
|
98 |
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
99 |
|
100 |
# Features for synth images.
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
# Noise regularization.
|
105 |
reg_loss = 0.0
|
|
|
22 |
import dnnlib
|
23 |
import legacy
|
24 |
|
25 |
+
_MODELS = {
|
26 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
27 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
28 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
29 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
30 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
31 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
32 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
33 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
34 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
35 |
+
}
|
36 |
+
|
37 |
def project(
|
38 |
G,
|
39 |
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
|
|
|
47 |
noise_ramp_length = 0.75,
|
48 |
regularize_noise_weight = 1e5,
|
49 |
verbose = False,
|
50 |
+
model_name='vgg16',
|
51 |
+
loss_type='l2',
|
52 |
+
normalize_for_clip=True,
|
53 |
device: torch.device
|
54 |
):
|
55 |
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
|
|
|
71 |
# Setup noise inputs.
|
72 |
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
|
73 |
|
74 |
+
USE_CLIP = model_name != 'vgg16'
|
75 |
# Load VGG16 feature detector.
|
76 |
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
77 |
+
if USE_CLIP:
|
78 |
+
# url = 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt'
|
79 |
+
# url = 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt'
|
80 |
+
# url = 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt'
|
81 |
+
# url = 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt'
|
82 |
+
url = _MODELS[model_name]
|
83 |
with dnnlib.util.open_url(url) as f:
|
84 |
vgg16 = torch.jit.load(f).eval().to(device)
|
85 |
|
86 |
# Features for target image.
|
87 |
target_images = target.unsqueeze(0).to(device).to(torch.float32)
|
88 |
+
if USE_CLIP:
|
89 |
+
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)[:, None, None]
|
90 |
+
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)[:, None, None]
|
91 |
+
# target_images = F.interpolate(target_images, size=(224, 224), mode='area')
|
92 |
+
target_images = F.interpolate(target_images, size=(vgg16.input_resolution.item(), vgg16.input_resolution.item()), mode='area')
|
93 |
+
print("target_images.shape:", target_images.shape)
|
94 |
+
def _encode_image(image):
|
95 |
+
image = image / 255.
|
96 |
+
# image = torch.sigmoid(image)
|
97 |
+
if normalize_for_clip:
|
98 |
+
image = (image - image_mean) / image_std
|
99 |
+
return vgg16.encode_image(image)
|
100 |
+
target_features = _encode_image(target_images.clamp(0, 255))
|
101 |
+
target_features = target_features.detach()
|
102 |
+
else:
|
103 |
+
if target_images.shape[2] > 256:
|
104 |
+
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
|
105 |
+
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
|
106 |
|
107 |
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
|
108 |
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
|
|
|
135 |
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
136 |
|
137 |
# Features for synth images.
|
138 |
+
if USE_CLIP:
|
139 |
+
synth_images = F.interpolate(synth_images, size=(vgg16.input_resolution.item(), vgg16.input_resolution.item()), mode='area')
|
140 |
+
synth_features = _encode_image(synth_images)
|
141 |
+
if loss_type == 'cosine':
|
142 |
+
target_features_normalized = target_features / target_features.norm(dim=-1, keepdim=True).detach()
|
143 |
+
synth_features_normalized = synth_features / synth_features.norm(dim=-1, keepdim=True).detach()
|
144 |
+
dist = 1.0 - torch.sum(synth_features_normalized * target_features_normalized)
|
145 |
+
elif loss_type == 'l1':
|
146 |
+
dist = (target_features - synth_features).abs().sum()
|
147 |
+
else:
|
148 |
+
dist = (target_features - synth_features).square().sum()
|
149 |
+
else:
|
150 |
+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
|
151 |
+
dist = (target_features - synth_features).square().sum()
|
152 |
|
153 |
# Noise regularization.
|
154 |
reg_loss = 0.0
|