|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Generate images using pretrained network pickle.""" |
|
|
|
import os |
|
import re |
|
from typing import List, Optional |
|
|
|
import click |
|
import dnnlib |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
|
|
import legacy |
|
|
|
|
|
|
|
def num_range(s: str) -> List[int]: |
|
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' |
|
|
|
range_re = re.compile(r'^(\d+)-(\d+)$') |
|
m = range_re.match(s) |
|
if m: |
|
return list(range(int(m.group(1)), int(m.group(2))+1)) |
|
vals = s.split(',') |
|
return [int(x) for x in vals] |
|
|
|
|
|
|
|
@click.command() |
|
@click.pass_context |
|
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) |
|
@click.option('--seeds', type=num_range, help='List of random seeds') |
|
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) |
|
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') |
|
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) |
|
@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') |
|
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') |
|
def generate_images( |
|
ctx: click.Context, |
|
network_pkl: str, |
|
seeds: Optional[List[int]], |
|
truncation_psi: float, |
|
noise_mode: str, |
|
outdir: str, |
|
class_idx: Optional[int], |
|
projected_w: Optional[str] |
|
): |
|
"""Generate images using pretrained network pickle. |
|
|
|
Examples: |
|
|
|
\b |
|
# Generate curated MetFaces images without truncation (Fig.10 left) |
|
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl |
|
|
|
\b |
|
# Generate uncurated MetFaces images with truncation (Fig.12 upper left) |
|
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl |
|
|
|
\b |
|
# Generate class conditional CIFAR-10 images (Fig.17 left, Car) |
|
python generate.py --outdir=out --seeds=0-35 --class=1 \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl |
|
|
|
\b |
|
# Render an image from projected W |
|
python generate.py --outdir=out --projected_w=projected_w.npz \\ |
|
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl |
|
""" |
|
|
|
print('Loading networks from "%s"...' % network_pkl) |
|
device = torch.device('cuda') |
|
with dnnlib.util.open_url(network_pkl) as f: |
|
G = legacy.load_network_pkl(f)['G_ema'].to(device) |
|
|
|
os.makedirs(outdir, exist_ok=True) |
|
|
|
|
|
if projected_w is not None: |
|
if seeds is not None: |
|
print ('warn: --seeds is ignored when using --projected-w') |
|
print(f'Generating images from projected W "{projected_w}"') |
|
ws = np.load(projected_w)['w'] |
|
ws = torch.tensor(ws, device=device) |
|
assert ws.shape[1:] == (G.num_ws, G.w_dim) |
|
for idx, w in enumerate(ws): |
|
img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png') |
|
return |
|
|
|
if seeds is None: |
|
ctx.fail('--seeds option is required when not using --projected-w') |
|
|
|
|
|
label = torch.zeros([1, G.c_dim], device=device) |
|
if G.c_dim != 0: |
|
if class_idx is None: |
|
ctx.fail('Must specify class label with --class when using a conditional network') |
|
label[:, class_idx] = 1 |
|
else: |
|
if class_idx is not None: |
|
print ('warn: --class=lbl ignored when running on an unconditional network') |
|
|
|
|
|
for seed_idx, seed in enumerate(seeds): |
|
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) |
|
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) |
|
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
generate_images() |
|
|
|
|
|
|