Spaces:
Sleeping
Sleeping
update find_direction
Browse files- .DS_Store +0 -0
- find_direction.py +13 -121
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
find_direction.py
CHANGED
@@ -8,34 +8,17 @@
|
|
8 |
|
9 |
"""Generate images using pretrained network pickle."""
|
10 |
|
11 |
-
import os
|
12 |
-
import re
|
13 |
-
import random
|
14 |
import math
|
15 |
-
import time
|
16 |
-
import click
|
17 |
import legacy
|
18 |
-
from typing import List, Optional
|
19 |
-
|
20 |
-
import cv2
|
21 |
import clip
|
22 |
import dnnlib
|
23 |
import numpy as np
|
24 |
import torch
|
25 |
-
from torch import linalg as LA
|
26 |
import torch.nn.functional as F
|
27 |
-
import
|
28 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
29 |
-
import PIL.Image
|
30 |
from PIL import Image
|
31 |
-
import matplotlib.pyplot as plt
|
32 |
-
|
33 |
from torch_utils import misc
|
34 |
-
from torch_utils import persistence
|
35 |
-
from torch_utils.ops import conv2d_resample
|
36 |
from torch_utils.ops import upfirdn2d
|
37 |
-
from torch_utils.ops import bias_act
|
38 |
-
from torch_utils.ops import fma
|
39 |
import id_loss
|
40 |
|
41 |
|
@@ -81,8 +64,6 @@ def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None
|
|
81 |
assert img is None or img.dtype == torch.float32
|
82 |
return x, img
|
83 |
|
84 |
-
|
85 |
-
|
86 |
def unravel_index(index, shape):
|
87 |
out = []
|
88 |
for dim in reversed(shape):
|
@@ -90,108 +71,27 @@ def unravel_index(index, shape):
|
|
90 |
index = index // dim
|
91 |
return tuple(reversed(out))
|
92 |
|
93 |
-
|
94 |
-
def num_range(s: str) -> List[int]:
|
95 |
-
"""
|
96 |
-
Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.
|
97 |
-
"""
|
98 |
-
|
99 |
-
range_re = re.compile(r'^(\d+)-(\d+)$')
|
100 |
-
m = range_re.match(s)
|
101 |
-
if m:
|
102 |
-
return list(range(int(m.group(1)), int(m.group(2)) + 1))
|
103 |
-
vals = s.split(',')
|
104 |
-
return [int(x) for x in vals]
|
105 |
-
|
106 |
-
|
107 |
-
@click.command()
|
108 |
-
@click.pass_context
|
109 |
-
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
110 |
-
@click.option('--seeds', type=num_range, help='List of random seeds')
|
111 |
-
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
|
112 |
-
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
|
113 |
-
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
114 |
-
@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
|
115 |
-
@click.option('--projected_s', help='Projection result file', type=str, metavar='FILE')
|
116 |
-
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
|
117 |
-
@click.option('--text_prompt', help='Text', type=str, required=True)
|
118 |
-
@click.option('--resolution', help='Resolution of output images', type=int, required=True)
|
119 |
-
@click.option('--batch_size', help='Batch Size', type=int, required=True)
|
120 |
-
@click.option('--identity_power', help='How much change occurs on the face', type=str, required=True)
|
121 |
-
def generate_images(
|
122 |
-
ctx: click.Context,
|
123 |
network_pkl: str,
|
124 |
-
seeds: Optional[List[int]],
|
125 |
-
truncation_psi: float,
|
126 |
-
noise_mode: str,
|
127 |
-
outdir: str,
|
128 |
-
class_idx: Optional[int],
|
129 |
-
projected_w: Optional[str],
|
130 |
-
projected_s: Optional[str],
|
131 |
text_prompt: str,
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
135 |
):
|
136 |
-
|
137 |
-
Generate images using pretrained network pickle.
|
138 |
-
|
139 |
-
Examples:
|
140 |
-
# Generate curated MetFaces images without truncation (Fig.10 left)
|
141 |
-
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
|
142 |
-
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
143 |
-
|
144 |
-
# Generate uncurated MetFaces images with truncation (Fig.12 upper left)
|
145 |
-
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
|
146 |
-
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
147 |
-
|
148 |
-
# Generate class conditional CIFAR-10 images (Fig.17 left, Car)
|
149 |
-
python generate.py --outdir=out --seeds=0-35 --class=1 \\
|
150 |
-
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
|
151 |
-
|
152 |
-
# Render an image from projected W
|
153 |
-
python generate.py --outdir=out --projected_w=projected_w.npz \\
|
154 |
-
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
155 |
-
"""
|
156 |
|
|
|
157 |
print('Loading networks from "%s"...' % network_pkl)
|
158 |
-
|
159 |
-
if torch.cuda.is_available():
|
160 |
-
device = torch.device("cuda")
|
161 |
-
else:
|
162 |
-
device = torch.device("cpu")
|
163 |
-
|
164 |
with dnnlib.util.open_url(network_pkl) as f:
|
165 |
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
166 |
|
167 |
-
os.makedirs(outdir, exist_ok=True)
|
168 |
-
|
169 |
-
# Synthesize the result of a W projection
|
170 |
-
if projected_w is not None:
|
171 |
-
if seeds is not None:
|
172 |
-
print('warn: --seeds is ignored when using --projected-w')
|
173 |
-
print(f'Generating images from projected W "{projected_w}"')
|
174 |
-
ws = np.load(projected_w)['w']
|
175 |
-
ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
|
176 |
-
assert ws.shape[1:] == (G.num_ws, G.w_dim)
|
177 |
-
for idx, w in enumerate(ws):
|
178 |
-
img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
|
179 |
-
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
180 |
-
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
|
181 |
-
return
|
182 |
-
|
183 |
-
if seeds is None:
|
184 |
-
ctx.fail('--seeds option is required when not using --projected-w')
|
185 |
-
|
186 |
# Labels
|
|
|
187 |
label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
|
188 |
if G.c_dim != 0:
|
189 |
-
if class_idx is None:
|
190 |
-
ctx.fail('Must specify class label with --class when using a conditional network')
|
191 |
label[:, class_idx] = 1
|
192 |
-
else:
|
193 |
-
if class_idx is not None:
|
194 |
-
print('warn: --class=lbl ignored when running on an unconditional network')
|
195 |
|
196 |
model, preprocess = clip.load("ViT-B/32", device=device)
|
197 |
text = clip.tokenize([text_prompt]).to(device)
|
@@ -211,8 +111,6 @@ def generate_images(
|
|
211 |
transf = Compose([Resize(224, interpolation=Image.BICUBIC), CenterCrop(224)])
|
212 |
|
213 |
styles_array = []
|
214 |
-
print("seeds:", seeds)
|
215 |
-
t1 = time.time()
|
216 |
for seed_idx, seed in enumerate(seeds):
|
217 |
if seed == seeds[-1]:
|
218 |
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
|
@@ -260,8 +158,7 @@ def generate_images(
|
|
260 |
styles_array.append(styles)
|
261 |
|
262 |
resolution_dict = {256: 6, 512: 7, 1024: 8}
|
263 |
-
|
264 |
-
id_coeff = id_coeff_dict[identity_power]
|
265 |
styles_direction = torch.zeros(1, 26, 512, device=device)
|
266 |
styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device)
|
267 |
styles_direction.requires_grad_()
|
@@ -272,7 +169,6 @@ def generate_images(
|
|
272 |
temp_photos = []
|
273 |
grads = []
|
274 |
for i in range(math.ceil(len(seeds) / batch_size)):
|
275 |
-
# print(i*batch_size, "processed", time.time()-t1)
|
276 |
|
277 |
styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
|
278 |
seed = seeds[i]
|
@@ -325,6 +221,7 @@ def generate_images(
|
|
325 |
styles_direction *= 0
|
326 |
|
327 |
for i in range(math.ceil(len(seeds) / batch_size)):
|
|
|
328 |
seed = seeds[i]
|
329 |
styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
|
330 |
img2 = torch.tensor(temp_photos[i]).to(device)
|
@@ -364,9 +261,4 @@ def generate_images(
|
|
364 |
styles_direction = styles_direction.detach()
|
365 |
styles_direction[styles_direction_grad_el2 > (len(seeds) / batch_size) / 4] = 0
|
366 |
|
367 |
-
|
368 |
-
np.savez(output_filepath, s=styles_direction.cpu().numpy())
|
369 |
-
|
370 |
-
|
371 |
-
if __name__ == "__main__":
|
372 |
-
generate_images()
|
|
|
8 |
|
9 |
"""Generate images using pretrained network pickle."""
|
10 |
|
|
|
|
|
|
|
11 |
import math
|
|
|
|
|
12 |
import legacy
|
|
|
|
|
|
|
13 |
import clip
|
14 |
import dnnlib
|
15 |
import numpy as np
|
16 |
import torch
|
|
|
17 |
import torch.nn.functional as F
|
18 |
+
from torchvision.transforms import Compose, Resize, CenterCrop
|
|
|
|
|
19 |
from PIL import Image
|
|
|
|
|
20 |
from torch_utils import misc
|
|
|
|
|
21 |
from torch_utils.ops import upfirdn2d
|
|
|
|
|
22 |
import id_loss
|
23 |
|
24 |
|
|
|
64 |
assert img is None or img.dtype == torch.float32
|
65 |
return x, img
|
66 |
|
|
|
|
|
67 |
def unravel_index(index, shape):
|
68 |
out = []
|
69 |
for dim in reversed(shape):
|
|
|
71 |
index = index // dim
|
72 |
return tuple(reversed(out))
|
73 |
|
74 |
+
def find_direction(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
network_pkl: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
text_prompt: str,
|
77 |
+
truncation_psi: float = 0.7,
|
78 |
+
noise_mode: str = "const",
|
79 |
+
resolution: int = 256,
|
80 |
+
identity_power: float = 0.5,
|
81 |
):
|
82 |
+
seeds=np.random.randint(0, 1000, 128)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
batch_size=1
|
85 |
print('Loading networks from "%s"...' % network_pkl)
|
86 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
87 |
with dnnlib.util.open_url(network_pkl) as f:
|
88 |
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
# Labels
|
91 |
+
class_idx=None
|
92 |
label = torch.zeros([1, G.c_dim], device=device).requires_grad_()
|
93 |
if G.c_dim != 0:
|
|
|
|
|
94 |
label[:, class_idx] = 1
|
|
|
|
|
|
|
95 |
|
96 |
model, preprocess = clip.load("ViT-B/32", device=device)
|
97 |
text = clip.tokenize([text_prompt]).to(device)
|
|
|
111 |
transf = Compose([Resize(224, interpolation=Image.BICUBIC), CenterCrop(224)])
|
112 |
|
113 |
styles_array = []
|
|
|
|
|
114 |
for seed_idx, seed in enumerate(seeds):
|
115 |
if seed == seeds[-1]:
|
116 |
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
|
|
|
158 |
styles_array.append(styles)
|
159 |
|
160 |
resolution_dict = {256: 6, 512: 7, 1024: 8}
|
161 |
+
id_coeff = identity_power
|
|
|
162 |
styles_direction = torch.zeros(1, 26, 512, device=device)
|
163 |
styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device)
|
164 |
styles_direction.requires_grad_()
|
|
|
169 |
temp_photos = []
|
170 |
grads = []
|
171 |
for i in range(math.ceil(len(seeds) / batch_size)):
|
|
|
172 |
|
173 |
styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
|
174 |
seed = seeds[i]
|
|
|
221 |
styles_direction *= 0
|
222 |
|
223 |
for i in range(math.ceil(len(seeds) / batch_size)):
|
224 |
+
|
225 |
seed = seeds[i]
|
226 |
styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device)
|
227 |
img2 = torch.tensor(temp_photos[i]).to(device)
|
|
|
261 |
styles_direction = styles_direction.detach()
|
262 |
styles_direction[styles_direction_grad_el2 > (len(seeds) / batch_size) / 4] = 0
|
263 |
|
264 |
+
return styles_direction.cpu().numpy()
|
|
|
|
|
|
|
|
|
|