diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b5ee9bf994cc9441cb659c3527160b4ee5bcb33 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,97 @@ +Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. + + +NVIDIA Source Code License for StyleGAN3 + + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. Notwithstanding + the foregoing, NVIDIA and its affiliates may use the Work and any + derivative works commercially. As used herein, "non-commercially" + means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grant in Section 2.1) will terminate + immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor’s or its affiliates’ names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grant in Section 2.1) will + terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5c169484e7883263acb0404500e611bad4971279 --- /dev/null +++ b/app.py @@ -0,0 +1,169 @@ +import os +os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html") +os.system("git clone https://github.com/NVlabs/stylegan3") +os.system("git clone https://github.com/openai/CLIP") +os.system("pip install -e ./CLIP") +os.system("pip install einops ninja scipy numpy Pillow tqdm") +import sys +sys.path.append('./CLIP') +sys.path.append('./stylegan3') +import io +import os, time +import pickle +import shutil +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import requests +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +import clip +from tqdm.notebook import tqdm +from torchvision.transforms import Compose, Resize, ToTensor, Normalize +from einops import rearrange +device = torch.device('cuda:0') +def fetch(url_or_path): + if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): + r = requests.get(url_or_path) + r.raise_for_status() + fd = io.BytesIO() + fd.write(r.content) + fd.seek(0) + return fd + return open(url_or_path, 'rb') +def fetch_model(url_or_path): + basename = os.path.basename(url_or_path) + if os.path.exists(basename): + return basename + else: + os.system("wget -c '{url_or_path}'") + return basename +def norm1(prompt): + "Normalize to the unit sphere." + return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt() +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) +class MakeCutouts(torch.nn.Module): + def __init__(self, cut_size, cutn, cut_pow=1.): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.cut_pow = cut_pow + def forward(self, input): + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(self.cutn): + size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) +make_cutouts = MakeCutouts(224, 32, 0.5) +def embed_image(image): + n = image.shape[0] + cutouts = make_cutouts(image) + embeds = clip_model.embed_cutout(cutouts) + embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n) + return embeds +def embed_url(url): + image = Image.open(fetch(url)).convert('RGB') + return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0) +class CLIP(object): + def __init__(self): + clip_model = "ViT-B/32" + self.model, _ = clip.load(clip_model) + self.model = self.model.requires_grad_(False) + self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + @torch.no_grad() + def embed_text(self, prompt): + "Normalized clip text embedding." + return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float()) + def embed_cutout(self, image): + "Normalized clip image embedding." + return norm1(self.model.encode_image(self.normalize(image))) + +clip_model = CLIP() +# Load stylegan model +base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/" +model_name = "stylegan3-t-ffhqu-1024x1024.pkl" +#model_name = "stylegan3-r-metfacesu-1024x1024.pkl" +#model_name = "stylegan3-t-afhqv2-512x512.pkl" +network_url = base_url + model_name +os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl") +with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp: + G = pickle.load(fp)['G_ema'].to(device) +zs = torch.randn([10000, G.mapping.z_dim], device=device) +w_stds = G.mapping(zs, None).std(0) + + +def inference(text): + target = clip_model.embed_text(text) + steps = 600 + seed = 2 + tf = Compose([ + Resize(224), + lambda x: torch.clamp((x+1)/2,min=0,max=1), + ]) + torch.manual_seed(seed) + timestring = time.strftime('%Y%m%d%H%M%S') + with torch.no_grad(): + qs = [] + losses = [] + for _ in range(8): + q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds + images = G.synthesis(q * w_stds + G.mapping.w_avg) + embeds = embed_image(images.add(1).div(2)) + loss = spherical_dist_loss(embeds, target).mean(0) + i = torch.argmin(loss) + qs.append(q[i]) + losses.append(loss[i]) + qs = torch.stack(qs) + losses = torch.stack(losses) + print(losses) + print(losses.shape, qs.shape) + i = torch.argmin(losses) + q = qs[i].unsqueeze(0) + q.requires_grad_() + q_ema = q + opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999)) + loop = tqdm(range(steps)) + for i in loop: + opt.zero_grad() + w = q * w_stds + image = G.synthesis(w + G.mapping.w_avg, noise_mode='const') + embed = embed_image(image.add(1).div(2)) + loss = spherical_dist_loss(embed, target).mean() + loss.backward() + opt.step() + loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item()) + q_ema = q_ema * 0.9 + q * 0.1 + image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const') + if i % 10 == 0: + display(TF.to_pil_image(tf(image)[0])) + pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1)) + #os.makedirs(f'samples/{timestring}', exist_ok=True) + #pil_image.save(f'samples/{timestring}/{i:04}.jpg') + return pil_image + + +title = "StyleGAN+CLIP_with_Latent_Bootstraping" +description = "Gradio demo for StyleGAN+CLIP_with_Latent_Bootstraping. To use it, simply add your text, or click one of the examples to load them. Read more at the links below." +article = "
colab by https://twitter.com/EricHallahan Colab
" +examples = [['elon musk']] +gr.Interface( + inference, + "text", + gr.outputs.Image(type="pil", label="Output"), + title=title, + description=description, + article=article, + enable_queue=True, + examples=examples + ).launch(debug=True) diff --git a/avg_spectra.py b/avg_spectra.py new file mode 100644 index 0000000000000000000000000000000000000000..afaef87de54e49df230b432b52fda92667d17667 --- /dev/null +++ b/avg_spectra.py @@ -0,0 +1,276 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Compare average power spectra between real and generated images, +or between multiple generators.""" + +import os +import numpy as np +import torch +import torch.fft +import scipy.ndimage +import matplotlib.pyplot as plt +import click +import tqdm +import dnnlib + +import legacy +from training import dataset + +#---------------------------------------------------------------------------- +# Setup an iterator for streaming images, in uint8 NCHW format, based on the +# respective command line options. + +def stream_source_images(source, num, seed, device, data_loader_kwargs=None): # => num_images, image_size, image_iter + ext = source.split('.')[-1].lower() + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + if ext == 'pkl': + if num is None: + raise click.ClickException('--num is required when --source points to network pickle') + with dnnlib.util.open_url(source) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) + def generate_image(seed): + rnd = np.random.RandomState(seed) + z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(device) + c = torch.zeros([1, G.c_dim], device=device) + if G.c_dim > 0: + c[:, rnd.randint(G.c_dim)] = 1 + return (G(z=z, c=c) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + _ = generate_image(seed) # warm up + image_iter = (generate_image(seed + idx) for idx in range(num)) + return num, G.img_resolution, image_iter + + elif ext == 'zip' or os.path.isdir(source): + dataset_obj = dataset.ImageFolderDataset(path=source, max_size=num, random_seed=seed) + if num is not None and num != len(dataset_obj): + raise click.ClickException(f'--source contains fewer than {num} images') + data_loader = torch.utils.data.DataLoader(dataset_obj, batch_size=1, **data_loader_kwargs) + image_iter = (image.to(device) for image, _label in data_loader) + return len(dataset_obj), dataset_obj.resolution, image_iter + + else: + raise click.ClickException('--source must point to network pickle, dataset zip, or directory') + +#---------------------------------------------------------------------------- +# Load average power spectrum from the specified .npz file and construct +# the corresponding heatmap for visualization. + +def construct_heatmap(npz_file, smooth): + npz_data = np.load(npz_file) + spectrum = npz_data['spectrum'] + image_size = npz_data['image_size'] + hmap = np.log10(spectrum) * 10 # dB + hmap = np.fft.fftshift(hmap) + hmap = np.concatenate([hmap, hmap[:1, :]], axis=0) + hmap = np.concatenate([hmap, hmap[:, :1]], axis=1) + if smooth > 0: + sigma = spectrum.shape[0] / image_size * smooth + hmap = scipy.ndimage.gaussian_filter(hmap, sigma=sigma, mode='nearest') + return hmap, image_size + +#---------------------------------------------------------------------------- + +@click.group() +def main(): + """Compare average power spectra between real and generated images, + or between multiple generators. + + Example: + + \b + # Calculate dataset mean and std, needed in subsequent steps. + python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip + + \b + # Calculate average spectrum for the training data. + python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \\ + --dest=tmp/training-data.npz --mean=112.684 --std=69.509 + + \b + # Calculate average spectrum for a pre-trained generator. + python avg_spectra.py calc \\ + --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \\ + --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000 + + \b + # Display results. + python avg_spectra.py heatmap tmp/training-data.npz + python avg_spectra.py heatmap tmp/stylegan3-r.npz + python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz + + \b + # Save as PNG. + python avg_spectra.py heatmap tmp/training-data.npz --save=tmp/training-data.png --dpi=300 + python avg_spectra.py heatmap tmp/stylegan3-r.npz --save=tmp/stylegan3-r.png --dpi=300 + python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz --save=tmp/slices.png --dpi=300 + """ + +#---------------------------------------------------------------------------- + +@main.command() +@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True) +@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1)) +@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +def stats(source, num, seed, device=torch.device('cuda')): + """Calculate dataset mean and standard deviation needed by 'calc'.""" + torch.multiprocessing.set_start_method('spawn') + num_images, _image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device) + + # Accumulate moments. + moments = torch.zeros([3], dtype=torch.float64, device=device) + for image in tqdm.tqdm(image_iter, total=num_images): + image = image.to(torch.float64) + moments += torch.stack([torch.ones_like(image).sum(), image.sum(), image.square().sum()]) + moments = moments / moments[0] + + # Compute mean and standard deviation. + mean = moments[1] + std = (moments[2] - moments[1].square()).sqrt() + print(f'--mean={mean:g} --std={std:g}') + +#---------------------------------------------------------------------------- + +@main.command() +@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True) +@click.option('--dest', help='Where to store the result', metavar='NPZ', required=True) +@click.option('--mean', help='Dataset mean for whitening', metavar='FLOAT', type=float, required=True) +@click.option('--std', help='Dataset standard deviation for whitening', metavar='FLOAT', type=click.FloatRange(min=0), required=True) +@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1)) +@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) +@click.option('--beta', help='Shape parameter for the Kaiser window', metavar='FLOAT', type=click.FloatRange(min=0), default=8, show_default=True) +@click.option('--interp', help='Frequency-domain interpolation factor', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True) +def calc(source, dest, mean, std, num, seed, beta, interp, device=torch.device('cuda')): + """Calculate average power spectrum and store it in .npz file.""" + torch.multiprocessing.set_start_method('spawn') + num_images, image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device) + spectrum_size = image_size * interp + padding = spectrum_size - image_size + + # Setup window function. + window = torch.kaiser_window(image_size, periodic=False, beta=beta, device=device) + window *= window.square().sum().rsqrt() + window = window.ger(window).unsqueeze(0).unsqueeze(1) + + # Accumulate power spectrum. + spectrum = torch.zeros([spectrum_size, spectrum_size], dtype=torch.float64, device=device) + for image in tqdm.tqdm(image_iter, total=num_images): + image = (image.to(torch.float64) - mean) / std + image = torch.nn.functional.pad(image * window, [0, padding, 0, padding]) + spectrum += torch.fft.fftn(image, dim=[2,3]).abs().square().mean(dim=[0,1]) + spectrum /= num_images + + # Save result. + if os.path.dirname(dest): + os.makedirs(os.path.dirname(dest), exist_ok=True) + np.savez(dest, spectrum=spectrum.cpu().numpy(), image_size=image_size) + +#---------------------------------------------------------------------------- + +@main.command() +@click.argument('npz-file', nargs=1) +@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]') +@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True) +@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=1.25, show_default=True) +def heatmap(npz_file, save, smooth, dpi): + """Visualize 2D heatmap based on the given .npz file.""" + hmap, image_size = construct_heatmap(npz_file=npz_file, smooth=smooth) + + # Setup plot. + plt.figure(figsize=[6, 4.8], dpi=dpi, tight_layout=True) + freqs = np.linspace(-0.5, 0.5, num=hmap.shape[0], endpoint=True) * image_size + ticks = np.linspace(freqs[0], freqs[-1], num=5, endpoint=True) + levels = np.linspace(-40, 20, num=13, endpoint=True) + + # Draw heatmap. + plt.xlim(ticks[0], ticks[-1]) + plt.ylim(ticks[0], ticks[-1]) + plt.xticks(ticks) + plt.yticks(ticks) + plt.contourf(freqs, freqs, hmap, levels=levels, extend='both', cmap='Blues') + plt.gca().set_aspect('equal') + plt.colorbar(ticks=levels) + plt.contour(freqs, freqs, hmap, levels=levels, extend='both', linestyles='solid', linewidths=1, colors='midnightblue', alpha=0.2) + + # Display or save. + if save is None: + plt.show() + else: + if os.path.dirname(save): + os.makedirs(os.path.dirname(save), exist_ok=True) + plt.savefig(save) + +#---------------------------------------------------------------------------- + +@main.command() +@click.argument('npz-files', nargs=-1, required=True) +@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]') +@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True) +@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) +def slices(npz_files, save, dpi, smooth): + """Visualize 1D slices based on the given .npz files.""" + cases = [dnnlib.EasyDict(npz_file=npz_file) for npz_file in npz_files] + for c in cases: + c.hmap, c.image_size = construct_heatmap(npz_file=c.npz_file, smooth=smooth) + c.label = os.path.splitext(os.path.basename(c.npz_file))[0] + + # Check consistency. + image_size = cases[0].image_size + hmap_size = cases[0].hmap.shape[0] + if any(c.image_size != image_size or c.hmap.shape[0] != hmap_size for c in cases): + raise click.ClickException('All .npz must have the same resolution') + + # Setup plot. + plt.figure(figsize=[12, 4.6], dpi=dpi, tight_layout=True) + hmap_center = hmap_size // 2 + hmap_range = np.arange(hmap_center, hmap_size) + freqs0 = np.linspace(0, image_size / 2, num=(hmap_size // 2 + 1), endpoint=True) + freqs45 = np.linspace(0, image_size / np.sqrt(2), num=(hmap_size // 2 + 1), endpoint=True) + xticks0 = np.linspace(freqs0[0], freqs0[-1], num=9, endpoint=True) + xticks45 = np.round(np.linspace(freqs45[0], freqs45[-1], num=9, endpoint=True)) + yticks = np.linspace(-50, 30, num=9, endpoint=True) + + # Draw 0 degree slice. + plt.subplot(1, 2, 1) + plt.title('0\u00b0 slice') + plt.xlim(xticks0[0], xticks0[-1]) + plt.ylim(yticks[0], yticks[-1]) + plt.xticks(xticks0) + plt.yticks(yticks) + for c in cases: + plt.plot(freqs0, c.hmap[hmap_center, hmap_range], label=c.label) + plt.grid() + plt.legend(loc='upper right') + + # Draw 45 degree slice. + plt.subplot(1, 2, 2) + plt.title('45\u00b0 slice') + plt.xlim(xticks45[0], xticks45[-1]) + plt.ylim(yticks[0], yticks[-1]) + plt.xticks(xticks45) + plt.yticks(yticks) + for c in cases: + plt.plot(freqs45, c.hmap[hmap_range, hmap_range], label=c.label) + plt.grid() + plt.legend(loc='upper right') + + # Display or save. + if save is None: + plt.show() + else: + if os.path.dirname(save): + os.makedirs(os.path.dirname(save), exist_ok=True) + plt.savefig(save) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/calc_metrics.py b/calc_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..74a398a407f56e749e3a88eb9d8ff976191758f4 --- /dev/null +++ b/calc_metrics.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Calculate quality metrics for previous training run or pretrained network pickle.""" + +import os +import click +import json +import tempfile +import copy +import torch + +import dnnlib +import legacy +from metrics import metric_main +from metrics import metric_utils +from torch_utils import training_stats +from torch_utils import custom_ops +from torch_utils import misc +from torch_utils.ops import conv2d_gradfix + +#---------------------------------------------------------------------------- + +def subprocess_fn(rank, args, temp_dir): + dnnlib.util.Logger(should_flush=True) + + # Init torch.distributed. + if args.num_gpus > 1: + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + if rank != 0 or not args.verbose: + custom_ops.verbosity = 'none' + + # Configure torch. + device = torch.device('cuda', rank) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + conv2d_gradfix.enabled = True + + # Print network summary. + G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) + if rank == 0 and args.verbose: + z = torch.empty([1, G.z_dim], device=device) + c = torch.empty([1, G.c_dim], device=device) + misc.print_module_summary(G, [z, c]) + + # Calculate each metric. + for metric in args.metrics: + if rank == 0 and args.verbose: + print(f'Calculating {metric}...') + progress = metric_utils.ProgressMonitor(verbose=args.verbose) + result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, + num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) + if rank == 0: + metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) + if rank == 0 and args.verbose: + print() + + # Done. + if rank == 0 and args.verbose: + print('Exiting...') + +#---------------------------------------------------------------------------- + +def parse_comma_separated_list(s): + if isinstance(s, list): + return s + if s is None or s.lower() == 'none' or s == '': + return [] + return s.split(',') + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) +@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) +@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') +@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') +@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) +@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) + +def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): + """Calculate quality metrics for previous training run or pretrained network pickle. + + Examples: + + \b + # Previous training run: look up options automatically, save result to JSONL file. + python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ + --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl + + \b + # Pre-trained network pickle: specify dataset explicitly, print result to stdout. + python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl + + \b + Recommended metrics: + fid50k_full Frechet inception distance against the full dataset. + kid50k_full Kernel inception distance against the full dataset. + pr50k3_full Precision and recall againt the full dataset. + ppl2_wend Perceptual path length in W, endpoints, full image. + eqt50k_int Equivariance w.r.t. integer translation (EQ-T). + eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). + eqr50k Equivariance w.r.t. rotation (EQ-R). + + \b + Legacy metrics: + fid50k Frechet inception distance against 50k real images. + kid50k Kernel inception distance against 50k real images. + pr50k3 Precision and recall against 50k real images. + is50k Inception score for CIFAR-10. + """ + dnnlib.util.Logger(should_flush=True) + + # Validate arguments. + args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) + if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): + ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) + if not args.num_gpus >= 1: + ctx.fail('--gpus must be at least 1') + + # Load network. + if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): + ctx.fail('--network must point to a file or URL') + if args.verbose: + print(f'Loading network from "{network_pkl}"...') + with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: + network_dict = legacy.load_network_pkl(f) + args.G = network_dict['G_ema'] # subclass of torch.nn.Module + + # Initialize dataset options. + if data is not None: + args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) + elif network_dict['training_set_kwargs'] is not None: + args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) + else: + ctx.fail('Could not look up dataset options; please specify --data') + + # Finalize dataset options. + args.dataset_kwargs.resolution = args.G.img_resolution + args.dataset_kwargs.use_labels = (args.G.c_dim != 0) + if mirror is not None: + args.dataset_kwargs.xflip = mirror + + # Print dataset options. + if args.verbose: + print('Dataset options:') + print(json.dumps(args.dataset_kwargs, indent=2)) + + # Locate run dir. + args.run_dir = None + if os.path.isfile(network_pkl): + pkl_dir = os.path.dirname(network_pkl) + if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): + args.run_dir = pkl_dir + + # Launch processes. + if args.verbose: + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + if args.num_gpus == 1: + subprocess_fn(rank=0, args=args, temp_dir=temp_dir) + else: + torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + calc_metrics() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/dataset_tool.py b/dataset_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..e9382fb1265489053eaed0166385a10ef67965c2 --- /dev/null +++ b/dataset_tool.py @@ -0,0 +1,455 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Tool for creating ZIP/PNG based datasets.""" + +import functools +import gzip +import io +import json +import os +import pickle +import re +import sys +import tarfile +import zipfile +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +import click +import numpy as np +import PIL.Image +from tqdm import tqdm + +#---------------------------------------------------------------------------- + +def error(msg): + print('Error: ' + msg) + sys.exit(1) + +#---------------------------------------------------------------------------- + +def parse_tuple(s: str) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +def maybe_min(a: int, b: Optional[int]) -> int: + if b is not None: + return min(a, b) + return a + +#---------------------------------------------------------------------------- + +def file_ext(name: Union[str, Path]) -> str: + return str(name).split('.')[-1] + +#---------------------------------------------------------------------------- + +def is_image_ext(fname: Union[str, Path]) -> bool: + ext = file_ext(fname).lower() + return f'.{ext}' in PIL.Image.EXTENSION # type: ignore + +#---------------------------------------------------------------------------- + +def open_image_folder(source_dir, *, max_images: Optional[int]): + input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] + + # Load labels. + labels = {} + meta_fname = os.path.join(source_dir, 'dataset.json') + if os.path.isfile(meta_fname): + with open(meta_fname, 'r') as file: + labels = json.load(file)['labels'] + if labels is not None: + labels = { x[0]: x[1] for x in labels } + else: + labels = {} + + max_idx = maybe_min(len(input_images), max_images) + + def iterate_images(): + for idx, fname in enumerate(input_images): + arch_fname = os.path.relpath(fname, source_dir) + arch_fname = arch_fname.replace('\\', '/') + img = np.array(PIL.Image.open(fname)) + yield dict(img=img, label=labels.get(arch_fname)) + if idx >= max_idx-1: + break + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_image_zip(source, *, max_images: Optional[int]): + with zipfile.ZipFile(source, mode='r') as z: + input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] + + # Load labels. + labels = {} + if 'dataset.json' in z.namelist(): + with z.open('dataset.json', 'r') as file: + labels = json.load(file)['labels'] + if labels is not None: + labels = { x[0]: x[1] for x in labels } + else: + labels = {} + + max_idx = maybe_min(len(input_images), max_images) + + def iterate_images(): + with zipfile.ZipFile(source, mode='r') as z: + for idx, fname in enumerate(input_images): + with z.open(fname, 'r') as file: + img = PIL.Image.open(file) # type: ignore + img = np.array(img) + yield dict(img=img, label=labels.get(fname)) + if idx >= max_idx-1: + break + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): + import cv2 # pip install opencv-python # pylint: disable=import-error + import lmdb # pip install lmdb # pylint: disable=import-error + + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + max_idx = maybe_min(txn.stat()['entries'], max_images) + + def iterate_images(): + with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: + for idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.array(PIL.Image.open(io.BytesIO(value))) + yield dict(img=img, label=None) + if idx >= max_idx-1: + break + except: + print(sys.exc_info()[1]) + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_cifar10(tarball: str, *, max_images: Optional[int]): + images = [] + labels = [] + + with tarfile.open(tarball, 'r:gz') as tar: + for batch in range(1, 6): + member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') + with tar.extractfile(member) as file: + data = pickle.load(file, encoding='latin1') + images.append(data['data'].reshape(-1, 3, 32, 32)) + labels.append(data['labels']) + + images = np.concatenate(images) + labels = np.concatenate(labels) + images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC + assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def open_mnist(images_gz: str, *, max_images: Optional[int]): + labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') + assert labels_gz != images_gz + images = [] + labels = [] + + with gzip.open(images_gz, 'rb') as f: + images = np.frombuffer(f.read(), np.uint8, offset=16) + with gzip.open(labels_gz, 'rb') as f: + labels = np.frombuffer(f.read(), np.uint8, offset=8) + + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + + max_idx = maybe_min(len(images), max_images) + + def iterate_images(): + for idx, img in enumerate(images): + yield dict(img=img, label=int(labels[idx])) + if idx >= max_idx-1: + break + + return max_idx, iterate_images() + +#---------------------------------------------------------------------------- + +def make_transform( + transform: Optional[str], + output_width: Optional[int], + output_height: Optional[int] +) -> Callable[[np.ndarray], Optional[np.ndarray]]: + def scale(width, height, img): + w = img.shape[1] + h = img.shape[0] + if width == w and height == h: + return img + img = PIL.Image.fromarray(img) + ww = width if width is not None else w + hh = height if height is not None else h + img = img.resize((ww, hh), PIL.Image.LANCZOS) + return np.array(img) + + def center_crop(width, height, img): + crop = np.min(img.shape[:2]) + img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.LANCZOS) + return np.array(img) + + def center_crop_wide(width, height, img): + ch = int(np.round(width * img.shape[0] / img.shape[1])) + if img.shape[1] < width or ch < height: + return None + + img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.LANCZOS) + img = np.array(img) + + canvas = np.zeros([width, width, 3], dtype=np.uint8) + canvas[(width - height) // 2 : (width + height) // 2, :] = img + return canvas + + if transform is None: + return functools.partial(scale, output_width, output_height) + if transform == 'center-crop': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + 'transform') + return functools.partial(center_crop, output_width, output_height) + if transform == 'center-crop-wide': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + ' transform') + return functools.partial(center_crop_wide, output_width, output_height) + assert False, 'unknown transform' + +#---------------------------------------------------------------------------- + +def open_dataset(source, *, max_images: Optional[int]): + if os.path.isdir(source): + if source.rstrip('/').endswith('_lmdb'): + return open_lmdb(source, max_images=max_images) + else: + return open_image_folder(source, max_images=max_images) + elif os.path.isfile(source): + if os.path.basename(source) == 'cifar-10-python.tar.gz': + return open_cifar10(source, max_images=max_images) + elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': + return open_mnist(source, max_images=max_images) + elif file_ext(source) == 'zip': + return open_image_zip(source, max_images=max_images) + else: + assert False, 'unknown archive type' + else: + error(f'Missing input file or directory: {source}') + +#---------------------------------------------------------------------------- + +def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: + dest_ext = file_ext(dest) + + if dest_ext == 'zip': + if os.path.dirname(dest) != '': + os.makedirs(os.path.dirname(dest), exist_ok=True) + zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) + def zip_write_bytes(fname: str, data: Union[bytes, str]): + zf.writestr(fname, data) + return '', zip_write_bytes, zf.close + else: + # If the output folder already exists, check that is is + # empty. + # + # Note: creating the output directory is not strictly + # necessary as folder_write_bytes() also mkdirs, but it's better + # to give an error message earlier in case the dest folder + # somehow cannot be created. + if os.path.isdir(dest) and len(os.listdir(dest)) != 0: + error('--dest folder must be empty') + os.makedirs(dest, exist_ok=True) + + def folder_write_bytes(fname: str, data: Union[bytes, str]): + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, 'wb') as fout: + if isinstance(data, str): + data = data.encode('utf8') + fout.write(data) + return dest, folder_write_bytes, lambda: None + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH') +@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH') +@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None) +@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) +@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple) +def convert_dataset( + ctx: click.Context, + source: str, + dest: str, + max_images: Optional[int], + transform: Optional[str], + resolution: Optional[Tuple[int, int]] +): + """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. + + The input dataset format is guessed from the --source argument: + + \b + --source *_lmdb/ Load LSUN dataset + --source cifar-10-python.tar.gz Load CIFAR-10 dataset + --source train-images-idx3-ubyte.gz Load MNIST dataset + --source path/ Recursively load all images from path/ + --source dataset.zip Recursively load all images from dataset.zip + + Specifying the output format and path: + + \b + --dest /path/to/dir Save output files under /path/to/dir + --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip + + The output dataset format can be either an image folder or an uncompressed zip archive. + Zip archives makes it easier to move datasets around file servers and clusters, and may + offer better training performance on network file systems. + + Images within the dataset archive will be stored as uncompressed PNG. + Uncompresed PNGs can be efficiently decoded in the training loop. + + Class labels are stored in a file called 'dataset.json' that is stored at the + dataset root folder. This file has the following structure: + + \b + { + "labels": [ + ["00000/img00000000.png",6], + ["00000/img00000001.png",9], + ... repeated for every image in the datase + ["00049/img00049999.png",1] + ] + } + + If the 'dataset.json' file cannot be found, the dataset is interpreted as + not containing class labels. + + Image scale/crop and resolution requirements: + + Output images must be square-shaped and they must all have the same power-of-two + dimensions. + + To scale arbitrary input image size to a specific width and height, use the + --resolution option. Output resolution will be either the original + input resolution (if resolution was not specified) or the one specified with + --resolution option. + + Use the --transform=center-crop or --transform=center-crop-wide options to apply a + center crop transform on the input image. These options should be used with the + --resolution option. For example: + + \b + python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ + --transform=center-crop-wide --resolution=512x384 + """ + + PIL.Image.init() # type: ignore + + if dest == '': + ctx.fail('--dest output filename or directory must not be an empty string') + + num_files, input_iter = open_dataset(source, max_images=max_images) + archive_root_dir, save_bytes, close_dest = open_dest(dest) + + if resolution is None: resolution = (None, None) + transform_image = make_transform(transform, *resolution) + + dataset_attrs = None + + labels = [] + for idx, image in tqdm(enumerate(input_iter), total=num_files): + idx_str = f'{idx:08d}' + archive_fname = f'{idx_str[:5]}/img{idx_str}.png' + + # Apply crop and resize. + img = transform_image(image['img']) + + # Transform may drop images. + if img is None: + continue + + # Error check to require uniform image attributes across + # the whole dataset. + channels = img.shape[2] if img.ndim == 3 else 1 + cur_image_attrs = { + 'width': img.shape[1], + 'height': img.shape[0], + 'channels': channels + } + if dataset_attrs is None: + dataset_attrs = cur_image_attrs + width = dataset_attrs['width'] + height = dataset_attrs['height'] + if width != height: + error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') + if dataset_attrs['channels'] not in [1, 3]: + error('Input images must be stored as RGB or grayscale') + if width != 2 ** int(np.floor(np.log2(width))): + error('Image width/height after scale and crop are required to be power-of-two') + elif dataset_attrs != cur_image_attrs: + err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object + error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) + + # Save the image as an uncompressed PNG. + img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels]) + image_bits = io.BytesIO() + img.save(image_bits, format='png', compress_level=0, optimize=False) + save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) + labels.append([archive_fname, image['label']] if image['label'] is not None else None) + + metadata = { + 'labels': labels if all(x is not None for x in labels) else None + } + save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) + close_dest() + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_dataset() # pylint: disable=no-value-for-parameter diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a006715176e91a5ed94a5d2362a87d53b4d889 --- /dev/null +++ b/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/dnnlib/util.py b/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..191b52f6ac7ad75344fb3921f03c37987047287c --- /dev/null +++ b/dnnlib/util.py @@ -0,0 +1,491 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..578a58a00e6588a0205ccdd55801824ba0ec1922 --- /dev/null +++ b/environment.yml @@ -0,0 +1,24 @@ +name: stylegan3 +channels: + - pytorch + - nvidia +dependencies: + - python >= 3.8 + - pip + - numpy>=1.20 + - click>=8.0 + - pillow=8.3.1 + - scipy=1.7.1 + - pytorch=1.9.1 + - cudatoolkit=11.1 + - requests=2.26.0 + - tqdm=4.62.2 + - ninja=1.10.2 + - matplotlib=3.4.2 + - imageio=2.9.0 + - pip: + - imgui==1.3.0 + - glfw==2.2.0 + - pyopengl==3.1.5 + - imageio-ffmpeg==0.4.3 + - pyspng diff --git a/gen_images.py b/gen_images.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a4b11b2fcdbad986a21753a29d7fee2fc26dbd --- /dev/null +++ b/gen_images.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import numpy as np +import PIL.Image +import torch + +import legacy + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: + '''Parse a floating point 2-vector of syntax 'a,b'. + + Example: + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + parts = s.split(',') + if len(parts) == 2: + return (float(parts[0]), float(parts[1])) + raise ValueError(f'cannot parse 2-vector {s}') + +#---------------------------------------------------------------------------- + +def make_transform(translate: Tuple[float,float], angle: float): + m = np.eye(3) + s = np.sin(angle/360.0*np.pi*2) + c = np.cos(angle/360.0*np.pi*2) + m[0][0] = c + m[0][1] = s + m[0][2] = translate[0] + m[1][0] = -s + m[1][1] = c + m[1][2] = translate[1] + return m + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') +@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +def generate_images( + network_pkl: str, + seeds: List[int], + truncation_psi: float, + noise_mode: str, + outdir: str, + translate: Tuple[float,float], + rotate: float, + class_idx: Optional[int] +): + """Generate images using pretrained network pickle. + + Examples: + + \b + # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). + python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + \b + # Generate uncurated images with truncation using the MetFaces-U dataset + python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl + """ + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + os.makedirs(outdir, exist_ok=True) + + # Labels. + label = torch.zeros([1, G.c_dim], device=device) + if G.c_dim != 0: + if class_idx is None: + raise click.ClickException('Must specify class label with --class when using a conditional network') + label[:, class_idx] = 1 + else: + if class_idx is not None: + print ('warn: --class=lbl ignored when running on an unconditional network') + + # Generate images. + for seed_idx, seed in enumerate(seeds): + print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + + # Construct an inverse rotation/translation matrix and pass to the generator. The + # generator expects this matrix as an inverse to avoid potentially failing numerical + # operations in the network. + if hasattr(G.synthesis, 'input'): + m = make_transform(translate, rotate) + m = np.linalg.inv(m) + G.synthesis.input.transform.copy_(torch.from_numpy(m)) + + img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') + + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/gen_video.py b/gen_video.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4bcc0ea7669530fa2727392e4c09500d8eed5e --- /dev/null +++ b/gen_video.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate lerp videos using pretrained network pickle.""" + +import copy +import os +import re +from typing import List, Optional, Tuple, Union + +import click +import dnnlib +import imageio +import numpy as np +import scipy.interpolate +import torch +from tqdm import tqdm + +import legacy + +#---------------------------------------------------------------------------- + +def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): + batch_size, channels, img_h, img_w = img.shape + if grid_w is None: + grid_w = batch_size // grid_h + assert batch_size == grid_w * grid_h + if float_to_uint8: + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = img.reshape(grid_h, grid_w, channels, img_h, img_w) + img = img.permute(2, 0, 3, 1, 4) + img = img.reshape(channels, grid_h * img_h, grid_w * img_w) + if chw_to_hwc: + img = img.permute(1, 2, 0) + if to_numpy: + img = img.cpu().numpy() + return img + +#---------------------------------------------------------------------------- + +def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): + grid_w = grid_dims[0] + grid_h = grid_dims[1] + + if num_keyframes is None: + if len(seeds) % (grid_w*grid_h) != 0: + raise ValueError('Number of input seeds must be divisible by grid W*H') + num_keyframes = len(seeds) // (grid_w*grid_h) + + all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) + for idx in range(num_keyframes*grid_h*grid_w): + all_seeds[idx] = seeds[idx % len(seeds)] + + if shuffle_seed is not None: + rng = np.random.RandomState(seed=shuffle_seed) + rng.shuffle(all_seeds) + + zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) + ws = G.mapping(z=zs, c=None, truncation_psi=psi) + _ = G.synthesis(ws[:1]) # warm up + ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) + + # Interpolation. + grid = [] + for yi in range(grid_h): + row = [] + for xi in range(grid_w): + x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) + y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) + interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) + row.append(interp) + grid.append(row) + + # Render video. + video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) + for frame_idx in tqdm(range(num_keyframes * w_frames)): + imgs = [] + for yi in range(grid_h): + for xi in range(grid_w): + interp = grid[yi][xi] + w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) + img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] + imgs.append(img) + video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) + video_out.close() + +#---------------------------------------------------------------------------- + +def parse_range(s: Union[str, List[int]]) -> List[int]: + '''Parse a comma separated list of numbers or ranges and return a list of ints. + + Example: '1,2,5-10' returns [1, 2, 5, 6, 7] + ''' + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + if m := range_re.match(p): + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: + '''Parse a 'M,N' or 'MxN' integer tuple. + + Example: + '4x2' returns (4,2) + '0,1' returns (0,1) + ''' + if isinstance(s, tuple): return s + if m := re.match(r'^(\d+)[x,](\d+)$', s): + return (int(m.group(1)), int(m.group(2))) + raise ValueError(f'cannot parse tuple {s}') + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=parse_range, help='List of random seeds', required=True) +@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) +@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) +@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) +@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') +def generate_images( + network_pkl: str, + seeds: List[int], + shuffle_seed: Optional[int], + truncation_psi: float, + grid: Tuple[int,int], + num_keyframes: Optional[int], + w_frames: int, + output: str +): + """Render a latent vector interpolation video. + + Examples: + + \b + # Render a 4x2 grid of interpolations for seeds 0 through 31. + python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ + --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl + + Animation length and seed keyframes: + + The animation length is either determined based on the --seeds value or explicitly + specified using the --num-keyframes option. + + When num keyframes is specified with --num-keyframes, the output video length + will be 'num_keyframes*w_frames' frames. + + If --num-keyframes is not specified, the number of seeds given with + --seeds must be divisible by grid size W*H (--grid). In this case the + output video length will be '# seeds/(w*h)*w_frames' frames. + """ + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/gui_utils/__init__.py b/gui_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/gui_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/gui_utils/gl_utils.py b/gui_utils/gl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd8bd96f4b74bdfb274f25d810ddfd43dc44068 --- /dev/null +++ b/gui_utils/gl_utils.py @@ -0,0 +1,374 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import functools +import contextlib +import numpy as np +import OpenGL.GL as gl +import OpenGL.GL.ARB.texture_float +import dnnlib + +#---------------------------------------------------------------------------- + +def init_egl(): + assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL. + import OpenGL.EGL as egl + import ctypes + + # Initialize EGL. + display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY) + assert display != egl.EGL_NO_DISPLAY + major = ctypes.c_int32() + minor = ctypes.c_int32() + ok = egl.eglInitialize(display, major, minor) + assert ok + assert major.value * 10 + minor.value >= 14 + + # Choose config. + config_attribs = [ + egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, + egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, + egl.EGL_NONE + ] + configs = (ctypes.c_int32 * 1)() + num_configs = ctypes.c_int32() + ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs) + assert ok + assert num_configs.value == 1 + config = configs[0] + + # Create dummy pbuffer surface. + surface_attribs = [ + egl.EGL_WIDTH, 1, + egl.EGL_HEIGHT, 1, + egl.EGL_NONE + ] + surface = egl.eglCreatePbufferSurface(display, config, surface_attribs) + assert surface != egl.EGL_NO_SURFACE + + # Setup GL context. + ok = egl.eglBindAPI(egl.EGL_OPENGL_API) + assert ok + context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None) + assert context != egl.EGL_NO_CONTEXT + ok = egl.eglMakeCurrent(display, surface, surface, context) + assert ok + +#---------------------------------------------------------------------------- + +_texture_formats = { + ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8), + ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8), + ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8), + ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8), + ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB), + ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB), + ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F), + ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F), +} + +def get_texture_format(dtype, channels): + return _texture_formats[(np.dtype(dtype).name, int(channels))] + +#---------------------------------------------------------------------------- + +def prepare_texture_data(image): + image = np.asarray(image) + if image.ndim == 2: + image = image[:, :, np.newaxis] + if image.dtype.name == 'float64': + image = image.astype('float32') + return image + +#---------------------------------------------------------------------------- + +def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True): + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) + align = np.broadcast_to(np.asarray(align, dtype='float32'), [2]) + image = prepare_texture_data(image) + height, width, channels = image.shape + size = zoom * [width, height] + pos = pos - size * align + if rint: + pos = np.rint(pos) + fmt = get_texture_format(image.dtype, channels) + + gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT) + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glRasterPos2f(pos[0], pos[1]) + gl.glPixelZoom(zoom[0], -zoom[1]) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl.glDrawPixels(width, height, fmt.format, fmt.type, image) + gl.glPopClientAttrib() + gl.glPopAttrib() + +#---------------------------------------------------------------------------- + +def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3): + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + dtype = np.dtype(dtype) + fmt = get_texture_format(dtype, channels) + image = np.empty([height, width, channels], dtype=dtype) + + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1) + gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image) + gl.glPopClientAttrib() + return np.flipud(image) + +#---------------------------------------------------------------------------- + +class Texture: + def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True): + self.gl_id = None + self.bilinear = bilinear + self.mipmap = mipmap + + # Determine size and dtype. + if image is not None: + image = prepare_texture_data(image) + self.height, self.width, self.channels = image.shape + self.dtype = image.dtype + else: + assert width is not None and height is not None + self.width = width + self.height = height + self.channels = channels if channels is not None else 3 + self.dtype = np.dtype(dtype) if dtype is not None else np.uint8 + + # Validate size and dtype. + assert isinstance(self.width, int) and self.width >= 0 + assert isinstance(self.height, int) and self.height >= 0 + assert isinstance(self.channels, int) and self.channels >= 1 + assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype) + + # Create texture object. + self.gl_id = gl.glGenTextures(1) + with self.bind(): + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST) + gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST) + self.update(image) + + def delete(self): + if self.gl_id is not None: + gl.glDeleteTextures([self.gl_id]) + self.gl_id = None + + def __del__(self): + try: + self.delete() + except: + pass + + @contextlib.contextmanager + def bind(self): + prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D) + gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id) + yield + gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id) + + def update(self, image): + if image is not None: + image = prepare_texture_data(image) + assert self.is_compatible(image=image) + with self.bind(): + fmt = get_texture_format(self.dtype, self.channels) + gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT) + gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image) + if self.mipmap: + gl.glGenerateMipmap(gl.GL_TEXTURE_2D) + gl.glPopClientAttrib() + + def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0): + zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2]) + size = zoom * [self.width, self.height] + with self.bind(): + gl.glPushAttrib(gl.GL_ENABLE_BIT) + gl.glEnable(gl.GL_TEXTURE_2D) + draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding) + gl.glPopAttrib() + + def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements + if image is not None: + if image.ndim != 3: + return False + ih, iw, ic = image.shape + if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype): + return False + if width is not None and self.width != width: + return False + if height is not None and self.height != height: + return False + if channels is not None and self.channels != channels: + return False + if dtype is not None and self.dtype != dtype: + return False + return True + +#---------------------------------------------------------------------------- + +class Framebuffer: + def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0): + self.texture = texture + self.gl_id = None + self.gl_color = None + self.gl_depth_stencil = None + self.msaa = msaa + + # Determine size and dtype. + if texture is not None: + assert isinstance(self.texture, Texture) + self.width = texture.width + self.height = texture.height + self.channels = texture.channels + self.dtype = texture.dtype + else: + assert width is not None and height is not None + self.width = width + self.height = height + self.channels = channels if channels is not None else 4 + self.dtype = np.dtype(dtype) if dtype is not None else np.float32 + + # Validate size and dtype. + assert isinstance(self.width, int) and self.width >= 0 + assert isinstance(self.height, int) and self.height >= 0 + assert isinstance(self.channels, int) and self.channels >= 1 + assert width is None or width == self.width + assert height is None or height == self.height + assert channels is None or channels == self.channels + assert dtype is None or dtype == self.dtype + + # Create framebuffer object. + self.gl_id = gl.glGenFramebuffers(1) + with self.bind(): + + # Setup color buffer. + if self.texture is not None: + assert self.msaa == 0 + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0) + else: + fmt = get_texture_format(self.dtype, self.channels) + self.gl_color = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color) + gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color) + + # Setup depth/stencil buffer. + self.gl_depth_stencil = gl.glGenRenderbuffers(1) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil) + gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height) + gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil) + + def delete(self): + if self.gl_id is not None: + gl.glDeleteFramebuffers([self.gl_id]) + self.gl_id = None + if self.gl_color is not None: + gl.glDeleteRenderbuffers(1, [self.gl_color]) + self.gl_color = None + if self.gl_depth_stencil is not None: + gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil]) + self.gl_depth_stencil = None + + def __del__(self): + try: + self.delete() + except: + pass + + @contextlib.contextmanager + def bind(self): + prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING) + prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id) + if self.width is not None and self.height is not None: + gl.glViewport(0, 0, self.width, self.height) + yield + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo) + gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo) + + def blit(self, dst=None): + assert dst is None or isinstance(dst, Framebuffer) + with self.bind(): + gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo) + gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST) + +#---------------------------------------------------------------------------- + +def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1): + assert vertices.ndim == 2 and vertices.shape[1] == 2 + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) + color = np.broadcast_to(np.asarray(color, dtype='float32'), [3]) + alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1) + + gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT) + gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT) + gl.glMatrixMode(gl.GL_MODELVIEW) + gl.glPushMatrix() + + gl.glEnableClientState(gl.GL_VERTEX_ARRAY) + gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY) + gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices) + gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices) + gl.glTranslate(pos[0], pos[1], 0) + gl.glScale(size[0], size[1], 1) + gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha) + gl.glDrawArrays(mode, 0, vertices.shape[0]) + + gl.glPopMatrix() + gl.glPopAttrib() + gl.glPopClientAttrib() + +#---------------------------------------------------------------------------- + +def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0): + assert pos2 is None or size is None + pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2]) + pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None + size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None + size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32') + pos = pos - size * align + if rint: + pos = np.rint(pos) + rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2]) + rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5) + if np.min(rounding) == 0: + rounding *= 0 + vertices = _setup_rect(float(rounding[0]), float(rounding[1])) + draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha) + +@functools.lru_cache(maxsize=10000) +def _setup_rect(rx, ry): + t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64) + s = 1 - np.sin(t); c = 1 - np.cos(t) + x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx] + y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry] + v = np.stack([x, y], axis=-1).reshape(-1, 2) + return v.astype('float32') + +#---------------------------------------------------------------------------- + +def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1): + hole = np.broadcast_to(np.asarray(hole, dtype='float32'), []) + vertices = _setup_circle(float(hole)) + draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha) + +@functools.lru_cache(maxsize=10000) +def _setup_circle(hole): + t = np.linspace(0, np.pi * 2, 128) + s = np.sin(t); c = np.cos(t) + v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2) + return v.astype('float32') + +#---------------------------------------------------------------------------- diff --git a/gui_utils/glfw_window.py b/gui_utils/glfw_window.py new file mode 100644 index 0000000000000000000000000000000000000000..94c4a8dd2534b6def2607d6ff2bf29fe472dc53f --- /dev/null +++ b/gui_utils/glfw_window.py @@ -0,0 +1,229 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import time +import glfw +import OpenGL.GL as gl +from . import gl_utils + +#---------------------------------------------------------------------------- + +class GlfwWindow: # pylint: disable=too-many-public-methods + def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): + self._glfw_window = None + self._drawing_frame = False + self._frame_start_time = None + self._frame_delta = 0 + self._fps_limit = None + self._vsync = None + self._skip_frames = 0 + self._deferred_show = deferred_show + self._close_on_esc = close_on_esc + self._esc_pressed = False + self._drag_and_drop_paths = None + self._capture_next_frame = False + self._captured_frame = None + + # Create window. + glfw.init() + glfw.window_hint(glfw.VISIBLE, False) + self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) + self._attach_glfw_callbacks() + self.make_context_current() + + # Adjust window. + self.set_vsync(False) + self.set_window_size(window_width, window_height) + if not self._deferred_show: + glfw.show_window(self._glfw_window) + + def close(self): + if self._drawing_frame: + self.end_frame() + if self._glfw_window is not None: + glfw.destroy_window(self._glfw_window) + self._glfw_window = None + #glfw.terminate() # Commented out to play it nice with other glfw clients. + + def __del__(self): + try: + self.close() + except: + pass + + @property + def window_width(self): + return self.content_width + + @property + def window_height(self): + return self.content_height + self.title_bar_height + + @property + def content_width(self): + width, _height = glfw.get_window_size(self._glfw_window) + return width + + @property + def content_height(self): + _width, height = glfw.get_window_size(self._glfw_window) + return height + + @property + def title_bar_height(self): + _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) + return top + + @property + def monitor_width(self): + _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) + return width + + @property + def monitor_height(self): + _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) + return height + + @property + def frame_delta(self): + return self._frame_delta + + def set_title(self, title): + glfw.set_window_title(self._glfw_window, title) + + def set_window_size(self, width, height): + width = min(width, self.monitor_width) + height = min(height, self.monitor_height) + glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) + if width == self.monitor_width and height == self.monitor_height: + self.maximize() + + def set_content_size(self, width, height): + self.set_window_size(width, height + self.title_bar_height) + + def maximize(self): + glfw.maximize_window(self._glfw_window) + + def set_position(self, x, y): + glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) + + def center(self): + self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) + + def set_vsync(self, vsync): + vsync = bool(vsync) + if vsync != self._vsync: + glfw.swap_interval(1 if vsync else 0) + self._vsync = vsync + + def set_fps_limit(self, fps_limit): + self._fps_limit = int(fps_limit) + + def should_close(self): + return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) + + def skip_frame(self): + self.skip_frames(1) + + def skip_frames(self, num): # Do not update window for the next N frames. + self._skip_frames = max(self._skip_frames, int(num)) + + def is_skipping_frames(self): + return self._skip_frames > 0 + + def capture_next_frame(self): + self._capture_next_frame = True + + def pop_captured_frame(self): + frame = self._captured_frame + self._captured_frame = None + return frame + + def pop_drag_and_drop_paths(self): + paths = self._drag_and_drop_paths + self._drag_and_drop_paths = None + return paths + + def draw_frame(self): # To be overridden by subclass. + self.begin_frame() + # Rendering code goes here. + self.end_frame() + + def make_context_current(self): + if self._glfw_window is not None: + glfw.make_context_current(self._glfw_window) + + def begin_frame(self): + # End previous frame. + if self._drawing_frame: + self.end_frame() + + # Apply FPS limit. + if self._frame_start_time is not None and self._fps_limit is not None: + delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit + if delay > 0: + time.sleep(delay) + cur_time = time.perf_counter() + if self._frame_start_time is not None: + self._frame_delta = cur_time - self._frame_start_time + self._frame_start_time = cur_time + + # Process events. + glfw.poll_events() + + # Begin frame. + self._drawing_frame = True + self.make_context_current() + + # Initialize GL state. + gl.glViewport(0, 0, self.content_width, self.content_height) + gl.glMatrixMode(gl.GL_PROJECTION) + gl.glLoadIdentity() + gl.glTranslate(-1, 1, 0) + gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) + gl.glMatrixMode(gl.GL_MODELVIEW) + gl.glLoadIdentity() + gl.glEnable(gl.GL_BLEND) + gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. + + # Clear. + gl.glClearColor(0, 0, 0, 1) + gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) + + def end_frame(self): + assert self._drawing_frame + self._drawing_frame = False + + # Skip frames if requested. + if self._skip_frames > 0: + self._skip_frames -= 1 + return + + # Capture frame if requested. + if self._capture_next_frame: + self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) + self._capture_next_frame = False + + # Update window. + if self._deferred_show: + glfw.show_window(self._glfw_window) + self._deferred_show = False + glfw.swap_buffers(self._glfw_window) + + def _attach_glfw_callbacks(self): + glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) + glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) + + def _glfw_key_callback(self, _window, key, _scancode, action, _mods): + if action == glfw.PRESS and key == glfw.KEY_ESCAPE: + self._esc_pressed = True + + def _glfw_drop_callback(self, _window, paths): + self._drag_and_drop_paths = paths + +#---------------------------------------------------------------------------- diff --git a/gui_utils/imgui_utils.py b/gui_utils/imgui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cb118d996a380d4b263762051af248ee0383c7 --- /dev/null +++ b/gui_utils/imgui_utils.py @@ -0,0 +1,169 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import contextlib +import imgui + +#---------------------------------------------------------------------------- + +def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): + s = imgui.get_style() + s.window_padding = [spacing, spacing] + s.item_spacing = [spacing, spacing] + s.item_inner_spacing = [spacing, spacing] + s.columns_min_spacing = spacing + s.indent_spacing = indent + s.scrollbar_size = scrollbar + s.frame_padding = [4, 3] + s.window_border_size = 1 + s.child_border_size = 1 + s.popup_border_size = 1 + s.frame_border_size = 1 + s.window_rounding = 0 + s.child_rounding = 0 + s.popup_rounding = 3 + s.frame_rounding = 3 + s.scrollbar_rounding = 3 + s.grab_rounding = 3 + + getattr(imgui, f'style_colors_{color_scheme}')(s) + c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] + c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] + s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] + +#---------------------------------------------------------------------------- + +@contextlib.contextmanager +def grayed_out(cond=True): + if cond: + s = imgui.get_style() + text = s.colors[imgui.COLOR_TEXT_DISABLED] + grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] + back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] + imgui.push_style_color(imgui.COLOR_TEXT, *text) + imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) + imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) + imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_BUTTON, *back) + imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_HEADER, *back) + imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) + imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) + imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) + yield + imgui.pop_style_color(14) + else: + yield + +#---------------------------------------------------------------------------- + +@contextlib.contextmanager +def item_width(width=None): + if width is not None: + imgui.push_item_width(width) + yield + imgui.pop_item_width() + else: + yield + +#---------------------------------------------------------------------------- + +def scoped_by_object_id(method): + def decorator(self, *args, **kwargs): + imgui.push_id(str(id(self))) + res = method(self, *args, **kwargs) + imgui.pop_id() + return res + return decorator + +#---------------------------------------------------------------------------- + +def button(label, width=0, enabled=True): + with grayed_out(not enabled): + clicked = imgui.button(label, width=width) + clicked = clicked and enabled + return clicked + +#---------------------------------------------------------------------------- + +def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): + expanded = False + if show: + if default: + flags |= imgui.TREE_NODE_DEFAULT_OPEN + if not enabled: + flags |= imgui.TREE_NODE_LEAF + with grayed_out(not enabled): + expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) + expanded = expanded and enabled + return expanded, visible + +#---------------------------------------------------------------------------- + +def popup_button(label, width=0, enabled=True): + if button(label, width, enabled): + imgui.open_popup(label) + opened = imgui.begin_popup(label) + return opened + +#---------------------------------------------------------------------------- + +def input_text(label, value, buffer_length, flags, width=None, help_text=''): + old_value = value + color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) + if value == '': + color[-1] *= 0.5 + with item_width(width): + imgui.push_style_color(imgui.COLOR_TEXT, *color) + value = value if value != '' else help_text + changed, value = imgui.input_text(label, value, buffer_length, flags) + value = value if value != help_text else '' + imgui.pop_style_color(1) + if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: + changed = (value != old_value) + return changed, value + +#---------------------------------------------------------------------------- + +def drag_previous_control(enabled=True): + dragging = False + dx = 0 + dy = 0 + if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): + if enabled: + dragging = True + dx, dy = imgui.get_mouse_drag_delta() + imgui.reset_mouse_drag_delta() + imgui.end_drag_drop_source() + return dragging, dx, dy + +#---------------------------------------------------------------------------- + +def drag_button(label, width=0, enabled=True): + clicked = button(label, width=width, enabled=enabled) + dragging, dx, dy = drag_previous_control(enabled=enabled) + return clicked, dragging, dx, dy + +#---------------------------------------------------------------------------- + +def drag_hidden_window(label, x, y, width, height, enabled=True): + imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) + imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) + imgui.set_next_window_position(x, y) + imgui.set_next_window_size(width, height) + imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) + dragging, dx, dy = drag_previous_control(enabled=enabled) + imgui.end() + imgui.pop_style_color(2) + return dragging, dx, dy + +#---------------------------------------------------------------------------- diff --git a/gui_utils/imgui_window.py b/gui_utils/imgui_window.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf7caa76f0e1261e490bce8ef1c8267a1e5c31d --- /dev/null +++ b/gui_utils/imgui_window.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import imgui +import imgui.integrations.glfw + +from . import glfw_window +from . import imgui_utils +from . import text_utils + +#---------------------------------------------------------------------------- + +class ImguiWindow(glfw_window.GlfwWindow): + def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): + if font is None: + font = text_utils.get_default_font() + font_sizes = {int(size) for size in font_sizes} + super().__init__(title=title, **glfw_kwargs) + + # Init fields. + self._imgui_context = None + self._imgui_renderer = None + self._imgui_fonts = None + self._cur_font_size = max(font_sizes) + + # Delete leftover imgui.ini to avoid unexpected behavior. + if os.path.isfile('imgui.ini'): + os.remove('imgui.ini') + + # Init ImGui. + self._imgui_context = imgui.create_context() + self._imgui_renderer = _GlfwRenderer(self._glfw_window) + self._attach_glfw_callbacks() + imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. + imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). + self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} + self._imgui_renderer.refresh_font_texture() + + def close(self): + self.make_context_current() + self._imgui_fonts = None + if self._imgui_renderer is not None: + self._imgui_renderer.shutdown() + self._imgui_renderer = None + if self._imgui_context is not None: + #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. + self._imgui_context = None + super().close() + + def _glfw_key_callback(self, *args): + super()._glfw_key_callback(*args) + self._imgui_renderer.keyboard_callback(*args) + + @property + def font_size(self): + return self._cur_font_size + + @property + def spacing(self): + return round(self._cur_font_size * 0.4) + + def set_font_size(self, target): # Applied on next frame. + self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] + + def begin_frame(self): + # Begin glfw frame. + super().begin_frame() + + # Process imgui events. + self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 + if self.content_width > 0 and self.content_height > 0: + self._imgui_renderer.process_inputs() + + # Begin imgui frame. + imgui.new_frame() + imgui.push_font(self._imgui_fonts[self._cur_font_size]) + imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) + + def end_frame(self): + imgui.pop_font() + imgui.render() + imgui.end_frame() + self._imgui_renderer.render(imgui.get_draw_data()) + super().end_frame() + +#---------------------------------------------------------------------------- +# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. + +class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mouse_wheel_multiplier = 1 + + def scroll_callback(self, window, x_offset, y_offset): + self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier + +#---------------------------------------------------------------------------- diff --git a/gui_utils/text_utils.py b/gui_utils/text_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0c7a9b13fe63df5deb20664e22584a8240fc59 --- /dev/null +++ b/gui_utils/text_utils.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import functools +from typing import Optional + +import dnnlib +import numpy as np +import PIL.Image +import PIL.ImageFont +import scipy.ndimage + +from . import gl_utils + +#---------------------------------------------------------------------------- + +def get_default_font(): + url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular + return dnnlib.util.open_url(url, return_filename=True) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=None) +def get_pil_font(font=None, size=32): + if font is None: + font = get_default_font() + return PIL.ImageFont.truetype(font=font, size=size) + +#---------------------------------------------------------------------------- + +def get_array(string, *, dropshadow_radius: int=None, **kwargs): + if dropshadow_radius is not None: + offset_x = int(np.ceil(dropshadow_radius*2/3)) + offset_y = int(np.ceil(dropshadow_radius*2/3)) + return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) + else: + return _get_array_priv(string, **kwargs) + +@functools.lru_cache(maxsize=10000) +def _get_array_priv( + string: str, *, + size: int = 32, + max_width: Optional[int]=None, + max_height: Optional[int]=None, + min_size=10, + shrink_coef=0.8, + dropshadow_radius: int=None, + offset_x: int=None, + offset_y: int=None, + **kwargs +): + cur_size = size + array = None + while True: + if dropshadow_radius is not None: + # separate implementation for dropshadow text rendering + array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) + else: + array = _get_array_impl(string, size=cur_size, **kwargs) + height, width, _ = array.shape + if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): + break + cur_size = max(int(cur_size * shrink_coef), min_size) + return array + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): + pil_font = get_pil_font(font=font, size=size) + lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] + lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] + width = max(line.shape[1] for line in lines) + lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] + line_spacing = line_pad if line_pad is not None else size // 2 + lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] + mask = np.concatenate(lines, axis=0) + alpha = mask + if outline > 0: + mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) + alpha = mask.astype(np.float32) / 255 + alpha = scipy.ndimage.gaussian_filter(alpha, outline) + alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp + alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) + alpha = np.maximum(alpha, mask) + return np.stack([mask, alpha], axis=-1) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): + assert (offset_x > 0) and (offset_y > 0) + pil_font = get_pil_font(font=font, size=size) + lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] + lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] + width = max(line.shape[1] for line in lines) + lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] + line_spacing = line_pad if line_pad is not None else size // 2 + lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] + mask = np.concatenate(lines, axis=0) + alpha = mask + + mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) + alpha = mask.astype(np.float32) / 255 + alpha = scipy.ndimage.gaussian_filter(alpha, radius) + alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 + alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) + alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] + alpha = np.maximum(alpha, mask) + return np.stack([mask, alpha], axis=-1) + +#---------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=10000) +def get_texture(string, bilinear=True, mipmap=True, **kwargs): + return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) + +#---------------------------------------------------------------------------- diff --git a/legacy.py b/legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..e361fc2351c383c23aa6786f0bbc9daae047afca --- /dev/null +++ b/legacy.py @@ -0,0 +1,323 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Converting legacy network pickle into the new format.""" + +import click +import pickle +import re +import copy +import numpy as np +import torch +import dnnlib +from torch_utils import misc + +#---------------------------------------------------------------------------- + +def load_network_pkl(f, force_fp16=False): + data = _LegacyUnpickler(f).load() + + # Legacy TensorFlow pickle => convert. + if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): + tf_G, tf_D, tf_Gs = data + G = convert_tf_generator(tf_G) + D = convert_tf_discriminator(tf_D) + G_ema = convert_tf_generator(tf_Gs) + data = dict(G=G, D=D, G_ema=G_ema) + + # Add missing fields. + if 'training_set_kwargs' not in data: + data['training_set_kwargs'] = None + if 'augment_pipe' not in data: + data['augment_pipe'] = None + + # Validate contents. + assert isinstance(data['G'], torch.nn.Module) + assert isinstance(data['D'], torch.nn.Module) + assert isinstance(data['G_ema'], torch.nn.Module) + assert isinstance(data['training_set_kwargs'], (dict, type(None))) + assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) + + # Force FP16. + if force_fp16: + for key in ['G', 'D', 'G_ema']: + old = data[key] + kwargs = copy.deepcopy(old.init_kwargs) + fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs) + fp16_kwargs.num_fp16_res = 4 + fp16_kwargs.conv_clamp = 256 + if kwargs != old.init_kwargs: + new = type(old)(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(old, new, require_all=True) + data[key] = new + return data + +#---------------------------------------------------------------------------- + +class _TFNetworkStub(dnnlib.EasyDict): + pass + +class _LegacyUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'dnnlib.tflib.network' and name == 'Network': + return _TFNetworkStub + return super().find_class(module, name) + +#---------------------------------------------------------------------------- + +def _collect_tf_params(tf_net): + # pylint: disable=protected-access + tf_params = dict() + def recurse(prefix, tf_net): + for name, value in tf_net.variables: + tf_params[prefix + name] = value + for name, comp in tf_net.components.items(): + recurse(prefix + name + '/', comp) + recurse('', tf_net) + return tf_params + +#---------------------------------------------------------------------------- + +def _populate_module_params(module, *patterns): + for name, tensor in misc.named_params_and_buffers(module): + found = False + value = None + for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): + match = re.fullmatch(pattern, name) + if match: + found = True + if value_fn is not None: + value = value_fn(*match.groups()) + break + try: + assert found + if value is not None: + tensor.copy_(torch.from_numpy(np.array(value))) + except: + print(name, list(tensor.shape)) + raise + +#---------------------------------------------------------------------------- + +def convert_tf_generator(tf_G): + if tf_G.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_G.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None, none=None): + known_kwargs.add(tf_name) + val = tf_kwargs.get(tf_name, default) + return val if val is not None else none + + # Convert kwargs. + from training import networks_stylegan2 + network_class = networks_stylegan2.Generator + kwargs = dnnlib.EasyDict( + z_dim = kwarg('latent_size', 512), + c_dim = kwarg('label_size', 0), + w_dim = kwarg('dlatent_size', 512), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + architecture = kwarg('architecture', 'skip'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + use_noise = kwarg('use_noise', True), + activation = kwarg('nonlinearity', 'lrelu'), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 8), + embed_features = kwarg('label_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('mapping_nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.01), + w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), + ), + ) + + # Check for unknown kwargs. + kwarg('truncation_psi') + kwarg('truncation_cutoff') + kwarg('style_mixing_prob') + kwarg('structure') + kwarg('conditioning') + kwarg('fused_modconv') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_G) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value + kwargs.synthesis.kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + G = network_class(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + # pylint: disable=f-string-without-interpolation + _populate_module_params(G, + r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], + r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], + r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], + r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], + r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], + r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], + r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), + r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], + r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], + r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], + r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], + r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], + r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], + r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, + r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], + r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, + r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'.*\.resample_filter', None, + r'.*\.act_filter', None, + ) + return G + +#---------------------------------------------------------------------------- + +def convert_tf_discriminator(tf_D): + if tf_D.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_D.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None): + known_kwargs.add(tf_name) + return tf_kwargs.get(tf_name, default) + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + c_dim = kwarg('label_size', 0), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + architecture = kwarg('architecture', 'resnet'), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + cmap_dim = kwarg('mapping_fmaps', None), + block_kwargs = dnnlib.EasyDict( + activation = kwarg('nonlinearity', 'lrelu'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + freeze_layers = kwarg('freeze_layers', 0), + ), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 0), + embed_features = kwarg('mapping_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.1), + ), + epilogue_kwargs = dnnlib.EasyDict( + mbstd_group_size = kwarg('mbstd_group_size', None), + mbstd_num_channels = kwarg('mbstd_num_features', 1), + activation = kwarg('nonlinearity', 'lrelu'), + ), + ) + + # Check for unknown kwargs. + kwarg('structure') + kwarg('conditioning') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_D) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value + kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + from training import networks_stylegan2 + D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + # pylint: disable=f-string-without-interpolation + _populate_module_params(D, + r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], + r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], + r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), + r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], + r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], + r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), + r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], + r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), + r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], + r'.*\.resample_filter', None, + ) + return D + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--source', help='Input pickle', required=True, metavar='PATH') +@click.option('--dest', help='Output pickle', required=True, metavar='PATH') +@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) +def convert_network_pickle(source, dest, force_fp16): + """Convert legacy network pickle into the native PyTorch format. + + The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. + It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. + + Example: + + \b + python legacy.py \\ + --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ + --dest=stylegan2-cat-config-f.pkl + """ + print(f'Loading "{source}"...') + with dnnlib.util.open_url(source) as f: + data = load_network_pkl(f, force_fp16=force_fp16) + print(f'Saving "{dest}"...') + with open(dest, 'wb') as f: + pickle.dump(data, f) + print('Done.') + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_network_pickle() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/metrics/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/metrics/equivariance.py b/metrics/equivariance.py new file mode 100644 index 0000000000000000000000000000000000000000..c96ebed07fe478542ab56b56a6506e79f03d1388 --- /dev/null +++ b/metrics/equivariance.py @@ -0,0 +1,267 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper +"Alias-Free Generative Adversarial Networks".""" + +import copy +import numpy as np +import torch +import torch.fft +from torch_utils.ops import upfirdn2d +from . import metric_utils + +#---------------------------------------------------------------------------- +# Utilities. + +def sinc(x): + y = (x * np.pi).abs() + z = torch.sin(y) / y.clamp(1e-30, float('inf')) + return torch.where(y < 1e-30, torch.ones_like(x), z) + +def lanczos_window(x, a): + x = x.abs() / a + return torch.where(x < 1, sinc(x), torch.zeros_like(x)) + +def rotation_matrix(angle): + angle = torch.as_tensor(angle).to(torch.float32) + mat = torch.eye(3, device=angle.device) + mat[0, 0] = angle.cos() + mat[0, 1] = angle.sin() + mat[1, 0] = -angle.sin() + mat[1, 1] = angle.cos() + return mat + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.1. + +def apply_integer_translation(x, tx, ty): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.round().to(torch.int64) + iy = ty.round().to(torch.int64) + + z = torch.zeros_like(x) + m = torch.zeros_like(x) + if abs(ix) < W and abs(iy) < H: + y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)] + z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y + m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.2. + +def apply_fractional_translation(x, tx, ty, a=3): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.floor().to(torch.int64) + iy = ty.floor().to(torch.int64) + fx = tx - ix + fy = ty - iy + b = a - 1 + + z = torch.zeros_like(x) + zx0 = max(ix - b, 0) + zy0 = max(iy - b, 0) + zx1 = min(ix + a, 0) + W + zy1 = min(iy + a, 0) + H + if zx0 < zx1 and zy0 < zy1: + taps = torch.arange(a * 2, device=x.device) - b + filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0) + filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1) + y = x + y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0]) + y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a]) + y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)] + z[:, :, zy0:zy1, zx0:zx1] = y + + m = torch.zeros_like(x) + mx0 = max(ix + a, 0) + my0 = max(iy + a, 0) + mx1 = min(ix - b, 0) + W + my1 = min(iy - b, 0) + H + if mx0 < mx1 and my0 < my1: + m[:, :, my0:my1, mx0:mx1] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Construct an oriented low-pass filter that applies the appropriate +# bandlimit with respect to the input and output of the given affine 2D +# image transformation. + +def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): + assert a <= amax < aflt + mat = torch.as_tensor(mat).to(torch.float32) + + # Construct 2D filter taps in input & output coordinate spaces. + taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) + yi, xi = torch.meshgrid(taps, taps) + xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) + + # Convolution of two oriented 2D sinc filters. + fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in) + fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out) + f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real + + # Convolution of two oriented 2D Lanczos windows. + wi = lanczos_window(xi, a) * lanczos_window(yi, a) + wo = lanczos_window(xo, a) * lanczos_window(yo, a) + w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real + + # Construct windowed FIR filter. + f = f * w + + # Finalize. + c = (aflt - amax) * up + f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] + f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) + f = f / f.sum([0,2], keepdim=True) / (up ** 2) + f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] + return f + +#---------------------------------------------------------------------------- +# Apply the given affine transformation to a batch of 2D images. + +def apply_affine_transformation(x, mat, up=4, **filter_kwargs): + _N, _C, H, W = x.shape + mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) + + # Construct filter. + f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) + assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 + p = f.shape[0] // 2 + + # Construct sampling grid. + theta = mat.inverse() + theta[:2, 2] *= 2 + theta[0, 2] += 1 / up / W + theta[1, 2] += 1 / up / H + theta[0, :] *= W / (W + p / up * 2) + theta[1, :] *= H / (H + p / up * 2) + theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) + g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) + + # Resample image. + y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) + z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Form mask. + m = torch.zeros_like(y) + c = p * 2 + 1 + m[:, :, c:-c, c:-c] = 1 + m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) + return z, m + +#---------------------------------------------------------------------------- +# Apply fractional rotation to a batch of 2D images. Corresponds to the +# operator R_\alpha in Appendix E.3. + +def apply_fractional_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(angle) + return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs) + +#---------------------------------------------------------------------------- +# Modify the frequency content of a batch of 2D images as if they had undergo +# fractional rotation -- but without actually rotating them. Corresponds to +# the operator R^*_\alpha in Appendix E.3. + +def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(-angle) + f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs) + y = upfirdn2d.filter2d(x=x, f=f) + m = torch.zeros_like(y) + c = f.shape[0] // 2 + m[:, :, c:-c, c:-c] = 1 + return y, m + +#---------------------------------------------------------------------------- +# Compute the selected equivariance metrics for the given generator. + +def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False): + assert compute_eqt_int or compute_eqt_frac or compute_eqr + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + I = torch.eye(3, device=opts.device) + M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None) + if M is None: + raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations') + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + sums = None + progress = opts.progress.sub(tag='eq sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + s = [] + + # Randomize noise buffers, if any. + for name, buf in G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Run mapping network. + z = torch.randn([batch_size, G.z_dim], device=opts.device) + c = next(c_iter) + ws = G.mapping(z=z, c=c) + + # Generate reference image. + M[:] = I + orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + + # Integer translation (EQ-T). + if compute_eqt_int: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + t = (t * G.img_resolution).round() / G.img_resolution + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_integer_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Fractional translation (EQ-T_frac). + if compute_eqt_frac: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_fractional_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Rotation (EQ-R). + if compute_eqr: + angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi) + M[:] = rotation_matrix(-angle) + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, ref_mask = apply_fractional_rotation(orig, angle) + pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) + mask = ref_mask * pseudo_mask + s += [(ref - pseudo).square() * mask, mask] + + # Accumulate results. + s = torch.stack([x.to(torch.float64).sum() for x in s]) + sums = sums + s if sums is not None else s + progress.update(num_samples) + + # Compute PSNRs. + if opts.num_gpus > 1: + torch.distributed.all_reduce(sums) + sums = sums.cpu() + mses = sums[0::2] / sums[1::2] + psnrs = np.log10(2) * 20 - mses.log10() * 10 + psnrs = tuple(psnrs.numpy()) + return psnrs[0] if len(psnrs) == 1 else psnrs + +#---------------------------------------------------------------------------- diff --git a/metrics/frechet_inception_distance.py b/metrics/frechet_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdd6b6ce1f44a345c1451150634bb2fa0c7e9b3 --- /dev/null +++ b/metrics/frechet_inception_distance.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Frechet Inception Distance (FID) from the paper +"GANs trained by a two time-scale update rule converge to a local Nash +equilibrium". Matches the original implementation by Heusel et al. at +https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" + +import numpy as np +import scipy.linalg +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_fid(opts, max_real, num_gen): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() + + mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() + + if opts.rank != 0: + return float('nan') + + m = np.square(mu_gen - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + return float(fid) + +#---------------------------------------------------------------------------- diff --git a/metrics/inception_score.py b/metrics/inception_score.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a7f4acbcd3c33db6d6848f61d176e0ff97295e --- /dev/null +++ b/metrics/inception_score.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Inception Score (IS) from the paper "Improved techniques for training +GANs". Matches the original implementation by Salimans et al. at +https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_is(opts, num_gen, num_splits): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. + + gen_probs = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan'), float('nan') + + scores = [] + for i in range(num_splits): + part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] + kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) + kl = np.mean(np.sum(kl, axis=1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)), float(np.std(scores)) + +#---------------------------------------------------------------------------- diff --git a/metrics/kernel_inception_distance.py b/metrics/kernel_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e0bd12ef56e64d8e77091aaf465891f4984d9e --- /dev/null +++ b/metrics/kernel_inception_distance.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Kernel Inception Distance (KID) from the paper "Demystifying MMD +GANs". Matches the original implementation by Binkowski et al. at +https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan') + + n = real_features.shape[1] + m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) + t = 0 + for _subset_idx in range(num_subsets): + x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] + y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] + a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 + b = (x @ y.T / n + 1) ** 3 + t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m + kid = t / num_subsets / m + return float(kid) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_main.py b/metrics/metric_main.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f7389fbe322aff06a1860d581da56f9d9ad937 --- /dev/null +++ b/metrics/metric_main.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Main API for computing and reporting quality metrics.""" + +import os +import time +import json +import torch +import dnnlib + +from . import metric_utils +from . import frechet_inception_distance +from . import kernel_inception_distance +from . import precision_recall +from . import perceptual_path_length +from . import inception_score +from . import equivariance + +#---------------------------------------------------------------------------- + +_metric_dict = dict() # name => fn + +def register_metric(fn): + assert callable(fn) + _metric_dict[fn.__name__] = fn + return fn + +def is_valid_metric(metric): + return metric in _metric_dict + +def list_valid_metrics(): + return list(_metric_dict.keys()) + +#---------------------------------------------------------------------------- + +def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. + assert is_valid_metric(metric) + opts = metric_utils.MetricOptions(**kwargs) + + # Calculate. + start_time = time.time() + results = _metric_dict[metric](opts) + total_time = time.time() - start_time + + # Broadcast results. + for key, value in list(results.items()): + if opts.num_gpus > 1: + value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) + torch.distributed.broadcast(tensor=value, src=0) + value = float(value.cpu()) + results[key] = value + + # Decorate with metadata. + return dnnlib.EasyDict( + results = dnnlib.EasyDict(results), + metric = metric, + total_time = total_time, + total_time_str = dnnlib.util.format_time(total_time), + num_gpus = opts.num_gpus, + ) + +#---------------------------------------------------------------------------- + +def report_metric(result_dict, run_dir=None, snapshot_pkl=None): + metric = result_dict['metric'] + assert is_valid_metric(metric) + if run_dir is not None and snapshot_pkl is not None: + snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) + + jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) + print(jsonl_line) + if run_dir is not None and os.path.isdir(run_dir): + with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: + f.write(jsonl_line + '\n') + +#---------------------------------------------------------------------------- +# Recommended metrics. + +@register_metric +def fid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) + return dict(fid50k_full=fid) + +@register_metric +def kid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k_full=kid) + +@register_metric +def pr50k3_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) + +@register_metric +def ppl2_wend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) + return dict(ppl2_wend=ppl) + +@register_metric +def eqt50k_int(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) + return dict(eqt50k_int=psnr) + +@register_metric +def eqt50k_frac(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) + return dict(eqt50k_frac=psnr) + +@register_metric +def eqr50k(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) + return dict(eqr50k=psnr) + +#---------------------------------------------------------------------------- +# Legacy metrics. + +@register_metric +def fid50k(opts): + opts.dataset_kwargs.update(max_size=None) + fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) + return dict(fid50k=fid) + +@register_metric +def kid50k(opts): + opts.dataset_kwargs.update(max_size=None) + kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k=kid) + +@register_metric +def pr50k3(opts): + opts.dataset_kwargs.update(max_size=None) + precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_precision=precision, pr50k3_recall=recall) + +@register_metric +def is50k(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) + return dict(is50k_mean=mean, is50k_std=std) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_utils.py b/metrics/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44b67eed7b5bbf029481ecbd865457fa42f7cc89 --- /dev/null +++ b/metrics/metric_utils.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utilities used internally by the quality metrics.""" + +import os +import time +import hashlib +import pickle +import copy +import uuid +import numpy as np +import torch +import dnnlib + +#---------------------------------------------------------------------------- + +class MetricOptions: + def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): + assert 0 <= rank < num_gpus + self.G = G + self.G_kwargs = dnnlib.EasyDict(G_kwargs) + self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) + self.num_gpus = num_gpus + self.rank = rank + self.device = device if device is not None else torch.device('cuda', rank) + self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() + self.cache = cache + +#---------------------------------------------------------------------------- + +_feature_detector_cache = dict() + +def get_feature_detector_name(url): + return os.path.splitext(url.split('/')[-1])[0] + +def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): + assert 0 <= rank < num_gpus + key = (url, device) + if key not in _feature_detector_cache: + is_leader = (rank == 0) + if not is_leader and num_gpus > 1: + torch.distributed.barrier() # leader goes first + with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: + _feature_detector_cache[key] = pickle.load(f).to(device) + if is_leader and num_gpus > 1: + torch.distributed.barrier() # others follow + return _feature_detector_cache[key] + +#---------------------------------------------------------------------------- + +def iterate_random_labels(opts, batch_size): + if opts.G.c_dim == 0: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + while True: + yield c + else: + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + while True: + c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + yield c + +#---------------------------------------------------------------------------- + +class FeatureStats: + def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): + self.capture_all = capture_all + self.capture_mean_cov = capture_mean_cov + self.max_items = max_items + self.num_items = 0 + self.num_features = None + self.all_features = None + self.raw_mean = None + self.raw_cov = None + + def set_num_features(self, num_features): + if self.num_features is not None: + assert num_features == self.num_features + else: + self.num_features = num_features + self.all_features = [] + self.raw_mean = np.zeros([num_features], dtype=np.float64) + self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) + + def is_full(self): + return (self.max_items is not None) and (self.num_items >= self.max_items) + + def append(self, x): + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): + if self.num_items >= self.max_items: + return + x = x[:self.max_items - self.num_items] + + self.set_num_features(x.shape[1]) + self.num_items += x.shape[0] + if self.capture_all: + self.all_features.append(x) + if self.capture_mean_cov: + x64 = x.astype(np.float64) + self.raw_mean += x64.sum(axis=0) + self.raw_cov += x64.T @ x64 + + def append_torch(self, x, num_gpus=1, rank=0): + assert isinstance(x, torch.Tensor) and x.ndim == 2 + assert 0 <= rank < num_gpus + if num_gpus > 1: + ys = [] + for src in range(num_gpus): + y = x.clone() + torch.distributed.broadcast(y, src=src) + ys.append(y) + x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples + self.append(x.cpu().numpy()) + + def get_all(self): + assert self.capture_all + return np.concatenate(self.all_features, axis=0) + + def get_all_torch(self): + return torch.from_numpy(self.get_all()) + + def get_mean_cov(self): + assert self.capture_mean_cov + mean = self.raw_mean / self.num_items + cov = self.raw_cov / self.num_items + cov = cov - np.outer(mean, mean) + return mean, cov + + def save(self, pkl_file): + with open(pkl_file, 'wb') as f: + pickle.dump(self.__dict__, f) + + @staticmethod + def load(pkl_file): + with open(pkl_file, 'rb') as f: + s = dnnlib.EasyDict(pickle.load(f)) + obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) + obj.__dict__.update(s) + return obj + +#---------------------------------------------------------------------------- + +class ProgressMonitor: + def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): + self.tag = tag + self.num_items = num_items + self.verbose = verbose + self.flush_interval = flush_interval + self.progress_fn = progress_fn + self.pfn_lo = pfn_lo + self.pfn_hi = pfn_hi + self.pfn_total = pfn_total + self.start_time = time.time() + self.batch_time = self.start_time + self.batch_items = 0 + if self.progress_fn is not None: + self.progress_fn(self.pfn_lo, self.pfn_total) + + def update(self, cur_items): + assert (self.num_items is None) or (cur_items <= self.num_items) + if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): + return + cur_time = time.time() + total_time = cur_time - self.start_time + time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) + if (self.verbose) and (self.tag is not None): + print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') + self.batch_time = cur_time + self.batch_items = cur_items + + if (self.progress_fn is not None) and (self.num_items is not None): + self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) + + def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): + return ProgressMonitor( + tag = tag, + num_items = num_items, + flush_interval = flush_interval, + verbose = self.verbose, + progress_fn = self.progress_fn, + pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, + pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, + pfn_total = self.pfn_total, + ) + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + # Try to lookup from cache. + cache_file = None + if opts.cache: + # Choose cache file name. + args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) + md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) + cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' + cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') + + # Check if the file exists (all processes must agree). + flag = os.path.isfile(cache_file) if opts.rank == 0 else False + if opts.num_gpus > 1: + flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) + torch.distributed.broadcast(tensor=flag, src=0) + flag = (float(flag.cpu()) != 0) + + # Load. + if flag: + return FeatureStats.load(cache_file) + + # Initialize. + num_items = len(dataset) + if max_items is not None: + num_items = min(num_items, max_items) + stats = FeatureStats(max_items=num_items, **stats_kwargs) + progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] + for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images.to(opts.device), **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + + # Save to cache. + if cache_file is not None and opts.rank == 0: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + temp_file = cache_file + '.' + uuid.uuid4().hex + stats.save(temp_file) + os.replace(temp_file, cache_file) # atomic + return stats + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs): + if batch_gen is None: + batch_gen = min(batch_size, 4) + assert batch_size % batch_gen == 0 + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen) + + # Initialize. + stats = FeatureStats(**stats_kwargs) + assert stats.max_items is not None + progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + while not stats.is_full(): + images = [] + for _i in range(batch_size // batch_gen): + z = torch.randn([batch_gen, G.z_dim], device=opts.device) + img = G(z=z, c=next(c_iter), **opts.G_kwargs) + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + images.append(img) + images = torch.cat(images) + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images, **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + return stats + +#---------------------------------------------------------------------------- diff --git a/metrics/perceptual_path_length.py b/metrics/perceptual_path_length.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb74396475181c3a80feb6321d3b0f45eda7000 --- /dev/null +++ b/metrics/perceptual_path_length.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator +Architecture for Generative Adversarial Networks". Matches the original +implementation by Karras et al. at +https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" + +import copy +import numpy as np +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + +#---------------------------------------------------------------------------- + +class PPLSampler(torch.nn.Module): + def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__() + self.G = copy.deepcopy(G) + self.G_kwargs = G_kwargs + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.crop = crop + self.vgg16 = copy.deepcopy(vgg16) + + def forward(self, c): + # Generate random latents and interpolation t-values. + t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) + z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) + + # Interpolate in W or Z. + if self.space == 'w': + w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) + wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) + wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) + else: # space == 'z' + zt0 = slerp(z0, z1, t.unsqueeze(1)) + zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) + wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) + + # Randomize noise buffers. + for name, buf in self.G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Generate images. + img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) + + # Center crop. + if self.crop: + assert img.shape[2] == img.shape[3] + c = img.shape[2] // 8 + img = img[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample to 256x256. + factor = self.G.img_resolution // 256 + if factor > 1: + img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) + + # Scale dynamic range from [-1,1] to [0,255]. + img = (img + 1) * (255 / 2) + if self.G.img_channels == 1: + img = img.repeat([1, 3, 1, 1]) + + # Evaluate differential LPIPS. + lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) + dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 + return dist + +#---------------------------------------------------------------------------- + +def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): + vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) + + # Setup sampler and labels. + sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) + sampler.eval().requires_grad_(False).to(opts.device) + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + dist = [] + progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + x = sampler(next(c_iter)) + for src in range(opts.num_gpus): + y = x.clone() + if opts.num_gpus > 1: + torch.distributed.broadcast(y, src=src) + dist.append(y) + progress.update(num_samples) + + # Compute PPL. + if opts.rank != 0: + return float('nan') + dist = torch.cat(dist)[:num_samples].cpu().numpy() + lo = np.percentile(dist, 1, interpolation='lower') + hi = np.percentile(dist, 99, interpolation='higher') + ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() + return float(ppl) + +#---------------------------------------------------------------------------- diff --git a/metrics/precision_recall.py b/metrics/precision_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..17e5b4286b43e2d09aeba19d2521869a6cbe7ea1 --- /dev/null +++ b/metrics/precision_recall.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Precision/Recall (PR) from the paper "Improved Precision and Recall +Metric for Assessing Generative Models". Matches the original implementation +by Kynkaanniemi et al. at +https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" + +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): + assert 0 <= rank < num_gpus + num_cols = col_features.shape[0] + num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus + col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) + dist_batches = [] + for col_batch in col_batches[rank :: num_gpus]: + dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] + for src in range(num_gpus): + dist_broadcast = dist_batch.clone() + if num_gpus > 1: + torch.distributed.broadcast(dist_broadcast, src=src) + dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) + return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None + +#---------------------------------------------------------------------------- + +def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + detector_kwargs = dict(return_features=True) + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) + + results = dict() + for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: + kth = [] + for manifold_batch in manifold.split(row_batch_size): + dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) + kth = torch.cat(kth) if opts.rank == 0 else None + pred = [] + for probes_batch in probes.split(row_batch_size): + dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) + results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') + return results['precision'], results['recall'] + +#---------------------------------------------------------------------------- diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd34882519598c472f1224cfe68c9ff6952ce69 --- /dev/null +++ b/torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dffd4bd8a75b3495862954546ea04bb7ef39c998 --- /dev/null +++ b/torch_utils/custom_ops.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/torch_utils/misc.py b/torch_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..02f97e276e756bb1c4140ac16e7e4bcc63da628a --- /dev/null +++ b/torch_utils/misc.py @@ -0,0 +1,266 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '